1use std::sync::Arc;
22use std::time::Instant;
23
24use crate::error::GraphError;
25use crate::event::FlowEvent;
26use crate::ids::SpanId;
27use crate::node::FlowNode;
28use crate::node_context::NodeContext;
29use crate::state::{State, StateMerge};
30use crate::workflow_state::{MergeStrategy, WorkflowState};
31
32pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
53 label: Option<String>,
54 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
55 error_strategy: ParallelErrorStrategy,
56 _merge_strategy: std::marker::PhantomData<M>,
58}
59
60impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
61 fn clone(&self) -> Self {
62 Self {
63 label: self.label.clone(),
64 branches: self.branches.clone(),
65 error_strategy: self.error_strategy,
66 _merge_strategy: std::marker::PhantomData,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub enum ParallelErrorStrategy {
74 #[default]
76 FailFast,
77 CollectAll,
79}
80
81impl ParallelNode {
82 pub fn builder() -> ParallelNodeBuilder {
84 ParallelNodeBuilder::new()
85 }
86}
87
88impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
89 pub fn with_label(mut self, label: impl Into<String>) -> Self {
90 self.label = Some(label.into());
91 self
92 }
93
94 pub fn branch_count(&self) -> usize {
95 self.branches.len()
96 }
97
98 pub fn branch_names(&self) -> Vec<&str> {
99 self.branches
100 .iter()
101 .map(|(name, _)| name.as_str())
102 .collect()
103 }
104
105 pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
106 self.branches
107 .iter()
108 .map(|(name, node)| (name.as_str(), node))
109 }
110
111 pub fn error_strategy(&self) -> ParallelErrorStrategy {
112 self.error_strategy
113 }
114
115 pub fn label(&self) -> Option<&str> {
116 self.label.as_deref()
117 }
118
119 fn display_name(&self) -> String {
120 self.label.clone().unwrap_or_else(|| "parallel".to_string())
121 }
122}
123
124pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
126 label: Option<String>,
127 branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
128 error_strategy: ParallelErrorStrategy,
129 _phantom: std::marker::PhantomData<M>,
130}
131
132impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
133 fn new() -> Self {
134 Self {
135 label: None,
136 branches: Vec::new(),
137 error_strategy: ParallelErrorStrategy::default(),
138 _phantom: std::marker::PhantomData,
139 }
140 }
141
142 pub fn label(mut self, label: impl Into<String>) -> Self {
143 self.label = Some(label.into());
144 self
145 }
146
147 pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
148 self.branches.push((name.into(), node));
149 self
150 }
151
152 pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
153 self.error_strategy = strategy;
154 self
155 }
156
157 pub fn build(self) -> ParallelNode<S, M> {
158 if self.branches.is_empty() {
159 panic!("ParallelNode must have at least one branch");
160 }
161 ParallelNode {
162 label: self.label,
163 branches: self.branches,
164 error_strategy: self.error_strategy,
165 _merge_strategy: std::marker::PhantomData,
166 }
167 }
168
169 pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
171 where
172 NM: MergeStrategy<S>,
173 {
174 ParallelNodeBuilder {
175 label: self.label,
176 branches: self.branches,
177 error_strategy: self.error_strategy,
178 _phantom: std::marker::PhantomData,
179 }
180 }
181}
182
183pub struct ParallelNodeBuilderWithMerge<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>(
185 pub ParallelNodeBuilder<S, M>,
186);
187
188impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("ParallelNode")
191 .field("label", &self.label)
192 .field(
193 "branches",
194 &self
195 .branches
196 .iter()
197 .map(|(n, _)| n.as_str())
198 .collect::<Vec<_>>(),
199 )
200 .field("error_strategy", &self.error_strategy)
201 .finish()
202 }
203}
204
205#[async_trait::async_trait]
206impl<S: WorkflowState, M: MergeStrategy<S>> FlowNode<S> for ParallelNode<S, M> {
207 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
208 let start_time = Instant::now();
209 let span_id = SpanId::new();
210 let branch_count = self.branches.len();
211
212 ctx.emit_flow_event(FlowEvent::ParallelStarted {
213 node_id: self.display_name(),
214 branch_count,
215 span_id,
216 });
217
218 let base_state = ctx.state().clone();
220 let mut branch_results: Vec<S> = Vec::with_capacity(self.branches.len());
221
222 for (name, node) in &self.branches {
224 let branch_start = Instant::now();
225 let branch_span = SpanId::new();
226
227 let mut branch_state = base_state.clone();
229 let mut branch_bs = ctx.branch().fork();
230 let mut branch_ctx = NodeContext::new(&mut branch_state, &mut branch_bs, None);
231
232 let result = node.execute(&mut branch_ctx).await.map_err(|e| {
233 GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
234 node: format!("{}/{}", self.display_name(), name),
235 source: e.into(),
236 })
237 });
238
239 let effects = branch_ctx.consume_effects();
241 branch_state.apply_batch(effects);
242
243 let branch_duration = branch_start.elapsed();
244 let success = result.is_ok();
245
246 ctx.emit_flow_event(FlowEvent::BranchCompleted {
247 branch_name: name.clone(),
248 node_id: self.display_name(),
249 span_id: branch_span,
250 success,
251 duration: branch_duration,
252 });
253
254 if !success {
255 return result;
256 }
257
258 branch_results.push(branch_state);
259 }
260
261 let merged = M::merge(branch_results).map_err(|e| {
263 GraphError::Terminal(crate::error::TerminalError::StateError(format!(
264 "parallel merge conflict: {e}",
265 )))
266 })?;
267
268 *ctx.state_mut() = merged;
270
271 ctx.emit_flow_event(FlowEvent::ParallelCompleted {
272 node_id: self.display_name(),
273 span_id,
274 duration: start_time.elapsed(),
275 });
276
277 Ok(())
278 }
279}