1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use tracing::{debug, error, info, warn};
6
7use crate::action::{ActionType, DefaultAction};
8use crate::error::FloxideError;
9use crate::node::{Node, NodeId, NodeOutcome};
10
11#[derive(Debug, thiserror::Error)]
13pub enum WorkflowError {
14 #[error("Initial node not found: {0}")]
16 InitialNodeNotFound(NodeId),
17
18 #[error("Node not found: {0}")]
20 NodeNotFound(NodeId),
21
22 #[error("Action not handled: {0}")]
24 ActionNotHandled(String),
25
26 #[error("Node execution error: {0}")]
28 NodeExecution(#[from] FloxideError),
29}
30
31pub struct Workflow<Context, A = DefaultAction, Output = ()>
33where
34 A: ActionType,
35{
36 start_node: NodeId,
38
39 pub(crate) nodes: HashMap<NodeId, Arc<dyn Node<Context, A, Output = Output>>>,
41
42 edges: HashMap<(NodeId, A), NodeId>,
44
45 default_routes: HashMap<NodeId, NodeId>,
47}
48
49impl<Context, A, Output> Workflow<Context, A, Output>
50where
51 Context: Send + Sync + 'static,
52 A: ActionType + Debug + Default + Clone + Send + Sync + 'static,
53 Output: Send + Sync + 'static + std::fmt::Debug,
54{
55 pub fn new<N>(start_node: N) -> Self
57 where
58 N: Node<Context, A, Output = Output> + 'static,
59 {
60 let id = start_node.id();
61 let mut nodes = HashMap::new();
62 nodes.insert(
63 id.clone(),
64 Arc::new(start_node) as Arc<dyn Node<Context, A, Output = Output>>,
65 );
66
67 Self {
68 start_node: id,
69 nodes,
70 edges: HashMap::new(),
71 default_routes: HashMap::new(),
72 }
73 }
74
75 pub fn add_node<N>(&mut self, node: N) -> &mut Self
77 where
78 N: Node<Context, A, Output = Output> + 'static,
79 {
80 let id = node.id();
81 self.nodes.insert(
82 id,
83 Arc::new(node) as Arc<dyn Node<Context, A, Output = Output>>,
84 );
85 self
86 }
87
88 pub fn connect(&mut self, from: &NodeId, action: A, to: &NodeId) -> &mut Self {
90 self.edges.insert((from.clone(), action), to.clone());
91 self
92 }
93
94 pub fn set_default_route(&mut self, from: &NodeId, to: &NodeId) -> &mut Self {
96 self.default_routes.insert(from.clone(), to.clone());
97 self
98 }
99
100 pub fn get_node(&self, id: NodeId) -> Option<&dyn Node<Context, A, Output = Output>> {
102 self.nodes.get(&id).map(|node| node.as_ref())
103 }
104
105 pub async fn execute(&self, ctx: &mut Context) -> Result<Output, WorkflowError> {
107 let mut current_node_id = self.start_node.clone();
108 let mut visited = HashSet::new();
109
110 info!(start_node = %current_node_id, "Starting workflow execution");
111 eprintln!("Starting workflow execution from node: {}", current_node_id);
112
113 eprintln!("Node connections:");
115 for ((from, action), to) in &self.edges {
116 eprintln!(" {} -[{:?}]-> {}", from, action, to);
117 }
118
119 eprintln!("Default routes:");
120 for (from, to) in &self.default_routes {
121 eprintln!(" {} -> {}", from, to);
122 }
123
124 while !visited.contains(¤t_node_id) {
125 let node = self.nodes.get(¤t_node_id).ok_or_else(|| {
126 error!(node_id = %current_node_id, "Node not found in workflow");
127 WorkflowError::NodeNotFound(current_node_id.clone())
128 })?;
129
130 visited.insert(current_node_id.clone());
131 debug!(node_id = %current_node_id, "Executing node");
132
133 let outcome = node
134 .process(ctx)
135 .await
136 .map_err(WorkflowError::NodeExecution)?;
137
138 match &outcome {
139 NodeOutcome::Success(_) => {
140 info!(node_id = %current_node_id, "Node completed successfully with Success outcome");
141 eprintln!("Node {} completed with Success outcome", current_node_id);
142 }
143 NodeOutcome::Skipped => {
144 info!(node_id = %current_node_id, "Node completed with Skipped outcome");
145 eprintln!("Node {} completed with Skipped outcome", current_node_id);
146 }
147 NodeOutcome::RouteToAction(action) => {
148 info!(node_id = %current_node_id, action = %action.name(), "Node completed with RouteToAction outcome");
149 eprintln!(
150 "Node {} completed with RouteToAction({:?}) outcome",
151 current_node_id, action
152 );
153 }
154 }
155
156 match outcome {
157 NodeOutcome::Success(output) => {
158 info!(node_id = %current_node_id, "Node completed successfully");
159 if let Some(next) = self.default_routes.get(¤t_node_id) {
161 debug!(
162 node_id = %current_node_id,
163 next_node = %next,
164 "Following default route"
165 );
166 current_node_id = next.clone();
167 } else {
168 debug!(node_id = %current_node_id, "Workflow execution completed");
169 return Ok(output);
170 }
171 }
172 NodeOutcome::Skipped => {
173 warn!(node_id = %current_node_id, "Node was skipped");
174 if let Some(next) = self.default_routes.get(¤t_node_id) {
176 debug!(
177 node_id = %current_node_id,
178 next_node = %next,
179 "Following default route after skip"
180 );
181 current_node_id = next.clone();
182 } else {
183 warn!(node_id = %current_node_id, "Node was skipped but no default route exists");
184 return Err(WorkflowError::ActionNotHandled(
185 "Skipped node without default route".into(),
186 ));
187 }
188 }
189 NodeOutcome::RouteToAction(action) => {
190 debug!(
191 node_id = %current_node_id,
192 action = ?action,
193 "Node routed to action"
194 );
195
196 if let Some(next) = self.edges.get(&(current_node_id.clone(), action.clone())) {
198 debug!(
199 node_id = %current_node_id,
200 action = ?action,
201 next_node = %next,
202 "Following edge for action"
203 );
204 current_node_id = next.clone();
205 }
206 else if action != A::default() {
208 if let Some(next) = self.edges.get(&(current_node_id.clone(), A::default()))
209 {
210 debug!(
211 node_id = %current_node_id,
212 next_node = %next,
213 "No edge for action, following default action"
214 );
215 current_node_id = next.clone();
216 } else if let Some(next) = self.default_routes.get(¤t_node_id) {
217 debug!(
218 node_id = %current_node_id,
219 next_node = %next,
220 "No edge for action or default action, following default route"
221 );
222 current_node_id = next.clone();
223 } else {
224 error!(
225 node_id = %current_node_id,
226 action = ?action,
227 "No edge found for action and no default route"
228 );
229
230 error!(
232 "Available edges: {:?}",
233 self.edges
234 .iter()
235 .map(|((from, act), to)| format!(
236 "{} -[{:?}]-> {}",
237 from, act, to
238 ))
239 .collect::<Vec<_>>()
240 );
241 error!(
242 "Default routes: {:?}",
243 self.default_routes
244 .iter()
245 .map(|(from, to)| format!("{} -> {}", from, to))
246 .collect::<Vec<_>>()
247 );
248
249 return Err(WorkflowError::ActionNotHandled(format!("{:?}", action)));
250 }
251 } else if let Some(next) = self.default_routes.get(¤t_node_id) {
252 debug!(
253 node_id = %current_node_id,
254 next_node = %next,
255 "No edge for default action, following default route"
256 );
257 current_node_id = next.clone();
258 } else {
259 error!(
260 node_id = %current_node_id,
261 action = ?action,
262 "No edge found for default action and no default route"
263 );
264
265 error!(
267 "Available edges: {:?}",
268 self.edges
269 .iter()
270 .map(|((from, act), to)| format!("{} -[{:?}]-> {}", from, act, to))
271 .collect::<Vec<_>>()
272 );
273 error!(
274 "Default routes: {:?}",
275 self.default_routes
276 .iter()
277 .map(|(from, to)| format!("{} -> {}", from, to))
278 .collect::<Vec<_>>()
279 );
280
281 return Err(WorkflowError::ActionNotHandled(
282 "Default action not handled".into(),
283 ));
284 }
285 }
286 }
287 }
288
289 error!(
291 node_id = %current_node_id,
292 "Cycle detected in workflow execution"
293 );
294 Err(WorkflowError::NodeExecution(
295 FloxideError::WorkflowCycleDetected,
296 ))
297 }
298}
299
300impl<Context, A, Output> Clone for Workflow<Context, A, Output>
302where
303 Context: Send + Sync + 'static,
304 A: ActionType + Clone + Send + Sync + 'static,
305 Output: Send + Sync + 'static,
306{
307 fn clone(&self) -> Self {
308 Self {
309 start_node: self.start_node.clone(),
310 nodes: self.nodes.clone(), edges: self.edges.clone(),
312 default_routes: self.default_routes.clone(),
313 }
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::node::closure;
321
322 #[derive(Debug, Clone)]
323 struct TestContext {
324 value: i32,
325 visited: Vec<String>,
326 }
327
328 #[tokio::test]
329 async fn test_simple_linear_workflow() {
330 let start_node = closure::node(|mut ctx: TestContext| async move {
332 ctx.value += 1;
333 ctx.visited.push("start".to_string());
334 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
335 });
336
337 let middle_node = closure::node(|mut ctx: TestContext| async move {
338 ctx.value *= 2;
339 ctx.visited.push("middle".to_string());
340 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
341 });
342
343 let end_node = closure::node(|mut ctx: TestContext| async move {
344 ctx.value -= 3;
345 ctx.visited.push("end".to_string());
346 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
347 });
348
349 let mut workflow = Workflow::new(start_node);
351 let start_id = workflow.start_node.clone();
352 let middle_id = middle_node.id();
353 let end_id = end_node.id();
354
355 workflow
356 .add_node(middle_node)
357 .add_node(end_node)
358 .set_default_route(&start_id, &middle_id)
359 .set_default_route(&middle_id, &end_id);
360
361 let mut ctx = TestContext {
363 value: 10,
364 visited: vec![],
365 };
366
367 let result = workflow.execute(&mut ctx).await;
368 assert!(result.is_ok());
369
370 assert_eq!(ctx.value, 19); assert_eq!(ctx.visited, vec!["start", "middle", "end"]);
373 }
374
375 #[tokio::test]
376 async fn test_workflow_with_routing() {
377 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
379 enum TestAction {
380 Default,
381 Route1,
382 Route2,
383 }
384
385 impl Default for TestAction {
386 fn default() -> Self {
387 Self::Default
388 }
389 }
390
391 impl ActionType for TestAction {
392 fn name(&self) -> &str {
393 match self {
394 Self::Default => "default",
395 Self::Route1 => "route1",
396 Self::Route2 => "route2",
397 }
398 }
399 }
400
401 let start_node = closure::node(|mut ctx: TestContext| async move {
403 ctx.visited.push("start".to_string());
404 if ctx.value > 5 {
406 Ok((
407 ctx,
408 NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route1),
409 ))
410 } else {
411 Ok((
412 ctx,
413 NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route2),
414 ))
415 }
416 });
417
418 let path1_node = closure::node(|mut ctx: TestContext| async move {
419 ctx.value += 100;
420 ctx.visited.push("path1".to_string());
421 Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
422 });
423
424 let path2_node = closure::node(|mut ctx: TestContext| async move {
425 ctx.value *= 10;
426 ctx.visited.push("path2".to_string());
427 Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
428 });
429
430 let mut workflow = Workflow::<_, TestAction, _>::new(start_node);
432 let start_id = workflow.start_node.clone();
433 let path1_id = path1_node.id();
434 let path2_id = path2_node.id();
435
436 workflow
437 .add_node(path1_node)
438 .add_node(path2_node)
439 .connect(&start_id, TestAction::Route1, &path1_id)
440 .connect(&start_id, TestAction::Route2, &path2_id);
441
442 let mut ctx1 = TestContext {
444 value: 10,
445 visited: vec![],
446 };
447
448 let result1 = workflow.execute(&mut ctx1).await;
449 assert!(result1.is_ok());
450 assert_eq!(ctx1.value, 110); assert_eq!(ctx1.visited, vec!["start", "path1"]);
452
453 let mut ctx2 = TestContext {
455 value: 3,
456 visited: vec![],
457 };
458
459 let result2 = workflow.execute(&mut ctx2).await;
460 assert!(result2.is_ok());
461 assert_eq!(ctx2.value, 30); assert_eq!(ctx2.visited, vec!["start", "path2"]);
463 }
464
465 #[tokio::test]
466 async fn test_workflow_with_skipped_node() {
467 let start_node = closure::node(|mut ctx: TestContext| async move {
468 ctx.visited.push("start".to_string());
469 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
470 });
471
472 let skip_node = closure::node(|mut ctx: TestContext| async move {
473 ctx.visited.push("skip_check".to_string());
474 if ctx.value > 5 {
475 Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
477 } else {
478 ctx.value *= 2;
479 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
480 }
481 });
482
483 let end_node = closure::node(|mut ctx: TestContext| async move {
484 ctx.visited.push("end".to_string());
485 ctx.value += 5;
486 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
487 });
488
489 let mut workflow = Workflow::new(start_node);
491 let start_id = workflow.start_node.clone();
492 let skip_id = skip_node.id();
493 let end_id = end_node.id();
494
495 workflow
496 .add_node(skip_node)
497 .add_node(end_node)
498 .set_default_route(&start_id, &skip_id)
499 .set_default_route(&skip_id, &end_id);
500
501 let mut ctx1 = TestContext {
503 value: 10,
504 visited: vec![],
505 };
506
507 let result1 = workflow.execute(&mut ctx1).await;
508 assert!(result1.is_ok());
509 assert_eq!(ctx1.value, 15); assert_eq!(ctx1.visited, vec!["start", "skip_check", "end"]);
511
512 let mut ctx2 = TestContext {
514 value: 3,
515 visited: vec![],
516 };
517
518 let result2 = workflow.execute(&mut ctx2).await;
519 assert!(result2.is_ok());
520 assert_eq!(ctx2.value, 11); assert_eq!(ctx2.visited, vec!["start", "skip_check", "end"]);
522 }
523}