1use std::collections::{HashMap, HashSet};
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41
42use anyhow::{Result, anyhow};
43use petgraph::algo::is_cyclic_directed;
44use petgraph::graph::{DiGraph, NodeIndex};
45use serde_json::Value;
46use tokio::sync::RwLock;
47
48#[derive(Clone)]
55pub struct WorkflowContext {
56 state: Arc<RwLock<HashMap<String, Value>>>,
57 results: Arc<RwLock<HashMap<String, Value>>>,
59}
60
61impl WorkflowContext {
62 pub fn new() -> Self {
64 Self {
65 state: Arc::new(RwLock::new(HashMap::new())),
66 results: Arc::new(RwLock::new(HashMap::new())),
67 }
68 }
69
70 pub async fn set(&self, key: impl Into<String>, value: Value) {
72 self.state.write().await.insert(key.into(), value);
73 }
74
75 pub async fn get(&self, key: &str) -> Option<Value> {
77 self.state.read().await.get(key).cloned()
78 }
79
80 pub async fn remove(&self, key: &str) -> Option<Value> {
82 self.state.write().await.remove(key)
83 }
84
85 pub async fn node_result(&self, node_name: &str) -> Option<Value> {
87 self.results.read().await.get(node_name).cloned()
88 }
89
90 async fn store_result(&self, node_name: impl Into<String>, value: Value) {
92 self.results.write().await.insert(node_name.into(), value);
93 }
94
95 pub async fn all_results(&self) -> HashMap<String, Value> {
97 self.results.read().await.clone()
98 }
99}
100
101impl Default for WorkflowContext {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107pub type NodeFn = Box<
111 dyn Fn(WorkflowContext) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync,
112>;
113
114pub type ConditionalFn = Box<dyn Fn(&Value) -> Vec<String> + Send + Sync>;
117
118struct WorkflowNode {
121 name: String,
122 handler: NodeFn,
123}
124
125enum EdgeType {
126 Direct { from: String, to: String },
128 Conditional {
130 from: String,
131 evaluator: ConditionalFn,
132 },
133}
134
135pub struct WorkflowBuilder {
142 name: String,
143 nodes: Vec<WorkflowNode>,
144 node_names: HashSet<String>,
145 edges: Vec<EdgeType>,
146}
147
148impl WorkflowBuilder {
149 pub fn new(name: impl Into<String>) -> Self {
151 Self {
152 name: name.into(),
153 nodes: Vec::new(),
154 node_names: HashSet::new(),
155 edges: Vec::new(),
156 }
157 }
158
159 pub fn node<F, Fut>(mut self, name: impl Into<String>, handler: F) -> Self
164 where
165 F: Fn(WorkflowContext) -> Fut + Send + Sync + 'static,
166 Fut: Future<Output = Result<Value>> + Send + 'static,
167 {
168 let name = name.into();
169 self.node_names.insert(name.clone());
170 self.nodes.push(WorkflowNode {
171 name,
172 handler: Box::new(move |ctx| Box::pin(handler(ctx))),
173 });
174 self
175 }
176
177 pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
183 self.edges.push(EdgeType::Direct {
184 from: from.into(),
185 to: to.into(),
186 });
187 self
188 }
189
190 pub fn conditional<F>(mut self, from: impl Into<String>, evaluator: F) -> Self
197 where
198 F: Fn(&Value) -> Vec<String> + Send + Sync + 'static,
199 {
200 self.edges.push(EdgeType::Conditional {
201 from: from.into(),
202 evaluator: Box::new(evaluator),
203 });
204 self
205 }
206
207 pub fn build(self) -> Result<Workflow> {
214 if self.nodes.is_empty() {
215 return Err(anyhow!("Workflow '{}' has no nodes", self.name));
216 }
217
218 let mut graph = DiGraph::<String, ()>::new();
220 let mut name_to_idx: HashMap<String, NodeIndex> = HashMap::new();
221
222 for node in &self.nodes {
223 let idx = graph.add_node(node.name.clone());
224 name_to_idx.insert(node.name.clone(), idx);
225 }
226
227 let mut direct_edges: Vec<(String, String)> = Vec::new();
229 let mut conditional_edges: Vec<(String, ConditionalFn)> = Vec::new();
230
231 for edge in self.edges {
232 match edge {
233 EdgeType::Direct { from, to } => {
234 if !name_to_idx.contains_key(&from) {
235 return Err(anyhow!("Edge references unknown source node '{}'", from));
236 }
237 if !name_to_idx.contains_key(&to) {
238 return Err(anyhow!("Edge references unknown target node '{}'", to));
239 }
240 graph.add_edge(name_to_idx[&from], name_to_idx[&to], ());
241 direct_edges.push((from, to));
242 }
243 EdgeType::Conditional { from, evaluator } => {
244 if !name_to_idx.contains_key(&from) {
245 return Err(anyhow!(
246 "Conditional edge references unknown source node '{}'",
247 from
248 ));
249 }
250 conditional_edges.push((from, evaluator));
251 }
252 }
253 }
254
255 if is_cyclic_directed(&graph) {
256 return Err(anyhow!("Workflow '{}' contains a cycle", self.name));
257 }
258
259 let targets: HashSet<&str> = direct_edges.iter().map(|(_, t)| t.as_str()).collect();
261 let entry_nodes: Vec<String> = self
262 .nodes
263 .iter()
264 .map(|n| &n.name)
265 .filter(|n| !targets.contains(n.as_str()))
266 .cloned()
267 .collect();
268
269 if entry_nodes.is_empty() {
270 return Err(anyhow!(
271 "Workflow '{}' has no entry nodes (every node has an incoming edge)",
272 self.name
273 ));
274 }
275
276 let mut handlers: HashMap<String, NodeFn> = HashMap::new();
278 for node in self.nodes {
279 handlers.insert(node.name, node.handler);
280 }
281
282 Ok(Workflow {
283 name: self.name,
284 handlers: Arc::new(handlers),
285 direct_edges,
286 conditional_edges: Arc::new(conditional_edges),
287 entry_nodes,
288 all_nodes: self.node_names,
289 })
290 }
291}
292
293pub struct Workflow {
302 name: String,
303 handlers: Arc<HashMap<String, NodeFn>>,
304 direct_edges: Vec<(String, String)>,
305 conditional_edges: Arc<Vec<(String, ConditionalFn)>>,
306 entry_nodes: Vec<String>,
307 all_nodes: HashSet<String>,
308}
309
310impl std::fmt::Debug for Workflow {
311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 f.debug_struct("Workflow")
313 .field("name", &self.name)
314 .field("entry_nodes", &self.entry_nodes)
315 .field("all_nodes", &self.all_nodes)
316 .field("direct_edges", &self.direct_edges)
317 .field("handlers", &format!("<{} handlers>", self.handlers.len()))
318 .finish()
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct WorkflowResult {
325 pub name: String,
327 pub success: bool,
329 pub node_results: HashMap<String, Value>,
331 pub skipped_nodes: Vec<String>,
333 pub failed_nodes: HashMap<String, String>,
335}
336
337impl Workflow {
338 pub async fn run(&self) -> Result<WorkflowResult> {
344 self.run_with_context(WorkflowContext::new()).await
345 }
346
347 pub async fn run_with_context(&self, ctx: WorkflowContext) -> Result<WorkflowResult> {
349 let completed: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
350 let failed: Arc<RwLock<HashMap<String, String>>> = Arc::new(RwLock::new(HashMap::new()));
351 let skipped: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
352
353 let mut deps: HashMap<String, HashSet<String>> = HashMap::new();
355 for node in &self.all_nodes {
356 deps.insert(node.clone(), HashSet::new());
357 }
358 for (from, to) in &self.direct_edges {
359 deps.entry(to.clone()).or_default().insert(from.clone());
360 }
361
362 loop {
363 {
366 let done = completed.read().await;
367 let fail = failed.read().await;
368 let skip = skipped.read().await;
369 let mut to_skip = Vec::new();
370 for (name, predecessors) in &deps {
371 if done.contains(name) || fail.contains_key(name) || skip.contains(name) {
372 continue;
373 }
374 if predecessors.iter().any(|p| fail.contains_key(p)) {
375 to_skip.push(name.clone());
376 }
377 }
378 drop(done);
379 drop(fail);
380 drop(skip);
381 if !to_skip.is_empty() {
382 let mut skip_guard = skipped.write().await;
383 for name in to_skip {
384 skip_guard.insert(name);
385 }
386 }
387 }
388
389 let ready: Vec<String> = {
390 let done = completed.read().await;
391 let fail = failed.read().await;
392 let skip = skipped.read().await;
393 deps.iter()
394 .filter(|(name, predecessors)| {
395 !done.contains(*name)
396 && !fail.contains_key(*name)
397 && !skip.contains(*name)
398 && predecessors
399 .iter()
400 .all(|p| done.contains(p) || skip.contains(p))
401 })
402 .map(|(name, _)| name.clone())
403 .collect()
404 };
405
406 if ready.is_empty() {
407 break;
408 }
409
410 let mut handles = Vec::new();
412 for name in ready {
413 let ctx = ctx.clone();
414 let handlers = Arc::clone(&self.handlers);
415 let completed = Arc::clone(&completed);
416 let failed = Arc::clone(&failed);
417 let conditional_edges = Arc::clone(&self.conditional_edges);
418 let node_name = name.clone();
419
420 let handle = tokio::spawn(async move {
421 if let Some(handler) = handlers.get(&node_name) {
422 match handler(ctx.clone()).await {
423 Ok(result) => {
424 ctx.store_result(&node_name, result.clone()).await;
425
426 for (from, evaluator) in conditional_edges.iter() {
429 if from == &node_name {
430 let activated = evaluator(&result);
431 ctx.set(
432 format!("__conditional_activated_{}", node_name),
433 serde_json::json!(activated),
434 )
435 .await;
436 }
437 }
438
439 completed.write().await.insert(node_name);
440 }
441 Err(e) => {
442 failed.write().await.insert(node_name, e.to_string());
443 }
444 }
445 } else {
446 failed
447 .write()
448 .await
449 .insert(node_name, "Handler not found".to_string());
450 }
451 });
452 handles.push(handle);
453 }
454
455 for handle in handles {
457 let _ = handle.await;
458 }
459
460 {
462 let ctx_state = ctx.state.read().await;
463 let mut skip_guard = skipped.write().await;
464 for (from, _) in self.conditional_edges.iter() {
465 let key = format!("__conditional_activated_{}", from);
466 if let Some(activated_val) = ctx_state.get(&key)
467 && let Some(activated) = activated_val.as_array()
468 {
469 let activated_set: HashSet<String> = activated
470 .iter()
471 .filter_map(|v| v.as_str().map(|s| s.to_string()))
472 .collect();
473 for (edge_from, edge_to) in &self.direct_edges {
475 if edge_from == from && !activated_set.contains(edge_to) {
476 skip_guard.insert(edge_to.clone());
477 }
478 }
479 }
480 }
481 }
482 }
483
484 let node_results = ctx.all_results().await;
485 let failed_map = failed.read().await.clone();
486 let skipped_vec: Vec<String> = skipped.read().await.iter().cloned().collect();
487 let success = failed_map.is_empty();
488
489 Ok(WorkflowResult {
490 name: self.name.clone(),
491 success,
492 node_results,
493 skipped_nodes: skipped_vec,
494 failed_nodes: failed_map,
495 })
496 }
497
498 pub fn name(&self) -> &str {
500 &self.name
501 }
502
503 pub fn entry_nodes(&self) -> &[String] {
505 &self.entry_nodes
506 }
507
508 pub fn node_names(&self) -> &HashSet<String> {
510 &self.all_nodes
511 }
512}
513
514#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[tokio::test]
521 async fn test_simple_linear_workflow() {
522 let workflow = WorkflowBuilder::new("linear")
523 .node("a", |ctx| {
524 Box::pin(async move {
525 ctx.set("counter", serde_json::json!(1)).await;
526 Ok(serde_json::json!({"step": "a"}))
527 })
528 })
529 .node("b", |ctx| {
530 Box::pin(async move {
531 let val = ctx.get("counter").await.unwrap();
532 let n = val.as_i64().unwrap();
533 ctx.set("counter", serde_json::json!(n + 1)).await;
534 Ok(serde_json::json!({"step": "b"}))
535 })
536 })
537 .edge("a", "b")
538 .build()
539 .unwrap();
540
541 let result = workflow.run().await.unwrap();
542 assert!(result.success);
543 assert_eq!(result.node_results.len(), 2);
544 assert!(result.failed_nodes.is_empty());
545 }
546
547 #[tokio::test]
548 async fn test_parallel_workflow() {
549 let workflow = WorkflowBuilder::new("parallel")
550 .node("start", |_ctx| {
551 Box::pin(async move { Ok(serde_json::json!("started")) })
552 })
553 .node("branch_a", |_ctx| {
554 Box::pin(async move { Ok(serde_json::json!("a_done")) })
555 })
556 .node("branch_b", |_ctx| {
557 Box::pin(async move { Ok(serde_json::json!("b_done")) })
558 })
559 .node("join", |ctx| {
560 Box::pin(async move {
561 let a = ctx.node_result("branch_a").await;
562 let b = ctx.node_result("branch_b").await;
563 Ok(serde_json::json!({"a": a, "b": b}))
564 })
565 })
566 .edge("start", "branch_a")
567 .edge("start", "branch_b")
568 .edge("branch_a", "join")
569 .edge("branch_b", "join")
570 .build()
571 .unwrap();
572
573 let result = workflow.run().await.unwrap();
574 assert!(result.success);
575 assert_eq!(result.node_results.len(), 4);
576 }
577
578 #[tokio::test]
579 async fn test_diamond_workflow() {
580 let workflow = WorkflowBuilder::new("diamond")
581 .node("a", |_| Box::pin(async { Ok(serde_json::json!(1)) }))
582 .node("b", |_| Box::pin(async { Ok(serde_json::json!(2)) }))
583 .node("c", |_| Box::pin(async { Ok(serde_json::json!(3)) }))
584 .node("d", |ctx| {
585 Box::pin(async move {
586 let b = ctx.node_result("b").await.unwrap();
587 let c = ctx.node_result("c").await.unwrap();
588 Ok(serde_json::json!(b.as_i64().unwrap() + c.as_i64().unwrap()))
589 })
590 })
591 .edge("a", "b")
592 .edge("a", "c")
593 .edge("b", "d")
594 .edge("c", "d")
595 .build()
596 .unwrap();
597
598 let result = workflow.run().await.unwrap();
599 assert!(result.success);
600 assert_eq!(result.node_results["d"], serde_json::json!(5));
601 }
602
603 #[tokio::test]
604 async fn test_conditional_workflow() {
605 let workflow = WorkflowBuilder::new("conditional")
606 .node("check", |_| {
607 Box::pin(async { Ok(serde_json::json!({"route": "fast"})) })
608 })
609 .node("fast_path", |_| {
610 Box::pin(async { Ok(serde_json::json!("fast_done")) })
611 })
612 .node("slow_path", |_| {
613 Box::pin(async { Ok(serde_json::json!("slow_done")) })
614 })
615 .edge("check", "fast_path")
616 .edge("check", "slow_path")
617 .conditional("check", |result| {
618 let route = result
619 .get("route")
620 .and_then(|v| v.as_str())
621 .unwrap_or("fast");
622 if route == "fast" {
623 vec!["fast_path".to_string()]
624 } else {
625 vec!["slow_path".to_string()]
626 }
627 })
628 .build()
629 .unwrap();
630
631 let result = workflow.run().await.unwrap();
632 assert!(result.success);
633 assert!(result.node_results.contains_key("fast_path"));
634 assert!(result.skipped_nodes.contains(&"slow_path".to_string()));
635 }
636
637 #[tokio::test]
638 async fn test_cycle_detection() {
639 let result = WorkflowBuilder::new("cyclic")
640 .node("a", |_| Box::pin(async { Ok(serde_json::json!(1)) }))
641 .node("b", |_| Box::pin(async { Ok(serde_json::json!(2)) }))
642 .edge("a", "b")
643 .edge("b", "a")
644 .build();
645
646 assert!(result.is_err());
647 assert!(result.unwrap_err().to_string().contains("cycle"));
648 }
649
650 #[tokio::test]
651 async fn test_unknown_node_in_edge() {
652 let result = WorkflowBuilder::new("bad")
653 .node("a", |_| Box::pin(async { Ok(serde_json::json!(1)) }))
654 .edge("a", "nonexistent")
655 .build();
656
657 assert!(result.is_err());
658 assert!(result.unwrap_err().to_string().contains("unknown target"));
659 }
660
661 #[tokio::test]
662 async fn test_empty_workflow() {
663 let result = WorkflowBuilder::new("empty").build();
664 assert!(result.is_err());
665 assert!(result.unwrap_err().to_string().contains("no nodes"));
666 }
667
668 #[tokio::test]
669 async fn test_single_node_workflow() {
670 let workflow = WorkflowBuilder::new("single")
671 .node("only", |_| {
672 Box::pin(async { Ok(serde_json::json!("done")) })
673 })
674 .build()
675 .unwrap();
676
677 let result = workflow.run().await.unwrap();
678 assert!(result.success);
679 assert_eq!(result.node_results.len(), 1);
680 }
681
682 #[tokio::test]
683 async fn test_node_failure_skips_dependents() {
684 let workflow = WorkflowBuilder::new("fail")
685 .node("a", |_| Box::pin(async { Err(anyhow::anyhow!("boom")) }))
686 .node("b", |_| {
687 Box::pin(async { Ok(serde_json::json!("should not run")) })
688 })
689 .edge("a", "b")
690 .build()
691 .unwrap();
692
693 let result = workflow.run().await.unwrap();
694 assert!(!result.success);
695 assert!(result.failed_nodes.contains_key("a"));
696 assert!(result.skipped_nodes.contains(&"b".to_string()));
697 }
698
699 #[tokio::test]
700 async fn test_pre_populated_context() {
701 let ctx = WorkflowContext::new();
702 ctx.set("input", serde_json::json!("hello")).await;
703
704 let workflow = WorkflowBuilder::new("with-ctx")
705 .node("use_input", |ctx| {
706 Box::pin(async move {
707 let input = ctx.get("input").await.unwrap();
708 Ok(serde_json::json!({"received": input}))
709 })
710 })
711 .build()
712 .unwrap();
713
714 let result = workflow.run_with_context(ctx).await.unwrap();
715 assert!(result.success);
716 assert_eq!(
717 result.node_results["use_input"],
718 serde_json::json!({"received": "hello"})
719 );
720 }
721}