Skip to main content

lellm_graph/
node.rs

1//! 节点核心类型与模块。
2//!
3//! - `FlowNode` trait — trait-based 节点,Graph 不知道具体节点类型
4//! - `NextStep` 枚举,`StreamNodeResult` 枚举
5//! - `NodeKind` 节点类型枚举(Task, Condition, Barrier)
6//! - `TaskNode`, `ConditionNode`
7//!
8//! AgentNode → AgentFlowNode(由 lellm-agent 提供,实现 FlowNode trait)
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::delta::StateDelta;
15use crate::error::{GraphError, ObservedError};
16use crate::event::{BarrierId, GraphEvent};
17use crate::ids::SpanId;
18use crate::state::State;
19
20// ─── 子模块重新导出 ────────────────────────────────────────────
21
22pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
23pub use crate::parallel_node::{ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder};
24
25// ─── 核心类型 ──────────────────────────────────────────────────
26
27/// 节点执行后的下一步。
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum NextStep {
30    /// 跳转到指定节点
31    Goto(String),
32    /// 跳转到下一个节点(按拓扑顺序)
33    GoToNext,
34    /// 结束执行
35    End,
36}
37
38/// 节点执行输出 — 修改意图 + 下一步。
39///
40/// 节点不再直接修改 State(`&mut State`),而是输出 `Vec<StateDelta>`。
41/// Executor 收集所有 Delta 后统一 apply 到 State。
42#[derive(Debug)]
43pub struct NodeOutput {
44    /// 状态增量(节点对 State 的修改意图)
45    pub deltas: Vec<StateDelta>,
46    /// 下一步路由
47    pub next: NextStep,
48    /// 节点元数据(可选 — 用于 Adaptive Checkpoint 等)
49    pub metadata: Option<NodeMetadata>,
50}
51
52/// 节点执行元数据 — 提供给 Executor 的额外信息。
53#[derive(Debug, Clone, Default)]
54pub struct NodeMetadata {
55    /// Token 消耗成本(0.0 表示无 LLM 调用)
56    pub token_cost: f64,
57    /// 是否有外部副作用(如部署、发送消息)
58    pub has_side_effects: bool,
59}
60
61impl NodeOutput {
62    /// 创建无 Delta 的输出。
63    pub fn new(next: NextStep) -> Self {
64        Self {
65            deltas: Vec::new(),
66            next,
67            metadata: None,
68        }
69    }
70
71    /// 追加一个 Delta。
72    pub fn with_delta(mut self, delta: StateDelta) -> Self {
73        self.deltas.push(delta);
74        self
75    }
76
77    /// 追加多个 Delta。
78    pub fn with_deltas(mut self, deltas: Vec<StateDelta>) -> Self {
79        self.deltas.extend(deltas);
80        self
81    }
82
83    /// 设置节点元数据。
84    pub fn with_metadata(mut self, metadata: NodeMetadata) -> Self {
85        self.metadata = Some(metadata);
86        self
87    }
88
89    /// 设置 token 成本。
90    pub fn with_token_cost(mut self, cost: f64) -> Self {
91        self.metadata
92            .get_or_insert_with(Default::default)
93            .token_cost = cost;
94        self
95    }
96
97    /// 标记有副作用。
98    pub fn with_side_effects(mut self) -> Self {
99        self.metadata
100            .get_or_insert_with(Default::default)
101            .has_side_effects = true;
102        self
103    }
104}
105
106/// 节点流式执行结果。
107#[derive(Debug)]
108pub enum StreamNodeResult {
109    /// 节点正常完成(统一 Done + Observed)
110    Continue {
111        /// 状态增量
112        deltas: Vec<StateDelta>,
113        /// 下一步
114        next: NextStep,
115        /// 执行实例 ID
116        span_id: SpanId,
117        /// 可选的观测错误(不影响 control flow)
118        observed: Option<ObservedError>,
119        /// 节点元数据(可选 — 用于 Adaptive Checkpoint 等)
120        metadata: Option<NodeMetadata>,
121    },
122    /// Barrier 暂停,等待外部决策
123    Pause {
124        /// 状态增量(Barrier 进入等待前的修改)
125        deltas: Vec<StateDelta>,
126        /// Barrier 审批请求 ID
127        barrier_id: BarrierId,
128        /// 节点名称
129        node_name: String,
130        /// 执行实例 ID
131        span_id: SpanId,
132        /// 超时时间(None = 无限等待)
133        timeout: Option<std::time::Duration>,
134        /// 超时默认行为
135        default_action: BarrierDefaultAction,
136    },
137    /// 节点主动声明走备用路径(控制流,非错误)。
138    ///
139    /// 与 `GraphError::Terminal` 不同:Fallback 是节点主动声明的降级策略,
140    /// executor 根据 fallback 边路由到备用节点。
141    Fallback {
142        /// 状态增量(Fallback 前的修改)
143        deltas: Vec<StateDelta>,
144        /// 降级原因
145        reason: String,
146        /// 节点名称
147        node_name: String,
148    },
149}
150
151/// 节点执行 trait — trait-based 设计。
152///
153/// Graph 只知道 `dyn FlowNode`,不知道 `AgentNode`、`ToolNode` 等具体类型。
154/// `AgentFlowNode` 由 `lellm-agent` crate 提供。
155///
156/// **节点不修改 State。** 节点读取 `&State`,输出 `NodeOutput { deltas, next }`。
157/// Executor 收集 Delta 后统一 apply 到 State。
158#[async_trait]
159pub trait FlowNode: Send + Sync {
160    /// 执行节点逻辑(阻塞模式)。
161    ///
162    /// - `state` — 只读访问当前 State
163    /// - 返回 `NodeOutput { deltas, next }` — 修改意图 + 下一步路由
164    async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError>;
165
166    /// 执行节点逻辑(流式模式),将内部事件转发到 channel。
167    ///
168    /// - `state` — 只读访问当前 State
169    /// - `sink` — 事件输出 channel
170    /// - `span_id` — 执行实例 ID(由 executor 生成)
171    ///
172    /// 默认实现直接调用 `execute`,返回 `StreamNodeResult::Continue`。
173    /// BarrierNode 覆写此方法以返回 `StreamNodeResult::Pause`。
174    async fn execute_stream(
175        &self,
176        state: &State,
177        _sink: &tokio::sync::mpsc::Sender<GraphEvent>,
178        span_id: SpanId,
179    ) -> Result<StreamNodeResult, GraphError> {
180        let output = self.execute(state).await?;
181        Ok(StreamNodeResult::Continue {
182            deltas: output.deltas,
183            next: output.next,
184            span_id,
185            observed: None,
186            metadata: output.metadata,
187        })
188    }
189
190    /// 节点元数据提示 — 静态声明节点的执行特征。
191    ///
192    /// 用于 Adaptive Checkpoint 的默认值。
193    /// NodeOutput.metadata 会覆盖此值。
194    ///
195    /// **四层优先级:**
196    /// 1. `NodeOutput.metadata` — 运行时实际值(最高优先级)
197    /// 2. `metadata_hint()` — 节点静态声明
198    /// 3. `NodeKind` 推断 — Executor 根据类型推断
199    /// 4. `NodeMetadata::default()` — 兜底值
200    fn metadata_hint(&self) -> NodeMetadata {
201        NodeMetadata::default()
202    }
203}
204
205/// 节点类型枚举。
206///
207/// 只包含 Graph 内置节点类型。Agent/LLM/Tool 节点由外部 crate 提供。
208///
209/// 注意:External 使用 Arc 以支持 Clone(Graph 需要 Clone 来构建)。
210#[derive(Clone)]
211pub enum NodeKind {
212    /// 自定义逻辑
213    Task(TaskNode),
214    /// 条件分支
215    Condition(ConditionNode),
216    /// Human-in-the-loop 审批屏障(仅流式模式)
217    Barrier(BarrierNode),
218    /// 并行执行多个分支,合并 StateDelta
219    Parallel(ParallelNode),
220    /// 外部节点(由 lellm-agent 等 crate 提供)
221    ///
222    /// 使用 `Arc<dyn FlowNode>` 让 Graph 不知道具体节点类型,同时支持 Clone。
223    External(std::sync::Arc<dyn FlowNode>),
224}
225
226// ─── TaskNode ────────────────────────────────────────────────
227
228/// Task 节点回调类型别名。
229///
230/// 闭包接收只读 `&State`,返回 `Vec<StateDelta>` 作为修改意图。
231/// Arc 包装以支持 Clone。
232pub type TaskFn = Arc<dyn Fn(&State) -> Result<Vec<StateDelta>, GraphError> + Send + Sync>;
233
234/// 条件分支回调类型别名。
235/// Arc 包装以支持 Clone。
236pub type BranchCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
237
238/// 自定义逻辑节点。
239#[derive(Clone)]
240pub struct TaskNode {
241    pub name: String,
242    pub func: TaskFn,
243}
244
245impl TaskNode {
246    pub fn new(
247        name: impl Into<String>,
248        func: impl Fn(&State) -> Result<Vec<StateDelta>, GraphError> + Send + Sync + 'static,
249    ) -> Self {
250        Self {
251            name: name.into(),
252            func: Arc::new(func),
253        }
254    }
255}
256
257#[async_trait]
258impl FlowNode for TaskNode {
259    async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
260        let deltas = (self.func)(state)?;
261        Ok(NodeOutput {
262            deltas,
263            next: NextStep::GoToNext,
264            metadata: None,
265        })
266    }
267
268    fn metadata_hint(&self) -> NodeMetadata {
269        // TaskNode 默认轻量级(纯 CPU 计算)
270        NodeMetadata {
271            token_cost: 0.0,
272            has_side_effects: false,
273        }
274    }
275}
276
277// ─── ConditionNode ───────────────────────────────────────────
278
279/// 条件分支节点。
280///
281/// 按声明顺序求值分支条件,返回第一个匹配分支的 `NextStep::Goto(target)`。
282/// 无匹配时返回 `NextStep::GoToNext`,由 Graph 层的 `edge_fallback` 处理兜底路由。
283#[derive(Clone)]
284pub struct ConditionNode {
285    pub name: String,
286    pub branches: Vec<(String, BranchCondition)>,
287}
288
289impl ConditionNode {
290    pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder {
291        ConditionNodeBuilder {
292            name: name.into(),
293            branches: Vec::new(),
294        }
295    }
296}
297
298/// ConditionNode 构建器。
299pub struct ConditionNodeBuilder {
300    name: String,
301    branches: Vec<(String, BranchCondition)>,
302}
303
304impl ConditionNodeBuilder {
305    pub fn branch(
306        mut self,
307        target: impl Into<String>,
308        condition: impl Fn(&State) -> bool + Send + Sync + 'static,
309    ) -> Self {
310        self.branches.push((target.into(), Arc::new(condition)));
311        self
312    }
313
314    pub fn build(self) -> ConditionNode {
315        ConditionNode {
316            name: self.name,
317            branches: self.branches,
318        }
319    }
320}
321
322#[async_trait]
323impl FlowNode for ConditionNode {
324    async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
325        for (target, condition) in &self.branches {
326            if condition(state) {
327                return Ok(NodeOutput::new(NextStep::Goto(target.clone())));
328            }
329        }
330        // 无匹配 → GoToNext,由 Graph 层 edge_fallback 处理兜底
331        Ok(NodeOutput::new(NextStep::GoToNext))
332    }
333
334    fn metadata_hint(&self) -> NodeMetadata {
335        // ConditionNode 是纯逻辑判断,轻量级
336        NodeMetadata {
337            token_cost: 0.0,
338            has_side_effects: false,
339        }
340    }
341}
342
343// ─── NodeKind FlowNode impl ──────────────────────────────────
344
345#[async_trait]
346impl FlowNode for NodeKind {
347    async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
348        match self {
349            Self::Task(n) => n.execute(state).await,
350            Self::Condition(n) => n.execute(state).await,
351            Self::Barrier(n) => n.execute(state).await,
352            Self::Parallel(n) => n.execute_sequential(state).await,
353            Self::External(n) => n.execute(state).await,
354        }
355    }
356
357    async fn execute_stream(
358        &self,
359        state: &State,
360        sink: &tokio::sync::mpsc::Sender<GraphEvent>,
361        span_id: SpanId,
362    ) -> Result<StreamNodeResult, GraphError> {
363        match self {
364            Self::Task(n) => n.execute_stream(state, sink, span_id).await,
365            Self::Condition(n) => n.execute_stream(state, sink, span_id).await,
366            Self::Barrier(n) => n.execute_stream(state, sink, span_id).await,
367            Self::Parallel(_) => {
368                // ⚠️ Parallel 节点应由 Executor::handle_parallel() 特殊处理。
369                // 此处提供串行 fallback,确保直接调用 execute_stream 也能工作。
370                let output = self.execute(state).await?;
371                Ok(StreamNodeResult::Continue {
372                    deltas: output.deltas,
373                    next: output.next,
374                    span_id,
375                    observed: None,
376                    metadata: output.metadata,
377                })
378            }
379            Self::External(n) => n.execute_stream(state, sink, span_id).await,
380        }
381    }
382}
383
384// ─── Backward Compatibility Alias ─────────────────────────────
385
386/// 向后兼容别名 — `GraphNode` → `FlowNode`。
387///
388/// v0.2 代码使用 `GraphNode`,v0.3 统一为 `FlowNode`。
389pub type GraphNode = dyn FlowNode;