1use std::sync::Arc;
22use std::time::Instant;
23
24use crate::error::GraphError;
25use crate::event::FlowEvent;
26use crate::execution_engine::{ExecutionEngine, ExecutorState, OwnedExecutionEngine};
27use crate::ids::SpanId;
28use crate::node::FlowNode;
29use crate::state::{State, StateMerge};
30use crate::workflow_state::{MergeStrategy, WorkflowState};
31
32pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
53 label: Option<String>,
54 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
55 error_strategy: ParallelErrorStrategy,
56 _merge_strategy: std::marker::PhantomData<M>,
58}
59
60impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
61 fn clone(&self) -> Self {
62 Self {
63 label: self.label.clone(),
64 branches: self.branches.clone(),
65 error_strategy: self.error_strategy,
66 _merge_strategy: std::marker::PhantomData,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub enum ParallelErrorStrategy {
74 #[default]
76 FailFast,
77 CollectAll,
79}
80
81impl ParallelNode {
82 pub fn builder() -> ParallelNodeBuilder {
84 ParallelNodeBuilder::new()
85 }
86}
87
88impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
89 pub fn with_label(mut self, label: impl Into<String>) -> Self {
90 self.label = Some(label.into());
91 self
92 }
93
94 pub fn branch_count(&self) -> usize {
95 self.branches.len()
96 }
97
98 pub fn branch_names(&self) -> Vec<&str> {
99 self.branches
100 .iter()
101 .map(|(name, _)| name.as_str())
102 .collect()
103 }
104
105 pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
106 self.branches
107 .iter()
108 .map(|(name, node)| (name.as_str(), node))
109 }
110
111 pub fn error_strategy(&self) -> ParallelErrorStrategy {
112 self.error_strategy
113 }
114
115 pub fn label(&self) -> Option<&str> {
116 self.label.as_deref()
117 }
118
119 fn display_name(&self) -> String {
120 self.label.clone().unwrap_or_else(|| "parallel".to_string())
121 }
122}
123
124pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
126 label: Option<String>,
127 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
128 error_strategy: ParallelErrorStrategy,
129 _phantom: std::marker::PhantomData<M>,
130}
131
132impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
133 fn new() -> Self {
134 Self {
135 label: None,
136 branches: Vec::new(),
137 error_strategy: ParallelErrorStrategy::default(),
138 _phantom: std::marker::PhantomData,
139 }
140 }
141
142 pub fn label(mut self, label: impl Into<String>) -> Self {
143 self.label = Some(label.into());
144 self
145 }
146
147 pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
148 self.branches.push((name.into(), node));
149 self
150 }
151
152 pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
153 self.error_strategy = strategy;
154 self
155 }
156
157 pub fn build(self) -> ParallelNode<S, M> {
158 if self.branches.is_empty() {
159 panic!("ParallelNode must have at least one branch");
160 }
161 ParallelNode {
162 label: self.label,
163 branches: self.branches,
164 error_strategy: self.error_strategy,
165 _merge_strategy: std::marker::PhantomData,
166 }
167 }
168
169 pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
171 where
172 NM: MergeStrategy<S>,
173 {
174 ParallelNodeBuilder {
175 label: self.label,
176 branches: self.branches,
177 error_strategy: self.error_strategy,
178 _phantom: std::marker::PhantomData,
179 }
180 }
181}
182
183impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("ParallelNode")
186 .field("label", &self.label)
187 .field(
188 "branches",
189 &self
190 .branches
191 .iter()
192 .map(|(n, _)| n.as_str())
193 .collect::<Vec<_>>(),
194 )
195 .field("error_strategy", &self.error_strategy)
196 .finish()
197 }
198}
199
200impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ParallelNode<S, M> {
201 pub async fn execute(&self, engine: &mut ExecutionEngine<'_, S>) -> Result<(), GraphError> {
203 let start_time = Instant::now();
204 let span_id = SpanId::new();
205 let branch_count = self.branches.len();
206 let display_name = self.display_name();
207
208 engine.emit_flow_event(FlowEvent::ParallelStarted {
209 node_id: display_name.clone(),
210 branch_count,
211 span_id,
212 });
213
214 let base_state = engine.clone_state();
216
217 let parent_cancel = engine.cancel_token().clone();
219 let parent_stream = engine.stream_sink();
220
221 let branches: Vec<(String, Arc<dyn crate::node::FlowNode<S>>)> = self
223 .branches
224 .iter()
225 .map(|(n, nd)| (n.clone(), nd.clone()))
226 .collect();
227
228 let branch_futures: Vec<_> = branches
230 .into_iter()
231 .map(|(branch_name, node)| {
232 let state = base_state.clone();
233 let child_cancel = parent_cancel.child_token();
234 let child_stream = parent_stream.clone();
235 async move {
236 let branch_start = Instant::now();
237
238 let mut child_engine =
240 OwnedExecutionEngine::new(state, child_stream, child_cancel);
241
242 let mut branch_ctx = child_engine.build_node_context();
243 let ok = node.execute(&mut branch_ctx).await.is_ok();
244 drop(branch_ctx);
245
246 if !ok {
247 return (branch_name, Err("branch execution failed".into()));
248 }
249
250 child_engine.commit();
252
253 let duration = branch_start.elapsed();
254
255 (branch_name, Ok((child_engine.into_state(), duration)))
256 }
257 })
258 .collect();
259
260 let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
262 futures::future::join_all(branch_futures).await;
263
264 let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
266 let mut errors: Vec<(String, String)> = Vec::new();
267
268 for (branch_name, result) in raw_results {
269 match result {
270 Ok((state, branch_duration)) => {
271 engine.emit_flow_event(FlowEvent::BranchCompleted {
272 branch_name,
273 node_id: display_name.clone(),
274 span_id: SpanId::new(),
275 success: true,
276 duration: branch_duration,
277 });
278 branch_states.push(state);
279 }
280 Err(reason) => {
281 errors.push((branch_name, reason));
282 }
283 }
284 }
285
286 if !errors.is_empty() {
288 match self.error_strategy {
289 ParallelErrorStrategy::FailFast => {
290 let (name, reason) = &errors[0];
291 return Err(GraphError::Terminal(
292 crate::error::TerminalError::NodeExecutionFailed {
293 node: format!("{}/{}", display_name, name),
294 source: reason.clone().into(),
295 },
296 ));
297 }
298 ParallelErrorStrategy::CollectAll => {
299 if !branch_states.is_empty() {
300 for (name, reason) in &errors {
301 tracing::warn!(
302 parallel = %display_name,
303 branch = %name,
304 error = %reason,
305 "branch failed (CollectAll strategy)"
306 );
307 }
308 }
309 let (name, reason) = &errors[0];
310 return Err(GraphError::Terminal(
311 crate::error::TerminalError::NodeExecutionFailed {
312 node: format!("{}/{}", display_name, name),
313 source: reason.clone().into(),
314 },
315 ));
316 }
317 }
318 }
319
320 let merged = M::merge(branch_states).map_err(|e| {
322 GraphError::Terminal(crate::error::TerminalError::StateError(format!(
323 "parallel merge conflict: {e}",
324 )))
325 })?;
326
327 engine.replace_state(merged);
329
330 engine.emit_flow_event(FlowEvent::ParallelCompleted {
331 node_id: display_name,
332 span_id,
333 duration: start_time.elapsed(),
334 });
335
336 Ok(())
337 }
338}