use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::registry::KernelError;
pub type ToolName = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: ToolName,
pub description: String,
pub args_schema: Value,
pub result_schema: Value,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn schema(&self) -> ToolSchema;
fn name(&self) -> ToolName {
self.schema().name
}
async fn invoke(&self, args: Value) -> Result<Value, KernelError>;
}
pub struct LocalTool {
schema: ToolSchema,
#[allow(clippy::type_complexity)]
f: Arc<
dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, KernelError>> + Send>>
+ Send
+ Sync,
>,
}
impl LocalTool {
pub fn new<F, Fut>(schema: ToolSchema, f: F) -> Self
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Value, KernelError>> + Send + 'static,
{
Self {
schema,
f: Arc::new(move |v| Box::pin(f(v))),
}
}
}
#[async_trait]
impl Tool for LocalTool {
fn schema(&self) -> ToolSchema {
self.schema.clone()
}
fn name(&self) -> ToolName {
self.schema.name.clone()
}
async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
(self.f)(args).await
}
}
#[cfg(test)]
mod tests {
use crate::*;
use serde_json::json;
#[tokio::test]
async fn local_tool_roundtrip() {
let schema = ToolSchema {
name: "test.echo".into(),
description: "echoes the input".into(),
args_schema: json!({"type": "object"}),
result_schema: json!({"type": "object"}),
};
let tool = LocalTool::new(schema, |v| async move { Ok(v) });
let out = tool.invoke(json!({"hello": "world"})).await.unwrap();
assert_eq!(out, json!({"hello": "world"}));
assert_eq!(tool.name(), "test.echo");
}
}