use serde::{Serialize, de::DeserializeOwned};
use std::{collections::BTreeMap, future::Future, marker::PhantomData, sync::Arc};
use crate::{
BoxError, BoxPinFut, Function, Json, Resource, ToolOutput, context::BaseContext,
model::FunctionDefinition, select_resources, validate_function_name,
};
pub trait Tool<C>: Send + Sync
where
C: BaseContext + Send + Sync,
{
type Args: DeserializeOwned + Send;
type Output: Serialize;
fn name(&self) -> String;
fn description(&self) -> String;
fn definition(&self) -> FunctionDefinition;
fn supported_resource_tags(&self) -> Vec<String> {
Vec::new()
}
fn select_resources(&self, resources: &mut Vec<Resource>) -> Vec<Resource> {
let supported_tags = self.supported_resource_tags();
select_resources(resources, &supported_tags)
}
fn init(&self, _ctx: C) -> impl Future<Output = Result<(), BoxError>> + Send {
futures::future::ready(Ok(()))
}
fn call(
&self,
ctx: C,
args: Self::Args,
resources: Vec<Resource>,
) -> impl Future<Output = Result<ToolOutput<Self::Output>, BoxError>> + Send;
fn call_raw(
&self,
ctx: C,
args: Json,
resources: Vec<Resource>,
) -> impl Future<Output = Result<ToolOutput<Json>, BoxError>> + Send {
async move {
let args: Self::Args = serde_json::from_value(args)
.map_err(|err| format!("tool {}, invalid args: {}", self.name(), err))?;
let mut result = self
.call(ctx, args, resources)
.await
.map_err(|err| format!("tool {}, call failed: {}", self.name(), err))?;
let output = serde_json::to_value(&result.output)?;
if result.usage.requests == 0 {
result.usage.requests = 1;
}
Ok(ToolOutput {
output,
artifacts: result.artifacts,
usage: result.usage,
})
}
}
}
pub trait DynTool<C>: Send + Sync
where
C: BaseContext + Send + Sync,
{
fn name(&self) -> String;
fn definition(&self) -> FunctionDefinition;
fn supported_resource_tags(&self) -> Vec<String>;
fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>>;
fn call(
&self,
ctx: C,
args: Json,
resources: Vec<Resource>,
) -> BoxPinFut<Result<ToolOutput<Json>, BoxError>>;
}
struct ToolWrapper<T, C>(Arc<T>, PhantomData<C>)
where
T: Tool<C> + 'static,
C: BaseContext + Send + Sync + 'static;
impl<T, C> DynTool<C> for ToolWrapper<T, C>
where
T: Tool<C> + 'static,
C: BaseContext + Send + Sync + 'static,
{
fn name(&self) -> String {
self.0.name()
}
fn definition(&self) -> FunctionDefinition {
self.0.definition()
}
fn supported_resource_tags(&self) -> Vec<String> {
self.0.supported_resource_tags()
}
fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>> {
let tool = self.0.clone();
Box::pin(async move { tool.init(ctx).await })
}
fn call(
&self,
ctx: C,
args: Json,
resources: Vec<Resource>,
) -> BoxPinFut<Result<ToolOutput<Json>, BoxError>> {
let tool = self.0.clone();
Box::pin(async move { tool.call_raw(ctx, args, resources).await })
}
}
#[derive(Default)]
pub struct ToolSet<C: BaseContext> {
pub set: BTreeMap<String, Arc<dyn DynTool<C>>>,
}
impl<C> ToolSet<C>
where
C: BaseContext + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
set: BTreeMap::new(),
}
}
pub fn contains(&self, name: &str) -> bool {
self.set.contains_key(&name.to_ascii_lowercase())
}
pub fn contains_lowercase(&self, lowercase_name: &str) -> bool {
self.set.contains_key(lowercase_name)
}
pub fn names(&self) -> Vec<String> {
self.set.keys().cloned().collect()
}
pub fn definition(&self, name: &str) -> Option<FunctionDefinition> {
self.set
.get(&name.to_ascii_lowercase())
.map(|tool| tool.definition())
}
pub fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition> {
let names: Option<Vec<String>> =
names.map(|names| names.iter().map(|n| n.to_ascii_lowercase()).collect());
self.set
.iter()
.filter_map(|(name, tool)| match &names {
Some(names) => {
if names.contains(name) {
Some(tool.definition())
} else {
None
}
}
None => Some(tool.definition()),
})
.collect()
}
pub fn functions(&self, names: Option<&[String]>) -> Vec<Function> {
let names: Option<Vec<String>> =
names.map(|names| names.iter().map(|n| n.to_ascii_lowercase()).collect());
self.set
.iter()
.filter_map(|(name, tool)| match &names {
Some(names) => {
if names.contains(name) {
Some(Function {
definition: tool.definition(),
supported_resource_tags: tool.supported_resource_tags(),
})
} else {
None
}
}
None => Some(Function {
definition: tool.definition(),
supported_resource_tags: tool.supported_resource_tags(),
}),
})
.collect()
}
pub fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
self.set
.get(&name.to_ascii_lowercase())
.map(|tool| {
let supported_tags = tool.supported_resource_tags();
select_resources(resources, &supported_tags)
})
.unwrap_or_default()
}
pub fn add<T>(&mut self, tool: Arc<T>) -> Result<(), BoxError>
where
T: Tool<C> + Send + Sync + 'static,
{
let name = tool.name().to_ascii_lowercase();
validate_function_name(&name)?;
if self.set.contains_key(&name) {
return Err(format!("tool {} already exists", name).into());
}
let tool_dyn = ToolWrapper(tool, PhantomData);
self.set.insert(name, Arc::new(tool_dyn));
Ok(())
}
pub fn get(&self, name: &str) -> Option<Arc<dyn DynTool<C>>> {
self.set.get(&name.to_ascii_lowercase()).cloned()
}
pub fn get_lowercase(&self, lowercase_name: &str) -> Option<Arc<dyn DynTool<C>>> {
self.set.get(lowercase_name).cloned()
}
}