use std::sync::Arc;
use super::base_tool::BaseTool;
use crate::{MaybeSend, MaybeSync};
#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
pub trait BaseToolset: MaybeSend + MaybeSync {
async fn get_tools(&self) -> Vec<&dyn BaseTool>;
async fn close(&self);
}
#[derive(Default)]
pub struct SimpleToolset {
tools: Vec<Box<dyn BaseTool>>,
}
impl SimpleToolset {
pub fn new<T>(tools: T) -> Self
where
T: IntoIterator<Item = Box<dyn BaseTool>>,
{
Self {
tools: tools.into_iter().collect(),
}
}
pub fn add_tool(&mut self, tool: Box<dyn BaseTool>) {
self.tools.push(tool);
}
pub fn add_tools<T>(&mut self, tools: T)
where
T: IntoIterator<Item = Box<dyn BaseTool>>,
{
self.tools.extend(tools);
}
#[must_use]
pub fn with_tool<U>(mut self, tool: U) -> Self
where
U: BaseTool + 'static,
{
self.add_tool(Box::new(tool));
self
}
#[must_use]
pub fn with_tools<I, U>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = U>,
U: BaseTool + 'static,
{
for tool in tools {
self.add_tool(Box::new(tool));
}
self
}
}
#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
impl BaseToolset for SimpleToolset {
async fn get_tools(&self) -> Vec<&dyn BaseTool> {
self.tools.iter().map(std::convert::AsRef::as_ref).collect()
}
async fn close(&self) {
}
}
pub struct CombinedToolset {
left: Arc<dyn BaseToolset>,
right: Arc<dyn BaseToolset>,
}
impl CombinedToolset {
pub fn new(left: Arc<dyn BaseToolset>, right: Arc<dyn BaseToolset>) -> Self {
Self { left, right }
}
}
#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
impl BaseToolset for CombinedToolset {
async fn get_tools(&self) -> Vec<&dyn BaseTool> {
let mut all_tools = self.left.get_tools().await;
all_tools.extend(self.right.get_tools().await);
all_tools
}
async fn close(&self) {
self.left.close().await;
self.right.close().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::function_tool::FunctionTool;
use serde_json::json;
fn build_tool(name: &str) -> Box<dyn BaseTool> {
Box::new(FunctionTool::new(name, "test tool", |_, _| {
Box::pin(async { crate::tools::ToolResult::success(json!(null)) })
}))
}
#[tokio::test(flavor = "current_thread")]
async fn simple_toolset_returns_tools() {
let toolset = SimpleToolset::new(vec![build_tool("alpha")]);
let tools = toolset.get_tools().await;
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name(), "alpha");
toolset.close().await;
}
#[tokio::test(flavor = "current_thread")]
async fn combined_toolset_aggregates_tools() {
let left = Arc::new(SimpleToolset::new(vec![build_tool("left")]));
let right = Arc::new(SimpleToolset::new(vec![build_tool("right")]));
let combined = CombinedToolset::new(left, right);
let tools = combined.get_tools().await;
let names: Vec<_> = tools.iter().map(|tool| tool.name().to_string()).collect();
assert_eq!(names, vec!["left".to_string(), "right".to_string()]);
combined.close().await;
}
#[tokio::test(flavor = "current_thread")]
async fn combined_toolset_handles_tool_name_clashes() {
let left = Arc::new(SimpleToolset::new(vec![build_tool("clash")]));
let right = Arc::new(SimpleToolset::new(vec![build_tool("clash")]));
let combined = CombinedToolset::new(left, right);
let tools = combined.get_tools().await;
let names: Vec<_> = tools.iter().map(|tool| tool.name().to_string()).collect();
assert_eq!(names.len(), 2);
assert!(names.contains(&"clash".to_string()));
combined.close().await;
}
}