use std::{
collections::{BTreeMap, BTreeSet, HashSet, VecDeque},
fmt::Debug,
ops::ControlFlow,
};
use machine_check_common::{NodeId, StateId};
use mck::{abstr, concr::FullMachine, misc::MetaWrap};
use partitions::PartitionVec;
use petgraph::{prelude::GraphMap, Directed};
use crate::{AbstrInput, AbstrParam, WrappedInput, WrappedParam};
pub struct StateGraph<M: FullMachine> {
node_graph: GraphMap<NodeId, Edge<WrappedInput<M>, WrappedParam<M>>, Directed>,
tail_partitions: BTreeMap<NodeId, PartitionVec<StateId>>,
}
impl<M: FullMachine> Debug for StateGraph<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "StateGraph [")?;
for (head, tail_partition) in &self.tail_partitions {
write!(f, "{} -> [", head)?;
for param_tail_set in tail_partition.all_sets() {
let param_tails =
BTreeSet::from_iter(param_tail_set.map(|(_, state_id)| *state_id));
write!(f, "{:?}, ", param_tails)?;
}
write!(f, "], ")?;
}
write!(f, "]")
}
}
#[derive(Debug)]
pub struct Edge<AI, AP> {
pub representative_input: AI,
pub representative_param: AP,
}
impl<M: FullMachine> StateGraph<M> {
pub fn new() -> Self {
let mut node_graph = GraphMap::new();
node_graph.add_node(NodeId::ROOT);
Self {
node_graph,
tail_partitions: BTreeMap::new(),
}
}
pub fn clear_step(
&mut self,
head_id: NodeId,
) -> (BTreeSet<StateId>, Option<PartitionVec<StateId>>) {
let direct_successor_indices: BTreeSet<_> = self.direct_successor_iter(head_id).collect();
for direct_successor_id in direct_successor_indices.clone() {
self.node_graph
.remove_edge(head_id, direct_successor_id.into());
}
let tail_partition = self.tail_partitions.remove(&head_id);
(direct_successor_indices, tail_partition)
}
pub fn add_step(
&mut self,
current_node: NodeId,
next_state: StateId,
representative_input: &<M::Abstr as abstr::Machine<M>>::Input,
representative_param: &<M::Abstr as abstr::Machine<M>>::Param,
param_id: Option<usize>,
) -> usize {
let next_node = next_state.into();
let tail_partition = self.tail_partitions.entry(current_node).or_default();
let result_param_id = tail_partition.len();
tail_partition.push(next_state);
if let Some(param_id) = param_id {
tail_partition.union(param_id, result_param_id);
};
if self.node_graph.contains_edge(current_node, next_node) {
return result_param_id;
}
self.node_graph.add_edge(
current_node,
next_node,
Edge {
representative_input: MetaWrap(representative_input.clone()),
representative_param: MetaWrap(representative_param.clone()),
},
);
result_param_id
}
pub fn make_compact(&mut self) -> BTreeSet<StateId> {
let mut marked = BTreeSet::new();
let mut stack = Vec::<NodeId>::new();
stack.push(NodeId::ROOT);
while let Some(node_id) = stack.pop() {
if let Ok(state_id) = StateId::try_from(node_id) {
if !marked.insert(state_id) {
continue;
}
}
for direct_successor_id in self.direct_successor_iter(node_id) {
stack.push(direct_successor_id.into());
}
}
let unmarked: BTreeSet<StateId> =
BTreeSet::from_iter(self.node_graph.nodes().filter_map(|node_id| {
node_id
.try_into()
.ok()
.filter(|&state_id| !marked.contains(&state_id))
}));
for state in unmarked {
self.node_graph.remove_node(state.into());
self.tail_partitions.remove(&state.into());
}
self.assert_left_total();
marked
}
pub fn breadth_first_search<T>(
&self,
result_fn: impl Fn(StateId) -> ControlFlow<T, ()>,
) -> Option<T> {
let mut queue = VecDeque::<NodeId>::new();
let mut processed = HashSet::<NodeId>::new();
queue.push_back(NodeId::ROOT);
while let Some(node_id) = queue.pop_front() {
if processed.contains(&node_id) {
continue;
}
if let Ok(state_id) = StateId::try_from(node_id) {
if let ControlFlow::Break(result) = result_fn(state_id) {
return Some(result);
}
}
processed.insert(node_id);
for direct_successor_id in self.direct_successor_iter(node_id) {
if !processed.contains(&direct_successor_id.into()) {
queue.push_back(direct_successor_id.into());
}
}
}
None
}
pub fn representative_input(&self, head: NodeId, tail: StateId) -> &AbstrInput<M> {
&self
.node_graph
.edge_weight(head, tail.into())
.expect("Edge should be present in graph")
.representative_input
.0
}
pub fn representative_param(&self, head: NodeId, tail: StateId) -> &AbstrParam<M> {
&self
.node_graph
.edge_weight(head, tail.into())
.expect("Edge should be present in graph")
.representative_param
.0
}
pub fn direct_predecessor_iter(
&self,
node_id: NodeId,
) -> impl Iterator<Item = NodeId> + Clone + '_ {
self.node_graph
.neighbors_directed(node_id, petgraph::Direction::Incoming)
}
pub fn direct_successor_iter(
&self,
node_id: NodeId,
) -> impl Iterator<Item = StateId> + Clone + '_ {
self.node_graph
.neighbors_directed(node_id, petgraph::Direction::Outgoing)
.map(|successor_id| StateId::try_from(successor_id).unwrap())
}
pub fn direct_successor_param_partition(
&self,
node_id: NodeId,
) -> Option<&PartitionVec<StateId>> {
self.tail_partitions.get(&node_id)
}
pub fn contains_edge(&self, head_id: NodeId, tail_id: StateId) -> bool {
self.node_graph.contains_edge(head_id, tail_id.into())
}
pub fn contains_state(&self, state_id: StateId) -> bool {
self.node_graph.contains_node(state_id.into())
}
pub fn initial_iter(&self) -> impl Iterator<Item = StateId> + Clone + '_ {
self.direct_successor_iter(NodeId::ROOT)
}
pub fn num_transitions(&self) -> usize {
self.node_graph.edge_count()
}
pub fn node_count(&self) -> usize {
self.node_graph.node_count()
}
pub fn nodes(&self) -> impl Iterator<Item = NodeId> + Clone + '_ {
self.node_graph.nodes()
}
pub fn assert_left_total(&self) {
for node_id in self.nodes() {
if self.direct_successor_iter(node_id).count() == 0 {
panic!(
"State space should be left-total but node {} has no successor",
node_id
);
}
}
}
}
impl<M: FullMachine> Default for StateGraph<M> {
fn default() -> Self {
Self::new()
}
}