agentic_workflow/engine/
fanout.rs1use std::collections::HashMap;
2
3use chrono::Utc;
4use uuid::Uuid;
5
6use crate::types::{
7 CompletionPolicy, FanOutBranch, FanOutBranchStatus, FanOutDestination,
8 FanOutStatus, FanOutStep, ResultAggregator,
9 WorkflowError, WorkflowResult,
10};
11
12pub struct FanOutEngine {
14 steps: HashMap<String, FanOutStep>,
15 statuses: HashMap<String, FanOutStatus>,
16}
17
18impl FanOutEngine {
19 pub fn new() -> Self {
20 Self {
21 steps: HashMap::new(),
22 statuses: HashMap::new(),
23 }
24 }
25
26 pub fn create_fanout(
28 &mut self,
29 destinations: Vec<FanOutDestination>,
30 completion_policy: CompletionPolicy,
31 aggregator: ResultAggregator,
32 timeout_ms: Option<u64>,
33 ) -> WorkflowResult<String> {
34 let id = Uuid::new_v4().to_string();
35 let step = FanOutStep {
36 id: id.clone(),
37 destinations,
38 completion_policy,
39 aggregator,
40 partial_success_threshold: None,
41 timeout_ms,
42 };
43
44 self.steps.insert(id.clone(), step);
45 Ok(id)
46 }
47
48 pub fn start_execution(
50 &mut self,
51 fanout_id: &str,
52 execution_id: &str,
53 ) -> WorkflowResult<()> {
54 let step = self
55 .steps
56 .get(fanout_id)
57 .ok_or_else(|| WorkflowError::Internal(format!("FanOut not found: {}", fanout_id)))?;
58
59 let branches: Vec<FanOutBranch> = step
60 .destinations
61 .iter()
62 .map(|d| FanOutBranch {
63 destination_id: d.id.clone(),
64 status: FanOutBranchStatus::Pending,
65 output: None,
66 error: None,
67 duration_ms: None,
68 })
69 .collect();
70
71 let status = FanOutStatus {
72 fanout_id: fanout_id.to_string(),
73 execution_id: execution_id.to_string(),
74 branches,
75 started_at: Utc::now(),
76 completed: false,
77 };
78
79 self.statuses.insert(execution_id.to_string(), status);
80 Ok(())
81 }
82
83 pub fn update_branch(
85 &mut self,
86 execution_id: &str,
87 destination_id: &str,
88 status: FanOutBranchStatus,
89 output: Option<serde_json::Value>,
90 error: Option<String>,
91 duration_ms: Option<u64>,
92 ) -> WorkflowResult<()> {
93 let fanout_status = self
94 .statuses
95 .get_mut(execution_id)
96 .ok_or_else(|| {
97 WorkflowError::ExecutionNotFound(execution_id.to_string())
98 })?;
99
100 if let Some(branch) = fanout_status
101 .branches
102 .iter_mut()
103 .find(|b| b.destination_id == destination_id)
104 {
105 branch.status = status;
106 branch.output = output;
107 branch.error = error;
108 branch.duration_ms = duration_ms;
109 }
110
111 let all_done = fanout_status
113 .branches
114 .iter()
115 .all(|b| matches!(
116 b.status,
117 FanOutBranchStatus::Success
118 | FanOutBranchStatus::Failed
119 | FanOutBranchStatus::TimedOut
120 | FanOutBranchStatus::Cancelled
121 ));
122
123 if all_done {
124 fanout_status.completed = true;
125 }
126
127 Ok(())
128 }
129
130 pub fn get_status(&self, execution_id: &str) -> WorkflowResult<&FanOutStatus> {
132 self.statuses
133 .get(execution_id)
134 .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))
135 }
136
137 pub fn get_step(&self, fanout_id: &str) -> WorkflowResult<&FanOutStep> {
139 self.steps
140 .get(fanout_id)
141 .ok_or_else(|| WorkflowError::Internal(format!("FanOut not found: {}", fanout_id)))
142 }
143}
144
145impl Default for FanOutEngine {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_fanout_creation() {
157 let mut engine = FanOutEngine::new();
158 let dests = vec![
159 FanOutDestination {
160 id: "d1".to_string(),
161 name: "API 1".to_string(),
162 step_config: serde_json::json!({}),
163 },
164 FanOutDestination {
165 id: "d2".to_string(),
166 name: "API 2".to_string(),
167 step_config: serde_json::json!({}),
168 },
169 ];
170
171 let fid = engine
172 .create_fanout(dests, CompletionPolicy::WaitAll, ResultAggregator::Merge, None)
173 .unwrap();
174
175 engine.start_execution(&fid, "exec-1").unwrap();
176 let status = engine.get_status("exec-1").unwrap();
177 assert_eq!(status.branches.len(), 2);
178 assert!(!status.completed);
179 }
180}