floxide_core/
workflow.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use tracing::{debug, error, info, warn};
6
7use crate::action::{ActionType, DefaultAction};
8use crate::error::FloxideError;
9use crate::node::{Node, NodeId, NodeOutcome};
10
11/// Workflow execution error
12#[derive(Debug, thiserror::Error)]
13pub enum WorkflowError {
14    /// Initial node not found
15    #[error("Initial node not found: {0}")]
16    InitialNodeNotFound(NodeId),
17
18    /// Node not found
19    #[error("Node not found: {0}")]
20    NodeNotFound(NodeId),
21
22    /// Action not handled
23    #[error("Action not handled: {0}")]
24    ActionNotHandled(String),
25
26    /// Node execution error
27    #[error("Node execution error: {0}")]
28    NodeExecution(#[from] FloxideError),
29}
30
31/// A workflow composed of connected nodes
32pub struct Workflow<Context, A = DefaultAction, Output = ()>
33where
34    A: ActionType,
35{
36    /// Initial node to start workflow execution from
37    start_node: NodeId,
38
39    /// All nodes in the workflow
40    pub(crate) nodes: HashMap<NodeId, Arc<dyn Node<Context, A, Output = Output>>>,
41
42    /// Connections between nodes and actions
43    edges: HashMap<(NodeId, A), NodeId>,
44
45    /// Default fallback routes for nodes
46    default_routes: HashMap<NodeId, NodeId>,
47
48    /// Whether to allow cycles in the workflow
49    allow_cycles: bool,
50
51    /// Maximum number of times a node can be visited (0 means no limit)
52    cycle_limit: usize,
53}
54
55impl<Context, A, Output> Workflow<Context, A, Output>
56where
57    Context: Send + Sync + 'static,
58    A: ActionType + Debug + Default + Clone + Send + Sync + 'static,
59    Output: Send + Sync + 'static + std::fmt::Debug,
60{
61    /// Create a new workflow with the given start node
62    pub fn new<N>(start_node: N) -> Self
63    where
64        N: Node<Context, A, Output = Output> + 'static,
65    {
66        let id = start_node.id();
67        let mut nodes = HashMap::new();
68        nodes.insert(
69            id.clone(),
70            Arc::new(start_node) as Arc<dyn Node<Context, A, Output = Output>>,
71        );
72
73        Self {
74            start_node: id,
75            nodes,
76            edges: HashMap::new(),
77            default_routes: HashMap::new(),
78            allow_cycles: false,
79            cycle_limit: 0,
80        }
81    }
82
83    /// Add a node to the workflow
84    pub fn add_node<N>(&mut self, node: N) -> &mut Self
85    where
86        N: Node<Context, A, Output = Output> + 'static,
87    {
88        let id = node.id();
89        self.nodes.insert(
90            id,
91            Arc::new(node) as Arc<dyn Node<Context, A, Output = Output>>,
92        );
93        self
94    }
95
96    /// Connect a node to another node with an action
97    pub fn connect(&mut self, from: &NodeId, action: A, to: &NodeId) -> &mut Self {
98        self.edges.insert((from.clone(), action), to.clone());
99        self
100    }
101
102    /// Set a default route from one node to another (used when no specific action matches)
103    pub fn set_default_route(&mut self, from: &NodeId, to: &NodeId) -> &mut Self {
104        self.default_routes.insert(from.clone(), to.clone());
105        self
106    }
107
108    /// Get a node by its ID
109    pub fn get_node(&self, id: NodeId) -> Option<&dyn Node<Context, A, Output = Output>> {
110        self.nodes.get(&id).map(|node| node.as_ref())
111    }
112
113    /// Configure whether to allow cycles in the workflow
114    ///
115    /// By default, cycles are not allowed and will result in a WorkflowCycleDetected error.
116    /// When cycles are allowed, the workflow will continue execution even when revisiting nodes.
117    ///
118    /// # Arguments
119    /// * `allow` - Whether to allow cycles in the workflow
120    pub fn allow_cycles(&mut self, allow: bool) -> &mut Self {
121        self.allow_cycles = allow;
122        self
123    }
124
125    /// Set a limit on the number of times a node can be visited
126    ///
127    /// This is only relevant when cycles are allowed. If a node is visited more than
128    /// the specified limit, a WorkflowCycleDetected error will be returned.
129    ///
130    /// # Arguments
131    /// * `limit` - Maximum number of times a node can be visited (0 means no limit)
132    pub fn set_cycle_limit(&mut self, limit: usize) -> &mut Self {
133        self.cycle_limit = limit;
134        self
135    }
136
137    /// Execute the workflow with the given context
138    pub async fn execute(&self, ctx: &mut Context) -> Result<Output, WorkflowError> {
139        let mut current_node_id = self.start_node.clone();
140        let mut visit_counts = HashMap::new();
141
142        info!(start_node = %current_node_id, "Starting workflow execution");
143        debug!(node = %current_node_id, "Starting workflow execution from node");
144
145        // Debug info for connections
146        debug!("Node connections:");
147        for ((from, action), to) in &self.edges {
148            debug!(from = %from, action = ?action, to = %to, "Connection");
149        }
150
151        debug!("Default routes:");
152        for (from, to) in &self.default_routes {
153            debug!(from = %from, to = %to, "Default route");
154        }
155
156        loop {
157            // Check if we've visited this node before
158            let visit_count = visit_counts.entry(current_node_id.clone()).or_insert(0);
159            *visit_count += 1;
160
161            // Check for cycles
162            if !self.allow_cycles && *visit_count > 1 {
163                // Cycles are not allowed and we've already visited this node
164                error!(
165                    node_id = %current_node_id,
166                    "Cycle detected in workflow execution"
167                );
168                return Err(WorkflowError::NodeExecution(
169                    FloxideError::WorkflowCycleDetected,
170                ));
171            }
172
173            // Check cycle limit if specified (0 means no limit)
174            if self.cycle_limit > 0 && *visit_count > self.cycle_limit {
175                error!(
176                    node_id = %current_node_id,
177                    visit_count = %visit_count,
178                    limit = %self.cycle_limit,
179                    "Cycle limit exceeded in workflow execution"
180                );
181                return Err(WorkflowError::NodeExecution(
182                    FloxideError::WorkflowCycleDetected,
183                ));
184            }
185
186            let node = self.nodes.get(&current_node_id).ok_or_else(|| {
187                error!(node_id = %current_node_id, "Node not found in workflow");
188                WorkflowError::NodeNotFound(current_node_id.clone())
189            })?;
190
191            debug!(node_id = %current_node_id, visit_count = %visit_count, "Executing node");
192
193            let outcome = node
194                .process(ctx)
195                .await
196                .map_err(WorkflowError::NodeExecution)?;
197
198            match &outcome {
199                NodeOutcome::Success(_) => {
200                    info!(node_id = %current_node_id, "Node completed successfully with Success outcome");
201                }
202                NodeOutcome::Skipped => {
203                    info!(node_id = %current_node_id, "Node completed with Skipped outcome");
204                }
205                NodeOutcome::RouteToAction(action) => {
206                    info!(node_id = %current_node_id, action = %action.name(), action_debug = ?action, "Node completed with RouteToAction outcome");
207                }
208            }
209
210            match outcome {
211                NodeOutcome::Success(output) => {
212                    info!(node_id = %current_node_id, "Node completed successfully");
213                    // Find the default route if there is one
214                    if let Some(next) = self.default_routes.get(&current_node_id) {
215                        debug!(
216                            node_id = %current_node_id,
217                            next_node = %next,
218                            "Following default route"
219                        );
220                        current_node_id = next.clone();
221                    } else {
222                        debug!(node_id = %current_node_id, "Workflow execution completed");
223                        return Ok(output);
224                    }
225                }
226                NodeOutcome::Skipped => {
227                    warn!(node_id = %current_node_id, "Node was skipped");
228                    // Find the default route if there is one
229                    if let Some(next) = self.default_routes.get(&current_node_id) {
230                        debug!(
231                            node_id = %current_node_id,
232                            next_node = %next,
233                            "Following default route after skip"
234                        );
235                        current_node_id = next.clone();
236                    } else {
237                        warn!(node_id = %current_node_id, "Node was skipped but no default route exists");
238                        return Err(WorkflowError::ActionNotHandled(
239                            "Skipped node without default route".into(),
240                        ));
241                    }
242                }
243                NodeOutcome::RouteToAction(action) => {
244                    debug!(
245                        node_id = %current_node_id,
246                        action = ?action,
247                        "Node routed to action"
248                    );
249
250                    // Look for an edge matching this action
251                    if let Some(next) = self.edges.get(&(current_node_id.clone(), action.clone())) {
252                        debug!(
253                            node_id = %current_node_id,
254                            action = ?action,
255                            next_node = %next,
256                            "Following edge for action"
257                        );
258                        current_node_id = next.clone();
259                    }
260                    // If no matching edge, try the default action if this wasn't already the default action
261                    else if action != A::default() {
262                        if let Some(next) = self.edges.get(&(current_node_id.clone(), A::default()))
263                        {
264                            debug!(
265                                node_id = %current_node_id,
266                                next_node = %next,
267                                "No edge for action, following default action"
268                            );
269                            current_node_id = next.clone();
270                        } else if let Some(next) = self.default_routes.get(&current_node_id) {
271                            debug!(
272                                node_id = %current_node_id,
273                                next_node = %next,
274                                "No edge for action or default action, following default route"
275                            );
276                            current_node_id = next.clone();
277                        } else {
278                            error!(
279                                node_id = %current_node_id,
280                                action = ?action,
281                                "No edge found for action and no default route"
282                            );
283
284                            // Debug info for connections
285                            error!(
286                                "Available edges: {:?}",
287                                self.edges
288                                    .iter()
289                                    .map(|((from, act), to)| format!(
290                                        "{} -[{:?}]-> {}",
291                                        from, act, to
292                                    ))
293                                    .collect::<Vec<_>>()
294                            );
295                            error!(
296                                "Default routes: {:?}",
297                                self.default_routes
298                                    .iter()
299                                    .map(|(from, to)| format!("{} -> {}", from, to))
300                                    .collect::<Vec<_>>()
301                            );
302
303                            return Err(WorkflowError::ActionNotHandled(format!("{:?}", action)));
304                        }
305                    } else if let Some(next) = self.default_routes.get(&current_node_id) {
306                        debug!(
307                            node_id = %current_node_id,
308                            next_node = %next,
309                            "No edge for default action, following default route"
310                        );
311                        current_node_id = next.clone();
312                    } else {
313                        error!(
314                            node_id = %current_node_id,
315                            action = ?action,
316                            "No edge found for default action and no default route"
317                        );
318
319                        // Debug info for connections
320                        error!(
321                            "Available edges: {:?}",
322                            self.edges
323                                .iter()
324                                .map(|((from, act), to)| format!("{} -[{:?}]-> {}", from, act, to))
325                                .collect::<Vec<_>>()
326                        );
327                        error!(
328                            "Default routes: {:?}",
329                            self.default_routes
330                                .iter()
331                                .map(|(from, to)| format!("{} -> {}", from, to))
332                                .collect::<Vec<_>>()
333                        );
334
335                        return Err(WorkflowError::ActionNotHandled(
336                            "Default action not handled".into(),
337                        ));
338                    }
339                }
340            }
341        }
342    }
343}
344
345// Implement Clone for Workflow
346impl<Context, A, Output> Clone for Workflow<Context, A, Output>
347where
348    Context: Send + Sync + 'static,
349    A: ActionType + Clone + Send + Sync + 'static,
350    Output: Send + Sync + 'static,
351{
352    fn clone(&self) -> Self {
353        Self {
354            start_node: self.start_node.clone(),
355            nodes: self.nodes.clone(), // Cloning HashMap<NodeId, Arc<dyn Node>> is now possible
356            edges: self.edges.clone(),
357            default_routes: self.default_routes.clone(),
358            allow_cycles: self.allow_cycles,
359            cycle_limit: self.cycle_limit,
360        }
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::node::closure;
368
369    #[derive(Debug, Clone)]
370    struct TestContext {
371        value: i32,
372        visited: Vec<String>,
373    }
374
375    #[tokio::test]
376    async fn test_simple_linear_workflow() {
377        // Create nodes
378        let start_node = closure::node(|mut ctx: TestContext| async move {
379            ctx.value += 1;
380            ctx.visited.push("start".to_string());
381            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
382        });
383
384        let middle_node = closure::node(|mut ctx: TestContext| async move {
385            ctx.value *= 2;
386            ctx.visited.push("middle".to_string());
387            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
388        });
389
390        let end_node = closure::node(|mut ctx: TestContext| async move {
391            ctx.value -= 3;
392            ctx.visited.push("end".to_string());
393            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
394        });
395
396        // Build workflow
397        let mut workflow = Workflow::new(start_node);
398        let start_id = workflow.start_node.clone();
399        let middle_id = middle_node.id();
400        let end_id = end_node.id();
401
402        workflow
403            .add_node(middle_node)
404            .add_node(end_node)
405            .set_default_route(&start_id, &middle_id)
406            .set_default_route(&middle_id, &end_id);
407
408        // Execute workflow
409        let mut ctx = TestContext {
410            value: 10,
411            visited: vec![],
412        };
413
414        let result = workflow.execute(&mut ctx).await;
415        assert!(result.is_ok());
416
417        // Check final state
418        assert_eq!(ctx.value, 19); // 10 + 1 = 11 -> 11 * 2 = 22 -> 22 - 3 = 19
419        assert_eq!(ctx.visited, vec!["start", "middle", "end"]);
420    }
421
422    #[tokio::test]
423    async fn test_workflow_with_routing() {
424        // Define custom action
425        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
426        enum TestAction {
427            Default,
428            Route1,
429            Route2,
430        }
431
432        impl Default for TestAction {
433            fn default() -> Self {
434                Self::Default
435            }
436        }
437
438        impl ActionType for TestAction {
439            fn name(&self) -> &str {
440                match self {
441                    Self::Default => "default",
442                    Self::Route1 => "route1",
443                    Self::Route2 => "route2",
444                }
445            }
446        }
447
448        // Create nodes
449        let start_node = closure::node(|mut ctx: TestContext| async move {
450            ctx.visited.push("start".to_string());
451            // Route based on value
452            if ctx.value > 5 {
453                Ok((
454                    ctx,
455                    NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route1),
456                ))
457            } else {
458                Ok((
459                    ctx,
460                    NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route2),
461                ))
462            }
463        });
464
465        let path1_node = closure::node(|mut ctx: TestContext| async move {
466            ctx.value += 100;
467            ctx.visited.push("path1".to_string());
468            Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
469        });
470
471        let path2_node = closure::node(|mut ctx: TestContext| async move {
472            ctx.value *= 10;
473            ctx.visited.push("path2".to_string());
474            Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
475        });
476
477        // Build workflow
478        let mut workflow = Workflow::<_, TestAction, _>::new(start_node);
479        let start_id = workflow.start_node.clone();
480        let path1_id = path1_node.id();
481        let path2_id = path2_node.id();
482
483        workflow
484            .add_node(path1_node)
485            .add_node(path2_node)
486            .connect(&start_id, TestAction::Route1, &path1_id)
487            .connect(&start_id, TestAction::Route2, &path2_id);
488
489        // Execute workflow - should take path 1
490        let mut ctx1 = TestContext {
491            value: 10,
492            visited: vec![],
493        };
494
495        let result1 = workflow.execute(&mut ctx1).await;
496        assert!(result1.is_ok());
497        assert_eq!(ctx1.value, 110); // 10 -> route1 -> +100 = 110
498        assert_eq!(ctx1.visited, vec!["start", "path1"]);
499
500        // Execute workflow - should take path 2
501        let mut ctx2 = TestContext {
502            value: 3,
503            visited: vec![],
504        };
505
506        let result2 = workflow.execute(&mut ctx2).await;
507        assert!(result2.is_ok());
508        assert_eq!(ctx2.value, 30); // 3 -> route2 -> *10 = 30
509        assert_eq!(ctx2.visited, vec!["start", "path2"]);
510    }
511
512    #[tokio::test]
513    async fn test_workflow_with_skipped_node() {
514        let start_node = closure::node(|mut ctx: TestContext| async move {
515            ctx.visited.push("start".to_string());
516            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
517        });
518
519        let skip_node = closure::node(|mut ctx: TestContext| async move {
520            ctx.visited.push("skip_check".to_string());
521            if ctx.value > 5 {
522                // Skip this node
523                Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
524            } else {
525                ctx.value *= 2;
526                Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
527            }
528        });
529
530        let end_node = closure::node(|mut ctx: TestContext| async move {
531            ctx.visited.push("end".to_string());
532            ctx.value += 5;
533            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
534        });
535
536        // Build workflow
537        let mut workflow = Workflow::new(start_node);
538        let start_id = workflow.start_node.clone();
539        let skip_id = skip_node.id();
540        let end_id = end_node.id();
541
542        workflow
543            .add_node(skip_node)
544            .add_node(end_node)
545            .set_default_route(&start_id, &skip_id)
546            .set_default_route(&skip_id, &end_id);
547
548        // Execute workflow - should skip the middle node
549        let mut ctx1 = TestContext {
550            value: 10,
551            visited: vec![],
552        };
553
554        let result1 = workflow.execute(&mut ctx1).await;
555        assert!(result1.is_ok());
556        assert_eq!(ctx1.value, 15); // 10 -> skip middle -> +5 = 15
557        assert_eq!(ctx1.visited, vec!["start", "skip_check", "end"]);
558
559        // Execute workflow - should not skip the middle node
560        let mut ctx2 = TestContext {
561            value: 3,
562            visited: vec![],
563        };
564
565        let result2 = workflow.execute(&mut ctx2).await;
566        assert!(result2.is_ok());
567        assert_eq!(ctx2.value, 11); // 3 -> *2 = 6 -> +5 = 11
568        assert_eq!(ctx2.visited, vec!["start", "skip_check", "end"]);
569    }
570
571    #[tokio::test]
572    async fn test_cyclic_workflow() {
573        // Create nodes
574        let start_node = closure::node(|mut ctx: TestContext| async move {
575            ctx.value += 1;
576            ctx.visited.push("start".to_string());
577            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
578        });
579
580        let loop_node = closure::node(|mut ctx: TestContext| async move {
581            ctx.value *= 2;
582            ctx.visited.push("loop".to_string());
583
584            // Continue looping until value > 100
585            if ctx.value <= 100 {
586                Ok((
587                    ctx,
588                    NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Next),
589                ))
590            } else {
591                Ok((
592                    ctx,
593                    NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Error),
594                ))
595            }
596        });
597
598        let end_node = closure::node(|mut ctx: TestContext| async move {
599            ctx.value -= 10;
600            ctx.visited.push("end".to_string());
601            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
602        });
603
604        // Build workflow with cycles allowed
605        let mut workflow = Workflow::new(start_node);
606        let start_id = workflow.start_node.clone();
607        let loop_id = loop_node.id();
608        let end_id = end_node.id();
609
610        workflow
611            .add_node(loop_node)
612            .add_node(end_node)
613            .set_default_route(&start_id, &loop_id)
614            .connect(&loop_id, DefaultAction::Next, &loop_id) // Cycle back to loop_node
615            .connect(&loop_id, DefaultAction::Error, &end_id)
616            .allow_cycles(true) // Enable cycles
617            .set_cycle_limit(10); // Set a reasonable limit
618
619        // Execute workflow
620        let mut ctx = TestContext {
621            value: 3,
622            visited: vec![],
623        };
624
625        let result = workflow.execute(&mut ctx).await;
626        assert!(result.is_ok());
627
628        // Check final state
629        // Initial value: 3
630        // After start: 3 + 1 = 4
631        // Loop iterations:
632        // 1: 4 * 2 = 8
633        // 2: 8 * 2 = 16
634        // 3: 16 * 2 = 32
635        // 4: 32 * 2 = 64
636        // 5: 64 * 2 = 128 (> 100, so route to end)
637        // After end: 128 - 10 = 118
638        assert_eq!(ctx.value, 118);
639
640        // Should have visited start once, loop 5 times, and end once
641        assert_eq!(ctx.visited.len(), 7);
642        assert_eq!(ctx.visited[0], "start");
643        assert_eq!(ctx.visited[1], "loop");
644        assert_eq!(ctx.visited[2], "loop");
645        assert_eq!(ctx.visited[3], "loop");
646        assert_eq!(ctx.visited[4], "loop");
647        assert_eq!(ctx.visited[5], "loop");
648        assert_eq!(ctx.visited[6], "end");
649
650        // Test with cycles disabled - create new nodes
651        let start_node2 = closure::node(|mut ctx: TestContext| async move {
652            ctx.value += 1;
653            ctx.visited.push("start".to_string());
654            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
655        });
656
657        let loop_node2 = closure::node(|mut ctx: TestContext| async move {
658            ctx.value *= 2;
659            ctx.visited.push("loop".to_string());
660
661            // Continue looping until value > 100
662            if ctx.value <= 100 {
663                Ok((
664                    ctx,
665                    NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Next),
666                ))
667            } else {
668                Ok((
669                    ctx,
670                    NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Error),
671                ))
672            }
673        });
674
675        let end_node2 = closure::node(|mut ctx: TestContext| async move {
676            ctx.value -= 10;
677            ctx.visited.push("end".to_string());
678            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
679        });
680
681        // Build workflow with cycles disabled
682        let mut workflow2 = Workflow::new(start_node2);
683        let start_id2 = workflow2.start_node.clone();
684        let loop_id2 = loop_node2.id();
685        let end_id2 = end_node2.id();
686
687        workflow2
688            .add_node(loop_node2)
689            .add_node(end_node2)
690            .set_default_route(&start_id2, &loop_id2)
691            .connect(&loop_id2, DefaultAction::Next, &loop_id2) // Cycle back to loop_node
692            .connect(&loop_id2, DefaultAction::Error, &end_id2)
693            .allow_cycles(false); // Disable cycles
694
695        let mut ctx2 = TestContext {
696            value: 3,
697            visited: vec![],
698        };
699
700        let result2 = workflow2.execute(&mut ctx2).await;
701        assert!(result2.is_err());
702
703        // Should have detected a cycle after the first loop iteration
704        match result2 {
705            Err(WorkflowError::NodeExecution(FloxideError::WorkflowCycleDetected)) => {
706                // This is the expected error
707            }
708            _ => panic!("Expected WorkflowCycleDetected error, got {:?}", result2),
709        }
710
711        // Should have visited start once and loop once
712        assert_eq!(ctx2.visited.len(), 2);
713        assert_eq!(ctx2.visited[0], "start");
714        assert_eq!(ctx2.visited[1], "loop");
715    }
716}