synaptic_graph/
tool_node.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolRuntime};
7use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
8use synaptic_tools::SerialToolExecutor;
9
10use crate::command::NodeOutput;
11use crate::node::Node;
12use crate::state::MessageState;
13
14struct BaseToolCaller {
16 executor: SerialToolExecutor,
17}
18
19#[async_trait]
20impl ToolCaller for BaseToolCaller {
21 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
22 self.executor
23 .execute(&request.call.name, request.call.arguments.clone())
24 .await
25 }
26}
27
28pub struct ToolNode {
34 executor: SerialToolExecutor,
35 middleware: Option<Arc<MiddlewareChain>>,
36 store: Option<Arc<dyn Store>>,
38 runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
40}
41
42impl ToolNode {
43 pub fn new(executor: SerialToolExecutor) -> Self {
44 Self {
45 executor,
46 middleware: None,
47 store: None,
48 runtime_tools: HashMap::new(),
49 }
50 }
51
52 pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
54 Self {
55 executor,
56 middleware: Some(middleware),
57 store: None,
58 runtime_tools: HashMap::new(),
59 }
60 }
61
62 pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
64 self.store = Some(store);
65 self
66 }
67
68 pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
74 self.runtime_tools.insert(tool.name().to_string(), tool);
75 self
76 }
77}
78
79#[async_trait]
80impl Node<MessageState> for ToolNode {
81 async fn process(
82 &self,
83 mut state: MessageState,
84 ) -> Result<NodeOutput<MessageState>, SynapticError> {
85 let last = state
86 .last_message()
87 .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
88
89 let tool_calls = last.tool_calls().to_vec();
90 if tool_calls.is_empty() {
91 return Ok(state.into());
92 }
93
94 let state_value = serde_json::to_value(&state).ok();
96
97 for call in &tool_calls {
98 let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
100 let runtime = ToolRuntime {
101 store: self.store.clone(),
102 stream_writer: None,
103 state: state_value.clone(),
104 tool_call_id: call.id.clone(),
105 config: None,
106 };
107 rt_tool
108 .call_with_runtime(call.arguments.clone(), runtime)
109 .await?
110 } else {
111 if let Some(ref chain) = self.middleware {
113 let request = ToolCallRequest { call: call.clone() };
114 let base = BaseToolCaller {
115 executor: self.executor.clone(),
116 };
117 chain.call_tool(request, &base).await?
118 } else {
119 self.executor
120 .execute(&call.name, call.arguments.clone())
121 .await?
122 }
123 };
124 state
125 .messages
126 .push(Message::tool(result.to_string(), &call.id));
127 }
128
129 Ok(state.into())
130 }
131}
132
133pub fn tools_condition(state: &MessageState) -> String {
138 if let Some(last) = state.last_message() {
139 if !last.tool_calls().is_empty() {
140 return "tools".to_string();
141 }
142 }
143 crate::END.to_string()
144}