use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use parking_lot::RwLock;
use serde_json::Value;
use crate::connections::Connection;
use crate::error::{Error, Result};
use crate::runtime::MaybeSendSync;
use crate::types::{ToolCall, ToolResult};
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
pub trait Tool: MaybeSendSync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> Value;
async fn execute(&self, args: Value, ctx: Option<Arc<ToolContext>>) -> Result<Value>;
}
pub struct ToolContext {
connection: Arc<dyn Connection>,
state: RwLock<HashMap<String, Value>>,
}
impl ToolContext {
pub fn new(connection: Arc<dyn Connection>) -> Self {
Self {
connection,
state: RwLock::new(HashMap::new()),
}
}
pub fn conversation_id(&self) -> &str {
self.connection.conversation_id()
}
pub fn is_idle(&self) -> bool {
self.connection.is_idle()
}
pub async fn send(&self, message: impl Into<String>) -> Result<()> {
self.connection.send_trigger(message.into()).await
}
pub fn get_state(&self, key: &str) -> Option<Value> {
self.state.read().get(key).cloned()
}
pub fn set_state(&self, key: impl Into<String>, value: Value) {
self.state.write().insert(key.into(), value);
}
}
pub struct ToolRunner {
tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
context: ArcSwapOption<ToolContext>,
}
impl Default for ToolRunner {
fn default() -> Self {
Self {
tools: RwLock::new(HashMap::new()),
context: ArcSwapOption::from(None),
}
}
}
impl ToolRunner {
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
self.tools.write().insert(name, tool);
}
pub fn set_context(&self, ctx: Arc<ToolContext>) {
self.context.store(Some(ctx));
}
pub fn clear_context(&self) {
self.context.store(None);
}
pub fn names(&self) -> Vec<String> {
self.tools.read().keys().cloned().collect()
}
pub fn iter_tools(&self) -> Vec<Arc<dyn Tool>> {
self.tools.read().values().cloned().collect()
}
pub async fn execute(&self, name: &str, args: Value) -> Result<Value> {
let tool = self
.tools
.read()
.get(name)
.cloned()
.ok_or_else(|| Error::ToolNotFound {
name: name.to_string(),
})?;
let ctx = self.context.load_full();
tool.execute(args, ctx).await
}
pub async fn process_tool_calls(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
let mut results = Vec::with_capacity(calls.len());
for call in calls {
match self.execute(&call.name, call.args.clone()).await {
Ok(value) => results.push(ToolResult::ok(call.name, call.id, value)),
Err(e) => results.push(ToolResult::err(call.name, call.id, e.to_string())),
}
}
results
}
}
#[cfg(not(target_arch = "wasm32"))]
type ToolFuture = futures_util::future::BoxFuture<'static, Result<Value>>;
#[cfg(target_arch = "wasm32")]
type ToolFuture = futures_util::future::LocalBoxFuture<'static, Result<Value>>;
#[cfg(not(target_arch = "wasm32"))]
type ClosureHandler = Arc<dyn Fn(Value, Option<Arc<ToolContext>>) -> ToolFuture + Send + Sync>;
#[cfg(target_arch = "wasm32")]
type ClosureHandler = Arc<dyn Fn(Value, Option<Arc<ToolContext>>) -> ToolFuture>;
pub struct ClosureTool {
name: String,
description: String,
schema: Value,
handler: ClosureHandler,
}
impl ClosureTool {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
schema: Value,
handler: F,
) -> Arc<Self>
where
F: Fn(Value, Option<Arc<ToolContext>>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
{
Arc::new(Self {
name: name.into(),
description: description.into(),
schema,
handler: Arc::new(move |a, c| Box::pin(handler(a, c))),
})
}
#[cfg(target_arch = "wasm32")]
pub fn new<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
schema: Value,
handler: F,
) -> Arc<Self>
where
F: Fn(Value, Option<Arc<ToolContext>>) -> Fut + 'static,
Fut: std::future::Future<Output = Result<Value>> + 'static,
{
Arc::new(Self {
name: name.into(),
description: description.into(),
schema,
handler: Arc::new(move |a, c| Box::pin(handler(a, c))),
})
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_state<S, F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
schema: Value,
state: S,
f: F,
) -> Arc<Self>
where
S: Clone + Send + Sync + 'static,
F: Fn(S, Value, Option<Arc<ToolContext>>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
{
Self::new(name, description, schema, move |args, ctx| {
f(state.clone(), args, ctx)
})
}
#[cfg(target_arch = "wasm32")]
pub fn with_state<S, F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
schema: Value,
state: S,
f: F,
) -> Arc<Self>
where
S: Clone + 'static,
F: Fn(S, Value, Option<Arc<ToolContext>>) -> Fut + 'static,
Fut: std::future::Future<Output = Result<Value>> + 'static,
{
Self::new(name, description, schema, move |args, ctx| {
f(state.clone(), args, ctx)
})
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Tool for ClosureTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn input_schema(&self) -> Value {
self.schema.clone()
}
async fn execute(&self, args: Value, ctx: Option<Arc<ToolContext>>) -> Result<Value> {
(self.handler)(args, ctx).await
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
use serde_json::json;
#[tokio::test]
async fn with_state_threads_shared_state_across_calls() {
let counter = Arc::new(AtomicU64::new(0));
let tool = ClosureTool::with_state(
"tick",
"Increment a shared counter and report the new value.",
json!({ "type": "object", "properties": {} }),
counter.clone(),
|counter: Arc<AtomicU64>, _args, _ctx| async move {
let prev = counter.fetch_add(1, Ordering::SeqCst);
Ok(json!({ "count": prev + 1 }))
},
);
let runner = ToolRunner::new();
runner.register(tool);
let r1 = runner.execute("tick", json!({})).await.unwrap();
let r2 = runner.execute("tick", json!({})).await.unwrap();
let r3 = runner.execute("tick", json!({})).await.unwrap();
assert_eq!(r1["count"], json!(1));
assert_eq!(r2["count"], json!(2));
assert_eq!(r3["count"], json!(3));
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
fn echo_tool(name: &'static str, marker: &'static str) -> Arc<ClosureTool> {
ClosureTool::new(name, "", json!({ "type": "object" }), move |_a, _c| async move {
Ok(json!({ "who": marker }))
})
}
#[tokio::test]
async fn register_overwrites_a_same_name_tool() {
let runner = ToolRunner::new();
runner.register(echo_tool("t", "first"));
runner.register(echo_tool("t", "second"));
assert_eq!(runner.names(), vec!["t".to_string()]);
let r = runner.execute("t", json!({})).await.unwrap();
assert_eq!(r["who"], "second");
}
#[tokio::test]
async fn execute_unknown_tool_is_tool_not_found() {
let runner = ToolRunner::new();
let err = runner.execute("ghost", json!({})).await.unwrap_err();
assert!(matches!(&err, Error::ToolNotFound { name } if name == "ghost"));
}
#[tokio::test]
async fn process_tool_calls_maps_ok_and_err_preserving_ids() {
let runner = ToolRunner::new();
runner.register(echo_tool("ok", "v"));
let calls = vec![
ToolCall { name: "ok".into(), args: json!({}), id: Some("a".into()), canonical_path: None },
ToolCall { name: "missing".into(), args: json!({}), id: Some("b".into()), canonical_path: None },
];
let results = runner.process_tool_calls(calls).await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].name, "ok");
assert_eq!(results[0].id.as_deref(), Some("a"));
assert!(results[0].error.is_none(), "ok call must have no error");
assert_eq!(results[1].id.as_deref(), Some("b"));
assert!(
results[1].error.as_deref().unwrap().contains("missing"),
"err result must name the missing tool"
);
}
#[tokio::test]
async fn names_and_iter_tools_reflect_registrations() {
let runner = ToolRunner::new();
assert!(runner.names().is_empty());
runner.register(echo_tool("a", "x"));
runner.register(echo_tool("b", "y"));
let mut names = runner.names();
names.sort();
assert_eq!(names, vec!["a".to_string(), "b".to_string()]);
assert_eq!(runner.iter_tools().len(), 2);
}
}