1use async_trait::async_trait;
4
5use crate::error::{GraphError, TerminalError};
6use crate::node::GraphNode;
7use crate::node::NextStep;
8use crate::state::State;
9
10pub struct ToolNode {
21 pub name: String,
22 executor: lellm_agent::ToolExecutor,
23 messages_key: String,
24}
25
26impl ToolNode {
27 pub fn all(executor: lellm_agent::ToolExecutor) -> Self {
29 Self {
30 name: "tools".into(),
31 executor,
32 messages_key: "messages".into(),
33 }
34 }
35
36 pub fn new(name: impl Into<String>, executor: lellm_agent::ToolExecutor) -> Self {
38 Self {
39 name: name.into(),
40 executor,
41 messages_key: "messages".into(),
42 }
43 }
44
45 pub fn with_messages_key(mut self, key: impl Into<String>) -> Self {
47 self.messages_key = key.into();
48 self
49 }
50}
51
52#[async_trait]
53impl GraphNode for ToolNode {
54 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
55 let messages = state
56 .get(&self.messages_key)
57 .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
58 .unwrap_or_default();
59
60 if messages.is_empty() {
61 return Ok(NextStep::GoToNext);
62 }
63
64 let last_msg = messages.last().ok_or(GraphError::Terminal(TerminalError::StateError(
66 "no messages to extract tool_calls from".into(),
67 )))?;
68
69 let tool_calls = match last_msg {
70 lellm_core::Message::Assistant { content } => content
71 .iter()
72 .filter_map(|b| match b {
73 lellm_core::ContentBlock::ToolCall(tc) => Some(tc.clone()),
74 _ => None,
75 })
76 .collect::<Vec<_>>(),
77 _ => Vec::new(),
78 };
79
80 if tool_calls.is_empty() {
81 return Ok(NextStep::GoToNext);
82 }
83
84 let mut result_messages = messages;
86 let snapshot = self.executor.snapshot().await;
87
88 for tc in &tool_calls {
89 let tool_result: lellm_agent::ToolResult =
90 self.executor.execute_with_snapshot(tc, &snapshot).await;
91
92 let tool_result_msg = lellm_core::Message::ToolResult {
93 tool_call_id: tc.id.clone(),
94 is_error: tool_result.is_err(),
95 content: lellm_core::text_block(match &tool_result {
96 Ok(v) => v.to_string(),
97 Err(e) => e.to_string(),
98 }),
99 };
100 result_messages.push(tool_result_msg);
101 }
102
103 state.insert(
104 self.messages_key.clone(),
105 serde_json::to_value(&result_messages).map_err(|e| {
106 GraphError::Terminal(TerminalError::StateError(format!("failed to serialize messages: {e}")))
107 })?,
108 );
109
110 Ok(NextStep::GoToNext)
111 }
112}