Skip to main content

lellm_graph/
node.rs

1//! 节点核心类型与模块。
2//!
3//! - `GraphNode` trait, `NextStep` 枚举
4//! - `NodeKind` 节点类型枚举
5//! - `TaskNode`, `ConditionNode`, `LoopNode`, `SubGraph`, `BarrierNode`
6//! - 重新导出 `llm_node`, `tool_node`, `barrier_node` 模块
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::error::{GraphError, ObservedError, TerminalError};
13use crate::event::{BarrierId, GraphEvent, SpanId};
14use crate::graph::Edge;
15use crate::state::State;
16
17// ─── 子模块重新导出 ────────────────────────────────────────────
18
19pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
20pub use crate::llm_node::{AgentNode, LLMNode};
21pub use crate::tool_node::ToolNode;
22
23// ─── 核心类型 ──────────────────────────────────────────────────
24
25/// 节点执行后的下一步。
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum NextStep {
28    /// 跳转到指定节点
29    Goto(String),
30    /// 跳转到下一个节点(按拓扑顺序)
31    GoToNext,
32    /// 结束执行
33    End,
34}
35
36/// 节点流式执行结果。
37#[derive(Debug)]
38pub enum StreamNodeResult {
39    /// 节点正常完成
40    Done {
41        /// 下一步
42        next: NextStep,
43        /// 执行实例 ID(由调用方传入)
44        span_id: SpanId,
45    },
46    /// Barrier 暂停,等待外部决策
47    BarrierPaused {
48        /// Barrier 审批请求 ID(由 executor 生成)
49        barrier_id: BarrierId,
50        /// 节点名称
51        node_name: String,
52        /// 执行实例 ID
53        span_id: SpanId,
54        /// 超时时间(None = 无限等待)
55        timeout: Option<std::time::Duration>,
56        /// 超时默认行为
57        default_action: crate::barrier_node::BarrierDefaultAction,
58    },
59    /// 观测错误 — 仅事件,不影响 control flow。
60    ///
61    /// 节点通过此变体声明式地报告非致命异常,executor 负责:
62    /// 1. 发送 `GraphEvent::ObservedError` 事件
63    /// 2. 按 `next` 继续推进控制流
64    Observed {
65        /// 观测错误
66        error: ObservedError,
67        /// 下一步
68        next: NextStep,
69        /// 执行实例 ID
70        span_id: SpanId,
71    },
72}
73
74/// 节点执行 trait。
75#[async_trait]
76pub trait GraphNode: Send + Sync {
77    /// 执行节点逻辑(阻塞模式)。
78    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError>;
79
80    /// 执行节点逻辑(流式模式),将内部事件转发到 channel。
81    ///
82    /// - `sink` — 事件输出 channel
83    /// - `span_id` — 执行实例 ID(由 executor 生成)
84    ///
85    /// 默认实现直接调用 `execute`,返回 `StreamNodeResult::Done`。
86    /// AgentNode 覆写此方法以转发 AgentEvent。
87    /// BarrierNode 覆写此方法以返回 `StreamNodeResult::BarrierPaused`。
88    async fn execute_stream(
89        &self,
90        state: &mut State,
91        _sink: &tokio::sync::mpsc::Sender<GraphEvent>,
92        span_id: SpanId,
93    ) -> Result<StreamNodeResult, GraphError> {
94        let next = self.execute(state).await?;
95        Ok(StreamNodeResult::Done { next, span_id })
96    }
97}
98
99/// 节点类型枚举。
100pub enum NodeKind {
101    /// 自定义逻辑
102    Task(TaskNode),
103    /// Agent(包装 ToolUseLoop)
104    Agent(Box<AgentNode>),
105    /// 工具调用
106    Tool(ToolNode),
107    /// 条件分支
108    Condition(ConditionNode),
109    /// 循环容器
110    Loop(Box<LoopNode>),
111    /// Human-in-the-loop 审批屏障(仅流式模式)
112    Barrier(BarrierNode),
113}
114
115// ─── TaskNode ────────────────────────────────────────────────
116
117/// Task 节点回调类型别名。
118/// Arc 包装以支持 Clone。
119pub type TaskFn = Arc<dyn Fn(&mut State) -> Result<(), GraphError> + Send + Sync>;
120
121/// 条件分支回调类型别名。
122/// Arc 包装以支持 Clone。
123pub type BranchCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
124
125/// 自定义逻辑节点。
126pub struct TaskNode {
127    pub name: String,
128    pub func: TaskFn,
129}
130
131impl TaskNode {
132    pub fn new(
133        name: impl Into<String>,
134        func: impl Fn(&mut State) -> Result<(), GraphError> + Send + Sync + 'static,
135    ) -> Self {
136        Self {
137            name: name.into(),
138            func: Arc::new(func),
139        }
140    }
141}
142
143#[async_trait]
144impl GraphNode for TaskNode {
145    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
146        (self.func)(state)?;
147        Ok(NextStep::GoToNext)
148    }
149}
150
151// ─── ConditionNode ───────────────────────────────────────────
152
153/// 条件分支节点。
154pub struct ConditionNode {
155    pub name: String,
156    pub branches: Vec<(String, BranchCondition)>,
157    /// 兜底目标 — 当所有 branch 条件均不匹配时,跳转到此节点。
158    /// 未设置时,无匹配则返回 TerminalError。
159    pub otherwise_target: Option<String>,
160}
161
162impl ConditionNode {
163    pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder {
164        ConditionNodeBuilder {
165            name: name.into(),
166            branches: Vec::new(),
167            otherwise_target: None,
168        }
169    }
170}
171
172/// ConditionNode 构建器。
173pub struct ConditionNodeBuilder {
174    name: String,
175    branches: Vec<(String, BranchCondition)>,
176    otherwise_target: Option<String>,
177}
178
179impl ConditionNodeBuilder {
180    pub fn branch(
181        mut self,
182        target: impl Into<String>,
183        condition: impl Fn(&State) -> bool + Send + Sync + 'static,
184    ) -> Self {
185        self.branches.push((target.into(), Arc::new(condition)));
186        self
187    }
188
189    /// 设置兜底目标 — 当所有 branch 条件均不匹配时,跳转到此节点。
190    ///
191    /// 解决"边有 fallback,节点没有"的概念不一致问题。
192    ///
193    /// ```rust,ignore
194    /// ConditionNode::builder("route")
195    ///     .branch("fast_path", |s| s.get("score").map(|v| v.as_u64().unwrap_or(0) >= 80))
196    ///     .branch("slow_path", |s| s.get("score").map(|v| v.as_u64().unwrap_or(0) >= 50))
197    ///     .otherwise("default")  // 兜底
198    ///     .build()
199    /// ```
200    pub fn otherwise(mut self, target: impl Into<String>) -> Self {
201        self.otherwise_target = Some(target.into());
202        self
203    }
204
205    pub fn build(self) -> ConditionNode {
206        ConditionNode {
207            name: self.name,
208            branches: self.branches,
209            otherwise_target: self.otherwise_target,
210        }
211    }
212}
213
214#[async_trait]
215impl GraphNode for ConditionNode {
216    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
217        for (target, condition) in &self.branches {
218            if condition(state) {
219                return Ok(NextStep::Goto(target.clone()));
220            }
221        }
222        // 有兜底目标 → 直接跳转
223        if let Some(ref target) = self.otherwise_target {
224            return Ok(NextStep::Goto(target.clone()));
225        }
226        Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
227            node: self.name.clone(),
228            source: "no matching branch and no otherwise target".into(),
229        }))
230    }
231}
232
233// ─── SubGraph ────────────────────────────────────────────────
234
235/// 子图(LoopNode 的执行单元)。
236///
237/// **注意:** SubGraph 内的节点不支持按名跳转(`NextStep::Goto`),
238/// 因为节点没有名字。需要条件回跳请使用外层 Graph 的 `edge_if`。
239#[derive(Default)]
240pub struct SubGraph {
241    pub nodes: Vec<Arc<dyn GraphNode>>,
242    pub edges: Vec<Edge>,
243}
244
245impl SubGraph {
246    pub fn new() -> Self {
247        Self::default()
248    }
249
250    /// 线性执行子图内所有节点,尊重 `NextStep` 语义。
251    ///
252    /// - `GoToNext` — 继续遍历下一个节点
253    /// - `End` — 提前退出子图(后续节点不再执行)
254    /// - `Goto(target)` — 报错(SubGraph 不支持按名跳转)
255    pub async fn execute(&self, state: &mut State) -> Result<(), GraphError> {
256        for node in &self.nodes {
257            match node.execute(state).await? {
258                NextStep::GoToNext => {
259                    // 继续线性遍历
260                }
261                NextStep::End => {
262                    // 提前退出子图
263                    break;
264                }
265                NextStep::Goto(target) => {
266                    return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
267                        "SubGraph does not support Goto(\"{}\"). Use Graph::edge_if for conditional jumps.",
268                        target
269                    ))));
270                }
271            }
272        }
273        Ok(())
274    }
275}
276
277// ─── LoopNode ────────────────────────────────────────────────
278
279/// 循环容器 — 可选的高级语法糖。
280///
281/// **推荐使用 `edge_if` 实现简单回跳。** LoopNode 适用于需要独立迭代计数
282/// 和独立熔断保护的封装场景(例如并行子任务中的局部循环)。
283///
284/// ```rust,ignore
285/// // 推荐:直接用有环图 + edge_if(更直观)
286/// GraphBuilder::new("retry")
287///     .edge_if("check", "agent", |s| !s.satisfied)  // 回跳
288///     .edge("check", "output")                       // 通过
289///
290/// // LoopNode:需要独立 max_iterations 时使用
291/// LoopNode::new("loop", SubGraph { ... }, |s| !s.satisfied, max_iterations: 5)
292/// ```
293pub struct LoopNode {
294    pub name: String,
295    pub body: SubGraph,
296    pub continue_condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
297    pub max_iterations: usize,
298}
299
300impl LoopNode {
301    pub fn new(
302        name: impl Into<String>,
303        body: SubGraph,
304        continue_condition: impl Fn(&State) -> bool + Send + Sync + 'static,
305        max_iterations: usize,
306    ) -> Self {
307        Self {
308            name: name.into(),
309            body,
310            continue_condition: Arc::new(continue_condition),
311            max_iterations,
312        }
313    }
314}
315
316#[async_trait]
317impl GraphNode for LoopNode {
318    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
319        for i in 0..self.max_iterations {
320            tracing::debug!(
321                loop_name = %self.name,
322                iteration = i + 1,
323                max = self.max_iterations,
324                "executing loop body"
325            );
326
327            self.body.execute(state).await?;
328
329            if !(self.continue_condition)(state) {
330                tracing::debug!(
331                    loop_name = %self.name,
332                    iterations = i + 1,
333                    "loop condition met, exiting"
334                );
335                return Ok(NextStep::GoToNext);
336            }
337        }
338
339        Err(GraphError::Terminal(TerminalError::LoopLimitExceeded {
340            limit: self.max_iterations,
341        }))
342    }
343}
344
345// ─── NodeKind GraphNode impl ─────────────────────────────────
346
347#[async_trait]
348impl GraphNode for NodeKind {
349    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
350        match self {
351            Self::Task(n) => n.execute(state).await,
352            Self::Agent(n) => n.execute(state).await,
353            Self::Tool(n) => n.execute(state).await,
354            Self::Condition(n) => n.execute(state).await,
355            Self::Loop(n) => n.execute(state).await,
356            Self::Barrier(n) => n.execute(state).await,
357        }
358    }
359
360    async fn execute_stream(
361        &self,
362        state: &mut State,
363        sink: &tokio::sync::mpsc::Sender<GraphEvent>,
364        span_id: SpanId,
365    ) -> Result<StreamNodeResult, GraphError> {
366        match self {
367            Self::Task(n) => n.execute_stream(state, sink, span_id).await,
368            Self::Agent(n) => n.execute_stream(state, sink, span_id).await,
369            Self::Tool(n) => n.execute_stream(state, sink, span_id).await,
370            Self::Condition(n) => n.execute_stream(state, sink, span_id).await,
371            Self::Loop(n) => n.execute_stream(state, sink, span_id).await,
372            Self::Barrier(n) => n.execute_stream(state, sink, span_id).await,
373        }
374    }
375}