Skip to main content

lellm_graph/
node.rs

1//! 节点核心类型与模块。
2//!
3//! - `FlowNode<S>` trait — trait-based 节点,Graph 不知道具体节点类型
4//! - `NextAction` 枚举(v04 统一)
5//! - `NodeKind<S>` 节点类型枚举(Task, Condition, Barrier, Parallel, External)
6//! - `TaskNode<S>`, `ConditionNode<S>`
7//!
8//! v0.4+: 所有节点类型泛型化 `S: WorkflowState`,默认 `S = State`(向后兼容)。
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::error::{GraphError, ObservedError};
15use crate::event::BarrierId;
16use crate::ids::SpanId;
17use crate::node_context::NodeContext;
18use crate::state::{State, StateMerge};
19use crate::workflow_state::{MergeStrategy, WorkflowState};
20
21// ─── 子模块重新导出 ────────────────────────────────────────────
22
23pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
24pub use crate::parallel_node::{
25    ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder, ParallelNodeBuilderWithMerge,
26};
27
28// ─── 核心类型 ──────────────────────────────────────────────────
29
30/// 节点执行后的下一步。
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum NextStep {
33    /// 跳转到指定节点
34    Goto(String),
35    /// 跳转到下一个节点(按拓扑顺序)
36    GoToNext,
37    /// 结束执行
38    End,
39}
40
41/// 节点执行输出 — 修改意图 + 下一步。
42///
43/// @deprecated v0.4+ 使用 NodeContext 替代。保留向后兼容。
44#[derive(Debug)]
45pub struct NodeOutput {
46    /// 状态增量(节点对 State 的修改意图)
47    pub deltas: Vec<crate::delta::StateDelta>,
48    /// 下一步路由
49    pub next: NextStep,
50    /// 节点元数据(可选)
51    pub metadata: Option<crate::node_context::NodeMetadata>,
52}
53
54impl NodeOutput {
55    pub fn new(next: NextStep) -> Self {
56        Self {
57            deltas: Vec::new(),
58            next,
59            metadata: None,
60        }
61    }
62
63    pub fn with_delta(mut self, delta: crate::delta::StateDelta) -> Self {
64        self.deltas.push(delta);
65        self
66    }
67
68    pub fn with_deltas(mut self, deltas: Vec<crate::delta::StateDelta>) -> Self {
69        self.deltas.extend(deltas);
70        self
71    }
72
73    pub fn with_metadata(mut self, metadata: crate::node_context::NodeMetadata) -> Self {
74        self.metadata = Some(metadata);
75        self
76    }
77
78    pub fn with_token_cost(mut self, cost: f64) -> Self {
79        self.metadata
80            .get_or_insert_with(Default::default)
81            .token_cost = cost;
82        self
83    }
84
85    pub fn with_side_effects(mut self) -> Self {
86        self.metadata
87            .get_or_insert_with(Default::default)
88            .has_side_effects = true;
89        self
90    }
91}
92
93/// 节点执行元数据。
94pub use crate::node_context::NodeMetadata;
95
96/// 节点流式执行结果。
97///
98/// @deprecated v0.4+ 使用 NodeContext 替代。保留向后兼容。
99#[derive(Debug)]
100pub enum StreamNodeResult {
101    Continue {
102        deltas: Vec<crate::delta::StateDelta>,
103        next: NextStep,
104        span_id: SpanId,
105        observed: Option<ObservedError>,
106        metadata: Option<NodeMetadata>,
107    },
108    Pause {
109        deltas: Vec<crate::delta::StateDelta>,
110        barrier_id: BarrierId,
111        node_name: String,
112        span_id: SpanId,
113        timeout: Option<std::time::Duration>,
114        default_action: BarrierDefaultAction,
115    },
116    Fallback {
117        deltas: Vec<crate::delta::StateDelta>,
118        reason: String,
119        node_name: String,
120    },
121}
122
123// ─── v04 FlowNode Trait ──────────────────────────────────────
124
125/// v04 节点执行 trait — Context 驱动一切。
126///
127/// 统一原则 — 节点不返回业务数据,只返回 `Result<(), GraphError>`:
128/// - State      → ctx.state() / ctx.state_mut()
129/// - Effects    → ctx.emit_effect()
130/// - Stream     → ctx.emit()
131/// - Metadata   → ctx.set_token_cost()
132/// - Control    → ctx.goto() / ctx.end() / ctx.pause()
133///
134/// # 泛型参数
135///
136/// - `S` — 类型化状态(默认 `State` = HashMap,向后兼容)
137#[async_trait]
138pub trait FlowNode<S: WorkflowState = State>: Send + Sync {
139    /// 执行节点逻辑。
140    async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError>;
141}
142
143/// 节点类型枚举。
144///
145/// # 泛型参数
146///
147/// - `S` — 类型化状态(默认 `State` = HashMap,向后兼容)
148/// - `M` — 并行合并策略(仅 `Parallel` 变体使用,默认 [`StateMerge`])
149pub enum NodeKind<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
150    /// 自定义逻辑
151    Task(TaskNode<S>),
152    /// 条件分支
153    Condition(ConditionNode<S>),
154    /// Human-in-the-loop 审批屏障
155    Barrier(BarrierNode<S>),
156    /// 并行执行多个分支
157    Parallel(ParallelNode<S, M>),
158    /// 外部节点(由 lellm-agent 等 crate 提供)
159    External(Arc<dyn FlowNode<S>>),
160}
161
162impl<S: WorkflowState, M: MergeStrategy<S>> Clone for NodeKind<S, M> {
163    fn clone(&self) -> Self {
164        match self {
165            Self::Task(n) => Self::Task(n.clone()),
166            Self::Condition(n) => Self::Condition(n.clone()),
167            Self::Barrier(n) => Self::Barrier(n.clone()),
168            Self::Parallel(n) => Self::Parallel(n.clone()),
169            Self::External(n) => Self::External(n.clone()),
170        }
171    }
172}
173
174// ─── TaskNode ────────────────────────────────────────────────
175
176/// Task 节点回调类型别名。
177pub type TaskFn<S> = Arc<dyn Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync>;
178
179/// 自定义逻辑节点。
180#[derive(Clone)]
181pub struct TaskNode<S: WorkflowState = State> {
182    pub name: String,
183    pub func: TaskFn<S>,
184}
185
186impl<S: WorkflowState> TaskNode<S> {
187    pub fn new(
188        name: impl Into<String>,
189        func: impl Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync + 'static,
190    ) -> Self {
191        Self {
192            name: name.into(),
193            func: Arc::new(func),
194        }
195    }
196}
197
198#[async_trait]
199impl<S: WorkflowState> FlowNode<S> for TaskNode<S> {
200    async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
201        (self.func)(ctx)
202    }
203}
204
205// ─── ConditionNode ───────────────────────────────────────────
206
207/// 条件分支回调类型别名。
208pub type BranchCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
209
210/// 条件分支节点。
211#[derive(Clone)]
212pub struct ConditionNode<S: WorkflowState = State> {
213    pub name: String,
214    pub branches: Vec<(String, BranchCondition<S>)>,
215}
216
217impl<S: WorkflowState> ConditionNode<S> {
218    pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder<S> {
219        ConditionNodeBuilder {
220            name: name.into(),
221            branches: Vec::new(),
222        }
223    }
224}
225
226/// ConditionNode 构建器。
227pub struct ConditionNodeBuilder<S: WorkflowState = State> {
228    name: String,
229    branches: Vec<(String, BranchCondition<S>)>,
230}
231
232impl<S: WorkflowState> ConditionNodeBuilder<S> {
233    pub fn branch(
234        mut self,
235        target: impl Into<String>,
236        condition: impl Fn(&S) -> bool + Send + Sync + 'static,
237    ) -> Self {
238        self.branches.push((target.into(), Arc::new(condition)));
239        self
240    }
241
242    pub fn build(self) -> ConditionNode<S> {
243        ConditionNode {
244            name: self.name,
245            branches: self.branches,
246        }
247    }
248}
249
250#[async_trait]
251impl<S: WorkflowState> FlowNode<S> for ConditionNode<S> {
252    async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
253        let state = ctx.state();
254        for (target, condition) in &self.branches {
255            if condition(state) {
256                ctx.goto(target);
257                return Ok(());
258            }
259        }
260        Ok(())
261    }
262}
263
264// ─── NodeKind FlowNode impl ──────────────────────────────────
265
266#[async_trait]
267impl<S: WorkflowState, M: MergeStrategy<S>> FlowNode<S> for NodeKind<S, M> {
268    async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
269        match self {
270            Self::Task(n) => n.execute(ctx).await,
271            Self::Condition(n) => n.execute(ctx).await,
272            Self::Barrier(n) => n.execute(ctx).await,
273            Self::Parallel(n) => n.execute(ctx).await,
274            Self::External(n) => n.execute(ctx).await,
275        }
276    }
277}
278
279// ─── Backward Compatibility Alias ─────────────────────────────
280
281/// 向后兼容别名 — `GraphNode` → `FlowNode`。
282pub type GraphNode<S> = dyn FlowNode<S>;