Skip to main content

lellm_graph/
parallel_node.rs

1//! ParallelNode — 并行执行多个分支,合并 StateDelta。
2//!
3//! 执行模型:
4//! ```text
5//! State
6//!  ↓
7//! fork (ParallelNode)
8//!  ↓
9//! Branch A     Branch B     Branch C
10//!  ↓            ↓            ↓
11//! StateDelta   StateDelta   StateDelta
12//!  ↓            ↓            ↓
13//! ReducerRegistry.merge_deltas()
14//!  ↓
15//! Merged Deltas → apply to State
16//! ```
17//!
18//! 每个分支接收相同的 State 快照,独立产生 StateDelta。
19//! 所有 Delta 收集后通过 `ReducerRegistry::merge_deltas()` 合并。
20//! 未注册 Reducer 的 key 发生多 writer → `StateConflict` 错误。
21
22use std::sync::Arc;
23
24use crate::error::GraphError;
25use crate::node::{FlowNode, NextStep, NodeOutput};
26use crate::state::State;
27
28/// 并行节点 — 同时执行多个分支,合并 StateDelta。
29///
30/// 每个分支接收相同的 State 快照,独立产生 StateDelta。
31/// 所有分支完成后,Delta 通过 `ReducerRegistry::merge_deltas()` 合并到 State。
32///
33/// # 示例
34///
35/// ```rust,ignore
36/// let parallel = ParallelNode::builder()
37///     .branch("search", Arc::new(SearchNode::new()))
38///     .branch("analyze", Arc::new(AnalyzeNode::new()))
39///     .build();
40///
41/// graph.node("research", NodeKind::Parallel(parallel));
42/// ```
43#[derive(Clone)]
44pub struct ParallelNode {
45    /// 调试标签(可选)
46    label: Option<String>,
47    /// 并行分支 — (名称, 节点)
48    branches: Vec<(String, Arc<dyn FlowNode>)>,
49    /// 错误处理策略
50    error_strategy: ParallelErrorStrategy,
51}
52
53/// 并行执行错误处理策略。
54#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
55pub enum ParallelErrorStrategy {
56    /// 任一分支失败 → 立即返回错误(其余分支继续执行但结果被忽略)
57    #[default]
58    FailFast,
59    /// 等待所有分支完成,至少一个失败 → 返回错误但包含成功分支的 Delta
60    CollectAll,
61}
62
63impl ParallelNode {
64    /// 创建构建器。
65    pub fn builder() -> ParallelNodeBuilder {
66        ParallelNodeBuilder::new()
67    }
68
69    /// 设置调试标签。
70    pub fn with_label(mut self, label: impl Into<String>) -> Self {
71        self.label = Some(label.into());
72        self
73    }
74
75    /// 获取分支数量。
76    pub fn branch_count(&self) -> usize {
77        self.branches.len()
78    }
79
80    /// 获取分支名称列表。
81    pub fn branch_names(&self) -> Vec<&str> {
82        self.branches
83            .iter()
84            .map(|(name, _)| name.as_str())
85            .collect()
86    }
87
88    /// 迭代所有分支(名称, 节点)引用。
89    pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode>)> {
90        self.branches
91            .iter()
92            .map(|(name, node)| (name.as_str(), node))
93    }
94
95    /// 获取错误处理策略。
96    pub fn error_strategy(&self) -> ParallelErrorStrategy {
97        self.error_strategy
98    }
99
100    /// 获取标签。
101    pub fn label(&self) -> Option<&str> {
102        self.label.as_deref()
103    }
104
105    /// 串行执行所有分支(用于阻塞模式 fallback)。
106    ///
107    /// ⚠️ 此方法顺序执行各分支,不发挥并行优势。
108    /// 真正的并行执行由 `Executor::handle_parallel()` 完成。
109    pub async fn execute_sequential(&self, state: &State) -> Result<NodeOutput, GraphError> {
110        let mut all_deltas = Vec::new();
111
112        for (name, branch) in &self.branches {
113            let output = branch.execute(state).await.map_err(|e| {
114                GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
115                    node: format!("{}/{}", self.display_name(), name),
116                    source: e.into(),
117                })
118            })?;
119            all_deltas.extend(output.deltas);
120        }
121
122        Ok(NodeOutput {
123            deltas: all_deltas,
124            next: NextStep::GoToNext,
125            metadata: None,
126        })
127    }
128
129    fn display_name(&self) -> String {
130        self.label.clone().unwrap_or_else(|| "parallel".to_string())
131    }
132}
133
134/// ParallelNode 构建器。
135pub struct ParallelNodeBuilder {
136    label: Option<String>,
137    branches: Vec<(String, Arc<dyn FlowNode>)>,
138    error_strategy: ParallelErrorStrategy,
139}
140
141impl ParallelNodeBuilder {
142    fn new() -> Self {
143        Self {
144            label: None,
145            branches: Vec::new(),
146            error_strategy: ParallelErrorStrategy::default(),
147        }
148    }
149
150    /// 设置调试标签。
151    pub fn label(mut self, label: impl Into<String>) -> Self {
152        self.label = Some(label.into());
153        self
154    }
155
156    /// 添加并行分支。
157    ///
158    /// - `name` — 分支名称(用于调试和事件标识)
159    /// - `node` — 分支执行的节点
160    pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode>) -> Self {
161        self.branches.push((name.into(), node));
162        self
163    }
164
165    /// 设置错误处理策略。
166    pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
167        self.error_strategy = strategy;
168        self
169    }
170
171    /// 构建 ParallelNode。
172    ///
173    /// # Panics
174    ///
175    /// 如果没有添加任何分支,则 panic。
176    pub fn build(self) -> ParallelNode {
177        if self.branches.is_empty() {
178            panic!("ParallelNode must have at least one branch");
179        }
180        ParallelNode {
181            label: self.label,
182            branches: self.branches,
183            error_strategy: self.error_strategy,
184        }
185    }
186}
187
188impl std::fmt::Debug for ParallelNode {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("ParallelNode")
191            .field("label", &self.label)
192            .field(
193                "branches",
194                &self
195                    .branches
196                    .iter()
197                    .map(|(n, _)| n.as_str())
198                    .collect::<Vec<_>>(),
199            )
200            .field("error_strategy", &self.error_strategy)
201            .finish()
202    }
203}