use crate::error::Result;
use std::future::Future;
use std::pin::Pin;
pub trait Tool: Send + Sync {
type ExecuteFut<'a>: Future<Output = Result<String>> + Send + 'a
where
Self: 'a;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn execute<'a>(&'a self, input: &'a str) -> Self::ExecuteFut<'a>;
}
pub trait DynTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn execute_dyn<'a>(
&'a self,
input: &'a str,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>;
}
impl<T: Tool> DynTool for T {
fn name(&self) -> &str {
Tool::name(self)
}
fn description(&self) -> &str {
Tool::description(self)
}
fn execute_dyn<'a>(
&'a self,
input: &'a str,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
Box::pin(Tool::execute(self, input))
}
}
pub fn tool(name: &'static str) -> ToolBuilder {
ToolBuilder::new(name)
}
pub struct ToolBuilder {
name: &'static str,
description: &'static str,
}
impl ToolBuilder {
pub fn new(name: &'static str) -> Self {
Self {
name,
description: "",
}
}
pub fn description(mut self, desc: &'static str) -> Self {
self.description = desc;
self
}
pub fn execute<F>(self, f: F) -> FnTool<F>
where
F: Fn(&str) -> Result<String> + Send + Sync,
{
FnTool {
name: self.name,
description: self.description,
executor: f,
}
}
pub fn execute_async<F, Fut>(self, f: F) -> AsyncFnTool<F>
where
F: Fn(String) -> Fut + Send + Sync,
Fut: Future<Output = Result<String>> + Send,
{
AsyncFnTool {
name: self.name,
description: self.description,
executor: f,
}
}
}
pub struct FnTool<F> {
name: &'static str,
description: &'static str,
executor: F,
}
impl<F> Tool for FnTool<F>
where
F: Fn(&str) -> Result<String> + Send + Sync,
{
type ExecuteFut<'a>
= std::future::Ready<Result<String>>
where
Self: 'a;
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
self.description
}
fn execute<'a>(&'a self, input: &'a str) -> Self::ExecuteFut<'a> {
std::future::ready((self.executor)(input))
}
}
pub struct AsyncFnTool<F> {
name: &'static str,
description: &'static str,
executor: F,
}
impl<F, Fut> Tool for AsyncFnTool<F>
where
F: Fn(String) -> Fut + Send + Sync,
Fut: Future<Output = Result<String>> + Send + 'static,
{
type ExecuteFut<'a>
= Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>
where
Self: 'a;
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
self.description
}
fn execute<'a>(&'a self, input: &'a str) -> Self::ExecuteFut<'a> {
let input_owned = input.to_owned();
let fut = (self.executor)(input_owned);
Box::pin(fut)
}
}
pub struct MockTool {
name: &'static str,
description: &'static str,
response: &'static str,
}
impl MockTool {
pub fn new(name: &'static str, description: &'static str, response: &'static str) -> Self {
Self {
name,
description,
response,
}
}
}
impl Tool for MockTool {
type ExecuteFut<'a>
= std::future::Ready<Result<String>>
where
Self: 'a;
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
self.description
}
fn execute<'a>(&'a self, _input: &'a str) -> Self::ExecuteFut<'a> {
std::future::ready(Ok(self.response.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_fn_tool() {
let calc = tool("calculator")
.description("Perform calculations")
.execute(|input| Ok(format!("Result: {}", input)));
assert_eq!(Tool::name(&calc), "calculator");
assert_eq!(Tool::description(&calc), "Perform calculations");
let result = Tool::execute(&calc, "2 + 2").await.unwrap();
assert_eq!(result, "Result: 2 + 2");
}
#[tokio::test]
async fn test_mock_tool() {
let t = MockTool::new("test", "A test tool", "mock response");
assert_eq!(Tool::name(&t), "test");
let result = Tool::execute(&t, "any input").await.unwrap();
assert_eq!(result, "mock response");
}
#[tokio::test]
async fn test_async_fn_tool() {
let search = tool("search")
.description("Search for information")
.execute_async(|query: String| async move { Ok(format!("Found: {}", query)) });
assert_eq!(Tool::name(&search), "search");
let result = Tool::execute(&search, "rust async").await.unwrap();
assert_eq!(result, "Found: rust async");
}
#[test]
fn test_dyn_tool_is_object_safe() {
let mock = MockTool::new("test", "desc", "response");
let _: &dyn DynTool = &mock;
}
}