use std::{collections::HashMap, fmt, future::Future, pin::Pin, sync::Arc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::errors::ToolExecutionError;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
Function,
}
pub type AsyncToolFn = Arc<
dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<String, ToolExecutionError>> + Send>>
+ Send
+ Sync,
>;
fn default_executor() -> AsyncToolFn {
Arc::new(|_| {
Box::pin(async {
panic!("Called a default, non-functional tool executor. The tool was not rehydrated after deserialization.");
})
})
}
#[derive(Serialize, Clone, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: Function,
#[serde(skip, default = "default_executor")]
pub executor: AsyncToolFn,
}
impl fmt::Debug for Tool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tool")
.field("tool_type", &self.tool_type)
.field("function", &self.function)
.field("executor", &"<async_fn>") .finish()
}
}
impl Tool {
pub async fn execute(&self, args: Value) -> Result<String, ToolExecutionError> {
(self.executor)(args).await
}
pub fn name(&self) -> &str {
&self.function.name
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub param_type: String,
pub properties: HashMap<String, Property>,
pub required: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Property {
#[serde(rename = "type")]
pub property_type: String,
pub description: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolCall {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(default = "default_tool_call_type", skip_serializing_if = "is_default_tool_call_type")]
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: ToolCallFunction,
}
fn default_tool_call_type() -> ToolType {
ToolType::Function }
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_default_tool_call_type(tool_type: &ToolType) -> bool {
*tool_type == default_tool_call_type()
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: serde_json::Value,
}