use super::chat_completion::{FunctionDefinition, Parameters, Tool};
use futures::{FutureExt, future::BoxFuture};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionTool {
#[serde(rename = "type")]
tool_type: String,
function: FunctionDef,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDef {
name: String,
description: String,
parameters: ParametersDef,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ParametersDef {
#[serde(rename = "type")]
param_type: String,
properties: HashMap<String, PropertyDef>,
required: Vec<String>,
#[serde(default)]
additional_properties: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PropertyDef {
#[serde(rename = "type")]
prop_type: String,
description: String,
}
impl FunctionTool {
pub fn from_type<T: JsonSchema>(name: &str, description: &str) -> Self {
let schema = schema_for!(T);
let val = serde_json::to_value(schema).unwrap_or_default();
let param_type = val
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("object")
.to_string();
let mut properties = HashMap::new();
if let Some(Value::Object(map)) = val.get("properties") {
for (k, v) in map {
let t = v
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("object")
.to_string();
let desc = v
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
properties.insert(
k.clone(),
PropertyDef {
prop_type: t,
description: desc,
},
);
}
}
let mut required = Vec::new();
if let Some(Value::Array(req)) = val.get("required") {
for r in req {
if let Some(s) = r.as_str() {
required.push(s.to_string());
}
}
}
FunctionTool {
tool_type: "function".to_string(),
function: FunctionDef {
name: name.to_string(),
description: description.to_string(),
parameters: ParametersDef {
param_type,
properties,
required,
additional_properties: false,
},
},
}
}
}
pub struct ParameterBuilder<F>(FunctionToolBuilder, std::marker::PhantomData<F>);
impl<F> ParameterBuilder<F> {
pub fn param<T: Into<String>>(
mut self,
name: T,
type_name: T,
description: T,
required: bool,
) -> Self {
let name_str = name.into();
self.0.properties.insert(
name_str.clone(),
PropertyDef {
prop_type: type_name.into(),
description: description.into(),
},
);
if required {
self.0.required.push(name_str);
}
self
}
pub fn build(self) -> FunctionTool {
self.0.build_internal()
}
}
pub struct FunctionToolBuilder {
name: String,
description: String,
properties: HashMap<String, PropertyDef>,
required: Vec<String>,
}
impl FunctionToolBuilder {
pub fn new<T: Into<String>>(name: T, description: T) -> Self {
Self {
name: name.into(),
description: description.into(),
properties: HashMap::new(),
required: Vec::new(),
}
}
pub fn params<F>(self) -> ParameterBuilder<F> {
ParameterBuilder(self, std::marker::PhantomData)
}
fn build_internal(self) -> FunctionTool {
FunctionTool {
tool_type: "function".to_string(),
function: FunctionDef {
name: self.name,
description: self.description,
parameters: ParametersDef {
param_type: "object".to_string(),
properties: self.properties,
required: self.required,
additional_properties: false,
},
},
}
}
}
pub struct ToolRegistry {
tools: HashMap<
String,
Box<dyn Fn(Value) -> BoxFuture<'static, Result<Value, String>> + Send + Sync>,
>,
definitions: HashMap<String, FunctionTool>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
definitions: HashMap::new(),
}
}
pub fn register<F, Fut, Args, Ret>(&mut self, tool: FunctionTool, f: F) -> &mut Self
where
F: Fn(Args) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = Ret> + Send + 'static,
Args: for<'de> Deserialize<'de> + Send + 'static,
Ret: Serialize + 'static,
{
let name = tool.function.name.clone();
self.tools.insert(
name.clone(),
Box::new(move |args| {
let f = f.clone();
let args: Result<Args, _> = serde_json::from_value(args);
async move {
let args =
args.map_err(|e| format!("Failed to deserialize arguments: {}", e))?;
let result = f(args).await;
serde_json::to_value(result)
.map_err(|e| format!("Failed to serialize result: {}", e))
}
.boxed()
}),
);
self.definitions.insert(name, tool);
self
}
pub fn register_with_callbacks<F, Fut, Args, Ret, C1, C2>(
&mut self,
tool: FunctionTool,
f: F,
on_call: Option<C1>,
on_result: Option<C2>,
) -> &mut Self
where
F: Fn(Args) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = Ret> + Send + 'static,
Args: for<'de> Deserialize<'de> + Clone + Send + 'static,
Ret: Serialize + Clone + 'static,
C1: Fn(&Args) + Send + Sync + Clone + 'static,
C2: Fn(&Args, &Ret) + Send + Sync + Clone + 'static,
{
let name = tool.function.name.clone();
self.tools.insert(
name.clone(),
Box::new(move |args| {
let f = f.clone();
let on_call = on_call.clone();
let on_result = on_result.clone();
let args: Result<Args, _> = serde_json::from_value(args);
async move {
let args =
args.map_err(|e| format!("Failed to deserialize arguments: {}", e))?;
if let Some(cb) = on_call.as_ref() {
cb(&args);
}
let result = f(args.clone()).await;
if let Some(cb) = on_result.as_ref() {
cb(&args, &result);
}
serde_json::to_value(result)
.map_err(|e| format!("Failed to serialize result: {}", e))
}
.boxed()
}),
);
self.definitions.insert(name, tool);
self
}
pub fn get_definitions(&self) -> Vec<&FunctionTool> {
self.definitions.values().collect()
}
pub async fn call(&self, name: &str, args: Value) -> Result<Value, String> {
tracing::info!("tool {}", name);
tracing::debug!("args: {}", args);
match self.tools.get(name) {
Some(f) => f(args).await,
None => Err(format!("Tool '{}' not found", name)),
}
}
pub fn to_api_tools(&self) -> Vec<Tool> {
self.get_definitions()
.into_iter()
.map(|tool| Tool {
tool_type: tool.tool_type.clone(),
function: FunctionDefinition {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: Parameters {
param_type: tool.function.parameters.param_type.clone(),
properties: serde_json::to_value(&tool.function.parameters.properties)
.unwrap_or_else(|_| serde_json::json!({})),
required: tool.function.parameters.required.clone(),
additional_properties: tool.function.parameters.additional_properties,
},
},
})
.collect()
}
}