use async_trait::async_trait;
use serde_json::Value as JsonValue;
use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::{AbstractToolset, ToolsetTool};
pub struct RenamedToolset<T, Deps = ()> {
inner: T,
name_map: HashMap<String, String>,
_phantom: PhantomData<fn() -> Deps>,
}
impl<T, Deps> RenamedToolset<T, Deps>
where
T: AbstractToolset<Deps>,
{
pub fn new(inner: T) -> Self {
Self {
inner,
name_map: HashMap::new(),
_phantom: PhantomData,
}
}
pub fn with_map(inner: T, name_map: HashMap<String, String>) -> Self {
Self {
inner,
name_map,
_phantom: PhantomData,
}
}
#[must_use]
pub fn rename(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.name_map.insert(to.into(), from.into());
self
}
#[must_use]
pub fn inner(&self) -> &T {
&self.inner
}
#[must_use]
pub fn name_map(&self) -> &HashMap<String, String> {
&self.name_map
}
fn original_name<'a>(&'a self, new_name: &'a str) -> &'a str {
self.name_map
.get(new_name)
.map(|s| s.as_str())
.unwrap_or(new_name)
}
fn new_name(&self, original: &str) -> String {
for (new, old) in &self.name_map {
if old == original {
return new.clone();
}
}
original.to_string()
}
}
#[async_trait]
impl<T, Deps> AbstractToolset<Deps> for RenamedToolset<T, Deps>
where
T: AbstractToolset<Deps>,
Deps: Send + Sync,
{
fn id(&self) -> Option<&str> {
self.inner.id()
}
fn type_name(&self) -> &'static str {
"RenamedToolset"
}
fn label(&self) -> String {
format!("RenamedToolset({})", self.inner.label())
}
async fn get_tools(
&self,
ctx: &RunContext<Deps>,
) -> Result<HashMap<String, ToolsetTool>, ToolError> {
let inner_tools = self.inner.get_tools(ctx).await?;
Ok(inner_tools
.into_iter()
.map(|(original_name, mut tool)| {
let new_name = self.new_name(&original_name);
tool.tool_def.name = new_name.clone();
(new_name, tool)
})
.collect())
}
async fn call_tool(
&self,
name: &str,
args: JsonValue,
ctx: &RunContext<Deps>,
tool: &ToolsetTool,
) -> Result<ToolReturn, ToolError> {
let original_name = self.original_name(name);
let mut original_tool = tool.clone();
original_tool.tool_def.name = original_name.to_string();
self.inner
.call_tool(original_name, args, ctx, &original_tool)
.await
}
async fn enter(&self) -> Result<(), ToolError> {
self.inner.enter().await
}
async fn exit(&self) -> Result<(), ToolError> {
self.inner.exit().await
}
}
impl<T: std::fmt::Debug, Deps> std::fmt::Debug for RenamedToolset<T, Deps> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RenamedToolset")
.field("inner", &self.inner)
.field("name_map", &self.name_map)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FunctionToolset;
use async_trait::async_trait;
use serdes_ai_tools::{Tool, ToolDefinition};
struct SearchTool;
#[async_trait]
impl Tool<()> for SearchTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition::new("search", "Search for items")
}
async fn call(
&self,
_ctx: &RunContext<()>,
_args: JsonValue,
) -> Result<ToolReturn, ToolError> {
Ok(ToolReturn::text("search result"))
}
}
struct QueryTool;
#[async_trait]
impl Tool<()> for QueryTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition::new("query", "Query the database")
}
async fn call(
&self,
_ctx: &RunContext<()>,
_args: JsonValue,
) -> Result<ToolReturn, ToolError> {
Ok(ToolReturn::text("query result"))
}
}
#[test]
fn test_original_name() {
let toolset = FunctionToolset::new().tool(SearchTool);
let renamed = RenamedToolset::new(toolset).rename("search", "find");
assert_eq!(renamed.original_name("find"), "search");
assert_eq!(renamed.original_name("other"), "other");
}
#[test]
fn test_new_name() {
let toolset = FunctionToolset::new().tool(SearchTool);
let renamed = RenamedToolset::new(toolset).rename("search", "find");
assert_eq!(renamed.new_name("search"), "find");
assert_eq!(renamed.new_name("other"), "other");
}
#[tokio::test]
async fn test_renamed_toolset_get_tools() {
let toolset = FunctionToolset::new().tool(SearchTool).tool(QueryTool);
let renamed = RenamedToolset::new(toolset).rename("search", "find_items");
let ctx = RunContext::minimal("test");
let tools = renamed.get_tools(&ctx).await.unwrap();
assert_eq!(tools.len(), 2);
assert!(tools.contains_key("find_items"));
assert!(tools.contains_key("query"));
assert!(!tools.contains_key("search"));
}
#[tokio::test]
async fn test_renamed_toolset_call_tool() {
let toolset = FunctionToolset::new().tool(SearchTool);
let renamed = RenamedToolset::new(toolset).rename("search", "find_items");
let ctx = RunContext::minimal("test");
let tools = renamed.get_tools(&ctx).await.unwrap();
let tool = tools.get("find_items").unwrap();
let result = renamed
.call_tool("find_items", serde_json::json!({}), &ctx, tool)
.await
.unwrap();
assert_eq!(result.as_text(), Some("search result"));
}
#[tokio::test]
async fn test_renamed_toolset_multiple_renames() {
let toolset = FunctionToolset::new().tool(SearchTool).tool(QueryTool);
let renamed = RenamedToolset::new(toolset)
.rename("search", "find")
.rename("query", "lookup");
let ctx = RunContext::minimal("test");
let tools = renamed.get_tools(&ctx).await.unwrap();
assert!(tools.contains_key("find"));
assert!(tools.contains_key("lookup"));
}
}