synaptic_graph/
tool_node.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolContext, ToolRuntime};
7use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
8use synaptic_tools::SerialToolExecutor;
9
10use crate::command::NodeOutput;
11use crate::node::Node;
12use crate::state::MessageState;
13
14struct BaseToolCaller {
16 executor: SerialToolExecutor,
17 #[expect(dead_code)]
18 tool_context: ToolContext,
19}
20
21#[async_trait]
22impl ToolCaller for BaseToolCaller {
23 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
24 self.executor
25 .execute(&request.call.name, request.call.arguments.clone())
26 .await
27 }
28}
29
30pub struct ToolNode {
36 executor: SerialToolExecutor,
37 middleware: Option<Arc<MiddlewareChain>>,
38 store: Option<Arc<dyn Store>>,
40 runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
42}
43
44impl ToolNode {
45 pub fn new(executor: SerialToolExecutor) -> Self {
46 Self {
47 executor,
48 middleware: None,
49 store: None,
50 runtime_tools: HashMap::new(),
51 }
52 }
53
54 pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
56 Self {
57 executor,
58 middleware: Some(middleware),
59 store: None,
60 runtime_tools: HashMap::new(),
61 }
62 }
63
64 pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
66 self.store = Some(store);
67 self
68 }
69
70 pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
76 self.runtime_tools.insert(tool.name().to_string(), tool);
77 self
78 }
79}
80
81#[async_trait]
82impl Node<MessageState> for ToolNode {
83 async fn process(
84 &self,
85 mut state: MessageState,
86 ) -> Result<NodeOutput<MessageState>, SynapticError> {
87 let last = state
88 .last_message()
89 .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
90
91 let tool_calls = last.tool_calls().to_vec();
92 if tool_calls.is_empty() {
93 return Ok(state.into());
94 }
95
96 let state_value = serde_json::to_value(&state).ok();
98
99 for call in &tool_calls {
100 let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
102 let runtime = ToolRuntime {
103 store: self.store.clone(),
104 stream_writer: None,
105 state: state_value.clone(),
106 tool_call_id: call.id.clone(),
107 config: None,
108 };
109 rt_tool
110 .call_with_runtime(call.arguments.clone(), runtime)
111 .await?
112 } else {
113 let tool_ctx = ToolContext {
115 state: state_value.clone(),
116 tool_call_id: call.id.clone(),
117 };
118
119 if let Some(ref chain) = self.middleware {
120 let request = ToolCallRequest { call: call.clone() };
121 let base = BaseToolCaller {
122 executor: self.executor.clone(),
123 tool_context: tool_ctx,
124 };
125 chain.call_tool(request, &base).await?
126 } else {
127 self.executor
128 .execute(&call.name, call.arguments.clone())
129 .await?
130 }
131 };
132 state
133 .messages
134 .push(Message::tool(result.to_string(), &call.id));
135 }
136
137 Ok(state.into())
138 }
139}
140
141pub fn tools_condition(state: &MessageState) -> String {
146 if let Some(last) = state.last_message() {
147 if !last.tool_calls().is_empty() {
148 return "tools".to_string();
149 }
150 }
151 crate::END.to_string()
152}