use std::borrow::Cow;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::ToolResult;
pub trait ToolArgParser: Sized {
fn parse(value: serde_json::Value) -> Result<Self, serde_json::Error>;
}
impl<T> ToolArgParser for T
where
T: for<'de> serde::Deserialize<'de>,
{
fn parse(value: serde_json::Value) -> Result<Self, serde_json::Error> {
serde_json::from_value(value)
}
}
pub trait ToolArgs: ToolArgParser {
const NAME: &'static str;
const DESCRIPTION: &'static str;
fn __schema() -> serde_json::Value;
fn tool_definition() -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_string(),
description: Self::DESCRIPTION.to_string(),
parameters: Self::__schema(),
cache_control: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParallelSafety {
Safe,
CategoryExclusive,
Exclusive,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ToolCategory(pub Cow<'static, str>);
impl ToolCategory {
pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
pub const NETWORK: Self = Self(Cow::Borrowed("network"));
pub const DATABASE: Self = Self(Cow::Borrowed("database"));
pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
Self(name.into())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<crate::message::CacheControl>,
}
impl ToolDefinition {
pub fn with_cache(self, cache: crate::message::CacheControl) -> Self {
Self {
cache_control: Some(cache),
..self
}
}
pub fn compute_and_clean_schema<S: schemars::JsonSchema>() -> serde_json::Value {
let root = schemars::schema_for!(S);
let val = serde_json::to_value(&root)
.expect("Failed to serialize JsonSchema; this is a bug in schemars");
Self::clean_schema(val)
}
fn clean_schema(mut value: serde_json::Value) -> serde_json::Value {
if let Some(obj) = value.as_object_mut() {
obj.remove("$schema");
obj.remove("$id");
obj.remove("title");
obj.remove("description");
}
value
}
}
pub type ToolFn = Arc<
dyn Fn(&serde_json::Value) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync,
>;
#[doc(hidden)]
pub fn __tool_box<F>(f: F) -> Pin<Box<dyn Future<Output = ToolResult> + Send>>
where
F: Future<Output = ToolResult> + Send + 'static,
{
Box::pin(f)
}
#[derive(Clone)]
pub struct ExecutableTool {
pub definition: ToolDefinition,
pub safety: ParallelSafety,
pub category: Option<ToolCategory>,
executor: ToolFn,
}
impl ExecutableTool {
pub fn definition(&self) -> &ToolDefinition {
&self.definition
}
pub fn safety(&self) -> &ParallelSafety {
&self.safety
}
pub fn category(&self) -> Option<&ToolCategory> {
self.category.as_ref()
}
pub fn execute(
&self,
args: &serde_json::Value,
) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> {
(self.executor)(args)
}
pub fn from_fn(
def: ToolDefinition,
safety: ParallelSafety,
category: Option<ToolCategory>,
f: ToolFn,
) -> Self {
Self {
definition: def,
safety,
category,
executor: f,
}
}
pub fn safe<F, Fut>(def: ToolDefinition, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::Safe,
category: None,
executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn category_exclusive<F, Fut>(def: ToolDefinition, category: ToolCategory, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::CategoryExclusive,
category: Some(category),
executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn exclusive<F, Fut>(def: ToolDefinition, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::Exclusive,
category: None,
executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn safe_fn<T, F, Fut>(def: ToolDefinition, f: F) -> Self
where
T: ToolArgParser + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
let f = Arc::new(f);
Self::safe(def, move |value| {
let f = Arc::clone(&f);
let result = T::parse(value.clone());
async move {
match result {
Ok(parsed) => f(parsed).await,
Err(e) => Err(crate::ToolError::invalid_input(format!(
"invalid tool arguments: {e}"
))),
}
}
})
}
pub fn category_exclusive_fn<T, F, Fut>(
def: ToolDefinition,
category: ToolCategory,
f: F,
) -> Self
where
T: ToolArgParser + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
let f = Arc::new(f);
Self::category_exclusive(def, category, move |value| {
let f = Arc::clone(&f);
let result = T::parse(value.clone());
async move {
match result {
Ok(parsed) => f(parsed).await,
Err(e) => Err(crate::ToolError::invalid_input(format!(
"invalid tool arguments: {e}"
))),
}
}
})
}
pub fn exclusive_fn<T, F, Fut>(def: ToolDefinition, f: F) -> Self
where
T: ToolArgParser + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
let f = Arc::new(f);
Self::exclusive(def, move |value| {
let f = Arc::clone(&f);
let result = T::parse(value.clone());
async move {
match result {
Ok(parsed) => f(parsed).await,
Err(e) => Err(crate::ToolError::invalid_input(format!(
"invalid tool arguments: {e}"
))),
}
}
})
}
}
#[deprecated(since = "0.5.0", note = "Use `ExecutableTool` instead")]
pub type ToolRegistration = ExecutableTool;