use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolRuntime};
use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
use synaptic_tools::SerialToolExecutor;
use crate::command::NodeOutput;
use crate::node::Node;
use crate::state::MessageState;
struct BaseToolCaller {
executor: SerialToolExecutor,
}
#[async_trait]
impl ToolCaller for BaseToolCaller {
async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
self.executor
.execute(&request.call.name, request.call.arguments.clone())
.await
}
}
pub struct ToolNode {
executor: SerialToolExecutor,
middleware: Option<Arc<MiddlewareChain>>,
store: Option<Arc<dyn Store>>,
runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
parallel: bool,
}
impl ToolNode {
pub fn new(executor: SerialToolExecutor) -> Self {
Self {
executor,
middleware: None,
store: None,
runtime_tools: HashMap::new(),
parallel: false,
}
}
pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
Self {
executor,
middleware: Some(middleware),
store: None,
runtime_tools: HashMap::new(),
parallel: false,
}
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
self.store = Some(store);
self
}
pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
self.runtime_tools.insert(tool.name().to_string(), tool);
self
}
}
#[async_trait]
impl Node<MessageState> for ToolNode {
async fn process(
&self,
mut state: MessageState,
) -> Result<NodeOutput<MessageState>, SynapticError> {
let last = state
.last_message()
.ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
let tool_calls = last.tool_calls().to_vec();
if tool_calls.is_empty() {
return Ok(state.into());
}
let state_value = serde_json::to_value(&state).ok();
if self.parallel && tool_calls.len() > 1 {
let futs: Vec<_> = tool_calls
.iter()
.map(|call| {
let executor = self.executor.clone();
let middleware = self.middleware.clone();
let rt_tool = self.runtime_tools.get(&call.name).cloned();
let store = self.store.clone();
let sv = state_value.clone();
let call = call.clone();
async move {
if let Some(rt) = rt_tool {
let runtime = ToolRuntime {
store,
stream_writer: None,
state: sv,
tool_call_id: call.id.clone(),
config: None,
};
rt.call_with_runtime(call.arguments.clone(), runtime).await
} else if let Some(ref chain) = middleware {
let request = ToolCallRequest { call: call.clone() };
let base = BaseToolCaller { executor };
chain.call_tool(request, &base).await
} else {
executor.execute(&call.name, call.arguments.clone()).await
}
}
})
.collect();
let results = futures::future::join_all(futs).await;
for (call, result) in tool_calls.iter().zip(results) {
state
.messages
.push(Message::tool(result?.to_string(), &call.id));
}
} else {
for call in &tool_calls {
let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
let runtime = ToolRuntime {
store: self.store.clone(),
stream_writer: None,
state: state_value.clone(),
tool_call_id: call.id.clone(),
config: None,
};
rt_tool
.call_with_runtime(call.arguments.clone(), runtime)
.await?
} else if let Some(ref chain) = self.middleware {
let request = ToolCallRequest { call: call.clone() };
let base = BaseToolCaller {
executor: self.executor.clone(),
};
chain.call_tool(request, &base).await?
} else {
self.executor
.execute(&call.name, call.arguments.clone())
.await?
};
state
.messages
.push(Message::tool(result.to_string(), &call.id));
}
}
Ok(state.into())
}
}
pub fn tools_condition(state: &MessageState) -> String {
if let Some(last) = state.last_message() {
if !last.tool_calls().is_empty() {
return "tools".to_string();
}
}
crate::END.to_string()
}