use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use serde_json::Value;
use crate::ToolExecutionError;
use super::tool::{AsyncToolFn, Function, FunctionParameters, Property, Tool, ToolType};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolBuilderError {
MissingFunctionName,
MissingFunctionDescription,
MissingExecutor,
}
impl std::fmt::Display for ToolBuilderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ToolBuilderError::MissingFunctionName => write!(f, "Function name is required."),
ToolBuilderError::MissingFunctionDescription => write!(f, "Function description is required."),
ToolBuilderError::MissingExecutor => write!(f, "Executor function is required for the tool."),
}
}
}
impl std::error::Error for ToolBuilderError {}
pub type ExecutorFn = Arc<
dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<String, ToolExecutionError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Default)]
pub struct ToolBuilder {
tool_type: Option<ToolType>,
function_name: Option<String>,
function_description: Option<String>,
function_properties: HashMap<String, Property>,
function_required: Vec<String>,
executor: Option<AsyncToolFn>,
}
impl std::fmt::Debug for ToolBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolBuilder")
.field("tool_type", &self.tool_type)
.field("function_name", &self.function_name)
.field("function_description", &self.function_description)
.field("function_properties", &self.function_properties)
.field("function_required", &self.function_required)
.field("executor", &self.executor.as_ref().map(|_| "<async_fn>")) .finish()
}
}
impl ToolBuilder {
pub fn new() -> Self {
ToolBuilder {
tool_type: Some(ToolType::Function),
function_properties: HashMap::new(),
function_required: Vec::new(),
..Default::default()
}
}
pub fn tool_type(mut self, tool_type: ToolType) -> Self {
self.tool_type = Some(tool_type);
self
}
pub fn function_name(mut self, name: impl Into<String>) -> Self {
self.function_name = Some(name.into());
self
}
pub fn function_description<T>(mut self, description: T) -> Self where T: Into<String> {
self.function_description = Some(description.into());
self
}
pub fn add_property(
mut self,
name: impl Into<String>,
property_type: impl Into<String>,
description: impl Into<String>,
) -> Self {
self.function_properties.insert(
name.into(),
Property {
property_type: property_type.into(),
description: description.into(),
},
);
self
}
pub fn add_required_property(
mut self,
name: impl Into<String>,
property_type: impl Into<String>,
description: impl Into<String>,
) -> Self {
let name = name.into();
self.function_properties.insert(
name.clone(),
Property {
property_type: property_type.into(),
description: description.into(),
},
);
self.function_required.push(name);
self
}
pub fn executor(mut self, exec: AsyncToolFn) -> Self {
self.executor = Some(exec);
self
}
pub fn executor_fn<F, Fut>(mut self, f: F) -> Self
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<String, crate::ToolExecutionError>> + Send + 'static,
{
let exec: AsyncToolFn = Arc::new(move |v: Value| Box::pin(f(v)));
self.executor = Some(exec);
self
}
pub fn build(self) -> Result<Tool, ToolBuilderError> {
let function_name = self.function_name.ok_or(ToolBuilderError::MissingFunctionName)?;
let function_description = self.function_description.ok_or(ToolBuilderError::MissingFunctionDescription)?;
let executor = self.executor.ok_or(ToolBuilderError::MissingExecutor)?;
let parameters = FunctionParameters {
param_type:"object".to_string(),
properties: self.function_properties,
required: self.function_required,
};
let function = Function {
name: function_name,
description: function_description,
parameters,
};
Ok(Tool {
tool_type: self.tool_type.unwrap_or(ToolType::Function),
function,
executor,
})
}
}
#[cfg(test)]
mod tests {
use super::*; use crate::AsyncToolFn;
use std::sync::Arc;
use serde_json::Value;
fn create_dummy_executor() -> AsyncToolFn {
Arc::new(|_args: Value| {
Box::pin(async { Ok("dummy execution".to_string()) })
})
}
#[test]
fn tool_builder_valid_tool() {
let tool_result = ToolBuilder::new()
.function_name("test_tool")
.function_description("A tool for testing")
.add_required_property("param1", "string", "A string parameter")
.executor(create_dummy_executor())
.build();
assert!(tool_result.is_ok());
let tool = tool_result.unwrap();
assert_eq!(tool.function.name, "test_tool");
assert_eq!(tool.function.parameters.properties.get("param1").unwrap().property_type, "string");
assert!(tool.function.parameters.required.contains(&"param1".to_string()));
}
#[test]
fn tool_builder_missing_name_fails() {
let tool_result = ToolBuilder::new()
.function_description("A tool missing a name")
.executor(create_dummy_executor())
.build();
assert!(tool_result.is_err());
assert_eq!(tool_result.unwrap_err(), ToolBuilderError::MissingFunctionName);
}
#[test]
fn tool_builder_missing_description_fails() {
let tool_result = ToolBuilder::new()
.function_name("test_tool_no_desc")
.executor(create_dummy_executor())
.build();
assert!(tool_result.is_err());
assert_eq!(tool_result.unwrap_err(), ToolBuilderError::MissingFunctionDescription);
}
#[test]
fn tool_builder_missing_executor_fails() {
let tool_result = ToolBuilder::new()
.function_name("test_tool_no_exec")
.function_description("A tool missing an executor")
.build();
assert!(tool_result.is_err());
assert_eq!(tool_result.unwrap_err(), ToolBuilderError::MissingExecutor);
}
}