Skip to main content

lellm_graph/
parallel_node.rs

1//! ParallelNode — 并行执行多个分支,通过 MergeStrategy 合并 State。
2//!
3//! 执行模型:
4//! ```text
5//! State
6//!  ↓
7//! fork (ParallelNode)
8//!  ↓
9//! Branch A     Branch B     Branch C
10//!  ↓            ↓            ↓
11//! State<S>     State<S>     State<S>
12//!  ↓            ↓            ↓
13//! MergeStrategy<S>::merge(branches)
14//!  ↓
15//! Merged State → replace parent state
16//! ```
17//!
18//! 每个分支接收相同的 State 快照,独立产生变更(通过 Effects)。
19//! 所有分支完成后,变更通过 MergeStrategy 合并到 State。
20
21use std::sync::Arc;
22use std::time::Instant;
23
24use crate::error::GraphError;
25use crate::event::FlowEvent;
26use crate::execution_engine::{ExecutionEngine, ExecutorState, OwnedExecutionEngine};
27use crate::ids::SpanId;
28use crate::node::FlowNode;
29use crate::state::{State, StateMerge};
30use crate::workflow_state::{MergeStrategy, WorkflowState};
31
32/// 并行节点 — 同时执行多个分支,通过 MergeStrategy 合并 State。
33///
34/// 每个分支接收相同的 State 快照,独立产生变更。
35/// 所有分支完成后,变更通过 MergeStrategy 合并。
36///
37/// # 泛型参数
38///
39/// - `S` — 类型化状态
40/// - `M` — 合并策略(默认为 [`StateMerge`])
41///
42/// # 示例
43///
44/// ```rust,ignore
45/// let parallel = ParallelNode::builder()
46///     .branch("search", Arc::new(SearchNode::new()))
47///     .branch("analyze", Arc::new(AnalyzeNode::new()))
48///     .build();
49///
50/// graph.node("research", NodeKind::Parallel(parallel));
51/// ```
52pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
53    label: Option<String>,
54    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
55    error_strategy: ParallelErrorStrategy,
56    /// Phantom — M 通过 `M::merge()` 静态调用,不需要实例。
57    _merge_strategy: std::marker::PhantomData<M>,
58}
59
60impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
61    fn clone(&self) -> Self {
62        Self {
63            label: self.label.clone(),
64            branches: self.branches.clone(),
65            error_strategy: self.error_strategy,
66            _merge_strategy: std::marker::PhantomData,
67        }
68    }
69}
70
71/// 并行执行错误处理策略。
72#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub enum ParallelErrorStrategy {
74    /// 任一分支失败 → 立即返回错误(其余分支继续执行但结果被忽略)
75    #[default]
76    FailFast,
77    /// 等待所有分支完成,至少一个失败 → 返回错误但包含成功分支的变更
78    CollectAll,
79}
80
81impl ParallelNode {
82    /// 创建默认构建器(`State` + `StateMerge`)。
83    pub fn builder() -> ParallelNodeBuilder {
84        ParallelNodeBuilder::new()
85    }
86}
87
88impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
89    pub fn with_label(mut self, label: impl Into<String>) -> Self {
90        self.label = Some(label.into());
91        self
92    }
93
94    pub fn branch_count(&self) -> usize {
95        self.branches.len()
96    }
97
98    pub fn branch_names(&self) -> Vec<&str> {
99        self.branches
100            .iter()
101            .map(|(name, _)| name.as_str())
102            .collect()
103    }
104
105    pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
106        self.branches
107            .iter()
108            .map(|(name, node)| (name.as_str(), node))
109    }
110
111    pub fn error_strategy(&self) -> ParallelErrorStrategy {
112        self.error_strategy
113    }
114
115    pub fn label(&self) -> Option<&str> {
116        self.label.as_deref()
117    }
118
119    fn display_name(&self) -> String {
120        self.label.clone().unwrap_or_else(|| "parallel".to_string())
121    }
122}
123
124/// ParallelNode 构建器。
125pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
126    label: Option<String>,
127    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
128    error_strategy: ParallelErrorStrategy,
129    _phantom: std::marker::PhantomData<M>,
130}
131
132impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
133    fn new() -> Self {
134        Self {
135            label: None,
136            branches: Vec::new(),
137            error_strategy: ParallelErrorStrategy::default(),
138            _phantom: std::marker::PhantomData,
139        }
140    }
141
142    pub fn label(mut self, label: impl Into<String>) -> Self {
143        self.label = Some(label.into());
144        self
145    }
146
147    pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
148        self.branches.push((name.into(), node));
149        self
150    }
151
152    pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
153        self.error_strategy = strategy;
154        self
155    }
156
157    pub fn build(self) -> ParallelNode<S, M> {
158        if self.branches.is_empty() {
159            panic!("ParallelNode must have at least one branch");
160        }
161        ParallelNode {
162            label: self.label,
163            branches: self.branches,
164            error_strategy: self.error_strategy,
165            _merge_strategy: std::marker::PhantomData,
166        }
167    }
168
169    /// 替换合并策略,返回新类型的构建器。
170    pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
171    where
172        NM: MergeStrategy<S>,
173    {
174        ParallelNodeBuilder {
175            label: self.label,
176            branches: self.branches,
177            error_strategy: self.error_strategy,
178            _phantom: std::marker::PhantomData,
179        }
180    }
181}
182
183impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("ParallelNode")
186            .field("label", &self.label)
187            .field(
188                "branches",
189                &self
190                    .branches
191                    .iter()
192                    .map(|(n, _)| n.as_str())
193                    .collect::<Vec<_>>(),
194            )
195            .field("error_strategy", &self.error_strategy)
196            .finish()
197    }
198}
199
200impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ParallelNode<S, M> {
201    /// 执行并行分支 — 创建独立的 OwnedExecutionEngine 给每个分支。
202    pub async fn execute(&self, engine: &mut ExecutionEngine<'_, S>) -> Result<(), GraphError> {
203        let start_time = Instant::now();
204        let span_id = SpanId::new();
205        let branch_count = self.branches.len();
206        let display_name = self.display_name();
207
208        engine.emit_flow_event(FlowEvent::ParallelStarted {
209            node_id: display_name.clone(),
210            branch_count,
211            span_id,
212        });
213
214        // Clone typed state for each branch — each branch works on its own copy
215        let base_state = engine.clone_state();
216
217        // Inherit parent's cancel token and stream (fan-out via Arc clone)
218        let parent_cancel = engine.cancel_token().clone();
219        let parent_stream = engine.stream_sink();
220
221        // Clone branch data so async blocks own everything they need
222        let branches: Vec<(String, Arc<dyn crate::node::FlowNode<S>>)> = self
223            .branches
224            .iter()
225            .map(|(n, nd)| (n.clone(), nd.clone()))
226            .collect();
227
228        // Create a future for each branch — no spawn, no 'static required
229        let branch_futures: Vec<_> = branches
230            .into_iter()
231            .map(|(branch_name, node)| {
232                let state = base_state.clone();
233                let child_cancel = parent_cancel.child_token();
234                let child_stream = parent_stream.clone();
235                async move {
236                    let branch_start = Instant::now();
237
238                    // Each branch gets its own OwnedExecutionEngine (child engine)
239                    let mut child_engine =
240                        OwnedExecutionEngine::new(state, child_stream, child_cancel);
241
242                    let mut branch_ctx = child_engine.build_node_context();
243                    let ok = node.execute(&mut branch_ctx).await.is_ok();
244                    drop(branch_ctx);
245
246                    if !ok {
247                        return (branch_name, Err("branch execution failed".into()));
248                    }
249
250                    // Commit mutations to child engine
251                    child_engine.commit();
252
253                    let duration = branch_start.elapsed();
254
255                    (branch_name, Ok((child_engine.into_state(), duration)))
256                }
257            })
258            .collect();
259
260        // Execute all branches concurrently (no spawn, just concurrent polling)
261        let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
262            futures::future::join_all(branch_futures).await;
263
264        // Process results in branch order
265        let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
266        let mut errors: Vec<(String, String)> = Vec::new();
267
268        for (branch_name, result) in raw_results {
269            match result {
270                Ok((state, branch_duration)) => {
271                    engine.emit_flow_event(FlowEvent::BranchCompleted {
272                        branch_name,
273                        node_id: display_name.clone(),
274                        span_id: SpanId::new(),
275                        success: true,
276                        duration: branch_duration,
277                    });
278                    branch_states.push(state);
279                }
280                Err(reason) => {
281                    errors.push((branch_name, reason));
282                }
283            }
284        }
285
286        // Error handling based on strategy
287        if !errors.is_empty() {
288            match self.error_strategy {
289                ParallelErrorStrategy::FailFast => {
290                    let (name, reason) = &errors[0];
291                    return Err(GraphError::Terminal(
292                        crate::error::TerminalError::NodeExecutionFailed {
293                            node: format!("{}/{}", display_name, name),
294                            source: reason.clone().into(),
295                        },
296                    ));
297                }
298                ParallelErrorStrategy::CollectAll => {
299                    if !branch_states.is_empty() {
300                        for (name, reason) in &errors {
301                            tracing::warn!(
302                                parallel = %display_name,
303                                branch = %name,
304                                error = %reason,
305                                "branch failed (CollectAll strategy)"
306                            );
307                        }
308                    }
309                    let (name, reason) = &errors[0];
310                    return Err(GraphError::Terminal(
311                        crate::error::TerminalError::NodeExecutionFailed {
312                            node: format!("{}/{}", display_name, name),
313                            source: reason.clone().into(),
314                        },
315                    ));
316                }
317            }
318        }
319
320        // Merge all branch states using MergeStrategy — Graph 层并行语义
321        let merged = M::merge(branch_states).map_err(|e| {
322            GraphError::Terminal(crate::error::TerminalError::StateError(format!(
323                "parallel merge conflict: {e}",
324            )))
325        })?;
326
327        // Replace parent state with merged result
328        engine.replace_state(merged);
329
330        engine.emit_flow_event(FlowEvent::ParallelCompleted {
331            node_id: display_name,
332            span_id,
333            duration: start_time.elapsed(),
334        });
335
336        Ok(())
337    }
338}