1use std::collections::HashSet;
2
3use crate::{HashKind, Kind, StateId};
4
5#[derive(Debug)]
6pub struct Insert<Id> {
7 pub parent_id: Option<Id>,
8 pub id: Id,
9}
10
11pub struct Update<Id, S> {
12 pub id: Id,
13 pub state: S,
14}
15
16#[derive(Debug)]
17pub struct Node<Id, S> {
18 pub id: Id,
19 pub state: S,
20 descendant_keys: HashSet<Id>, pub children: Vec<Node<Id, S>>,
22}
23
24impl<K> Node<StateId<K>, K::State>
25where
26 K: Kind + HashKind,
27{
28 #[must_use]
29 pub fn new(id: StateId<K>) -> Self {
30 Self {
31 state: id.new_state(),
32 id,
33 descendant_keys: HashSet::new(),
34 children: Vec::new(),
35 }
36 }
37
38 #[must_use]
39 pub const fn zipper(self) -> Zipper<StateId<K>, K::State> {
40 Zipper {
41 node: self,
42 parent: None,
43 self_idx: 0,
44 }
45 }
46
47 #[must_use]
48 pub fn get(&self, id: StateId<K>) -> Option<&Self> {
49 if self.id == id {
50 return Some(self);
51 }
52 if !self.descendant_keys.contains(&id) {
53 return None;
54 }
55
56 let mut node = self;
57 while node.descendant_keys.contains(&id) {
58 node = node.child(id).unwrap();
59 }
60 Some(node)
61 }
62
63 #[must_use]
64 pub fn get_state(&self, id: StateId<K>) -> Option<&K::State> {
65 self.get(id).map(|n| &n.state)
66 }
67
68 #[must_use]
69 pub fn child(&self, id: StateId<K>) -> Option<&Self> {
70 self.children
71 .iter()
72 .find(|node| node.id == id || node.descendant_keys.contains(&id))
73 }
74
75 #[must_use]
77 pub fn child_idx(&self, id: StateId<K>) -> Option<usize> {
78 self.children
79 .iter()
80 .enumerate()
81 .find(|(_idx, node)| node.id == id || node.descendant_keys.contains(&id))
82 .map(|(idx, _)| idx)
83 }
84
85 pub fn insert(&mut self, insert: Insert<StateId<K>>) {
86 let mut swap_node = Self::new(self.id);
92 std::mem::swap(&mut swap_node, self);
93
94 swap_node = swap_node.into_insert(insert);
95
96 std::mem::swap(&mut swap_node, self);
97 }
98
99 #[must_use]
101 pub fn into_insert(self, Insert { parent_id, id }: Insert<StateId<K>>) -> Self {
102 let parent_id = parent_id.unwrap();
106
107 self.zipper()
108 .by_id(parent_id)
109 .insert_child(id)
110 .finish_insert(id)
111 }
112
113 #[must_use]
114 pub fn get_parent_id(&self, id: StateId<K>) -> Option<StateId<K>> {
115 if !self.descendant_keys.contains(&id) {
117 return None;
118 }
119
120 let mut node = self;
121 while node.descendant_keys.contains(&id) {
122 let child_node = node.child(id).unwrap();
123 if child_node.id == id {
124 return Some(node.id);
125 }
126 node = child_node;
127 }
128
129 None
130 }
131
132 pub fn update(&mut self, update: Update<StateId<K>, K::State>) {
133 let mut swap_node = Self::new(self.id);
135 std::mem::swap(&mut swap_node, self);
136
137 swap_node = swap_node.into_update(update);
138
139 std::mem::swap(&mut swap_node, self);
140 }
141
142 pub fn update_and_get_parent_id(
144 &mut self,
145 Update { id, state }: Update<StateId<K>, K::State>,
146 ) -> Option<StateId<K>> {
147 let mut swap_node = Self::new(self.id);
149 std::mem::swap(&mut swap_node, self);
150
151 let (parent_id, mut swap_node) = swap_node
152 .zipper()
153 .by_id(id)
154 .set_state(state)
155 .finish_update_parent_id();
156
157 std::mem::swap(&mut swap_node, self);
158
159 parent_id
160 }
161
162 pub fn update_all_fn<F>(&mut self, f: F)
164 where
165 F: Fn(Zipper<StateId<K>, K::State>) -> Self + Clone,
166 {
167 let mut swap_node = Self::new(self.id);
169 std::mem::swap(&mut swap_node, self);
170
171 swap_node = swap_node.zipper().finish_update_fn(f);
172
173 std::mem::swap(&mut swap_node, self);
174 }
175
176 #[must_use]
177 pub fn into_update(self, Update { id, state }: Update<StateId<K>, K::State>) -> Self {
178 self.zipper().by_id(id).set_state(state).finish_update()
179 }
180}
181
182pub struct Zipper<Id, S> {
194 pub node: Node<Id, S>,
195 pub parent: Option<Box<Zipper<Id, S>>>,
196 self_idx: usize,
197}
198
199type ZipperNode<K> = Node<StateId<K>, <K as Kind>::State>;
200
201impl<K> Zipper<StateId<K>, K::State>
202where
203 K: Kind + HashKind,
204{
205 fn by_id(mut self, id: StateId<K>) -> Self {
206 let mut contains_id = self.node.descendant_keys.contains(&id);
207 while contains_id {
208 let idx = self.node.child_idx(id).unwrap();
209 self = self.child(idx);
210 contains_id = self.node.descendant_keys.contains(&id);
211 }
212 assert!(
213 !(self.node.id != id),
214 "id[{id}] should be in the node, this is a bug"
215 );
216 self
217 }
218
219 fn child(mut self, idx: usize) -> Self {
220 let child = self.node.children.swap_remove(idx);
225
226 Self {
228 node: child,
229 parent: Some(Box::new(self)),
230 self_idx: idx,
231 }
232 }
233
234 const fn set_state(mut self, state: K::State) -> Self {
235 self.node.state = state;
236 self
237 }
238
239 fn insert_child(mut self, id: StateId<K>) -> Self {
240 self.node.children.push(Node::new(id));
241 self
242 }
243
244 fn parent(self) -> Self {
245 let Self {
248 node,
249 parent,
250 self_idx,
251 } = self;
252
253 let mut parent = *parent.unwrap();
255
256 parent.node.children.push(node);
260 let last_idx = parent.node.children.len() - 1;
261 parent.node.children.swap(self_idx, last_idx);
262
263 Self {
265 node: parent.node,
266 parent: parent.parent,
267 self_idx: parent.self_idx,
268 }
269 }
270
271 fn finish_insert(mut self, id: StateId<K>) -> ZipperNode<K> {
273 self.node.descendant_keys.insert(id);
274 while self.parent.is_some() {
275 self = self.parent();
276 self.node.descendant_keys.insert(id);
277 }
278
279 self.node
280 }
281
282 #[must_use]
283 pub fn finish_update(mut self) -> ZipperNode<K> {
284 while self.parent.is_some() {
285 self = self.parent();
286 }
287
288 self.node
289 }
290
291 fn finish_update_parent_id(self) -> (Option<StateId<K>>, ZipperNode<K>) {
293 let parent_id = self.parent.as_ref().map(|z| z.node.id);
294 (parent_id, self.finish_update())
295 }
296
297 fn finish_update_fn<F>(mut self, f: F) -> ZipperNode<K>
299 where
300 F: Fn(Self) -> ZipperNode<K> + Clone,
301 {
302 self.node.children = self
303 .node
304 .children
305 .into_iter()
306 .map(|n| n.zipper().finish_update_fn(f.clone()))
307 .collect();
308 f(self)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::{node_state, Kind, State, StateId};
316
317 node_state!(Alice, Bob, Charlie, Dave, Eve);
318
319 #[test]
320 fn insert_child_state() {
321 let alice_id = StateId::new_rand(NodeKind::Alice);
322 let bob_id = StateId::new_rand(NodeKind::Bob);
323 let charlie_id = StateId::new_rand(NodeKind::Charlie);
324 let dave_id = StateId::new_rand(NodeKind::Dave);
325 let eve_id = StateId::new_rand(NodeKind::Eve);
326
327 let mut tree = Node::new(alice_id);
328
329 tree.insert(Insert {
341 parent_id: Some(alice_id),
342 id: bob_id,
343 });
344 tree.insert(Insert {
345 parent_id: Some(alice_id),
346 id: charlie_id,
347 });
348 tree.insert(Insert {
349 parent_id: Some(charlie_id),
350 id: dave_id,
351 });
352 tree.insert(Insert {
353 parent_id: Some(dave_id),
354 id: eve_id,
355 });
356 let mut bob = tree.get_state(bob_id).unwrap();
360 assert_eq!(bob, &NodeState::Bob(Bob::New));
361 tree = tree.into_update(Update {
362 id: bob_id,
363 state: NodeState::Bob(Bob::Awaiting),
364 });
365 bob = tree.get_state(bob_id).unwrap();
366 assert_eq!(bob, &NodeState::Bob(Bob::Awaiting));
367 let mut charlie = tree.get_state(charlie_id).unwrap();
371 assert_eq!(charlie, &NodeState::Charlie(Charlie::New));
372 tree = tree.into_update(Update {
373 id: charlie_id,
374 state: NodeState::Charlie(Charlie::Awaiting),
375 });
376 charlie = tree.get_state(charlie_id).unwrap();
377 assert_eq!(charlie, &NodeState::Charlie(Charlie::Awaiting));
378 let mut dave = tree.get_state(dave_id).unwrap();
382 assert_eq!(dave, &NodeState::Dave(Dave::New));
383 tree = tree.into_update(Update {
385 id: dave_id,
386 state: NodeState::Dave(Dave::Completed),
387 });
388 dave = tree.get_state(dave_id).unwrap();
389 assert_eq!(dave, &NodeState::Dave(Dave::Completed));
390 let mut eve = tree.get_state(eve_id).unwrap();
394 assert_eq!(eve, &NodeState::Eve(Eve::New));
395 tree = tree.into_update(Update {
397 id: eve_id,
398 state: NodeState::Eve(Eve::Failed),
399 });
400 eve = tree.get_state(eve_id).unwrap();
401 assert_eq!(eve, &NodeState::Eve(Eve::Failed));
402 tree = tree.zipper().finish_update_fn(|mut z| {
409 let kind: NodeKind = *z.node.state.as_ref();
410 if !(z.node.state == kind.completed_state()) {
411 z.node.state = kind.failed_state();
412 }
413 z.finish_update()
414 });
415 assert_eq!(&tree.state, &NodeState::Alice(Alice::Failed));
416 assert_eq!(
417 tree.get_state(bob_id).unwrap(),
418 &NodeState::Bob(Bob::Failed)
419 );
420 assert_eq!(
421 tree.get_state(charlie_id).unwrap(),
422 &NodeState::Charlie(Charlie::Failed)
423 );
424 assert_eq!(
425 tree.get_state(dave_id).unwrap(),
426 &NodeState::Dave(Dave::Completed)
427 );
428 assert_eq!(
429 tree.get_state(eve_id).unwrap(),
430 &NodeState::Eve(Eve::Failed)
431 );
432 }
433}