1use std::sync::Arc;
22use std::time::Instant;
23
24use crate::error::GraphError;
25use crate::event::FlowEvent;
26use crate::execution_engine::{ExecutionEngine, ExecutorState};
27use crate::ids::SpanId;
28use crate::node::{ExecutorOperation, FlowNode};
29use crate::state::{State, StateMerge};
30use crate::workflow_state::{MergeStrategy, WorkflowState};
31
32pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
53 label: Option<String>,
54 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
55 error_strategy: ParallelErrorStrategy,
56 _merge_strategy: std::marker::PhantomData<M>,
58}
59
60impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
61 fn clone(&self) -> Self {
62 Self {
63 label: self.label.clone(),
64 branches: self.branches.clone(),
65 error_strategy: self.error_strategy,
66 _merge_strategy: std::marker::PhantomData,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub enum ParallelErrorStrategy {
74 #[default]
76 FailFast,
77 CollectAll,
79}
80
81impl ParallelNode {
82 pub fn builder() -> ParallelNodeBuilder {
84 ParallelNodeBuilder::new()
85 }
86}
87
88impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
89 pub fn with_label(mut self, label: impl Into<String>) -> Self {
90 self.label = Some(label.into());
91 self
92 }
93
94 pub fn branch_count(&self) -> usize {
95 self.branches.len()
96 }
97
98 pub fn branch_names(&self) -> Vec<&str> {
99 self.branches
100 .iter()
101 .map(|(name, _)| name.as_str())
102 .collect()
103 }
104
105 pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
106 self.branches
107 .iter()
108 .map(|(name, node)| (name.as_str(), node))
109 }
110
111 pub fn error_strategy(&self) -> ParallelErrorStrategy {
112 self.error_strategy
113 }
114
115 pub fn label(&self) -> Option<&str> {
116 self.label.as_deref()
117 }
118
119 fn display_name(&self) -> String {
120 self.label.clone().unwrap_or_else(|| "parallel".to_string())
121 }
122}
123
124pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
126 label: Option<String>,
127 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
128 error_strategy: ParallelErrorStrategy,
129 _phantom: std::marker::PhantomData<M>,
130}
131
132impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
133 fn new() -> Self {
134 Self {
135 label: None,
136 branches: Vec::new(),
137 error_strategy: ParallelErrorStrategy::default(),
138 _phantom: std::marker::PhantomData,
139 }
140 }
141
142 pub fn label(mut self, label: impl Into<String>) -> Self {
143 self.label = Some(label.into());
144 self
145 }
146
147 pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
148 self.branches.push((name.into(), node));
149 self
150 }
151
152 pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
153 self.error_strategy = strategy;
154 self
155 }
156
157 pub fn build(self) -> ParallelNode<S, M> {
158 if self.branches.is_empty() {
159 panic!("ParallelNode must have at least one branch");
160 }
161 ParallelNode {
162 label: self.label,
163 branches: self.branches,
164 error_strategy: self.error_strategy,
165 _merge_strategy: std::marker::PhantomData,
166 }
167 }
168
169 pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
171 where
172 NM: MergeStrategy<S>,
173 {
174 ParallelNodeBuilder {
175 label: self.label,
176 branches: self.branches,
177 error_strategy: self.error_strategy,
178 _phantom: std::marker::PhantomData,
179 }
180 }
181}
182
183impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("ParallelNode")
186 .field("label", &self.label)
187 .field(
188 "branches",
189 &self
190 .branches
191 .iter()
192 .map(|(n, _)| n.as_str())
193 .collect::<Vec<_>>(),
194 )
195 .field("error_strategy", &self.error_strategy)
196 .finish()
197 }
198}
199
200#[async_trait::async_trait]
201impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ExecutorOperation<S>
202 for ParallelNode<S, M>
203{
204 async fn execute(&self, engine: &mut ExecutionEngine<S>) -> Result<(), GraphError> {
205 let start_time = Instant::now();
206 let span_id = SpanId::new();
207 let branch_count = self.branches.len();
208 let display_name = self.display_name();
209
210 engine.emit_flow_event(FlowEvent::ParallelStarted {
211 node_id: display_name.clone(),
212 branch_count,
213 span_id,
214 });
215
216 let base_state = engine.clone_state();
218
219 let parent_cancel = engine.cancel_token().clone();
221 let parent_stream = engine.stream_sink();
222
223 let branches: Vec<(String, Arc<dyn crate::node::FlowNode<S>>)> = self
225 .branches
226 .iter()
227 .map(|(n, nd)| (n.clone(), nd.clone()))
228 .collect();
229
230 let branch_futures: Vec<_> = branches
232 .into_iter()
233 .map(|(branch_name, node)| {
234 let state = base_state.clone();
235 let child_cancel = parent_cancel.child_token();
236 let child_stream = parent_stream.clone();
237 async move {
238 let branch_start = Instant::now();
239
240 let mut child_engine = ExecutionEngine::new(state, child_stream, child_cancel);
242
243 let mut branch_ctx = child_engine.build_node_context();
244 let ok = node.execute(&mut branch_ctx).await.is_ok();
245 drop(branch_ctx);
246
247 if !ok {
248 return (branch_name, Err("branch execution failed".into()));
249 }
250
251 child_engine.commit();
253
254 let duration = branch_start.elapsed();
255
256 (branch_name, Ok((child_engine.into_state(), duration)))
257 }
258 })
259 .collect();
260
261 let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
263 futures::future::join_all(branch_futures).await;
264
265 let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
267 let mut errors: Vec<(String, String)> = Vec::new();
268
269 for (branch_name, result) in raw_results {
270 match result {
271 Ok((state, branch_duration)) => {
272 engine.emit_flow_event(FlowEvent::BranchCompleted {
273 branch_name,
274 node_id: display_name.clone(),
275 span_id: SpanId::new(),
276 success: true,
277 duration: branch_duration,
278 });
279 branch_states.push(state);
280 }
281 Err(reason) => {
282 errors.push((branch_name, reason));
283 }
284 }
285 }
286
287 if !errors.is_empty() {
289 match self.error_strategy {
290 ParallelErrorStrategy::FailFast => {
291 let (name, reason) = &errors[0];
292 return Err(GraphError::Terminal(
293 crate::error::TerminalError::NodeExecutionFailed {
294 node: format!("{}/{}", display_name, name),
295 source: reason.clone().into(),
296 },
297 ));
298 }
299 ParallelErrorStrategy::CollectAll => {
300 if !branch_states.is_empty() {
301 for (name, reason) in &errors {
302 tracing::warn!(
303 parallel = %display_name,
304 branch = %name,
305 error = %reason,
306 "branch failed (CollectAll strategy)"
307 );
308 }
309 }
310 let (name, reason) = &errors[0];
311 return Err(GraphError::Terminal(
312 crate::error::TerminalError::NodeExecutionFailed {
313 node: format!("{}/{}", display_name, name),
314 source: reason.clone().into(),
315 },
316 ));
317 }
318 }
319 }
320
321 let merged = M::merge(branch_states).map_err(|e| {
323 GraphError::Terminal(crate::error::TerminalError::StateError(format!(
324 "parallel merge conflict: {e}",
325 )))
326 })?;
327
328 engine.replace_state(merged);
330
331 engine.emit_flow_event(FlowEvent::ParallelCompleted {
332 node_id: display_name,
333 span_id,
334 duration: start_time.elapsed(),
335 });
336
337 Ok(())
338 }
339}