Skip to main content

lellm_graph/
parallel_node.rs

1//! ParallelNode — 并行执行多个分支,通过 MergeStrategy 合并 State。
2//!
3//! 执行模型:
4//! ```text
5//! State
6//!  ↓
7//! fork (ParallelNode)
8//!  ↓
9//! Branch A     Branch B     Branch C
10//!  ↓            ↓            ↓
11//! State<S>     State<S>     State<S>
12//!  ↓            ↓            ↓
13//! MergeStrategy<S>::merge(branches)
14//!  ↓
15//! Merged State → replace parent state
16//! ```
17//!
18//! 每个分支接收相同的 State 快照,独立产生变更(通过 Effects)。
19//! 所有分支完成后,变更通过 MergeStrategy 合并到 State。
20
21use std::sync::Arc;
22use std::time::Instant;
23
24use crate::error::GraphError;
25use crate::event::FlowEvent;
26use crate::execution_engine::{ExecutionEngine, ExecutorState};
27use crate::ids::SpanId;
28use crate::node::{ExecutorOperation, FlowNode};
29use crate::state::{State, StateMerge};
30use crate::workflow_state::{MergeStrategy, WorkflowState};
31
32/// 并行节点 — 同时执行多个分支,通过 MergeStrategy 合并 State。
33///
34/// 每个分支接收相同的 State 快照,独立产生变更。
35/// 所有分支完成后,变更通过 MergeStrategy 合并。
36///
37/// # 泛型参数
38///
39/// - `S` — 类型化状态
40/// - `M` — 合并策略(默认为 [`StateMerge`])
41///
42/// # 示例
43///
44/// ```rust,ignore
45/// let parallel = ParallelNode::builder()
46///     .branch("search", Arc::new(SearchNode::new()))
47///     .branch("analyze", Arc::new(AnalyzeNode::new()))
48///     .build();
49///
50/// graph.node("research", NodeKind::Parallel(parallel));
51/// ```
52pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
53    label: Option<String>,
54    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
55    error_strategy: ParallelErrorStrategy,
56    /// Phantom — M 通过 `M::merge()` 静态调用,不需要实例。
57    _merge_strategy: std::marker::PhantomData<M>,
58}
59
60impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
61    fn clone(&self) -> Self {
62        Self {
63            label: self.label.clone(),
64            branches: self.branches.clone(),
65            error_strategy: self.error_strategy,
66            _merge_strategy: std::marker::PhantomData,
67        }
68    }
69}
70
71/// 并行执行错误处理策略。
72#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub enum ParallelErrorStrategy {
74    /// 任一分支失败 → 立即返回错误(其余分支继续执行但结果被忽略)
75    #[default]
76    FailFast,
77    /// 等待所有分支完成,至少一个失败 → 返回错误但包含成功分支的变更
78    CollectAll,
79}
80
81impl ParallelNode {
82    /// 创建默认构建器(`State` + `StateMerge`)。
83    pub fn builder() -> ParallelNodeBuilder {
84        ParallelNodeBuilder::new()
85    }
86}
87
88impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
89    pub fn with_label(mut self, label: impl Into<String>) -> Self {
90        self.label = Some(label.into());
91        self
92    }
93
94    pub fn branch_count(&self) -> usize {
95        self.branches.len()
96    }
97
98    pub fn branch_names(&self) -> Vec<&str> {
99        self.branches
100            .iter()
101            .map(|(name, _)| name.as_str())
102            .collect()
103    }
104
105    pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
106        self.branches
107            .iter()
108            .map(|(name, node)| (name.as_str(), node))
109    }
110
111    pub fn error_strategy(&self) -> ParallelErrorStrategy {
112        self.error_strategy
113    }
114
115    pub fn label(&self) -> Option<&str> {
116        self.label.as_deref()
117    }
118
119    fn display_name(&self) -> String {
120        self.label.clone().unwrap_or_else(|| "parallel".to_string())
121    }
122}
123
124/// ParallelNode 构建器。
125pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
126    label: Option<String>,
127    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
128    error_strategy: ParallelErrorStrategy,
129    _phantom: std::marker::PhantomData<M>,
130}
131
132impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
133    fn new() -> Self {
134        Self {
135            label: None,
136            branches: Vec::new(),
137            error_strategy: ParallelErrorStrategy::default(),
138            _phantom: std::marker::PhantomData,
139        }
140    }
141
142    pub fn label(mut self, label: impl Into<String>) -> Self {
143        self.label = Some(label.into());
144        self
145    }
146
147    pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
148        self.branches.push((name.into(), node));
149        self
150    }
151
152    pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
153        self.error_strategy = strategy;
154        self
155    }
156
157    pub fn build(self) -> ParallelNode<S, M> {
158        if self.branches.is_empty() {
159            panic!("ParallelNode must have at least one branch");
160        }
161        ParallelNode {
162            label: self.label,
163            branches: self.branches,
164            error_strategy: self.error_strategy,
165            _merge_strategy: std::marker::PhantomData,
166        }
167    }
168
169    /// 替换合并策略,返回新类型的构建器。
170    pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
171    where
172        NM: MergeStrategy<S>,
173    {
174        ParallelNodeBuilder {
175            label: self.label,
176            branches: self.branches,
177            error_strategy: self.error_strategy,
178            _phantom: std::marker::PhantomData,
179        }
180    }
181}
182
183impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("ParallelNode")
186            .field("label", &self.label)
187            .field(
188                "branches",
189                &self
190                    .branches
191                    .iter()
192                    .map(|(n, _)| n.as_str())
193                    .collect::<Vec<_>>(),
194            )
195            .field("error_strategy", &self.error_strategy)
196            .finish()
197    }
198}
199
200#[async_trait::async_trait]
201impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ExecutorOperation<S>
202    for ParallelNode<S, M>
203{
204    async fn execute(&self, engine: &mut ExecutionEngine<S>) -> Result<(), GraphError> {
205        let start_time = Instant::now();
206        let span_id = SpanId::new();
207        let branch_count = self.branches.len();
208        let display_name = self.display_name();
209
210        engine.emit_flow_event(FlowEvent::ParallelStarted {
211            node_id: display_name.clone(),
212            branch_count,
213            span_id,
214        });
215
216        // Clone typed state for each branch — each branch works on its own copy
217        let base_state = engine.clone_state();
218
219        // Inherit parent's cancel token and stream (fan-out via Arc clone)
220        let parent_cancel = engine.cancel_token().clone();
221        let parent_stream = engine.stream_sink();
222
223        // Clone branch data so async blocks own everything they need
224        let branches: Vec<(String, Arc<dyn crate::node::FlowNode<S>>)> = self
225            .branches
226            .iter()
227            .map(|(n, nd)| (n.clone(), nd.clone()))
228            .collect();
229
230        // Create a future for each branch — no spawn, no 'static required
231        let branch_futures: Vec<_> = branches
232            .into_iter()
233            .map(|(branch_name, node)| {
234                let state = base_state.clone();
235                let child_cancel = parent_cancel.child_token();
236                let child_stream = parent_stream.clone();
237                async move {
238                    let branch_start = Instant::now();
239
240                    // Each branch gets its own ExecutionEngine (child engine)
241                    let mut child_engine = ExecutionEngine::new(state, child_stream, child_cancel);
242
243                    let mut branch_ctx = child_engine.build_node_context();
244                    let ok = node.execute(&mut branch_ctx).await.is_ok();
245                    drop(branch_ctx);
246
247                    if !ok {
248                        return (branch_name, Err("branch execution failed".into()));
249                    }
250
251                    // Commit mutations to child engine
252                    child_engine.commit();
253
254                    let duration = branch_start.elapsed();
255
256                    (branch_name, Ok((child_engine.into_state(), duration)))
257                }
258            })
259            .collect();
260
261        // Execute all branches concurrently (no spawn, just concurrent polling)
262        let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
263            futures::future::join_all(branch_futures).await;
264
265        // Process results in branch order
266        let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
267        let mut errors: Vec<(String, String)> = Vec::new();
268
269        for (branch_name, result) in raw_results {
270            match result {
271                Ok((state, branch_duration)) => {
272                    engine.emit_flow_event(FlowEvent::BranchCompleted {
273                        branch_name,
274                        node_id: display_name.clone(),
275                        span_id: SpanId::new(),
276                        success: true,
277                        duration: branch_duration,
278                    });
279                    branch_states.push(state);
280                }
281                Err(reason) => {
282                    errors.push((branch_name, reason));
283                }
284            }
285        }
286
287        // Error handling based on strategy
288        if !errors.is_empty() {
289            match self.error_strategy {
290                ParallelErrorStrategy::FailFast => {
291                    let (name, reason) = &errors[0];
292                    return Err(GraphError::Terminal(
293                        crate::error::TerminalError::NodeExecutionFailed {
294                            node: format!("{}/{}", display_name, name),
295                            source: reason.clone().into(),
296                        },
297                    ));
298                }
299                ParallelErrorStrategy::CollectAll => {
300                    if !branch_states.is_empty() {
301                        for (name, reason) in &errors {
302                            tracing::warn!(
303                                parallel = %display_name,
304                                branch = %name,
305                                error = %reason,
306                                "branch failed (CollectAll strategy)"
307                            );
308                        }
309                    }
310                    let (name, reason) = &errors[0];
311                    return Err(GraphError::Terminal(
312                        crate::error::TerminalError::NodeExecutionFailed {
313                            node: format!("{}/{}", display_name, name),
314                            source: reason.clone().into(),
315                        },
316                    ));
317                }
318            }
319        }
320
321        // Merge all branch states using MergeStrategy — Graph 层并行语义
322        let merged = M::merge(branch_states).map_err(|e| {
323            GraphError::Terminal(crate::error::TerminalError::StateError(format!(
324                "parallel merge conflict: {e}",
325            )))
326        })?;
327
328        // Replace parent state with merged result
329        engine.replace_state(merged);
330
331        engine.emit_flow_event(FlowEvent::ParallelCompleted {
332            node_id: display_name,
333            span_id,
334            duration: start_time.elapsed(),
335        });
336
337        Ok(())
338    }
339}