use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use parking_lot::RwLock;
use serde_json::Value;
use crate::connections::Connection;
use crate::error::{Error, Result};
use crate::runtime::MaybeSendSync;
use crate::types::{ToolCall, ToolResult};
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
pub trait Tool: MaybeSendSync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> Value;
async fn execute(&self, args: Value, ctx: Option<Arc<ToolContext>>) -> Result<Value>;
}
pub struct ToolContext {
connection: Arc<dyn Connection>,
state: RwLock<HashMap<String, Value>>,
}
impl ToolContext {
pub fn new(connection: Arc<dyn Connection>) -> Self {
Self {
connection,
state: RwLock::new(HashMap::new()),
}
}
pub fn conversation_id(&self) -> &str {
self.connection.conversation_id()
}
pub fn is_idle(&self) -> bool {
self.connection.is_idle()
}
pub async fn send(&self, message: impl Into<String>) -> Result<()> {
self.connection.send_trigger(message.into()).await
}
pub fn get_state(&self, key: &str) -> Option<Value> {
self.state.read().get(key).cloned()
}
pub fn set_state(&self, key: impl Into<String>, value: Value) {
self.state.write().insert(key.into(), value);
}
}
pub struct ToolRunner {
tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
context: ArcSwapOption<ToolContext>,
}
impl Default for ToolRunner {
fn default() -> Self {
Self {
tools: RwLock::new(HashMap::new()),
context: ArcSwapOption::from(None),
}
}
}
impl ToolRunner {
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
self.tools.write().insert(name, tool);
}
pub fn set_context(&self, ctx: Arc<ToolContext>) {
self.context.store(Some(ctx));
}
pub fn clear_context(&self) {
self.context.store(None);
}
pub fn names(&self) -> Vec<String> {
self.tools.read().keys().cloned().collect()
}
pub fn iter_tools(&self) -> Vec<Arc<dyn Tool>> {
self.tools.read().values().cloned().collect()
}
pub async fn execute(&self, name: &str, args: Value) -> Result<Value> {
let tool = self
.tools
.read()
.get(name)
.cloned()
.ok_or_else(|| Error::ToolNotFound {
name: name.to_string(),
})?;
let ctx = self.context.load_full();
tool.execute(args, ctx).await
}
pub async fn process_tool_calls(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
let mut results = Vec::with_capacity(calls.len());
for call in calls {
match self.execute(&call.name, call.args.clone()).await {
Ok(value) => results.push(ToolResult::ok(call.name, call.id, value)),
Err(e) => results.push(ToolResult::err(call.name, call.id, e.to_string())),
}
}
results
}
}
#[cfg(not(target_arch = "wasm32"))]
type ToolFuture = futures_util::future::BoxFuture<'static, Result<Value>>;
#[cfg(target_arch = "wasm32")]
type ToolFuture = futures_util::future::LocalBoxFuture<'static, Result<Value>>;
#[cfg(not(target_arch = "wasm32"))]
type ClosureHandler = Arc<dyn Fn(Value, Option<Arc<ToolContext>>) -> ToolFuture + Send + Sync>;
#[cfg(target_arch = "wasm32")]
type ClosureHandler = Arc<dyn Fn(Value, Option<Arc<ToolContext>>) -> ToolFuture>;
pub struct ClosureTool {
name: String,
description: String,
schema: Value,
handler: ClosureHandler,
}
impl ClosureTool {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
schema: Value,
handler: F,
) -> Arc<Self>
where
F: Fn(Value, Option<Arc<ToolContext>>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
{
Arc::new(Self {
name: name.into(),
description: description.into(),
schema,
handler: Arc::new(move |a, c| Box::pin(handler(a, c))),
})
}
#[cfg(target_arch = "wasm32")]
pub fn new<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
schema: Value,
handler: F,
) -> Arc<Self>
where
F: Fn(Value, Option<Arc<ToolContext>>) -> Fut + 'static,
Fut: std::future::Future<Output = Result<Value>> + 'static,
{
Arc::new(Self {
name: name.into(),
description: description.into(),
schema,
handler: Arc::new(move |a, c| Box::pin(handler(a, c))),
})
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Tool for ClosureTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn input_schema(&self) -> Value {
self.schema.clone()
}
async fn execute(&self, args: Value, ctx: Option<Arc<ToolContext>>) -> Result<Value> {
(self.handler)(args, ctx).await
}
}