Skip to main content

synaptic_graph/
tool_node.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolRuntime};
7use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
8use synaptic_tools::SerialToolExecutor;
9
10use crate::command::NodeOutput;
11use crate::node::Node;
12use crate::state::MessageState;
13
14/// Wraps a `SerialToolExecutor` into a `ToolCaller` for the middleware chain.
15struct BaseToolCaller {
16    executor: SerialToolExecutor,
17}
18
19#[async_trait]
20impl ToolCaller for BaseToolCaller {
21    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
22        self.executor
23            .execute(&request.call.name, request.call.arguments.clone())
24            .await
25    }
26}
27
28/// Prebuilt node that executes tool calls from the last AI message in state.
29///
30/// Supports both regular `Tool` and `RuntimeAwareTool` instances.
31/// When a runtime-aware tool is registered, it receives the current graph
32/// state, store reference, and tool call ID via [`ToolRuntime`].
33pub struct ToolNode {
34    executor: SerialToolExecutor,
35    middleware: Option<Arc<MiddlewareChain>>,
36    /// Optional store reference injected into RuntimeAwareTool calls.
37    store: Option<Arc<dyn Store>>,
38    /// Runtime-aware tools keyed by tool name.
39    runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
40}
41
42impl ToolNode {
43    pub fn new(executor: SerialToolExecutor) -> Self {
44        Self {
45            executor,
46            middleware: None,
47            store: None,
48            runtime_tools: HashMap::new(),
49        }
50    }
51
52    /// Create a ToolNode with middleware support.
53    pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
54        Self {
55            executor,
56            middleware: Some(middleware),
57            store: None,
58            runtime_tools: HashMap::new(),
59        }
60    }
61
62    /// Set the store reference for runtime-aware tool injection.
63    pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
64        self.store = Some(store);
65        self
66    }
67
68    /// Register a runtime-aware tool.
69    ///
70    /// When a tool call matches a registered runtime-aware tool by name,
71    /// it will be called with a [`ToolRuntime`] containing the current
72    /// graph state, store, and tool call ID.
73    pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
74        self.runtime_tools.insert(tool.name().to_string(), tool);
75        self
76    }
77}
78
79#[async_trait]
80impl Node<MessageState> for ToolNode {
81    async fn process(
82        &self,
83        mut state: MessageState,
84    ) -> Result<NodeOutput<MessageState>, SynapticError> {
85        let last = state
86            .last_message()
87            .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
88
89        let tool_calls = last.tool_calls().to_vec();
90        if tool_calls.is_empty() {
91            return Ok(state.into());
92        }
93
94        // Serialize current state for context injection
95        let state_value = serde_json::to_value(&state).ok();
96
97        for call in &tool_calls {
98            // Check if this is a runtime-aware tool
99            let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
100                let runtime = ToolRuntime {
101                    store: self.store.clone(),
102                    stream_writer: None,
103                    state: state_value.clone(),
104                    tool_call_id: call.id.clone(),
105                    config: None,
106                };
107                rt_tool
108                    .call_with_runtime(call.arguments.clone(), runtime)
109                    .await?
110            } else {
111                // Regular tool execution
112                if let Some(ref chain) = self.middleware {
113                    let request = ToolCallRequest { call: call.clone() };
114                    let base = BaseToolCaller {
115                        executor: self.executor.clone(),
116                    };
117                    chain.call_tool(request, &base).await?
118                } else {
119                    self.executor
120                        .execute(&call.name, call.arguments.clone())
121                        .await?
122                }
123            };
124            state
125                .messages
126                .push(Message::tool(result.to_string(), &call.id));
127        }
128
129        Ok(state.into())
130    }
131}
132
133/// Standard routing function: returns "tools" if last message has tool_calls, else END.
134///
135/// This is the standard condition function used with `add_conditional_edges`
136/// to route between an agent node and a tools node.
137pub fn tools_condition(state: &MessageState) -> String {
138    if let Some(last) = state.last_message() {
139        if !last.tool_calls().is_empty() {
140            return "tools".to_string();
141        }
142    }
143    crate::END.to_string()
144}