portmatching/automaton/
modify.rs

1use itertools::izip;
2use portgraph::{LinkMut, PortMut, PortView};
3
4use crate::{
5    predicate::{EdgePredicate, Symbol},
6    EdgeProperty, HashSet, PatternID,
7};
8
9use super::{ScopeAutomaton, State, StateID, Transition};
10
11impl<PNode: Clone, PEdge: EdgeProperty> ScopeAutomaton<PNode, PEdge> {
12    pub(super) fn set_children<I>(
13        &mut self,
14        state: StateID,
15        preds: impl IntoIterator<IntoIter = I>,
16        next_states: &[Option<StateID>],
17        next_scopes: Vec<HashSet<Symbol>>,
18    ) -> Vec<Option<StateID>>
19    where
20        I: Iterator<Item = EdgePredicate<PNode, PEdge, PEdge::OffsetID>> + ExactSizeIterator,
21    {
22        let preds = preds.into_iter();
23        if self.graph.num_outputs(state.0) != 0 {
24            panic!("State already has outgoing ports");
25        }
26        // Allocate new ports
27        self.add_ports(state, 0, preds.len());
28
29        // Build the children
30        izip!(preds, next_states, next_scopes)
31            .enumerate()
32            .map(|(i, (pred, &next_state, next_scope))| {
33                self.add_child(state, i, pred.into(), next_state, Some(next_scope))
34            })
35            .collect()
36    }
37
38    fn add_child(
39        &mut self,
40        parent: StateID,
41        offset: usize,
42        pedge: Transition<PNode, PEdge, PEdge::OffsetID>,
43        new_state: Option<StateID>,
44        new_scope: Option<HashSet<Symbol>>,
45    ) -> Option<StateID> {
46        let mut added_state = false;
47        let (new_state_id, new_offset) = if let Some(new_state) = new_state {
48            let in_offset = self.graph.num_inputs(new_state.0);
49            self.add_ports(new_state, 1, 0);
50            (new_state, in_offset)
51        } else {
52            added_state = true;
53            (self.graph.add_node(1, 0).into(), 0)
54        };
55        self.graph
56            .link_nodes(parent.0, offset, new_state_id.0, new_offset)
57            .expect("Could not add child at offset p");
58        let new_scope = new_scope.unwrap_or_else(|| {
59            // By default, take scope of parent and add symbol if necessary
60            let mut old_scope = self.weights[parent.0]
61                .clone()
62                .expect("invalid parent")
63                .scope;
64            if let EdgePredicate::LinkNewNode { new_node, .. } = pedge.clone().into() {
65                old_scope.insert(new_node);
66            }
67            old_scope
68        });
69        let new_state = if let Some(mut new_state) = self.weights[new_state_id.0].take() {
70            new_state.scope.retain(|k| new_scope.contains(k));
71            new_state
72        } else {
73            State {
74                matches: Vec::new(),
75                scope: new_scope,
76                deterministic: true,
77            }
78        };
79        self.weights.nodes[new_state_id.0] = Some(new_state);
80        self.weights[self.graph.output(parent.0, offset).unwrap()] = Some(pedge);
81        added_state.then_some(new_state_id)
82    }
83
84    fn add_ports(&mut self, state: StateID, incoming: usize, outgoing: usize) {
85        let incoming = incoming + self.graph.num_inputs(state.0);
86        let outgoing = outgoing + self.graph.num_outputs(state.0);
87        self.graph
88            .set_num_ports(state.0, incoming, outgoing, |old, new| {
89                self.weights.ports.rekey(old, new.new_index());
90            });
91    }
92
93    pub(crate) fn add_match(&mut self, state: StateID, pattern: PatternID) {
94        self.weights[state.0]
95            .as_mut()
96            .expect("invalid state")
97            .matches
98            .push(pattern);
99    }
100
101    pub(crate) fn make_non_det(&mut self, state: StateID) {
102        if self.graph.num_outputs(state.0) > 0 {
103            panic!("Cannot make state non-deterministic: has outgoing ports");
104        }
105        self.weights[state.0]
106            .as_mut()
107            .expect("invalid state")
108            .deterministic = false;
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use crate::{patterns::IterationStatus, predicate::Symbol};
115
116    use super::*;
117
118    /// The child state's scope should be the intersection of all possible scopes
119    #[test]
120    fn intersect_scope() {
121        let mut a = ScopeAutomaton::new();
122        a.add_ports(a.root(), 0, 2);
123        let s_root = Symbol::root();
124        let s1 = Symbol::new(IterationStatus::Finished, 0);
125        let s2 = Symbol::new(IterationStatus::Finished, 1);
126        let t1: Transition<(), (), ()> = EdgePredicate::LinkNewNode {
127            node: s_root,
128            property: (),
129            new_node: s1,
130        }
131        .into();
132        let t2: Transition<(), (), ()> = EdgePredicate::LinkNewNode {
133            node: s_root,
134            property: (),
135            new_node: s2,
136        }
137        .into();
138        let child = a.add_child(a.root(), 0, t1, None, None).unwrap();
139
140        assert_eq!(a.scope(child), &[s_root, s1].into_iter().collect());
141        a.add_child(a.root(), 1, t2, Some(child), None);
142        assert_eq!(a.scope(child), &[s_root].into_iter().collect());
143    }
144}