rex/
node.rs

1use std::collections::HashSet;
2
3use crate::{HashKind, Kind, StateId};
4
5#[derive(Debug)]
6pub struct Insert<Id> {
7    pub parent_id: Option<Id>,
8    pub id: Id,
9}
10
11pub struct Update<Id, S> {
12    pub id: Id,
13    pub state: S,
14}
15
16#[derive(Debug)]
17pub struct Node<Id, S> {
18    pub id: Id,
19    pub state: S,
20    descendant_keys: HashSet<Id>, // https://en.wikipedia.org/wiki/Brzozowski_derivative
21    pub children: Vec<Node<Id, S>>,
22}
23
24impl<K> Node<StateId<K>, K::State>
25where
26    K: Kind + HashKind,
27{
28    #[must_use]
29    pub fn new(id: StateId<K>) -> Self {
30        Self {
31            state: id.new_state(),
32            id,
33            descendant_keys: HashSet::new(),
34            children: Vec::new(),
35        }
36    }
37
38    #[must_use]
39    pub const fn zipper(self) -> Zipper<StateId<K>, K::State> {
40        Zipper {
41            node: self,
42            parent: None,
43            self_idx: 0,
44        }
45    }
46
47    #[must_use]
48    pub fn get(&self, id: StateId<K>) -> Option<&Self> {
49        if self.id == id {
50            return Some(self);
51        }
52        if !self.descendant_keys.contains(&id) {
53            return None;
54        }
55
56        let mut node = self;
57        while node.descendant_keys.contains(&id) {
58            node = node.child(id).unwrap();
59        }
60        Some(node)
61    }
62
63    #[must_use]
64    pub fn get_state(&self, id: StateId<K>) -> Option<&K::State> {
65        self.get(id).map(|n| &n.state)
66    }
67
68    #[must_use]
69    pub fn child(&self, id: StateId<K>) -> Option<&Self> {
70        self.children
71            .iter()
72            .find(|node| node.id == id || node.descendant_keys.contains(&id))
73    }
74
75    // get array index by of node with StateId<K> in self.descendant_keys
76    #[must_use]
77    pub fn child_idx(&self, id: StateId<K>) -> Option<usize> {
78        self.children
79            .iter()
80            .enumerate()
81            .find(|(_idx, node)| node.id == id || node.descendant_keys.contains(&id))
82            .map(|(idx, _)| idx)
83    }
84
85    pub fn insert(&mut self, insert: Insert<StateId<K>>) {
86        // temporary allocation to allow a drop in &mut implementation
87        //
88        // this can be optimized later but right now allocation impact
89        // is non existent since Node::new
90        // does not grow its `?Sized` types
91        let mut swap_node = Self::new(self.id);
92        std::mem::swap(&mut swap_node, self);
93
94        swap_node = swap_node.into_insert(insert);
95
96        std::mem::swap(&mut swap_node, self);
97    }
98
99    /// inserts a new node using self by value
100    #[must_use]
101    pub fn into_insert(self, Insert { parent_id, id }: Insert<StateId<K>>) -> Self {
102        // inserts at this point should be guaranteed Some(id)
103        // ince a parent_id.is_none() should be handled by the node
104        // store through a new graph creation
105        let parent_id = parent_id.unwrap();
106
107        self.zipper()
108            .by_id(parent_id)
109            .insert_child(id)
110            .finish_insert(id)
111    }
112
113    #[must_use]
114    pub fn get_parent_id(&self, id: StateId<K>) -> Option<StateId<K>> {
115        // root_node edge case
116        if !self.descendant_keys.contains(&id) {
117            return None;
118        }
119
120        let mut node = self;
121        while node.descendant_keys.contains(&id) {
122            let child_node = node.child(id).unwrap();
123            if child_node.id == id {
124                return Some(node.id);
125            }
126            node = child_node;
127        }
128
129        None
130    }
131
132    pub fn update(&mut self, update: Update<StateId<K>, K::State>) {
133        // see Node::insert
134        let mut swap_node = Self::new(self.id);
135        std::mem::swap(&mut swap_node, self);
136
137        swap_node = swap_node.into_update(update);
138
139        std::mem::swap(&mut swap_node, self);
140    }
141
142    /// update a given node's state and return the parent ID if it exists
143    pub fn update_and_get_parent_id(
144        &mut self,
145        Update { id, state }: Update<StateId<K>, K::State>,
146    ) -> Option<StateId<K>> {
147        // see Node::insert
148        let mut swap_node = Self::new(self.id);
149        std::mem::swap(&mut swap_node, self);
150
151        let (parent_id, mut swap_node) = swap_node
152            .zipper()
153            .by_id(id)
154            .set_state(state)
155            .finish_update_parent_id();
156
157        std::mem::swap(&mut swap_node, self);
158
159        parent_id
160    }
161
162    // apply a closure to all nodes in a tree
163    pub fn update_all_fn<F>(&mut self, f: F)
164    where
165        F: Fn(Zipper<StateId<K>, K::State>) -> Self + Clone,
166    {
167        // see Node::insert
168        let mut swap_node = Self::new(self.id);
169        std::mem::swap(&mut swap_node, self);
170
171        swap_node = swap_node.zipper().finish_update_fn(f);
172
173        std::mem::swap(&mut swap_node, self);
174    }
175
176    #[must_use]
177    pub fn into_update(self, Update { id, state }: Update<StateId<K>, K::State>) -> Self {
178        self.zipper().by_id(id).set_state(state).finish_update()
179    }
180}
181
182/// Example of a [`Zipper`] cursor traversing a [`Vec`],
183/// the *focus* provides a view "Up" and "Down" the data:
184/// ```text
185/// [1, 2, 3, 4, 5] // array with 5 entries
186///  1}[2, 3, 4, 5] // zipper starts with focues at first index
187/// [1] 2}[3, 4, 5] // moving down the array
188/// [2, 1] 3}[4, 5]
189/// [3, 2, 1] 4}[5]
190/// [4, 3, 2, 1]{5  // zipper travels back up the array
191/// ```
192/// See `node/README.md` for further details.
193pub struct Zipper<Id, S> {
194    pub node: Node<Id, S>,
195    pub parent: Option<Box<Zipper<Id, S>>>,
196    self_idx: usize,
197}
198
199type ZipperNode<K> = Node<StateId<K>, <K as Kind>::State>;
200
201impl<K> Zipper<StateId<K>, K::State>
202where
203    K: Kind + HashKind,
204{
205    fn by_id(mut self, id: StateId<K>) -> Self {
206        let mut contains_id = self.node.descendant_keys.contains(&id);
207        while contains_id {
208            let idx = self.node.child_idx(id).unwrap();
209            self = self.child(idx);
210            contains_id = self.node.descendant_keys.contains(&id);
211        }
212        assert!(
213            !(self.node.id != id),
214            "id[{id}] should be in the node, this is a bug"
215        );
216        self
217    }
218
219    fn child(mut self, idx: usize) -> Self {
220        // Remove the specified child from the node's children.
221        //  Zipper should avoid having a parent reference
222        // since parents will be mutated during node refocusing.
223        // Vec::swap_remove() is used for efficiency.
224        let child = self.node.children.swap_remove(idx);
225
226        // Return a new Zipper focused on the specified child.
227        Self {
228            node: child,
229            parent: Some(Box::new(self)),
230            self_idx: idx,
231        }
232    }
233
234    const fn set_state(mut self, state: K::State) -> Self {
235        self.node.state = state;
236        self
237    }
238
239    fn insert_child(mut self, id: StateId<K>) -> Self {
240        self.node.children.push(Node::new(id));
241        self
242    }
243
244    fn parent(self) -> Self {
245        // Destructure this Zipper
246        // https://github.com/rust-lang/rust/issues/16293#issuecomment-185906859
247        let Self {
248            node,
249            parent,
250            self_idx,
251        } = self;
252
253        // Destructure the parent Zipper
254        let mut parent = *parent.unwrap();
255
256        // Insert the node of this Zipper back in its parent.
257        // Since we used swap_remove() to remove the child,
258        // we need to do the opposite of that.
259        parent.node.children.push(node);
260        let last_idx = parent.node.children.len() - 1;
261        parent.node.children.swap(self_idx, last_idx);
262
263        // Return a new Zipper focused on the parent.
264        Self {
265            node: parent.node,
266            parent: parent.parent,
267            self_idx: parent.self_idx,
268        }
269    }
270
271    //  try something like Iterator::fold
272    fn finish_insert(mut self, id: StateId<K>) -> ZipperNode<K> {
273        self.node.descendant_keys.insert(id);
274        while self.parent.is_some() {
275            self = self.parent();
276            self.node.descendant_keys.insert(id);
277        }
278
279        self.node
280    }
281
282    #[must_use]
283    pub fn finish_update(mut self) -> ZipperNode<K> {
284        while self.parent.is_some() {
285            self = self.parent();
286        }
287
288        self.node
289    }
290
291    // only act on parent nodes
292    fn finish_update_parent_id(self) -> (Option<StateId<K>>, ZipperNode<K>) {
293        let parent_id = self.parent.as_ref().map(|z| z.node.id);
294        (parent_id, self.finish_update())
295    }
296
297    // act on all nodes
298    fn finish_update_fn<F>(mut self, f: F) -> ZipperNode<K>
299    where
300        F: Fn(Self) -> ZipperNode<K> + Clone,
301    {
302        self.node.children = self
303            .node
304            .children
305            .into_iter()
306            .map(|n| n.zipper().finish_update_fn(f.clone()))
307            .collect();
308        f(self)
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::{node_state, Kind, State, StateId};
316
317    node_state!(Alice, Bob, Charlie, Dave, Eve);
318
319    #[test]
320    fn insert_child_state() {
321        let alice_id = StateId::new_rand(NodeKind::Alice);
322        let bob_id = StateId::new_rand(NodeKind::Bob);
323        let charlie_id = StateId::new_rand(NodeKind::Charlie);
324        let dave_id = StateId::new_rand(NodeKind::Dave);
325        let eve_id = StateId::new_rand(NodeKind::Eve);
326
327        let mut tree = Node::new(alice_id);
328
329        // =================================================
330        // Graph should look like this after four insertions:
331        //
332        //       (Alice)
333        //       /      \
334        //     (Bob) (Charlie)
335        //            /
336        //       (Dave)
337        //       /
338        //  (Eve)
339        // =================================================
340        tree.insert(Insert {
341            parent_id: Some(alice_id),
342            id: bob_id,
343        });
344        tree.insert(Insert {
345            parent_id: Some(alice_id),
346            id: charlie_id,
347        });
348        tree.insert(Insert {
349            parent_id: Some(charlie_id),
350            id: dave_id,
351        });
352        tree.insert(Insert {
353            parent_id: Some(dave_id),
354            id: eve_id,
355        });
356        // =================================================
357
358        // Bob =============================================
359        let mut bob = tree.get_state(bob_id).unwrap();
360        assert_eq!(bob, &NodeState::Bob(Bob::New));
361        tree = tree.into_update(Update {
362            id: bob_id,
363            state: NodeState::Bob(Bob::Awaiting),
364        });
365        bob = tree.get_state(bob_id).unwrap();
366        assert_eq!(bob, &NodeState::Bob(Bob::Awaiting));
367        // =================================================
368
369        // Charlie =========================================
370        let mut charlie = tree.get_state(charlie_id).unwrap();
371        assert_eq!(charlie, &NodeState::Charlie(Charlie::New));
372        tree = tree.into_update(Update {
373            id: charlie_id,
374            state: NodeState::Charlie(Charlie::Awaiting),
375        });
376        charlie = tree.get_state(charlie_id).unwrap();
377        assert_eq!(charlie, &NodeState::Charlie(Charlie::Awaiting));
378        // =================================================
379
380        // Dave ============================================
381        let mut dave = tree.get_state(dave_id).unwrap();
382        assert_eq!(dave, &NodeState::Dave(Dave::New));
383        // Dave finished whatever it was that Dave was doing
384        tree = tree.into_update(Update {
385            id: dave_id,
386            state: NodeState::Dave(Dave::Completed),
387        });
388        dave = tree.get_state(dave_id).unwrap();
389        assert_eq!(dave, &NodeState::Dave(Dave::Completed));
390        // =================================================
391
392        // Eve =============================================
393        let mut eve = tree.get_state(eve_id).unwrap();
394        assert_eq!(eve, &NodeState::Eve(Eve::New));
395        // Fail Eve (simulating timeout)
396        tree = tree.into_update(Update {
397            id: eve_id,
398            state: NodeState::Eve(Eve::Failed),
399        });
400        eve = tree.get_state(eve_id).unwrap();
401        assert_eq!(eve, &NodeState::Eve(Eve::Failed));
402        // =================================================
403
404        // =================================================
405        // Eve failed! Fail everyone!
406        // ...except for Dave, he is in "Completed" state
407        // =================================================
408        tree = tree.zipper().finish_update_fn(|mut z| {
409            let kind: NodeKind = *z.node.state.as_ref();
410            if !(z.node.state == kind.completed_state()) {
411                z.node.state = kind.failed_state();
412            }
413            z.finish_update()
414        });
415        assert_eq!(&tree.state, &NodeState::Alice(Alice::Failed));
416        assert_eq!(
417            tree.get_state(bob_id).unwrap(),
418            &NodeState::Bob(Bob::Failed)
419        );
420        assert_eq!(
421            tree.get_state(charlie_id).unwrap(),
422            &NodeState::Charlie(Charlie::Failed)
423        );
424        assert_eq!(
425            tree.get_state(dave_id).unwrap(),
426            &NodeState::Dave(Dave::Completed)
427        );
428        assert_eq!(
429            tree.get_state(eve_id).unwrap(),
430            &NodeState::Eve(Eve::Failed)
431        );
432    }
433}