Skip to main content

serdes_ai_toolsets/
renamed.rs

1//! Renamed toolset implementation.
2//!
3//! This module provides `RenamedToolset`, which renames specific tools
4//! from the wrapped toolset.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14/// Renames specific tools in a toolset.
15///
16/// This allows renaming individual tools without affecting others.
17///
18/// # Example
19///
20/// ```ignore
21/// use serdes_ai_toolsets::{RenamedToolset, FunctionToolset};
22/// use std::collections::HashMap;
23///
24/// let toolset = FunctionToolset::new().tool(search_tool);
25///
26/// let renamed = RenamedToolset::new(toolset)
27///     .rename("search", "find_items");
28/// ```
29pub struct RenamedToolset<T, Deps = ()> {
30    inner: T,
31    /// Maps new_name -> old_name
32    name_map: HashMap<String, String>,
33    _phantom: PhantomData<fn() -> Deps>,
34}
35
36impl<T, Deps> RenamedToolset<T, Deps>
37where
38    T: AbstractToolset<Deps>,
39{
40    /// Create a new renamed toolset.
41    pub fn new(inner: T) -> Self {
42        Self {
43            inner,
44            name_map: HashMap::new(),
45            _phantom: PhantomData,
46        }
47    }
48
49    /// Create with a name map.
50    pub fn with_map(inner: T, name_map: HashMap<String, String>) -> Self {
51        Self {
52            inner,
53            name_map,
54            _phantom: PhantomData,
55        }
56    }
57
58    /// Rename a tool.
59    ///
60    /// # Arguments
61    /// - `from`: The original tool name
62    /// - `to`: The new tool name
63    #[must_use]
64    pub fn rename(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
65        // name_map is new_name -> old_name
66        self.name_map.insert(to.into(), from.into());
67        self
68    }
69
70    /// Get the inner toolset.
71    #[must_use]
72    pub fn inner(&self) -> &T {
73        &self.inner
74    }
75
76    /// Get the name map.
77    #[must_use]
78    pub fn name_map(&self) -> &HashMap<String, String> {
79        &self.name_map
80    }
81
82    /// Get the original name for a (possibly renamed) tool.
83    fn original_name<'a>(&'a self, new_name: &'a str) -> &'a str {
84        self.name_map
85            .get(new_name)
86            .map(|s| s.as_str())
87            .unwrap_or(new_name)
88    }
89
90    /// Get the new name for an original tool name.
91    fn new_name(&self, original: &str) -> String {
92        // Reverse lookup in name_map
93        for (new, old) in &self.name_map {
94            if old == original {
95                return new.clone();
96            }
97        }
98        original.to_string()
99    }
100}
101
102#[async_trait]
103impl<T, Deps> AbstractToolset<Deps> for RenamedToolset<T, Deps>
104where
105    T: AbstractToolset<Deps>,
106    Deps: Send + Sync,
107{
108    fn id(&self) -> Option<&str> {
109        self.inner.id()
110    }
111
112    fn type_name(&self) -> &'static str {
113        "RenamedToolset"
114    }
115
116    fn label(&self) -> String {
117        format!("RenamedToolset({})", self.inner.label())
118    }
119
120    async fn get_tools(
121        &self,
122        ctx: &RunContext<Deps>,
123    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
124        let inner_tools = self.inner.get_tools(ctx).await?;
125
126        Ok(inner_tools
127            .into_iter()
128            .map(|(original_name, mut tool)| {
129                let new_name = self.new_name(&original_name);
130                tool.tool_def.name = new_name.clone();
131                (new_name, tool)
132            })
133            .collect())
134    }
135
136    async fn call_tool(
137        &self,
138        name: &str,
139        args: JsonValue,
140        ctx: &RunContext<Deps>,
141        tool: &ToolsetTool,
142    ) -> Result<ToolReturn, ToolError> {
143        let original_name = self.original_name(name);
144
145        // Create a tool with the original name for the inner toolset
146        let mut original_tool = tool.clone();
147        original_tool.tool_def.name = original_name.to_string();
148
149        self.inner
150            .call_tool(original_name, args, ctx, &original_tool)
151            .await
152    }
153
154    async fn enter(&self) -> Result<(), ToolError> {
155        self.inner.enter().await
156    }
157
158    async fn exit(&self) -> Result<(), ToolError> {
159        self.inner.exit().await
160    }
161}
162
163impl<T: std::fmt::Debug, Deps> std::fmt::Debug for RenamedToolset<T, Deps> {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        f.debug_struct("RenamedToolset")
166            .field("inner", &self.inner)
167            .field("name_map", &self.name_map)
168            .finish()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::FunctionToolset;
176    use async_trait::async_trait;
177    use serdes_ai_tools::{Tool, ToolDefinition};
178
179    struct SearchTool;
180
181    #[async_trait]
182    impl Tool<()> for SearchTool {
183        fn definition(&self) -> ToolDefinition {
184            ToolDefinition::new("search", "Search for items")
185        }
186
187        async fn call(
188            &self,
189            _ctx: &RunContext<()>,
190            _args: JsonValue,
191        ) -> Result<ToolReturn, ToolError> {
192            Ok(ToolReturn::text("search result"))
193        }
194    }
195
196    struct QueryTool;
197
198    #[async_trait]
199    impl Tool<()> for QueryTool {
200        fn definition(&self) -> ToolDefinition {
201            ToolDefinition::new("query", "Query the database")
202        }
203
204        async fn call(
205            &self,
206            _ctx: &RunContext<()>,
207            _args: JsonValue,
208        ) -> Result<ToolReturn, ToolError> {
209            Ok(ToolReturn::text("query result"))
210        }
211    }
212
213    #[test]
214    fn test_original_name() {
215        let toolset = FunctionToolset::new().tool(SearchTool);
216        let renamed = RenamedToolset::new(toolset).rename("search", "find");
217
218        assert_eq!(renamed.original_name("find"), "search");
219        assert_eq!(renamed.original_name("other"), "other");
220    }
221
222    #[test]
223    fn test_new_name() {
224        let toolset = FunctionToolset::new().tool(SearchTool);
225        let renamed = RenamedToolset::new(toolset).rename("search", "find");
226
227        assert_eq!(renamed.new_name("search"), "find");
228        assert_eq!(renamed.new_name("other"), "other");
229    }
230
231    #[tokio::test]
232    async fn test_renamed_toolset_get_tools() {
233        let toolset = FunctionToolset::new().tool(SearchTool).tool(QueryTool);
234        let renamed = RenamedToolset::new(toolset).rename("search", "find_items");
235
236        let ctx = RunContext::minimal("test");
237        let tools = renamed.get_tools(&ctx).await.unwrap();
238
239        assert_eq!(tools.len(), 2);
240        assert!(tools.contains_key("find_items"));
241        assert!(tools.contains_key("query"));
242        assert!(!tools.contains_key("search"));
243    }
244
245    #[tokio::test]
246    async fn test_renamed_toolset_call_tool() {
247        let toolset = FunctionToolset::new().tool(SearchTool);
248        let renamed = RenamedToolset::new(toolset).rename("search", "find_items");
249
250        let ctx = RunContext::minimal("test");
251        let tools = renamed.get_tools(&ctx).await.unwrap();
252        let tool = tools.get("find_items").unwrap();
253
254        let result = renamed
255            .call_tool("find_items", serde_json::json!({}), &ctx, tool)
256            .await
257            .unwrap();
258
259        assert_eq!(result.as_text(), Some("search result"));
260    }
261
262    #[tokio::test]
263    async fn test_renamed_toolset_multiple_renames() {
264        let toolset = FunctionToolset::new().tool(SearchTool).tool(QueryTool);
265        let renamed = RenamedToolset::new(toolset)
266            .rename("search", "find")
267            .rename("query", "lookup");
268
269        let ctx = RunContext::minimal("test");
270        let tools = renamed.get_tools(&ctx).await.unwrap();
271
272        assert!(tools.contains_key("find"));
273        assert!(tools.contains_key("lookup"));
274    }
275}