use async_trait::async_trait;
use serde_json::Value as JsonValue;
use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
use std::collections::HashMap;
use crate::{AbstractToolset, BoxedToolset, ToolsetTool};
pub struct CombinedToolset<Deps = ()> {
id: Option<String>,
toolsets: Vec<BoxedToolset<Deps>>,
}
impl<Deps> CombinedToolset<Deps> {
#[must_use]
pub fn new() -> Self {
Self {
id: None,
toolsets: Vec::new(),
}
}
#[must_use]
pub fn with_id(id: impl Into<String>) -> Self {
Self {
id: Some(id.into()),
toolsets: Vec::new(),
}
}
#[must_use]
pub fn with_toolset<T: AbstractToolset<Deps> + 'static>(mut self, toolset: T) -> Self {
self.toolsets.push(Box::new(toolset));
self
}
#[must_use]
pub fn add_boxed(mut self, toolset: BoxedToolset<Deps>) -> Self {
self.toolsets.push(toolset);
self
}
#[must_use]
pub fn toolsets<I, T>(mut self, toolsets: I) -> Self
where
I: IntoIterator<Item = T>,
T: AbstractToolset<Deps> + 'static,
{
for toolset in toolsets {
self.toolsets.push(Box::new(toolset));
}
self
}
#[must_use]
pub fn toolset_count(&self) -> usize {
self.toolsets.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.toolsets.is_empty()
}
}
impl<Deps> Default for CombinedToolset<Deps> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
struct ToolOwnership {
toolset_index: usize,
tool: ToolsetTool,
}
#[async_trait]
impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for CombinedToolset<Deps> {
fn id(&self) -> Option<&str> {
self.id.as_deref()
}
fn type_name(&self) -> &'static str {
"CombinedToolset"
}
fn tool_name_conflict_hint(&self) -> String {
"Use PrefixedToolset to add prefixes to tool names from different toolsets.".to_string()
}
async fn get_tools(
&self,
ctx: &RunContext<Deps>,
) -> Result<HashMap<String, ToolsetTool>, ToolError> {
let mut all_tools: HashMap<String, ToolOwnership> = HashMap::new();
let mut conflicts: Vec<(String, String, String)> = Vec::new();
for (idx, toolset) in self.toolsets.iter().enumerate() {
let tools = toolset.get_tools(ctx).await?;
for (name, tool) in tools {
if let Some(existing) = all_tools.get(&name) {
let existing_label = self.toolsets[existing.toolset_index].label();
let new_label = toolset.label();
conflicts.push((name.clone(), existing_label, new_label));
} else {
all_tools.insert(
name,
ToolOwnership {
toolset_index: idx,
tool,
},
);
}
}
}
if !conflicts.is_empty() {
let conflict_msgs: Vec<String> = conflicts
.iter()
.map(|(name, t1, t2)| format!(" - '{}' exists in {} and {}", name, t1, t2))
.collect();
return Err(ToolError::execution_failed(format!(
"Tool name conflicts in {}:\n{}\n\nHint: {}",
self.label(),
conflict_msgs.join("\n"),
self.tool_name_conflict_hint()
)));
}
Ok(all_tools
.into_iter()
.map(|(name, ownership)| (name, ownership.tool))
.collect())
}
async fn call_tool(
&self,
name: &str,
args: JsonValue,
ctx: &RunContext<Deps>,
tool: &ToolsetTool,
) -> Result<ToolReturn, ToolError> {
for toolset in &self.toolsets {
let tools = toolset.get_tools(ctx).await?;
if tools.contains_key(name) {
return toolset.call_tool(name, args, ctx, tool).await;
}
}
Err(ToolError::not_found(format!(
"Tool '{}' not found in {}",
name,
self.label()
)))
}
async fn enter(&self) -> Result<(), ToolError> {
for toolset in &self.toolsets {
toolset.enter().await?;
}
Ok(())
}
async fn exit(&self) -> Result<(), ToolError> {
for toolset in self.toolsets.iter().rev() {
toolset.exit().await?;
}
Ok(())
}
}
impl<Deps> std::fmt::Debug for CombinedToolset<Deps> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CombinedToolset")
.field("id", &self.id)
.field("toolset_count", &self.toolsets.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FunctionToolset;
use async_trait::async_trait;
use serdes_ai_tools::{Tool, ToolDefinition};
struct ToolA;
#[async_trait]
impl Tool<()> for ToolA {
fn definition(&self) -> ToolDefinition {
ToolDefinition::new("tool_a", "Tool A")
}
async fn call(
&self,
_ctx: &RunContext<()>,
_args: JsonValue,
) -> Result<ToolReturn, ToolError> {
Ok(ToolReturn::text("A"))
}
}
struct ToolB;
#[async_trait]
impl Tool<()> for ToolB {
fn definition(&self) -> ToolDefinition {
ToolDefinition::new("tool_b", "Tool B")
}
async fn call(
&self,
_ctx: &RunContext<()>,
_args: JsonValue,
) -> Result<ToolReturn, ToolError> {
Ok(ToolReturn::text("B"))
}
}
struct ConflictingTool;
#[async_trait]
impl Tool<()> for ConflictingTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition::new("tool_a", "Conflicting Tool A") }
async fn call(
&self,
_ctx: &RunContext<()>,
_args: JsonValue,
) -> Result<ToolReturn, ToolError> {
Ok(ToolReturn::text("Conflict"))
}
}
#[test]
fn test_combined_toolset_new() {
let toolset = CombinedToolset::<()>::new();
assert!(toolset.is_empty());
assert_eq!(toolset.toolset_count(), 0);
}
#[test]
fn test_combined_toolset_with_id() {
let toolset = CombinedToolset::<()>::with_id("combined");
assert_eq!(toolset.id(), Some("combined"));
}
#[tokio::test]
async fn test_combined_toolset_merges_tools() {
let ts1 = FunctionToolset::new().with_id("ts1").tool(ToolA);
let ts2 = FunctionToolset::new().with_id("ts2").tool(ToolB);
let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
let ctx = RunContext::minimal("test");
let tools = combined.get_tools(&ctx).await.unwrap();
assert_eq!(tools.len(), 2);
assert!(tools.contains_key("tool_a"));
assert!(tools.contains_key("tool_b"));
}
#[tokio::test]
async fn test_combined_toolset_call_tool() {
let ts1 = FunctionToolset::new().tool(ToolA);
let ts2 = FunctionToolset::new().tool(ToolB);
let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
let ctx = RunContext::minimal("test");
let tools = combined.get_tools(&ctx).await.unwrap();
let tool_a = tools.get("tool_a").unwrap();
let result = combined
.call_tool("tool_a", serde_json::json!({}), &ctx, tool_a)
.await
.unwrap();
assert_eq!(result.as_text(), Some("A"));
}
#[tokio::test]
async fn test_combined_toolset_conflict_detection() {
let ts1 = FunctionToolset::new().with_id("ts1").tool(ToolA);
let ts2 = FunctionToolset::new().with_id("ts2").tool(ConflictingTool);
let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
let ctx = RunContext::minimal("test");
let result = combined.get_tools(&ctx).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("conflict"));
assert!(err.message().contains("tool_a"));
}
#[tokio::test]
async fn test_combined_toolset_enter_exit() {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
let enter_count = Arc::new(AtomicU32::new(0));
let exit_count = Arc::new(AtomicU32::new(0));
struct TrackedToolset {
enter_count: Arc<AtomicU32>,
exit_count: Arc<AtomicU32>,
}
#[async_trait]
impl AbstractToolset<()> for TrackedToolset {
fn id(&self) -> Option<&str> {
None
}
async fn get_tools(
&self,
_ctx: &RunContext<()>,
) -> Result<HashMap<String, ToolsetTool>, ToolError> {
Ok(HashMap::new())
}
async fn call_tool(
&self,
_name: &str,
_args: JsonValue,
_ctx: &RunContext<()>,
_tool: &ToolsetTool,
) -> Result<ToolReturn, ToolError> {
Ok(ToolReturn::empty())
}
async fn enter(&self) -> Result<(), ToolError> {
self.enter_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn exit(&self) -> Result<(), ToolError> {
self.exit_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
let ts1 = TrackedToolset {
enter_count: enter_count.clone(),
exit_count: exit_count.clone(),
};
let ts2 = TrackedToolset {
enter_count: enter_count.clone(),
exit_count: exit_count.clone(),
};
let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
combined.enter().await.unwrap();
assert_eq!(enter_count.load(Ordering::SeqCst), 2);
combined.exit().await.unwrap();
assert_eq!(exit_count.load(Ordering::SeqCst), 2);
}
}