Skip to main content

lellm_graph/
parallel_node.rs

1//! ParallelNode — 并行执行多个分支,通过 MergeStrategy 合并 State。
2//!
3//! 执行模型:
4//! ```text
5//! State
6//!  ↓
7//! fork (ParallelNode)
8//!  ↓
9//! Branch A     Branch B     Branch C
10//!  ↓            ↓            ↓
11//! State<S>     State<S>     State<S>
12//!  ↓            ↓            ↓
13//! MergeStrategy<S>::merge(branches)
14//!  ↓
15//! Merged State → replace parent state
16//! ```
17//!
18//! 每个分支接收相同的 State 快照,独立产生变更(通过 Effects)。
19//! 所有分支完成后,变更通过 MergeStrategy 合并到 State。
20
21use std::sync::Arc;
22use std::time::Instant;
23
24use crate::error::GraphError;
25use crate::event::FlowEvent;
26use crate::ids::SpanId;
27use crate::node::FlowNode;
28use crate::node_context::NodeContext;
29use crate::state::{State, StateMerge};
30use crate::workflow_state::{MergeStrategy, WorkflowState};
31
32/// 并行节点 — 同时执行多个分支,通过 MergeStrategy 合并 State。
33///
34/// 每个分支接收相同的 State 快照,独立产生变更。
35/// 所有分支完成后,变更通过 MergeStrategy 合并。
36///
37/// # 泛型参数
38///
39/// - `S` — 类型化状态
40/// - `M` — 合并策略(默认为 [`StateMerge`])
41///
42/// # 示例
43///
44/// ```rust,ignore
45/// let parallel = ParallelNode::builder()
46///     .branch("search", Arc::new(SearchNode::new()))
47///     .branch("analyze", Arc::new(AnalyzeNode::new()))
48///     .build();
49///
50/// graph.node("research", NodeKind::Parallel(parallel));
51/// ```
52pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
53    label: Option<String>,
54    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
55    error_strategy: ParallelErrorStrategy,
56    /// Phantom — M 通过 `M::merge()` 静态调用,不需要实例。
57    _merge_strategy: std::marker::PhantomData<M>,
58}
59
60impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
61    fn clone(&self) -> Self {
62        Self {
63            label: self.label.clone(),
64            branches: self.branches.clone(),
65            error_strategy: self.error_strategy,
66            _merge_strategy: std::marker::PhantomData,
67        }
68    }
69}
70
71/// 并行执行错误处理策略。
72#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub enum ParallelErrorStrategy {
74    /// 任一分支失败 → 立即返回错误(其余分支继续执行但结果被忽略)
75    #[default]
76    FailFast,
77    /// 等待所有分支完成,至少一个失败 → 返回错误但包含成功分支的变更
78    CollectAll,
79}
80
81impl ParallelNode {
82    /// 创建默认构建器(`State` + `StateMerge`)。
83    pub fn builder() -> ParallelNodeBuilder {
84        ParallelNodeBuilder::new()
85    }
86}
87
88impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
89    pub fn with_label(mut self, label: impl Into<String>) -> Self {
90        self.label = Some(label.into());
91        self
92    }
93
94    pub fn branch_count(&self) -> usize {
95        self.branches.len()
96    }
97
98    pub fn branch_names(&self) -> Vec<&str> {
99        self.branches
100            .iter()
101            .map(|(name, _)| name.as_str())
102            .collect()
103    }
104
105    pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
106        self.branches
107            .iter()
108            .map(|(name, node)| (name.as_str(), node))
109    }
110
111    pub fn error_strategy(&self) -> ParallelErrorStrategy {
112        self.error_strategy
113    }
114
115    pub fn label(&self) -> Option<&str> {
116        self.label.as_deref()
117    }
118
119    fn display_name(&self) -> String {
120        self.label.clone().unwrap_or_else(|| "parallel".to_string())
121    }
122}
123
124/// ParallelNode 构建器。
125pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
126    label: Option<String>,
127    branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
128    error_strategy: ParallelErrorStrategy,
129    _phantom: std::marker::PhantomData<M>,
130}
131
132impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
133    fn new() -> Self {
134        Self {
135            label: None,
136            branches: Vec::new(),
137            error_strategy: ParallelErrorStrategy::default(),
138            _phantom: std::marker::PhantomData,
139        }
140    }
141
142    pub fn label(mut self, label: impl Into<String>) -> Self {
143        self.label = Some(label.into());
144        self
145    }
146
147    pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
148        self.branches.push((name.into(), node));
149        self
150    }
151
152    pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
153        self.error_strategy = strategy;
154        self
155    }
156
157    pub fn build(self) -> ParallelNode<S, M> {
158        if self.branches.is_empty() {
159            panic!("ParallelNode must have at least one branch");
160        }
161        ParallelNode {
162            label: self.label,
163            branches: self.branches,
164            error_strategy: self.error_strategy,
165            _merge_strategy: std::marker::PhantomData,
166        }
167    }
168
169    /// 替换合并策略,返回新类型的构建器。
170    pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
171    where
172        NM: MergeStrategy<S>,
173    {
174        ParallelNodeBuilder {
175            label: self.label,
176            branches: self.branches,
177            error_strategy: self.error_strategy,
178            _phantom: std::marker::PhantomData,
179        }
180    }
181}
182
183/// 带 MergeStrategy 的构建器 — 由 ParallelNode::builder() 返回。
184pub struct ParallelNodeBuilderWithMerge<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>(
185    pub ParallelNodeBuilder<S, M>,
186);
187
188impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("ParallelNode")
191            .field("label", &self.label)
192            .field(
193                "branches",
194                &self
195                    .branches
196                    .iter()
197                    .map(|(n, _)| n.as_str())
198                    .collect::<Vec<_>>(),
199            )
200            .field("error_strategy", &self.error_strategy)
201            .finish()
202    }
203}
204
205#[async_trait::async_trait]
206impl<S: WorkflowState, M: MergeStrategy<S>> FlowNode<S> for ParallelNode<S, M> {
207    async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
208        let start_time = Instant::now();
209        let span_id = SpanId::new();
210        let branch_count = self.branches.len();
211
212        ctx.emit_flow_event(FlowEvent::ParallelStarted {
213            node_id: self.display_name(),
214            branch_count,
215            span_id,
216        });
217
218        // Clone typed state for each branch — each branch works on its own copy
219        let base_state = ctx.state().clone();
220        let mut branch_results: Vec<S> = Vec::with_capacity(self.branches.len());
221
222        // Execute branches sequentially (serial fallback)
223        for (name, node) in &self.branches {
224            let branch_start = Instant::now();
225            let branch_span = SpanId::new();
226
227            // Each branch gets its own typed state clone + a forked BranchState
228            let mut branch_state = base_state.clone();
229            let mut branch_bs = ctx.branch().fork();
230            let mut branch_ctx = NodeContext::new(&mut branch_state, &mut branch_bs, None);
231
232            let result = node.execute(&mut branch_ctx).await.map_err(|e| {
233                GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
234                    node: format!("{}/{}", self.display_name(), name),
235                    source: e.into(),
236                })
237            });
238
239            // Consume effects → apply to branch's typed state(零序列化)
240            let effects = branch_ctx.consume_effects();
241            branch_state.apply_batch(effects);
242
243            let branch_duration = branch_start.elapsed();
244            let success = result.is_ok();
245
246            ctx.emit_flow_event(FlowEvent::BranchCompleted {
247                branch_name: name.clone(),
248                node_id: self.display_name(),
249                span_id: branch_span,
250                success,
251                duration: branch_duration,
252            });
253
254            if !success {
255                return result;
256            }
257
258            branch_results.push(branch_state);
259        }
260
261        // Merge all branch states using MergeStrategy — Graph 层并行语义
262        let merged = M::merge(branch_results).map_err(|e| {
263            GraphError::Terminal(crate::error::TerminalError::StateError(format!(
264                "parallel merge conflict: {e}",
265            )))
266        })?;
267
268        // Replace parent state with merged result
269        *ctx.state_mut() = merged;
270
271        ctx.emit_flow_event(FlowEvent::ParallelCompleted {
272            node_id: self.display_name(),
273            span_id,
274            duration: start_time.elapsed(),
275        });
276
277        Ok(())
278    }
279}