1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use tracing::{debug, error, info, warn};
6
7use crate::action::{ActionType, DefaultAction};
8use crate::error::FloxideError;
9use crate::node::{Node, NodeId, NodeOutcome};
10
11#[derive(Debug, thiserror::Error)]
13pub enum WorkflowError {
14 #[error("Initial node not found: {0}")]
16 InitialNodeNotFound(NodeId),
17
18 #[error("Node not found: {0}")]
20 NodeNotFound(NodeId),
21
22 #[error("Action not handled: {0}")]
24 ActionNotHandled(String),
25
26 #[error("Node execution error: {0}")]
28 NodeExecution(#[from] FloxideError),
29}
30
31pub struct Workflow<Context, A = DefaultAction, Output = ()>
33where
34 A: ActionType,
35{
36 start_node: NodeId,
38
39 pub(crate) nodes: HashMap<NodeId, Arc<dyn Node<Context, A, Output = Output>>>,
41
42 edges: HashMap<(NodeId, A), NodeId>,
44
45 default_routes: HashMap<NodeId, NodeId>,
47
48 allow_cycles: bool,
50
51 cycle_limit: usize,
53}
54
55impl<Context, A, Output> Workflow<Context, A, Output>
56where
57 Context: Send + Sync + 'static,
58 A: ActionType + Debug + Default + Clone + Send + Sync + 'static,
59 Output: Send + Sync + 'static + std::fmt::Debug,
60{
61 pub fn new<N>(start_node: N) -> Self
63 where
64 N: Node<Context, A, Output = Output> + 'static,
65 {
66 let id = start_node.id();
67 let mut nodes = HashMap::new();
68 nodes.insert(
69 id.clone(),
70 Arc::new(start_node) as Arc<dyn Node<Context, A, Output = Output>>,
71 );
72
73 Self {
74 start_node: id,
75 nodes,
76 edges: HashMap::new(),
77 default_routes: HashMap::new(),
78 allow_cycles: false,
79 cycle_limit: 0,
80 }
81 }
82
83 pub fn add_node<N>(&mut self, node: N) -> &mut Self
85 where
86 N: Node<Context, A, Output = Output> + 'static,
87 {
88 let id = node.id();
89 self.nodes.insert(
90 id,
91 Arc::new(node) as Arc<dyn Node<Context, A, Output = Output>>,
92 );
93 self
94 }
95
96 pub fn connect(&mut self, from: &NodeId, action: A, to: &NodeId) -> &mut Self {
98 self.edges.insert((from.clone(), action), to.clone());
99 self
100 }
101
102 pub fn set_default_route(&mut self, from: &NodeId, to: &NodeId) -> &mut Self {
104 self.default_routes.insert(from.clone(), to.clone());
105 self
106 }
107
108 pub fn get_node(&self, id: NodeId) -> Option<&dyn Node<Context, A, Output = Output>> {
110 self.nodes.get(&id).map(|node| node.as_ref())
111 }
112
113 pub fn allow_cycles(&mut self, allow: bool) -> &mut Self {
121 self.allow_cycles = allow;
122 self
123 }
124
125 pub fn set_cycle_limit(&mut self, limit: usize) -> &mut Self {
133 self.cycle_limit = limit;
134 self
135 }
136
137 pub async fn execute(&self, ctx: &mut Context) -> Result<Output, WorkflowError> {
139 let mut current_node_id = self.start_node.clone();
140 let mut visit_counts = HashMap::new();
141
142 info!(start_node = %current_node_id, "Starting workflow execution");
143 debug!(node = %current_node_id, "Starting workflow execution from node");
144
145 debug!("Node connections:");
147 for ((from, action), to) in &self.edges {
148 debug!(from = %from, action = ?action, to = %to, "Connection");
149 }
150
151 debug!("Default routes:");
152 for (from, to) in &self.default_routes {
153 debug!(from = %from, to = %to, "Default route");
154 }
155
156 loop {
157 let visit_count = visit_counts.entry(current_node_id.clone()).or_insert(0);
159 *visit_count += 1;
160
161 if !self.allow_cycles && *visit_count > 1 {
163 error!(
165 node_id = %current_node_id,
166 "Cycle detected in workflow execution"
167 );
168 return Err(WorkflowError::NodeExecution(
169 FloxideError::WorkflowCycleDetected,
170 ));
171 }
172
173 if self.cycle_limit > 0 && *visit_count > self.cycle_limit {
175 error!(
176 node_id = %current_node_id,
177 visit_count = %visit_count,
178 limit = %self.cycle_limit,
179 "Cycle limit exceeded in workflow execution"
180 );
181 return Err(WorkflowError::NodeExecution(
182 FloxideError::WorkflowCycleDetected,
183 ));
184 }
185
186 let node = self.nodes.get(¤t_node_id).ok_or_else(|| {
187 error!(node_id = %current_node_id, "Node not found in workflow");
188 WorkflowError::NodeNotFound(current_node_id.clone())
189 })?;
190
191 debug!(node_id = %current_node_id, visit_count = %visit_count, "Executing node");
192
193 let outcome = node
194 .process(ctx)
195 .await
196 .map_err(WorkflowError::NodeExecution)?;
197
198 match &outcome {
199 NodeOutcome::Success(_) => {
200 info!(node_id = %current_node_id, "Node completed successfully with Success outcome");
201 }
202 NodeOutcome::Skipped => {
203 info!(node_id = %current_node_id, "Node completed with Skipped outcome");
204 }
205 NodeOutcome::RouteToAction(action) => {
206 info!(node_id = %current_node_id, action = %action.name(), action_debug = ?action, "Node completed with RouteToAction outcome");
207 }
208 }
209
210 match outcome {
211 NodeOutcome::Success(output) => {
212 info!(node_id = %current_node_id, "Node completed successfully");
213 if let Some(next) = self.default_routes.get(¤t_node_id) {
215 debug!(
216 node_id = %current_node_id,
217 next_node = %next,
218 "Following default route"
219 );
220 current_node_id = next.clone();
221 } else {
222 debug!(node_id = %current_node_id, "Workflow execution completed");
223 return Ok(output);
224 }
225 }
226 NodeOutcome::Skipped => {
227 warn!(node_id = %current_node_id, "Node was skipped");
228 if let Some(next) = self.default_routes.get(¤t_node_id) {
230 debug!(
231 node_id = %current_node_id,
232 next_node = %next,
233 "Following default route after skip"
234 );
235 current_node_id = next.clone();
236 } else {
237 warn!(node_id = %current_node_id, "Node was skipped but no default route exists");
238 return Err(WorkflowError::ActionNotHandled(
239 "Skipped node without default route".into(),
240 ));
241 }
242 }
243 NodeOutcome::RouteToAction(action) => {
244 debug!(
245 node_id = %current_node_id,
246 action = ?action,
247 "Node routed to action"
248 );
249
250 if let Some(next) = self.edges.get(&(current_node_id.clone(), action.clone())) {
252 debug!(
253 node_id = %current_node_id,
254 action = ?action,
255 next_node = %next,
256 "Following edge for action"
257 );
258 current_node_id = next.clone();
259 }
260 else if action != A::default() {
262 if let Some(next) = self.edges.get(&(current_node_id.clone(), A::default()))
263 {
264 debug!(
265 node_id = %current_node_id,
266 next_node = %next,
267 "No edge for action, following default action"
268 );
269 current_node_id = next.clone();
270 } else if let Some(next) = self.default_routes.get(¤t_node_id) {
271 debug!(
272 node_id = %current_node_id,
273 next_node = %next,
274 "No edge for action or default action, following default route"
275 );
276 current_node_id = next.clone();
277 } else {
278 error!(
279 node_id = %current_node_id,
280 action = ?action,
281 "No edge found for action and no default route"
282 );
283
284 error!(
286 "Available edges: {:?}",
287 self.edges
288 .iter()
289 .map(|((from, act), to)| format!(
290 "{} -[{:?}]-> {}",
291 from, act, to
292 ))
293 .collect::<Vec<_>>()
294 );
295 error!(
296 "Default routes: {:?}",
297 self.default_routes
298 .iter()
299 .map(|(from, to)| format!("{} -> {}", from, to))
300 .collect::<Vec<_>>()
301 );
302
303 return Err(WorkflowError::ActionNotHandled(format!("{:?}", action)));
304 }
305 } else if let Some(next) = self.default_routes.get(¤t_node_id) {
306 debug!(
307 node_id = %current_node_id,
308 next_node = %next,
309 "No edge for default action, following default route"
310 );
311 current_node_id = next.clone();
312 } else {
313 error!(
314 node_id = %current_node_id,
315 action = ?action,
316 "No edge found for default action and no default route"
317 );
318
319 error!(
321 "Available edges: {:?}",
322 self.edges
323 .iter()
324 .map(|((from, act), to)| format!("{} -[{:?}]-> {}", from, act, to))
325 .collect::<Vec<_>>()
326 );
327 error!(
328 "Default routes: {:?}",
329 self.default_routes
330 .iter()
331 .map(|(from, to)| format!("{} -> {}", from, to))
332 .collect::<Vec<_>>()
333 );
334
335 return Err(WorkflowError::ActionNotHandled(
336 "Default action not handled".into(),
337 ));
338 }
339 }
340 }
341 }
342 }
343}
344
345impl<Context, A, Output> Clone for Workflow<Context, A, Output>
347where
348 Context: Send + Sync + 'static,
349 A: ActionType + Clone + Send + Sync + 'static,
350 Output: Send + Sync + 'static,
351{
352 fn clone(&self) -> Self {
353 Self {
354 start_node: self.start_node.clone(),
355 nodes: self.nodes.clone(), edges: self.edges.clone(),
357 default_routes: self.default_routes.clone(),
358 allow_cycles: self.allow_cycles,
359 cycle_limit: self.cycle_limit,
360 }
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::node::closure;
368
369 #[derive(Debug, Clone)]
370 struct TestContext {
371 value: i32,
372 visited: Vec<String>,
373 }
374
375 #[tokio::test]
376 async fn test_simple_linear_workflow() {
377 let start_node = closure::node(|mut ctx: TestContext| async move {
379 ctx.value += 1;
380 ctx.visited.push("start".to_string());
381 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
382 });
383
384 let middle_node = closure::node(|mut ctx: TestContext| async move {
385 ctx.value *= 2;
386 ctx.visited.push("middle".to_string());
387 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
388 });
389
390 let end_node = closure::node(|mut ctx: TestContext| async move {
391 ctx.value -= 3;
392 ctx.visited.push("end".to_string());
393 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
394 });
395
396 let mut workflow = Workflow::new(start_node);
398 let start_id = workflow.start_node.clone();
399 let middle_id = middle_node.id();
400 let end_id = end_node.id();
401
402 workflow
403 .add_node(middle_node)
404 .add_node(end_node)
405 .set_default_route(&start_id, &middle_id)
406 .set_default_route(&middle_id, &end_id);
407
408 let mut ctx = TestContext {
410 value: 10,
411 visited: vec![],
412 };
413
414 let result = workflow.execute(&mut ctx).await;
415 assert!(result.is_ok());
416
417 assert_eq!(ctx.value, 19); assert_eq!(ctx.visited, vec!["start", "middle", "end"]);
420 }
421
422 #[tokio::test]
423 async fn test_workflow_with_routing() {
424 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
426 enum TestAction {
427 Default,
428 Route1,
429 Route2,
430 }
431
432 impl Default for TestAction {
433 fn default() -> Self {
434 Self::Default
435 }
436 }
437
438 impl ActionType for TestAction {
439 fn name(&self) -> &str {
440 match self {
441 Self::Default => "default",
442 Self::Route1 => "route1",
443 Self::Route2 => "route2",
444 }
445 }
446 }
447
448 let start_node = closure::node(|mut ctx: TestContext| async move {
450 ctx.visited.push("start".to_string());
451 if ctx.value > 5 {
453 Ok((
454 ctx,
455 NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route1),
456 ))
457 } else {
458 Ok((
459 ctx,
460 NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route2),
461 ))
462 }
463 });
464
465 let path1_node = closure::node(|mut ctx: TestContext| async move {
466 ctx.value += 100;
467 ctx.visited.push("path1".to_string());
468 Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
469 });
470
471 let path2_node = closure::node(|mut ctx: TestContext| async move {
472 ctx.value *= 10;
473 ctx.visited.push("path2".to_string());
474 Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
475 });
476
477 let mut workflow = Workflow::<_, TestAction, _>::new(start_node);
479 let start_id = workflow.start_node.clone();
480 let path1_id = path1_node.id();
481 let path2_id = path2_node.id();
482
483 workflow
484 .add_node(path1_node)
485 .add_node(path2_node)
486 .connect(&start_id, TestAction::Route1, &path1_id)
487 .connect(&start_id, TestAction::Route2, &path2_id);
488
489 let mut ctx1 = TestContext {
491 value: 10,
492 visited: vec![],
493 };
494
495 let result1 = workflow.execute(&mut ctx1).await;
496 assert!(result1.is_ok());
497 assert_eq!(ctx1.value, 110); assert_eq!(ctx1.visited, vec!["start", "path1"]);
499
500 let mut ctx2 = TestContext {
502 value: 3,
503 visited: vec![],
504 };
505
506 let result2 = workflow.execute(&mut ctx2).await;
507 assert!(result2.is_ok());
508 assert_eq!(ctx2.value, 30); assert_eq!(ctx2.visited, vec!["start", "path2"]);
510 }
511
512 #[tokio::test]
513 async fn test_workflow_with_skipped_node() {
514 let start_node = closure::node(|mut ctx: TestContext| async move {
515 ctx.visited.push("start".to_string());
516 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
517 });
518
519 let skip_node = closure::node(|mut ctx: TestContext| async move {
520 ctx.visited.push("skip_check".to_string());
521 if ctx.value > 5 {
522 Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
524 } else {
525 ctx.value *= 2;
526 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
527 }
528 });
529
530 let end_node = closure::node(|mut ctx: TestContext| async move {
531 ctx.visited.push("end".to_string());
532 ctx.value += 5;
533 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
534 });
535
536 let mut workflow = Workflow::new(start_node);
538 let start_id = workflow.start_node.clone();
539 let skip_id = skip_node.id();
540 let end_id = end_node.id();
541
542 workflow
543 .add_node(skip_node)
544 .add_node(end_node)
545 .set_default_route(&start_id, &skip_id)
546 .set_default_route(&skip_id, &end_id);
547
548 let mut ctx1 = TestContext {
550 value: 10,
551 visited: vec![],
552 };
553
554 let result1 = workflow.execute(&mut ctx1).await;
555 assert!(result1.is_ok());
556 assert_eq!(ctx1.value, 15); assert_eq!(ctx1.visited, vec!["start", "skip_check", "end"]);
558
559 let mut ctx2 = TestContext {
561 value: 3,
562 visited: vec![],
563 };
564
565 let result2 = workflow.execute(&mut ctx2).await;
566 assert!(result2.is_ok());
567 assert_eq!(ctx2.value, 11); assert_eq!(ctx2.visited, vec!["start", "skip_check", "end"]);
569 }
570
571 #[tokio::test]
572 async fn test_cyclic_workflow() {
573 let start_node = closure::node(|mut ctx: TestContext| async move {
575 ctx.value += 1;
576 ctx.visited.push("start".to_string());
577 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
578 });
579
580 let loop_node = closure::node(|mut ctx: TestContext| async move {
581 ctx.value *= 2;
582 ctx.visited.push("loop".to_string());
583
584 if ctx.value <= 100 {
586 Ok((
587 ctx,
588 NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Next),
589 ))
590 } else {
591 Ok((
592 ctx,
593 NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Error),
594 ))
595 }
596 });
597
598 let end_node = closure::node(|mut ctx: TestContext| async move {
599 ctx.value -= 10;
600 ctx.visited.push("end".to_string());
601 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
602 });
603
604 let mut workflow = Workflow::new(start_node);
606 let start_id = workflow.start_node.clone();
607 let loop_id = loop_node.id();
608 let end_id = end_node.id();
609
610 workflow
611 .add_node(loop_node)
612 .add_node(end_node)
613 .set_default_route(&start_id, &loop_id)
614 .connect(&loop_id, DefaultAction::Next, &loop_id) .connect(&loop_id, DefaultAction::Error, &end_id)
616 .allow_cycles(true) .set_cycle_limit(10); let mut ctx = TestContext {
621 value: 3,
622 visited: vec![],
623 };
624
625 let result = workflow.execute(&mut ctx).await;
626 assert!(result.is_ok());
627
628 assert_eq!(ctx.value, 118);
639
640 assert_eq!(ctx.visited.len(), 7);
642 assert_eq!(ctx.visited[0], "start");
643 assert_eq!(ctx.visited[1], "loop");
644 assert_eq!(ctx.visited[2], "loop");
645 assert_eq!(ctx.visited[3], "loop");
646 assert_eq!(ctx.visited[4], "loop");
647 assert_eq!(ctx.visited[5], "loop");
648 assert_eq!(ctx.visited[6], "end");
649
650 let start_node2 = closure::node(|mut ctx: TestContext| async move {
652 ctx.value += 1;
653 ctx.visited.push("start".to_string());
654 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
655 });
656
657 let loop_node2 = closure::node(|mut ctx: TestContext| async move {
658 ctx.value *= 2;
659 ctx.visited.push("loop".to_string());
660
661 if ctx.value <= 100 {
663 Ok((
664 ctx,
665 NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Next),
666 ))
667 } else {
668 Ok((
669 ctx,
670 NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Error),
671 ))
672 }
673 });
674
675 let end_node2 = closure::node(|mut ctx: TestContext| async move {
676 ctx.value -= 10;
677 ctx.visited.push("end".to_string());
678 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
679 });
680
681 let mut workflow2 = Workflow::new(start_node2);
683 let start_id2 = workflow2.start_node.clone();
684 let loop_id2 = loop_node2.id();
685 let end_id2 = end_node2.id();
686
687 workflow2
688 .add_node(loop_node2)
689 .add_node(end_node2)
690 .set_default_route(&start_id2, &loop_id2)
691 .connect(&loop_id2, DefaultAction::Next, &loop_id2) .connect(&loop_id2, DefaultAction::Error, &end_id2)
693 .allow_cycles(false); let mut ctx2 = TestContext {
696 value: 3,
697 visited: vec![],
698 };
699
700 let result2 = workflow2.execute(&mut ctx2).await;
701 assert!(result2.is_err());
702
703 match result2 {
705 Err(WorkflowError::NodeExecution(FloxideError::WorkflowCycleDetected)) => {
706 }
708 _ => panic!("Expected WorkflowCycleDetected error, got {:?}", result2),
709 }
710
711 assert_eq!(ctx2.visited.len(), 2);
713 assert_eq!(ctx2.visited[0], "start");
714 assert_eq!(ctx2.visited[1], "loop");
715 }
716}