1use std::sync::Arc;
22use std::time::Instant;
23
24use super::FlowNode;
25use crate::error::GraphError;
26use crate::exec::execution_engine::{ExecutionEngine, ExecutorState, OwnedExecutionEngine};
27use crate::state::workflow_state::{MergeStrategy, WorkflowState};
28use crate::state::{State, StateMerge};
29
30pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
51 label: Option<String>,
52 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
53 error_strategy: ParallelErrorStrategy,
54 _merge_strategy: std::marker::PhantomData<M>,
56}
57
58impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
59 fn clone(&self) -> Self {
60 Self {
61 label: self.label.clone(),
62 branches: self.branches.clone(),
63 error_strategy: self.error_strategy,
64 _merge_strategy: std::marker::PhantomData,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
71pub enum ParallelErrorStrategy {
72 #[default]
74 FailFast,
75 CollectAll,
77}
78
79impl ParallelNode {
80 pub fn builder() -> ParallelNodeBuilder {
82 ParallelNodeBuilder::new()
83 }
84}
85
86impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
87 pub fn with_label(mut self, label: impl Into<String>) -> Self {
88 self.label = Some(label.into());
89 self
90 }
91
92 pub fn branch_count(&self) -> usize {
93 self.branches.len()
94 }
95
96 pub fn branch_names(&self) -> Vec<&str> {
97 self.branches
98 .iter()
99 .map(|(name, _)| name.as_str())
100 .collect()
101 }
102
103 pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
104 self.branches
105 .iter()
106 .map(|(name, node)| (name.as_str(), node))
107 }
108
109 pub fn error_strategy(&self) -> ParallelErrorStrategy {
110 self.error_strategy
111 }
112
113 pub fn label(&self) -> Option<&str> {
114 self.label.as_deref()
115 }
116
117 fn display_name(&self) -> String {
118 self.label.clone().unwrap_or_else(|| "parallel".to_string())
119 }
120}
121
122pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
124 label: Option<String>,
125 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
126 error_strategy: ParallelErrorStrategy,
127 _phantom: std::marker::PhantomData<M>,
128}
129
130impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
131 fn new() -> Self {
132 Self {
133 label: None,
134 branches: Vec::new(),
135 error_strategy: ParallelErrorStrategy::default(),
136 _phantom: std::marker::PhantomData,
137 }
138 }
139
140 pub fn label(mut self, label: impl Into<String>) -> Self {
141 self.label = Some(label.into());
142 self
143 }
144
145 pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
146 self.branches.push((name.into(), node));
147 self
148 }
149
150 pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
151 self.error_strategy = strategy;
152 self
153 }
154
155 pub fn build(self) -> ParallelNode<S, M> {
156 if self.branches.is_empty() {
157 panic!("ParallelNode must have at least one branch");
158 }
159 ParallelNode {
160 label: self.label,
161 branches: self.branches,
162 error_strategy: self.error_strategy,
163 _merge_strategy: std::marker::PhantomData,
164 }
165 }
166
167 pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
169 where
170 NM: MergeStrategy<S>,
171 {
172 ParallelNodeBuilder {
173 label: self.label,
174 branches: self.branches,
175 error_strategy: self.error_strategy,
176 _phantom: std::marker::PhantomData,
177 }
178 }
179}
180
181impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("ParallelNode")
184 .field("label", &self.label)
185 .field(
186 "branches",
187 &self
188 .branches
189 .iter()
190 .map(|(n, _)| n.as_str())
191 .collect::<Vec<_>>(),
192 )
193 .field("error_strategy", &self.error_strategy)
194 .finish()
195 }
196}
197
198impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ParallelNode<S, M> {
199 pub async fn execute(&self, engine: &mut ExecutionEngine<'_, S>) -> Result<(), GraphError> {
201 let start_time = Instant::now();
202 let branch_count = self.branches.len();
203 let display_name = self.display_name();
204
205 tracing::debug!(
206 parallel = %display_name,
207 branches = branch_count,
208 "parallel node started"
209 );
210
211 let base_state = engine.clone_state();
213
214 let parent_cancel = engine.cancel_token().clone();
216 let parent_stream = engine.stream_sink();
217
218 let branches: Vec<(String, Arc<dyn super::FlowNode<S>>)> = self
220 .branches
221 .iter()
222 .map(|(n, nd)| (n.clone(), nd.clone()))
223 .collect();
224
225 let branch_futures: Vec<_> = branches
227 .into_iter()
228 .map(|(branch_name, node)| {
229 let state = base_state.clone();
230 let child_cancel = parent_cancel.child_token();
231 let child_stream = parent_stream.clone();
232 async move {
233 let branch_start = Instant::now();
234
235 let mut child_engine =
237 OwnedExecutionEngine::new(state, child_stream, child_cancel);
238
239 let mut branch_ctx = child_engine.build_node_context();
240 let ok = node.execute(&mut branch_ctx).await.is_ok();
241 drop(branch_ctx);
242
243 if !ok {
244 return (branch_name, Err("branch execution failed".into()));
245 }
246
247 child_engine.commit();
249
250 let duration = branch_start.elapsed();
251
252 (branch_name, Ok((child_engine.into_state(), duration)))
253 }
254 })
255 .collect();
256
257 let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
259 futures::future::join_all(branch_futures).await;
260
261 let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
263 let mut errors: Vec<(String, String)> = Vec::new();
264
265 for (branch_name, result) in raw_results {
266 match result {
267 Ok((state, branch_duration)) => {
268 tracing::debug!(
269 parallel = %display_name,
270 branch = %branch_name,
271 duration_ms = branch_duration.as_millis(),
272 "branch completed"
273 );
274 branch_states.push(state);
275 }
276 Err(reason) => {
277 errors.push((branch_name, reason));
278 }
279 }
280 }
281
282 if !errors.is_empty() {
284 match self.error_strategy {
285 ParallelErrorStrategy::FailFast => {
286 let (name, reason) = &errors[0];
287 return Err(GraphError::Terminal(
288 crate::error::TerminalError::NodeExecutionFailed {
289 node: format!("{}/{}", display_name, name),
290 source: reason.clone().into(),
291 },
292 ));
293 }
294 ParallelErrorStrategy::CollectAll => {
295 if !branch_states.is_empty() {
296 for (name, reason) in &errors {
297 tracing::warn!(
298 parallel = %display_name,
299 branch = %name,
300 error = %reason,
301 "branch failed (CollectAll strategy)"
302 );
303 }
304 }
305 let (name, reason) = &errors[0];
306 return Err(GraphError::Terminal(
307 crate::error::TerminalError::NodeExecutionFailed {
308 node: format!("{}/{}", display_name, name),
309 source: reason.clone().into(),
310 },
311 ));
312 }
313 }
314 }
315
316 let merged = M::merge(branch_states).map_err(|e| {
318 GraphError::Terminal(crate::error::TerminalError::StateError(format!(
319 "parallel merge conflict: {e}",
320 )))
321 })?;
322
323 engine.replace_state(merged);
325
326 tracing::debug!(
327 parallel = %display_name,
328 duration_ms = start_time.elapsed().as_millis(),
329 "parallel node completed"
330 );
331
332 Ok(())
333 }
334}