floxide_core/
workflow.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use tracing::{debug, error, info, warn};
6
7use crate::action::{ActionType, DefaultAction};
8use crate::error::FloxideError;
9use crate::node::{Node, NodeId, NodeOutcome};
10
11/// Workflow execution error
12#[derive(Debug, thiserror::Error)]
13pub enum WorkflowError {
14    /// Initial node not found
15    #[error("Initial node not found: {0}")]
16    InitialNodeNotFound(NodeId),
17
18    /// Node not found
19    #[error("Node not found: {0}")]
20    NodeNotFound(NodeId),
21
22    /// Action not handled
23    #[error("Action not handled: {0}")]
24    ActionNotHandled(String),
25
26    /// Node execution error
27    #[error("Node execution error: {0}")]
28    NodeExecution(#[from] FloxideError),
29}
30
31/// A workflow composed of connected nodes
32pub struct Workflow<Context, A = DefaultAction, Output = ()>
33where
34    A: ActionType,
35{
36    /// Initial node to start workflow execution from
37    start_node: NodeId,
38
39    /// All nodes in the workflow
40    pub(crate) nodes: HashMap<NodeId, Arc<dyn Node<Context, A, Output = Output>>>,
41
42    /// Connections between nodes and actions
43    edges: HashMap<(NodeId, A), NodeId>,
44
45    /// Default fallback routes for nodes
46    default_routes: HashMap<NodeId, NodeId>,
47}
48
49impl<Context, A, Output> Workflow<Context, A, Output>
50where
51    Context: Send + Sync + 'static,
52    A: ActionType + Debug + Default + Clone + Send + Sync + 'static,
53    Output: Send + Sync + 'static + std::fmt::Debug,
54{
55    /// Create a new workflow with the given start node
56    pub fn new<N>(start_node: N) -> Self
57    where
58        N: Node<Context, A, Output = Output> + 'static,
59    {
60        let id = start_node.id();
61        let mut nodes = HashMap::new();
62        nodes.insert(
63            id.clone(),
64            Arc::new(start_node) as Arc<dyn Node<Context, A, Output = Output>>,
65        );
66
67        Self {
68            start_node: id,
69            nodes,
70            edges: HashMap::new(),
71            default_routes: HashMap::new(),
72        }
73    }
74
75    /// Add a node to the workflow
76    pub fn add_node<N>(&mut self, node: N) -> &mut Self
77    where
78        N: Node<Context, A, Output = Output> + 'static,
79    {
80        let id = node.id();
81        self.nodes.insert(
82            id,
83            Arc::new(node) as Arc<dyn Node<Context, A, Output = Output>>,
84        );
85        self
86    }
87
88    /// Connect a node to another node with an action
89    pub fn connect(&mut self, from: &NodeId, action: A, to: &NodeId) -> &mut Self {
90        self.edges.insert((from.clone(), action), to.clone());
91        self
92    }
93
94    /// Set a default route from one node to another (used when no specific action matches)
95    pub fn set_default_route(&mut self, from: &NodeId, to: &NodeId) -> &mut Self {
96        self.default_routes.insert(from.clone(), to.clone());
97        self
98    }
99
100    /// Get a node by its ID
101    pub fn get_node(&self, id: NodeId) -> Option<&dyn Node<Context, A, Output = Output>> {
102        self.nodes.get(&id).map(|node| node.as_ref())
103    }
104
105    /// Execute the workflow with the given context
106    pub async fn execute(&self, ctx: &mut Context) -> Result<Output, WorkflowError> {
107        let mut current_node_id = self.start_node.clone();
108        let mut visited = HashSet::new();
109
110        info!(start_node = %current_node_id, "Starting workflow execution");
111        eprintln!("Starting workflow execution from node: {}", current_node_id);
112
113        // Debug info for connections
114        eprintln!("Node connections:");
115        for ((from, action), to) in &self.edges {
116            eprintln!("  {} -[{:?}]-> {}", from, action, to);
117        }
118
119        eprintln!("Default routes:");
120        for (from, to) in &self.default_routes {
121            eprintln!("  {} -> {}", from, to);
122        }
123
124        while !visited.contains(&current_node_id) {
125            let node = self.nodes.get(&current_node_id).ok_or_else(|| {
126                error!(node_id = %current_node_id, "Node not found in workflow");
127                WorkflowError::NodeNotFound(current_node_id.clone())
128            })?;
129
130            visited.insert(current_node_id.clone());
131            debug!(node_id = %current_node_id, "Executing node");
132
133            let outcome = node
134                .process(ctx)
135                .await
136                .map_err(WorkflowError::NodeExecution)?;
137
138            match &outcome {
139                NodeOutcome::Success(_) => {
140                    info!(node_id = %current_node_id, "Node completed successfully with Success outcome");
141                    eprintln!("Node {} completed with Success outcome", current_node_id);
142                }
143                NodeOutcome::Skipped => {
144                    info!(node_id = %current_node_id, "Node completed with Skipped outcome");
145                    eprintln!("Node {} completed with Skipped outcome", current_node_id);
146                }
147                NodeOutcome::RouteToAction(action) => {
148                    info!(node_id = %current_node_id, action = %action.name(), "Node completed with RouteToAction outcome");
149                    eprintln!(
150                        "Node {} completed with RouteToAction({:?}) outcome",
151                        current_node_id, action
152                    );
153                }
154            }
155
156            match outcome {
157                NodeOutcome::Success(output) => {
158                    info!(node_id = %current_node_id, "Node completed successfully");
159                    // Find the default route if there is one
160                    if let Some(next) = self.default_routes.get(&current_node_id) {
161                        debug!(
162                            node_id = %current_node_id,
163                            next_node = %next,
164                            "Following default route"
165                        );
166                        current_node_id = next.clone();
167                    } else {
168                        debug!(node_id = %current_node_id, "Workflow execution completed");
169                        return Ok(output);
170                    }
171                }
172                NodeOutcome::Skipped => {
173                    warn!(node_id = %current_node_id, "Node was skipped");
174                    // Find the default route if there is one
175                    if let Some(next) = self.default_routes.get(&current_node_id) {
176                        debug!(
177                            node_id = %current_node_id,
178                            next_node = %next,
179                            "Following default route after skip"
180                        );
181                        current_node_id = next.clone();
182                    } else {
183                        warn!(node_id = %current_node_id, "Node was skipped but no default route exists");
184                        return Err(WorkflowError::ActionNotHandled(
185                            "Skipped node without default route".into(),
186                        ));
187                    }
188                }
189                NodeOutcome::RouteToAction(action) => {
190                    debug!(
191                        node_id = %current_node_id,
192                        action = ?action,
193                        "Node routed to action"
194                    );
195
196                    // Look for an edge matching this action
197                    if let Some(next) = self.edges.get(&(current_node_id.clone(), action.clone())) {
198                        debug!(
199                            node_id = %current_node_id,
200                            action = ?action,
201                            next_node = %next,
202                            "Following edge for action"
203                        );
204                        current_node_id = next.clone();
205                    }
206                    // If no matching edge, try the default action if this wasn't already the default action
207                    else if action != A::default() {
208                        if let Some(next) = self.edges.get(&(current_node_id.clone(), A::default()))
209                        {
210                            debug!(
211                                node_id = %current_node_id,
212                                next_node = %next,
213                                "No edge for action, following default action"
214                            );
215                            current_node_id = next.clone();
216                        } else if let Some(next) = self.default_routes.get(&current_node_id) {
217                            debug!(
218                                node_id = %current_node_id,
219                                next_node = %next,
220                                "No edge for action or default action, following default route"
221                            );
222                            current_node_id = next.clone();
223                        } else {
224                            error!(
225                                node_id = %current_node_id,
226                                action = ?action,
227                                "No edge found for action and no default route"
228                            );
229
230                            // Debug info for connections
231                            error!(
232                                "Available edges: {:?}",
233                                self.edges
234                                    .iter()
235                                    .map(|((from, act), to)| format!(
236                                        "{} -[{:?}]-> {}",
237                                        from, act, to
238                                    ))
239                                    .collect::<Vec<_>>()
240                            );
241                            error!(
242                                "Default routes: {:?}",
243                                self.default_routes
244                                    .iter()
245                                    .map(|(from, to)| format!("{} -> {}", from, to))
246                                    .collect::<Vec<_>>()
247                            );
248
249                            return Err(WorkflowError::ActionNotHandled(format!("{:?}", action)));
250                        }
251                    } else if let Some(next) = self.default_routes.get(&current_node_id) {
252                        debug!(
253                            node_id = %current_node_id,
254                            next_node = %next,
255                            "No edge for default action, following default route"
256                        );
257                        current_node_id = next.clone();
258                    } else {
259                        error!(
260                            node_id = %current_node_id,
261                            action = ?action,
262                            "No edge found for default action and no default route"
263                        );
264
265                        // Debug info for connections
266                        error!(
267                            "Available edges: {:?}",
268                            self.edges
269                                .iter()
270                                .map(|((from, act), to)| format!("{} -[{:?}]-> {}", from, act, to))
271                                .collect::<Vec<_>>()
272                        );
273                        error!(
274                            "Default routes: {:?}",
275                            self.default_routes
276                                .iter()
277                                .map(|(from, to)| format!("{} -> {}", from, to))
278                                .collect::<Vec<_>>()
279                        );
280
281                        return Err(WorkflowError::ActionNotHandled(
282                            "Default action not handled".into(),
283                        ));
284                    }
285                }
286            }
287        }
288
289        // If we get here, we've detected a cycle
290        error!(
291            node_id = %current_node_id,
292            "Cycle detected in workflow execution"
293        );
294        Err(WorkflowError::NodeExecution(
295            FloxideError::WorkflowCycleDetected,
296        ))
297    }
298}
299
300// Implement Clone for Workflow
301impl<Context, A, Output> Clone for Workflow<Context, A, Output>
302where
303    Context: Send + Sync + 'static,
304    A: ActionType + Clone + Send + Sync + 'static,
305    Output: Send + Sync + 'static,
306{
307    fn clone(&self) -> Self {
308        Self {
309            start_node: self.start_node.clone(),
310            nodes: self.nodes.clone(), // Cloning HashMap<NodeId, Arc<dyn Node>> is now possible
311            edges: self.edges.clone(),
312            default_routes: self.default_routes.clone(),
313        }
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::node::closure;
321
322    #[derive(Debug, Clone)]
323    struct TestContext {
324        value: i32,
325        visited: Vec<String>,
326    }
327
328    #[tokio::test]
329    async fn test_simple_linear_workflow() {
330        // Create nodes
331        let start_node = closure::node(|mut ctx: TestContext| async move {
332            ctx.value += 1;
333            ctx.visited.push("start".to_string());
334            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
335        });
336
337        let middle_node = closure::node(|mut ctx: TestContext| async move {
338            ctx.value *= 2;
339            ctx.visited.push("middle".to_string());
340            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
341        });
342
343        let end_node = closure::node(|mut ctx: TestContext| async move {
344            ctx.value -= 3;
345            ctx.visited.push("end".to_string());
346            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
347        });
348
349        // Build workflow
350        let mut workflow = Workflow::new(start_node);
351        let start_id = workflow.start_node.clone();
352        let middle_id = middle_node.id();
353        let end_id = end_node.id();
354
355        workflow
356            .add_node(middle_node)
357            .add_node(end_node)
358            .set_default_route(&start_id, &middle_id)
359            .set_default_route(&middle_id, &end_id);
360
361        // Execute workflow
362        let mut ctx = TestContext {
363            value: 10,
364            visited: vec![],
365        };
366
367        let result = workflow.execute(&mut ctx).await;
368        assert!(result.is_ok());
369
370        // Check final state
371        assert_eq!(ctx.value, 19); // 10 + 1 = 11 -> 11 * 2 = 22 -> 22 - 3 = 19
372        assert_eq!(ctx.visited, vec!["start", "middle", "end"]);
373    }
374
375    #[tokio::test]
376    async fn test_workflow_with_routing() {
377        // Define custom action
378        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
379        enum TestAction {
380            Default,
381            Route1,
382            Route2,
383        }
384
385        impl Default for TestAction {
386            fn default() -> Self {
387                Self::Default
388            }
389        }
390
391        impl ActionType for TestAction {
392            fn name(&self) -> &str {
393                match self {
394                    Self::Default => "default",
395                    Self::Route1 => "route1",
396                    Self::Route2 => "route2",
397                }
398            }
399        }
400
401        // Create nodes
402        let start_node = closure::node(|mut ctx: TestContext| async move {
403            ctx.visited.push("start".to_string());
404            // Route based on value
405            if ctx.value > 5 {
406                Ok((
407                    ctx,
408                    NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route1),
409                ))
410            } else {
411                Ok((
412                    ctx,
413                    NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route2),
414                ))
415            }
416        });
417
418        let path1_node = closure::node(|mut ctx: TestContext| async move {
419            ctx.value += 100;
420            ctx.visited.push("path1".to_string());
421            Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
422        });
423
424        let path2_node = closure::node(|mut ctx: TestContext| async move {
425            ctx.value *= 10;
426            ctx.visited.push("path2".to_string());
427            Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
428        });
429
430        // Build workflow
431        let mut workflow = Workflow::<_, TestAction, _>::new(start_node);
432        let start_id = workflow.start_node.clone();
433        let path1_id = path1_node.id();
434        let path2_id = path2_node.id();
435
436        workflow
437            .add_node(path1_node)
438            .add_node(path2_node)
439            .connect(&start_id, TestAction::Route1, &path1_id)
440            .connect(&start_id, TestAction::Route2, &path2_id);
441
442        // Execute workflow - should take path 1
443        let mut ctx1 = TestContext {
444            value: 10,
445            visited: vec![],
446        };
447
448        let result1 = workflow.execute(&mut ctx1).await;
449        assert!(result1.is_ok());
450        assert_eq!(ctx1.value, 110); // 10 -> route1 -> +100 = 110
451        assert_eq!(ctx1.visited, vec!["start", "path1"]);
452
453        // Execute workflow - should take path 2
454        let mut ctx2 = TestContext {
455            value: 3,
456            visited: vec![],
457        };
458
459        let result2 = workflow.execute(&mut ctx2).await;
460        assert!(result2.is_ok());
461        assert_eq!(ctx2.value, 30); // 3 -> route2 -> *10 = 30
462        assert_eq!(ctx2.visited, vec!["start", "path2"]);
463    }
464
465    #[tokio::test]
466    async fn test_workflow_with_skipped_node() {
467        let start_node = closure::node(|mut ctx: TestContext| async move {
468            ctx.visited.push("start".to_string());
469            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
470        });
471
472        let skip_node = closure::node(|mut ctx: TestContext| async move {
473            ctx.visited.push("skip_check".to_string());
474            if ctx.value > 5 {
475                // Skip this node
476                Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
477            } else {
478                ctx.value *= 2;
479                Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
480            }
481        });
482
483        let end_node = closure::node(|mut ctx: TestContext| async move {
484            ctx.visited.push("end".to_string());
485            ctx.value += 5;
486            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
487        });
488
489        // Build workflow
490        let mut workflow = Workflow::new(start_node);
491        let start_id = workflow.start_node.clone();
492        let skip_id = skip_node.id();
493        let end_id = end_node.id();
494
495        workflow
496            .add_node(skip_node)
497            .add_node(end_node)
498            .set_default_route(&start_id, &skip_id)
499            .set_default_route(&skip_id, &end_id);
500
501        // Execute workflow - should skip the middle node
502        let mut ctx1 = TestContext {
503            value: 10,
504            visited: vec![],
505        };
506
507        let result1 = workflow.execute(&mut ctx1).await;
508        assert!(result1.is_ok());
509        assert_eq!(ctx1.value, 15); // 10 -> skip middle -> +5 = 15
510        assert_eq!(ctx1.visited, vec!["start", "skip_check", "end"]);
511
512        // Execute workflow - should not skip the middle node
513        let mut ctx2 = TestContext {
514            value: 3,
515            visited: vec![],
516        };
517
518        let result2 = workflow.execute(&mut ctx2).await;
519        assert!(result2.is_ok());
520        assert_eq!(ctx2.value, 11); // 3 -> *2 = 6 -> +5 = 11
521        assert_eq!(ctx2.visited, vec!["start", "skip_check", "end"]);
522    }
523}