use crate::config::Config;
use crate::errors::{Error, Result};
use crate::llm::ToolInfo;
use crate::nodes::Node;
use crate::state::{Message, MessagesState, ToolCall};
use async_trait::async_trait;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::sync::Arc;
pub struct Tool {
pub name: String,
pub description: String,
pub schema: Option<JsonValue>,
pub func: Arc<dyn Fn(JsonValue) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<JsonValue>> + Send>> + Send + Sync>,
}
impl Tool {
pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, func: F) -> Self
where
F: Fn(JsonValue) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<JsonValue>> + Send + 'static,
{
Self {
name: name.into(),
description: description.into(),
schema: None,
func: Arc::new(move |args| Box::pin(func(args))),
}
}
pub fn with_schema(mut self, schema: JsonValue) -> Self {
self.schema = Some(schema);
self
}
pub async fn invoke(&self, args: JsonValue) -> Result<JsonValue> {
(self.func)(args).await
}
pub fn to_tool_info(&self) -> ToolInfo {
let parameters = self.schema.clone().unwrap_or_else(|| {
serde_json::json!({
"type": "object",
"properties": {},
"required": []
})
});
ToolInfo {
name: self.name.clone(),
description: self.description.clone(),
parameters,
}
}
pub fn to_schema(&self) -> JsonValue {
serde_json::json!({
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.schema.clone().unwrap_or_else(|| {
serde_json::json!({
"type": "object",
"properties": {},
"required": []
})
})
}
})
}
}
pub struct ToolNode {
tools: HashMap<String, Tool>,
}
impl ToolNode {
pub fn new(tools: Vec<Tool>) -> Self {
let tools = tools
.into_iter()
.map(|tool| (tool.name.clone(), tool))
.collect();
Self { tools }
}
async fn execute_tool_call(&self, tool_call: &ToolCall) -> Message {
match self.tools.get(&tool_call.name) {
Some(tool) => {
match tool.invoke(tool_call.arguments.clone()).await {
Ok(result) => Message::tool(
serde_json::to_string(&result).unwrap_or_else(|_| "{}".to_string()),
&tool_call.id,
),
Err(e) => Message::tool(
format!("Error: {}", e),
&tool_call.id,
),
}
}
None => Message::tool(
format!("Tool '{}' not found", tool_call.name),
&tool_call.id,
),
}
}
}
#[async_trait]
impl Node<MessagesState> for ToolNode {
async fn invoke(&self, state: MessagesState, _config: &Config) -> Result<MessagesState> {
let tool_calls = state
.messages
.last()
.and_then(|msg| msg.tool_calls.as_ref())
.cloned()
.unwrap_or_default();
if tool_calls.is_empty() {
return Ok(MessagesState {
messages: vec![],
});
}
let mut tool_messages = Vec::new();
for tool_call in &tool_calls {
let message = self.execute_tool_call(tool_call).await;
tool_messages.push(message);
}
Ok(MessagesState {
messages: tool_messages,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_tool_creation() {
let tool = Tool::new(
"test_tool",
"A test tool",
|args: JsonValue| async move {
Ok(serde_json::json!({"result": args["input"]}))
},
);
assert_eq!(tool.name, "test_tool");
assert_eq!(tool.description, "A test tool");
let result = tool.invoke(serde_json::json!({"input": "hello"})).await.unwrap();
assert_eq!(result["result"], "hello");
}
#[tokio::test]
async fn test_tool_node() {
let tool = Tool::new(
"echo",
"Echo the input",
|args: JsonValue| async move { Ok(args) },
);
let tool_node = ToolNode::new(vec![tool]);
let state = MessagesState {
messages: vec![
Message::assistant("").with_tool_calls(vec![
ToolCall::new("call-1", "echo", serde_json::json!({"msg": "test"}))
])
],
};
let result = tool_node.invoke(state, &Config::default()).await.unwrap();
assert_eq!(result.messages.len(), 1);
assert_eq!(result.messages[0].role, "tool");
assert_eq!(result.messages[0].tool_call_id.as_deref(), Some("call-1"));
}
#[tokio::test]
async fn test_tool_node_unknown_tool() {
let tool_node = ToolNode::new(vec![]);
let state = MessagesState {
messages: vec![
Message::assistant("").with_tool_calls(vec![
ToolCall::new("call-1", "unknown", serde_json::json!({}))
])
],
};
let result = tool_node.invoke(state, &Config::default()).await.unwrap();
assert_eq!(result.messages.len(), 1);
assert!(result.messages[0].content.contains("not found"));
}
}