Skip to main content

lellm_graph/
llm_node.rs

1//! LLM 相关节点 — AgentNode(完整 ReAct 循环)与 LLMNode(单次调用)。
2
3use async_trait::async_trait;
4use tokio::sync::mpsc;
5
6use crate::error::{GraphError, TerminalError};
7use crate::event::{GraphEvent, NodeEvent, SpanId};
8use crate::node::{GraphNode, NextStep, StreamNodeResult};
9use crate::state::State;
10
11// ─── AgentNode ───────────────────────────────────────────────
12
13/// Agent 节点(包装 ToolUseLoop)。
14///
15/// 执行后将以下字段写回 State(默认 key 可通过 builder 自定义):
16/// - `{prefix}.messages` — 完整对话历史(含工具调用与结果)
17/// - `{prefix}.output` — 最终回复纯文本
18/// - `{prefix}.iterations` — LLM 调用轮次
19/// - `{prefix}.tool_calls` — 工具调用总数
20/// - `{prefix}.stop_reason` — 停止原因("Complete" / "MaxIterations" / …)
21pub struct AgentNode {
22    pub name: String,
23    pub agent: lellm_agent::ToolUseLoop,
24    /// State 中的 key 前缀,默认 "agent"
25    pub prefix: String,
26    /// 是否写回完整 messages(默认 true)
27    pub write_messages: bool,
28    /// 是否写回执行统计(默认 true)
29    pub write_stats: bool,
30}
31
32impl AgentNode {
33    pub fn new(name: impl Into<String>, agent: lellm_agent::ToolUseLoop) -> Self {
34        Self {
35            name: name.into(),
36            agent,
37            prefix: "agent".into(),
38            write_messages: true,
39            write_stats: true,
40        }
41    }
42
43    /// 设置 State key 前缀(默认 "agent")。
44    ///
45    /// 写入的 key 为:`{prefix}.messages`、`{prefix}.output` 等。
46    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
47        self.prefix = prefix.into();
48        self
49    }
50
51    /// 控制是否将完整对话历史写回 State(默认 true)。
52    pub fn with_write_messages(mut self, enabled: bool) -> Self {
53        self.write_messages = enabled;
54        self
55    }
56
57    /// 控制是否写入 iterations / tool_calls / stop_reason(默认 true)。
58    pub fn with_write_stats(mut self, enabled: bool) -> Self {
59        self.write_stats = enabled;
60        self
61    }
62}
63
64/// 将 StopReason 序列化为简短字符串。
65fn stop_reason_str(reason: &lellm_agent::StopReason) -> &'static str {
66    match reason {
67        lellm_agent::StopReason::Complete => "Complete",
68        lellm_agent::StopReason::MaxIterationsReached => "MaxIterations",
69        lellm_agent::StopReason::Cancelled => "Cancelled",
70        lellm_agent::StopReason::OutputBudgetExceeded => "OutputBudget",
71        lellm_agent::StopReason::ReasoningBudgetExceeded => "ReasoningBudget",
72    }
73}
74
75/// 从 ToolUseResult 写入 State 的公共逻辑。
76fn write_agent_result(node: &AgentNode, result: &lellm_agent::ToolUseResult, state: &mut State) {
77    // 提取纯文本输出
78    let text: String = result
79        .response
80        .content
81        .iter()
82        .filter_map(|b| match b {
83            lellm_core::ContentBlock::Text(t) => Some(t.text.as_str()),
84            _ => None,
85        })
86        .collect::<Vec<_>>()
87        .join("");
88
89    if !text.is_empty() {
90        state.insert(
91            format!("{}.output", node.prefix),
92            serde_json::Value::String(text),
93        );
94    }
95
96    // 写回完整对话历史
97    if node.write_messages {
98        state.insert(
99            format!("{}.messages", node.prefix),
100            serde_json::to_value(&result.messages).unwrap_or(serde_json::Value::Null),
101        );
102    }
103
104    // 写回执行统计
105    if node.write_stats {
106        state.insert(
107            format!("{}.iterations", node.prefix),
108            serde_json::json!(result.iterations),
109        );
110        state.insert(
111            format!("{}.tool_calls", node.prefix),
112            serde_json::json!(result.tool_calls_executed),
113        );
114        state.insert(
115            format!("{}.stop_reason", node.prefix),
116            serde_json::json!(stop_reason_str(&result.stop_reason)),
117        );
118    }
119}
120
121/// 从 State 读取输入消息。
122fn read_messages(state: &State, prefix: &str) -> Vec<lellm_core::Message> {
123    let input_key = format!("{}.messages", prefix);
124    let messages = state
125        .get(&input_key)
126        .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
127        .unwrap_or_default();
128
129    // 兼容旧 key "messages"
130    if messages.is_empty() {
131        state
132            .get("messages")
133            .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
134            .unwrap_or_default()
135    } else {
136        messages
137    }
138}
139
140#[async_trait]
141impl GraphNode for AgentNode {
142    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
143        let messages = read_messages(state, &self.prefix);
144
145        let result =
146            self.agent
147                .execute(messages)
148                .await
149                .map_err(|e| GraphError::Terminal(TerminalError::NodeExecutionFailed {
150                    node: self.name.clone(),
151                    source: Box::new(e),
152                }))?;
153
154        write_agent_result(self, &result, state);
155        Ok(NextStep::GoToNext)
156    }
157
158    /// 流式执行 — 使用 ToolUseLoop::execute_stream,转发 AgentEvent。
159    async fn execute_stream(
160        &self,
161        state: &mut State,
162        sink: &mpsc::Sender<GraphEvent>,
163        span_id: SpanId,
164    ) -> Result<StreamNodeResult, GraphError> {
165        let messages = read_messages(state, &self.prefix);
166        let node_name = self.name.clone();
167
168        // 使用 ToolUseLoop 的流式执行
169        let mut stream = self.agent.execute_stream(messages);
170
171        /// 从 AgentEvent 中提取终态信息(避免 move 问题)。
172        struct ExtractedResult {
173            write_result: Option<lellm_agent::ToolUseResult>,
174            error_msg: Option<String>,
175        }
176
177        // 转发 Agent 事件,等待 LoopEnd 或 LoopError
178        while let Some(event) = stream.recv().await {
179            let extracted = match &event {
180                lellm_agent::AgentEvent::LoopEnd { result } => ExtractedResult {
181                    write_result: Some(result.clone()),
182                    error_msg: None,
183                },
184                lellm_agent::AgentEvent::LoopError { error, .. } => ExtractedResult {
185                    write_result: None,
186                    error_msg: Some(error.to_string()),
187                },
188                _ => ExtractedResult {
189                    write_result: None,
190                    error_msg: None,
191                },
192            };
193
194            // 转发到 Graph 层(通过 NodeEvent 中间层)
195            let _ = sink
196                .send(GraphEvent::Node {
197                    span_id,
198                    node_name: node_name.clone(),
199                    event: NodeEvent::Agent(event),
200                })
201                .await;
202
203            // 处理终态
204            if let Some(result) = extracted.write_result {
205                write_agent_result(self, &result, state);
206                return Ok(StreamNodeResult::Done {
207                    next: NextStep::GoToNext,
208                    span_id,
209                });
210            }
211            if let Some(err_msg) = extracted.error_msg {
212                return Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
213                    node: self.name.clone(),
214                    source: err_msg.into(),
215                }));
216            }
217        }
218
219        Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
220            node: self.name.clone(),
221            source: "agent stream closed without terminal event".into(),
222        }))
223    }
224}
225
226// ─── LLMNode (P3: 细粒度手动模式) ──────────────────────────────
227
228/// 单次 LLM 调用节点。
229///
230/// 与 `AgentNode`(完整 ReAct 循环)不同,`LLMNode` 仅执行一次 LLM 调用,
231/// 将响应写入 State。配合 `ToolNode` + `ConditionNode`,可手动构建 ReAct 循环。
232///
233/// ⚠️ **警告:** 使用 `LLMNode` + `ToolNode` 手动构建循环时,你将**失去**以下保护:
234/// - `ParallelSafety` 并发工具执行
235/// - `RetryPolicy` 自动重试
236/// - `FallbackStrategy` 容错路由
237/// - 输出/推理预算保险丝
238/// - `Context Compaction` 上下文压缩
239///
240/// **适用场景(窄但真实):**
241/// 1. 自定义 Agent Loop — 实现非 ReAct 的交互模式(如 multi-agent debate)
242/// 2. 调试/教学 — 逐步观察每轮 LLM 输入输出
243/// 3. 混合编排 — 多个 AgentNode 之间插入自定义处理逻辑
244///
245/// 除非你有明确理由,否则请使用 [`AgentNode`]。
246///
247/// ```rust,ignore
248/// // 手动 ReAct 循环:
249/// GraphBuilder::new("react")
250///     .start("llm")
251///     .node("llm", NodeKind::Llm(LLMNode::new("llm", model)))
252///     .node("tools", NodeKind::Tool(ToolNode::all(tool_executor)))
253///     .node("route", NodeKind::Condition(
254///         ConditionNode::builder("route")
255///             .branch("tools", |s| has_tool_calls(s))
256///             .branch("end", |_| true)
257///             .build()
258///     ))
259///     .edge("llm", "route")
260///     .edge("tools", "llm")
261///     .end("end")
262///     .build();
263/// ```
264pub struct LLMNode {
265    pub name: String,
266    model: lellm_agent::ResolvedModel,
267    system_prompt: Option<String>,
268    messages_key: String,
269}
270
271impl LLMNode {
272    pub fn new(name: impl Into<String>, model: lellm_agent::ResolvedModel) -> Self {
273        Self {
274            name: name.into(),
275            model,
276            system_prompt: None,
277            messages_key: "messages".into(),
278        }
279    }
280
281    /// 设置系统提示。
282    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
283        self.system_prompt = Some(prompt.into());
284        self
285    }
286
287    /// 设置 State 中消息的 key(默认 "messages")。
288    pub fn with_messages_key(mut self, key: impl Into<String>) -> Self {
289        self.messages_key = key.into();
290        self
291    }
292}
293
294#[async_trait]
295impl GraphNode for LLMNode {
296    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
297        // 读取消息
298        let mut messages = state
299            .get(&self.messages_key)
300            .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
301            .unwrap_or_default();
302
303        // 注入系统提示
304        if let Some(ref sys) = self.system_prompt {
305            // 移除已有 system message
306            messages.retain(|m| !matches!(m, lellm_core::Message::System { .. }));
307            messages.insert(
308                0,
309                lellm_core::Message::System {
310                    content: lellm_core::text_block(sys.clone()),
311                },
312            );
313        }
314
315        // 构建请求
316        let request = lellm_core::ChatRequest {
317            model: self.model.model.clone(),
318            messages: messages.clone(),
319            ..Default::default()
320        };
321
322        // 调用 LLM
323        let response = self.model.provider.call(&request).await.map_err(|e| {
324            GraphError::Terminal(TerminalError::NodeExecutionFailed {
325                node: self.name.clone(),
326                source: Box::new(e),
327            })
328        })?;
329
330        // 将响应追加到消息列表
331        let assistant_msg = lellm_core::Message::Assistant {
332            content: response.content,
333        };
334        messages.push(assistant_msg);
335        state.insert(
336            self.messages_key.clone(),
337            serde_json::to_value(&messages).map_err(|e| {
338                GraphError::Terminal(TerminalError::StateError(format!("failed to serialize messages: {e}")))
339            })?,
340        );
341
342        Ok(NextStep::GoToNext)
343    }
344}