use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ToolsetTool {
pub toolset_id: Option<String>,
pub tool_def: ToolDefinition,
pub max_retries: u32,
}
impl ToolsetTool {
#[must_use]
pub fn new(tool_def: ToolDefinition) -> Self {
Self {
toolset_id: None,
tool_def,
max_retries: 3,
}
}
#[must_use]
pub fn with_toolset_id(mut self, id: impl Into<String>) -> Self {
self.toolset_id = Some(id.into());
self
}
#[must_use]
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
#[must_use]
pub fn name(&self) -> &str {
&self.tool_def.name
}
#[must_use]
pub fn description(&self) -> &str {
&self.tool_def.description
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolsetInfo {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub type_name: String,
pub tool_count: usize,
pub tool_names: Vec<String>,
}
#[async_trait]
pub trait AbstractToolset<Deps = ()>: Send + Sync {
fn id(&self) -> Option<&str>;
fn label(&self) -> String {
let mut label = self.type_name().to_string();
if let Some(id) = self.id() {
label.push_str(&format!(" '{}'", id));
}
label
}
fn type_name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn tool_name_conflict_hint(&self) -> String {
format!(
"Rename the tool or use PrefixedToolset to avoid conflicts in {}.",
self.label()
)
}
async fn get_tools(
&self,
ctx: &RunContext<Deps>,
) -> Result<HashMap<String, ToolsetTool>, ToolError>;
async fn call_tool(
&self,
name: &str,
args: JsonValue,
ctx: &RunContext<Deps>,
tool: &ToolsetTool,
) -> Result<ToolReturn, ToolError>;
async fn enter(&self) -> Result<(), ToolError> {
Ok(())
}
async fn exit(&self) -> Result<(), ToolError> {
Ok(())
}
}
pub type BoxedToolset<Deps> = Box<dyn AbstractToolset<Deps>>;
pub type ToolsetResult<T> = Result<T, ToolError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_toolset_tool() {
let def = ToolDefinition::new("test", "Test tool");
let tool = ToolsetTool::new(def)
.with_toolset_id("my_toolset")
.with_max_retries(5);
assert_eq!(tool.name(), "test");
assert_eq!(tool.toolset_id, Some("my_toolset".to_string()));
assert_eq!(tool.max_retries, 5);
}
#[test]
fn test_toolset_info_serde() {
let info = ToolsetInfo {
id: Some("test_id".to_string()),
type_name: "TestToolset".to_string(),
tool_count: 3,
tool_names: vec!["a".to_string(), "b".to_string(), "c".to_string()],
};
let json = serde_json::to_string(&info).unwrap();
let parsed: ToolsetInfo = serde_json::from_str(&json).unwrap();
assert_eq!(info.id, parsed.id);
assert_eq!(info.tool_count, parsed.tool_count);
}
}