Skip to main content

lellm_graph/
tool_node.rs

1//! 工具执行节点。
2
3use async_trait::async_trait;
4
5use crate::error::{GraphError, TerminalError};
6use crate::node::GraphNode;
7use crate::node::NextStep;
8use crate::state::State;
9
10/// 工具执行节点。
11///
12/// 读取 State 中最后一条 Assistant 消息的 `tool_calls`,
13/// 执行所有工具调用,将 `ToolResult` 消息追加到消息列表。
14///
15/// ⚠️ **警告:** 此节点是 `LLMNode` 的配套组件,用于手动构建 ReAct 循环。
16/// 与 [`AgentNode`](crate::AgentNode) 不同,**不提供** `ParallelSafety` 并发执行、
17/// `RetryPolicy` 自动重试、`FallbackStrategy` 容错等保护。
18///
19/// 除非你有明确理由需要手动控制每轮 LLM 调用,否则请使用 [`AgentNode`](crate::AgentNode)。
20pub struct ToolNode {
21    pub name: String,
22    executor: lellm_agent::ToolExecutor,
23    messages_key: String,
24}
25
26impl ToolNode {
27    /// 创建包含所有注册工具的 ToolNode。
28    pub fn all(executor: lellm_agent::ToolExecutor) -> Self {
29        Self {
30            name: "tools".into(),
31            executor,
32            messages_key: "messages".into(),
33        }
34    }
35
36    /// 创建指定名称的 ToolNode。
37    pub fn new(name: impl Into<String>, executor: lellm_agent::ToolExecutor) -> Self {
38        Self {
39            name: name.into(),
40            executor,
41            messages_key: "messages".into(),
42        }
43    }
44
45    /// 设置 State 中消息的 key(默认 "messages")。
46    pub fn with_messages_key(mut self, key: impl Into<String>) -> Self {
47        self.messages_key = key.into();
48        self
49    }
50}
51
52#[async_trait]
53impl GraphNode for ToolNode {
54    async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
55        let messages = state
56            .get(&self.messages_key)
57            .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
58            .unwrap_or_default();
59
60        if messages.is_empty() {
61            return Ok(NextStep::GoToNext);
62        }
63
64        // 获取最后一条消息的 tool_calls
65        let last_msg = messages.last().ok_or(GraphError::Terminal(TerminalError::StateError(
66            "no messages to extract tool_calls from".into(),
67        )))?;
68
69        let tool_calls = match last_msg {
70            lellm_core::Message::Assistant { content } => content
71                .iter()
72                .filter_map(|b| match b {
73                    lellm_core::ContentBlock::ToolCall(tc) => Some(tc.clone()),
74                    _ => None,
75                })
76                .collect::<Vec<_>>(),
77            _ => Vec::new(),
78        };
79
80        if tool_calls.is_empty() {
81            return Ok(NextStep::GoToNext);
82        }
83
84        // 执行所有工具调用
85        let mut result_messages = messages;
86        let snapshot = self.executor.snapshot().await;
87
88        for tc in &tool_calls {
89            let tool_result: lellm_agent::ToolResult =
90                self.executor.execute_with_snapshot(tc, &snapshot).await;
91
92            let tool_result_msg = lellm_core::Message::ToolResult {
93                tool_call_id: tc.id.clone(),
94                is_error: tool_result.is_err(),
95                content: lellm_core::text_block(match &tool_result {
96                    Ok(v) => v.to_string(),
97                    Err(e) => e.to_string(),
98                }),
99            };
100            result_messages.push(tool_result_msg);
101        }
102
103        state.insert(
104            self.messages_key.clone(),
105            serde_json::to_value(&result_messages).map_err(|e| {
106                GraphError::Terminal(TerminalError::StateError(format!("failed to serialize messages: {e}")))
107            })?,
108        );
109
110        Ok(NextStep::GoToNext)
111    }
112}