Skip to main content

lellm_graph/node/
parallel_node.rs

1//! ParallelNode — 并行执行多个分支,通过 MergeStrategy 合并 State。
2//!
3//! 执行模型:
4//! ```text
5//! State
6//!  ↓
7//! fork (ParallelNode)
8//!  ↓
9//! Branch A     Branch B     Branch C
10//!  ↓            ↓            ↓
11//! State<S>     State<S>     State<S>
12//!  ↓            ↓            ↓
13//! MergeStrategy<S>::merge(branches)
14//!  ↓
15//! Merged State → replace parent state
16//! ```
17//!
18//! 每个分支接收相同的 State 快照,独立产生变更(通过 Effects)。
19//! 所有分支完成后,变更通过 MergeStrategy 合并到 State。
20
21use std::sync::Arc;
22use std::time::Instant;
23
24use super::FlowNode;
25use crate::error::GraphError;
26use crate::exec::execution_engine::{ExecutionEngine, ExecutorState, OwnedExecutionEngine};
27use crate::state::workflow_state::{MergeStrategy, WorkflowState};
28use crate::state::{State, StateMerge};
29
30/// 并行节点 — 同时执行多个分支,通过 MergeStrategy 合并 State。
31///
32/// 每个分支接收相同的 State 快照,独立产生变更。
33/// 所有分支完成后,变更通过 MergeStrategy 合并。
34///
35/// # 泛型参数
36///
37/// - `S` — 类型化状态
38/// - `M` — 合并策略(默认为 [`StateMerge`])
39///
40/// # 示例
41///
42/// ```rust,ignore
43/// let parallel = ParallelNode::builder()
44///     .branch("search", Arc::new(SearchNode::new()))
45///     .branch("analyze", Arc::new(AnalyzeNode::new()))
46///     .build();
47///
48/// graph.node("research", NodeKind::Parallel(parallel));
49/// ```
50pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
51    label: Option<String>,
52    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
53    error_strategy: ParallelErrorStrategy,
54    /// Phantom — M 通过 `M::merge()` 静态调用,不需要实例。
55    _merge_strategy: std::marker::PhantomData<M>,
56}
57
58impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
59    fn clone(&self) -> Self {
60        Self {
61            label: self.label.clone(),
62            branches: self.branches.clone(),
63            error_strategy: self.error_strategy,
64            _merge_strategy: std::marker::PhantomData,
65        }
66    }
67}
68
69/// 并行执行错误处理策略。
70#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
71pub enum ParallelErrorStrategy {
72    /// 任一分支失败 → 立即返回错误(其余分支继续执行但结果被忽略)
73    #[default]
74    FailFast,
75    /// 等待所有分支完成,至少一个失败 → 返回错误但包含成功分支的变更
76    CollectAll,
77}
78
79impl ParallelNode {
80    /// 创建默认构建器(`State` + `StateMerge`)。
81    pub fn builder() -> ParallelNodeBuilder {
82        ParallelNodeBuilder::new()
83    }
84}
85
86impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
87    pub fn with_label(mut self, label: impl Into<String>) -> Self {
88        self.label = Some(label.into());
89        self
90    }
91
92    pub fn branch_count(&self) -> usize {
93        self.branches.len()
94    }
95
96    pub fn branch_names(&self) -> Vec<&str> {
97        self.branches
98            .iter()
99            .map(|(name, _)| name.as_str())
100            .collect()
101    }
102
103    pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
104        self.branches
105            .iter()
106            .map(|(name, node)| (name.as_str(), node))
107    }
108
109    pub fn error_strategy(&self) -> ParallelErrorStrategy {
110        self.error_strategy
111    }
112
113    pub fn label(&self) -> Option<&str> {
114        self.label.as_deref()
115    }
116
117    fn display_name(&self) -> String {
118        self.label.clone().unwrap_or_else(|| "parallel".to_string())
119    }
120}
121
122/// ParallelNode 构建器。
123pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
124    label: Option<String>,
125    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
126    error_strategy: ParallelErrorStrategy,
127    _phantom: std::marker::PhantomData<M>,
128}
129
130impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
131    fn new() -> Self {
132        Self {
133            label: None,
134            branches: Vec::new(),
135            error_strategy: ParallelErrorStrategy::default(),
136            _phantom: std::marker::PhantomData,
137        }
138    }
139
140    pub fn label(mut self, label: impl Into<String>) -> Self {
141        self.label = Some(label.into());
142        self
143    }
144
145    pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
146        self.branches.push((name.into(), node));
147        self
148    }
149
150    pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
151        self.error_strategy = strategy;
152        self
153    }
154
155    pub fn build(self) -> ParallelNode<S, M> {
156        if self.branches.is_empty() {
157            panic!("ParallelNode must have at least one branch");
158        }
159        ParallelNode {
160            label: self.label,
161            branches: self.branches,
162            error_strategy: self.error_strategy,
163            _merge_strategy: std::marker::PhantomData,
164        }
165    }
166
167    /// 替换合并策略,返回新类型的构建器。
168    pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
169    where
170        NM: MergeStrategy<S>,
171    {
172        ParallelNodeBuilder {
173            label: self.label,
174            branches: self.branches,
175            error_strategy: self.error_strategy,
176            _phantom: std::marker::PhantomData,
177        }
178    }
179}
180
181impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("ParallelNode")
184            .field("label", &self.label)
185            .field(
186                "branches",
187                &self
188                    .branches
189                    .iter()
190                    .map(|(n, _)| n.as_str())
191                    .collect::<Vec<_>>(),
192            )
193            .field("error_strategy", &self.error_strategy)
194            .finish()
195    }
196}
197
198impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ParallelNode<S, M> {
199    /// 执行并行分支 — 创建独立的 OwnedExecutionEngine 给每个分支。
200    pub async fn execute(&self, engine: &mut ExecutionEngine<'_, S>) -> Result<(), GraphError> {
201        let start_time = Instant::now();
202        let branch_count = self.branches.len();
203        let display_name = self.display_name();
204
205        tracing::debug!(
206            parallel = %display_name,
207            branches = branch_count,
208            "parallel node started"
209        );
210
211        // Clone typed state for each branch — each branch works on its own copy
212        let base_state = engine.clone_state();
213
214        // Inherit parent's cancel token and stream (fan-out via Arc clone)
215        let parent_cancel = engine.cancel_token().clone();
216        let parent_stream = engine.stream_sink();
217
218        // Clone branch data so async blocks own everything they need
219        let branches: Vec<(String, Arc<dyn super::FlowNode<S>>)> = self
220            .branches
221            .iter()
222            .map(|(n, nd)| (n.clone(), nd.clone()))
223            .collect();
224
225        // Create a future for each branch — no spawn, no 'static required
226        let branch_futures: Vec<_> = branches
227            .into_iter()
228            .map(|(branch_name, node)| {
229                let state = base_state.clone();
230                let child_cancel = parent_cancel.child_token();
231                let child_stream = parent_stream.clone();
232                async move {
233                    let branch_start = Instant::now();
234
235                    // Each branch gets its own OwnedExecutionEngine (child engine)
236                    let mut child_engine =
237                        OwnedExecutionEngine::new(state, child_stream, child_cancel);
238
239                    let mut branch_ctx = child_engine.build_node_context();
240                    let ok = node.execute(&mut branch_ctx).await.is_ok();
241                    drop(branch_ctx);
242
243                    if !ok {
244                        return (branch_name, Err("branch execution failed".into()));
245                    }
246
247                    // Commit mutations to child engine
248                    child_engine.commit();
249
250                    let duration = branch_start.elapsed();
251
252                    (branch_name, Ok((child_engine.into_state(), duration)))
253                }
254            })
255            .collect();
256
257        // Execute all branches concurrently (no spawn, just concurrent polling)
258        let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
259            futures::future::join_all(branch_futures).await;
260
261        // Process results in branch order
262        let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
263        let mut errors: Vec<(String, String)> = Vec::new();
264
265        for (branch_name, result) in raw_results {
266            match result {
267                Ok((state, branch_duration)) => {
268                    tracing::debug!(
269                        parallel = %display_name,
270                        branch = %branch_name,
271                        duration_ms = branch_duration.as_millis(),
272                        "branch completed"
273                    );
274                    branch_states.push(state);
275                }
276                Err(reason) => {
277                    errors.push((branch_name, reason));
278                }
279            }
280        }
281
282        // Error handling based on strategy
283        if !errors.is_empty() {
284            match self.error_strategy {
285                ParallelErrorStrategy::FailFast => {
286                    let (name, reason) = &errors[0];
287                    return Err(GraphError::Terminal(
288                        crate::error::TerminalError::NodeExecutionFailed {
289                            node: format!("{}/{}", display_name, name),
290                            source: reason.clone().into(),
291                        },
292                    ));
293                }
294                ParallelErrorStrategy::CollectAll => {
295                    if !branch_states.is_empty() {
296                        for (name, reason) in &errors {
297                            tracing::warn!(
298                                parallel = %display_name,
299                                branch = %name,
300                                error = %reason,
301                                "branch failed (CollectAll strategy)"
302                            );
303                        }
304                    }
305                    let (name, reason) = &errors[0];
306                    return Err(GraphError::Terminal(
307                        crate::error::TerminalError::NodeExecutionFailed {
308                            node: format!("{}/{}", display_name, name),
309                            source: reason.clone().into(),
310                        },
311                    ));
312                }
313            }
314        }
315
316        // Merge all branch states using MergeStrategy — Graph 层并行语义
317        let merged = M::merge(branch_states).map_err(|e| {
318            GraphError::Terminal(crate::error::TerminalError::StateError(format!(
319                "parallel merge conflict: {e}",
320            )))
321        })?;
322
323        // Replace parent state with merged result
324        engine.replace_state(merged);
325
326        tracing::debug!(
327            parallel = %display_name,
328            duration_ms = start_time.elapsed().as_millis(),
329            "parallel node completed"
330        );
331
332        Ok(())
333    }
334}