use super::base_tool::BaseTool;
use crate::{
compat::MaybeSendBoxFuture,
tools::ToolContext,
tools::{FunctionDeclaration, ToolResult},
MaybeSend, MaybeSync,
};
use serde_json::json;
use serde_json::Value;
use std::collections::HashMap;
type ToolFuture<'a> = MaybeSendBoxFuture<'a, ToolResult>;
pub trait ToolFn:
for<'a> Fn(HashMap<String, Value>, &'a ToolContext<'a>) -> ToolFuture<'a> + MaybeSend + MaybeSync
{
}
impl<T> ToolFn for T where
T: for<'a> Fn(HashMap<String, Value>, &'a ToolContext<'a>) -> ToolFuture<'a>
+ MaybeSend
+ MaybeSync
+ 'static
{
}
type AsyncToolFunctionInner = dyn ToolFn;
pub type AsyncToolFunction = Box<AsyncToolFunctionInner>;
pub struct FunctionTool {
name: String,
description: String,
function: AsyncToolFunction,
parameters_schema: Value,
cached_declaration: Option<FunctionDeclaration>,
}
impl FunctionTool {
pub fn new<F>(name: impl Into<String>, description: impl Into<String>, function: F) -> Self
where
F: for<'a> Fn(HashMap<String, Value>, &'a ToolContext<'a>) -> ToolFuture<'a>
+ MaybeSend
+ MaybeSync
+ 'static,
{
Self {
name: name.into(),
description: description.into(),
function: Box::new(function),
parameters_schema: json!({}),
cached_declaration: None,
}
}
#[must_use]
pub fn with_parameters_schema(mut self, schema: Value) -> Self {
self.parameters_schema = schema;
self.cached_declaration = None; self
}
#[must_use]
pub const fn parameters_schema(&self) -> &Value {
&self.parameters_schema
}
}
#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
impl BaseTool for FunctionTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn declaration(&self) -> FunctionDeclaration {
FunctionDeclaration::new(
self.name.clone(),
self.description.clone(),
self.parameters_schema.clone(),
)
}
async fn run_async(
&self,
args: HashMap<String, Value>,
context: &ToolContext<'_>,
) -> ToolResult {
(self.function)(args, context).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::execution_state::DefaultExecutionState;
use serde_json::json;
#[tokio::test(flavor = "current_thread")]
async fn function_tool_executes_provided_closure() {
let tool = FunctionTool::new("increment", "Increment value", |args, _ctx| {
Box::pin(async move {
let value = args
.get("value")
.and_then(Value::as_i64)
.unwrap_or_default();
ToolResult::success(json!({ "value": value + 1 }))
})
});
let state = DefaultExecutionState::new();
let ctx = ToolContext::builder()
.with_state(&state)
.build()
.expect("context");
let result = tool
.run_async(HashMap::from([("value".to_string(), Value::from(5))]), &ctx)
.await;
assert!(result.is_success());
assert_eq!(result.data(), &json!({ "value": 6 }));
}
#[test]
fn function_tool_allows_schema_override() {
let tool = FunctionTool::new("noop", "Does nothing", |_, _| {
Box::pin(async { ToolResult::success(Value::Null) })
})
.with_parameters_schema(json!({
"type": "object",
"properties": { "count": { "type": "integer" } }
}));
let declaration = tool.declaration();
assert_eq!(
declaration.parameters(),
&json!({
"type": "object",
"properties": { "count": { "type": "integer" } }
})
);
}
}