1use stepflow_base::ObjectStore;
2use stepflow_step::{Step, StepId};
3use super::{Error};
4
5#[derive(PartialEq, Clone, Debug)]
6enum DFSDirection {
7 Down,
8 SiblingOrUp,
9 Done,
10}
11
12#[derive(Debug)]
13enum DFSStep {
14 DownTo(StepId),
15 SiblingTo(StepId),
16 CannotGoto(Error),
17 CannotLeaveForSibling(Error),
18 NothingMoreDown,
19 NothingMoreInStack,
20 PoppedUp,
21}
22
23#[derive(Debug)]
24pub struct DepthFirstSearch {
25 stack: Vec<StepId>,
26 next_direction: DFSDirection,
27}
28
29impl DepthFirstSearch {
30 pub fn new(root: StepId) -> Self {
31 DepthFirstSearch {
32 stack: vec![root],
33 next_direction: DFSDirection::Down,
34 }
35 }
36
37 pub fn current(&self) -> Option<&StepId> {
38 self.stack.last()
39 }
40
41 fn next_sibling_of_current<'store>(&self, step_store: &'store ObjectStore<Step, StepId>) -> Option<&'store StepId> {
42 let stack_len = self.stack.len();
43 if stack_len < 2 {
44 return None;
45 }
46 let current_id = self.stack.get(stack_len - 1).unwrap();
47 let parent_id = self.stack.get(stack_len - 2).unwrap();
48 let parent_step = step_store.get(parent_id)?;
49 parent_step.next_substep(current_id)
50 }
51
52 fn first_child_of<'stateid, 'store>(&self, step_id: &'stateid StepId, step_store: &'store ObjectStore<Step, StepId>) -> Option<&'store StepId> {
53 let step = step_store.get(step_id)?;
54 step.first_substep()
55 }
56
57 fn go_down<FnCanEnter>(&mut self, mut can_enter: FnCanEnter, step_store: &ObjectStore<Step, StepId>) -> DFSStep
58 where FnCanEnter: FnMut(&StepId) -> Result<(), Error>
59 {
60 let step_id_option = self.stack.last();
62 if step_id_option.is_none() {
63 return DFSStep::NothingMoreInStack;
64 }
65 let step_id = step_id_option.unwrap();
66
67 match self.first_child_of(step_id, step_store) {
69 Some(first_child) => {
70 if let Err(e) = can_enter(&first_child) {
71 return DFSStep::CannotGoto(e);
72 }
73 self.stack.push(first_child.clone());
74 DFSStep::DownTo(first_child.clone())
75 },
76 None => DFSStep::NothingMoreDown,
77 }
78 }
79
80 fn go_sibling_or_up<FnCanEnter, FnCanExit>(&mut self, can_enter: &mut FnCanEnter, mut can_exit: FnCanExit, step_store: &ObjectStore<Step, StepId>) -> DFSStep
81 where FnCanEnter: FnMut(&StepId) -> Result<(), Error>,
82 FnCanExit: FnMut(&StepId) -> Result<(), Error>
83 {
84 let top_stack = self.stack.last();
86 if top_stack.is_none() {
87 return DFSStep::NothingMoreInStack;
88 }
89
90 if let Err(e) = can_exit(top_stack.as_ref().unwrap()) {
92 return DFSStep::CannotLeaveForSibling(e);
93 }
94
95 match self.next_sibling_of_current(step_store) {
96 Some(next_sibling) => {
97 if let Err(e) = can_enter(next_sibling) {
98 return DFSStep::CannotGoto(e);
99 }
100 self.stack.pop();
101 self.stack.push(next_sibling.clone());
102 DFSStep::SiblingTo(next_sibling.clone())
103 },
104 None => {
105 self.stack.pop();
106 DFSStep::PoppedUp
107 }
108 }
109 }
110
111 pub fn next<FnCanEnter, FnCanExit>(&mut self, mut can_enter: FnCanEnter, mut can_exit: FnCanExit, step_store: &ObjectStore<Step, StepId>)
112 -> Result<Option<StepId>, Error>
113 where FnCanEnter: FnMut(&StepId) -> Result<(), Error>,
114 FnCanExit: FnMut(&StepId) -> Result<(), Error>
115 {
116 let mut next_direction = self.next_direction.clone();
117 let mut err: Option<Error> = None;
118 while err == None {
119 let step_result = match next_direction {
120 DFSDirection::Down => self.go_down(&mut can_enter, step_store),
121 DFSDirection::SiblingOrUp => self.go_sibling_or_up(&mut can_enter, &mut can_exit, step_store),
122 DFSDirection::Done => DFSStep::NothingMoreInStack,
123 };
124
125 next_direction = match step_result {
126 DFSStep::DownTo(_to_step_id) => DFSDirection::Down, DFSStep::SiblingTo(_to_sibling) => DFSDirection::Down, DFSStep::NothingMoreDown => {
132 next_direction = DFSDirection::SiblingOrUp;
133 break;
134 },
135
136 DFSStep::PoppedUp => DFSDirection::SiblingOrUp,
138
139 DFSStep::CannotGoto(step_err) |
141 DFSStep::CannotLeaveForSibling(step_err) => {
142 err = Some(step_err);
144 next_direction
145 },
146 DFSStep::NothingMoreInStack => {
147 next_direction = DFSDirection::Done;
148 break;
149 },
150 }
151 }
152 self.next_direction = next_direction;
153 if let Some(e) = err {
154 Err(e)
155 } else if self.next_direction == DFSDirection::Done {
156 Ok(None)
157 } else {
158 self.stack.last().map(|stack_id| Some(stack_id.clone())).ok_or(Error::NoStateToEval)
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use stepflow_base::ObjectStore;
166 use stepflow_step::{Step, StepId};
167 use super::{DepthFirstSearch, Error};
168
169 fn check_fail(fail: Option<&StepId>, step_id_check: &StepId, has_failed: &mut bool) -> Result<(), Error> {
170 if *has_failed {
171 return Ok(());
172 }
173 if let Some(step_id_fail) = fail {
174 if step_id_fail == step_id_check {
175 *has_failed = true;
176 return Err(Error::InvalidStateDataError)
177 }
178 }
179 Ok(())
180 }
181
182 fn assert_dfs_order(root: StepId, step_store: &ObjectStore<Step, StepId>, expected_children: &Vec<StepId>, fail_on_enter: Option<&StepId>, fail_on_exit: Option<&StepId>) {
183 let mut dfs = DepthFirstSearch::new(root);
184 let mut count_matches = 0;
185 let mut failed_enter = false;
186 let mut failed_exit = false;
187 let mut expected_iter = expected_children.iter();
188 let mut expected_child_opt = expected_iter.next();
189 loop {
190 if expected_child_opt.is_none() {
192 break;
193 }
194 let expected_child = expected_child_opt.unwrap();
195
196 let next = dfs.next(|step_id: &StepId| {
198 check_fail(fail_on_enter, step_id, &mut failed_enter)
199 },
200 |step_id: &StepId| {
201 check_fail(fail_on_exit, step_id, &mut failed_exit)
202 },
203 step_store);
204
205 match next {
207 Ok(step_id_opt) => {
208 if let Some(step_id) = step_id_opt {
209 if step_id != *expected_child {
210 break;
211 } else {
212 count_matches = count_matches + 1;
213 expected_child_opt = expected_iter.next();
214 }
215 } else {
216 break;
218 }
219 },
220 Err(err) => {
221 assert_eq!(err, Error::InvalidStateDataError);
223 }
224 }
225 }
226 assert_eq!(count_matches, expected_children.len());
227
228 for pass in 0..1 {
231 let final_next = dfs.next(|step_id: &StepId| {
232 check_fail(fail_on_enter, step_id, &mut failed_enter)
233 },
234 |step_id: &StepId| {
235 check_fail(fail_on_exit, step_id, &mut failed_exit)
236 },
237 step_store);
238
239 match final_next {
240 Ok(step_id_opt) => assert_eq!(step_id_opt, None),
241 Err(err) => {
242 assert_eq!(pass, 0); assert_eq!(err, Error::InvalidStateDataError);
244 }
245 }
246 }
247
248 if fail_on_enter.is_some() {
250 assert_eq!(failed_enter, true);
251 }
252 if fail_on_exit.is_some() {
253 assert_eq!(failed_exit, true);
254 }
255 }
256
257 fn assert_dfs_order_with_failures(root: StepId, step_store: &ObjectStore<Step, StepId>, expected_children: &Vec<StepId>) {
258 assert_dfs_order(root.clone(), step_store, expected_children, None, None);
259 for ienter in 0..expected_children.len() {
260 for iexit in 0..expected_children.len() {
261 assert_dfs_order(root.clone(), step_store, expected_children, Some(&expected_children[ienter]), Some(&expected_children[iexit]));
262 }
263 }
264 }
265
266 fn add_substeps(num: usize, parent_id: &StepId, step_store: &mut ObjectStore<Step, StepId>) -> Vec<StepId> {
267 let mut result = Vec::new();
268 for _ in 0..num {
269 let substep_id = step_store.insert_new(|id| Ok(Step::new(id, None, vec![]))).unwrap();
270 let parent_step = step_store.get_mut(parent_id).unwrap();
271 parent_step.push_substep(substep_id.clone());
272 result.push(substep_id);
273 }
274 result
275 }
276
277 #[test]
278 fn one_deep() {
279 let mut step_store: ObjectStore<Step, StepId> = ObjectStore::new();
280 let root = step_store.insert_new(|id| Ok(Step::new(id, None, vec![]))).unwrap();
281 let child_ids = add_substeps(2, &root, &mut step_store);
282 assert_dfs_order_with_failures(root, &step_store, &child_ids);
283 }
284
285 #[test]
286 fn two_deep() {
287 let mut step_store: ObjectStore<Step, StepId> = ObjectStore::new();
288 let root = step_store.insert_new(|id| Ok(Step::new(id, None, vec![]))).unwrap();
289 let root_children = add_substeps(2, &root, &mut step_store);
290 let children_1 = add_substeps(3, &root_children[0], &mut step_store);
291 let children_2 = add_substeps(3, &root_children[1], &mut step_store);
292
293 let mut expected_children = Vec::new();
294 expected_children.extend(children_1);
295 expected_children.extend(children_2);
296 assert_dfs_order_with_failures(root, &step_store, &expected_children);
297 }
298
299 #[test]
300 fn mixed_depth() {
301 let mut step_store: ObjectStore<Step, StepId> = ObjectStore::new();
302 let root = step_store.insert_new(|id| Ok(Step::new(id, None, vec![]))).unwrap();
303 let root_children = add_substeps(3, &root, &mut step_store);
304 let children1 = add_substeps(1, &root_children[0].clone(), &mut step_store);
305 let children3 = add_substeps(3, &root_children[2].clone(), &mut step_store);
306 let children3_children2 = add_substeps(3, &children3[1].clone(), &mut step_store);
307
308 let mut expected_children = Vec::new();
309 expected_children.extend(children1);
310 expected_children.push(root_children[1].clone());
311 expected_children.push(children3[0].clone());
312 expected_children.extend(children3_children2);
313 expected_children.push(children3[2].clone());
314
315 assert_dfs_order_with_failures(root, &step_store, &expected_children);
316 }
317}