serdes_ai_toolsets/
prefixed.rs1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14pub struct PrefixedToolset<T, Deps = ()> {
34 inner: T,
35 prefix: String,
36 separator: String,
37 _phantom: PhantomData<fn() -> Deps>,
38}
39
40impl<T, Deps> PrefixedToolset<T, Deps>
41where
42 T: AbstractToolset<Deps>,
43{
44 pub fn new(inner: T, prefix: impl Into<String>) -> Self {
46 Self {
47 inner,
48 prefix: prefix.into(),
49 separator: "_".to_string(),
50 _phantom: PhantomData,
51 }
52 }
53
54 #[must_use]
56 pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
57 self.separator = sep.into();
58 self
59 }
60
61 #[must_use]
63 pub fn prefix(&self) -> &str {
64 &self.prefix
65 }
66
67 #[must_use]
69 pub fn separator(&self) -> &str {
70 &self.separator
71 }
72
73 #[must_use]
75 pub fn inner(&self) -> &T {
76 &self.inner
77 }
78
79 fn prefixed_name(&self, name: &str) -> String {
81 format!("{}{}{}", self.prefix, self.separator, name)
82 }
83
84 fn strip_prefix(&self, prefixed: &str) -> Option<String> {
86 let prefix_with_sep = format!("{}{}", self.prefix, self.separator);
87 prefixed
88 .strip_prefix(&prefix_with_sep)
89 .map(|s| s.to_string())
90 }
91}
92
93#[async_trait]
94impl<T, Deps> AbstractToolset<Deps> for PrefixedToolset<T, Deps>
95where
96 T: AbstractToolset<Deps>,
97 Deps: Send + Sync,
98{
99 fn id(&self) -> Option<&str> {
100 self.inner.id()
101 }
102
103 fn type_name(&self) -> &'static str {
104 "PrefixedToolset"
105 }
106
107 fn label(&self) -> String {
108 format!(
109 "PrefixedToolset('{}{}', {})",
110 self.prefix,
111 self.separator,
112 self.inner.label()
113 )
114 }
115
116 async fn get_tools(
117 &self,
118 ctx: &RunContext<Deps>,
119 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
120 let inner_tools = self.inner.get_tools(ctx).await?;
121
122 Ok(inner_tools
123 .into_iter()
124 .map(|(name, mut tool)| {
125 let prefixed = self.prefixed_name(&name);
126 tool.tool_def.name = prefixed.clone();
128 (prefixed, tool)
129 })
130 .collect())
131 }
132
133 async fn call_tool(
134 &self,
135 name: &str,
136 args: JsonValue,
137 ctx: &RunContext<Deps>,
138 tool: &ToolsetTool,
139 ) -> Result<ToolReturn, ToolError> {
140 let original_name = self.strip_prefix(name).ok_or_else(|| {
142 ToolError::not_found(format!(
143 "Tool '{}' does not have expected prefix '{}{}'",
144 name, self.prefix, self.separator
145 ))
146 })?;
147
148 let mut original_tool = tool.clone();
150 original_tool.tool_def.name = original_name.clone();
151
152 self.inner
153 .call_tool(&original_name, args, ctx, &original_tool)
154 .await
155 }
156
157 async fn enter(&self) -> Result<(), ToolError> {
158 self.inner.enter().await
159 }
160
161 async fn exit(&self) -> Result<(), ToolError> {
162 self.inner.exit().await
163 }
164}
165
166impl<T: std::fmt::Debug, Deps> std::fmt::Debug for PrefixedToolset<T, Deps> {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("PrefixedToolset")
169 .field("prefix", &self.prefix)
170 .field("separator", &self.separator)
171 .field("inner", &self.inner)
172 .finish()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::FunctionToolset;
180 use async_trait::async_trait;
181 use serdes_ai_tools::{Tool, ToolDefinition};
182
183 struct SearchTool;
184
185 #[async_trait]
186 impl Tool<()> for SearchTool {
187 fn definition(&self) -> ToolDefinition {
188 ToolDefinition::new("search", "Search for items")
189 }
190
191 async fn call(
192 &self,
193 _ctx: &RunContext<()>,
194 args: JsonValue,
195 ) -> Result<ToolReturn, ToolError> {
196 let query = args["query"].as_str().unwrap_or("*");
197 Ok(ToolReturn::text(format!("Searching for: {}", query)))
198 }
199 }
200
201 #[test]
202 fn test_prefixed_name() {
203 let toolset = FunctionToolset::new().tool(SearchTool);
204 let prefixed = PrefixedToolset::new(toolset, "web");
205
206 assert_eq!(prefixed.prefixed_name("search"), "web_search");
207 }
208
209 #[test]
210 fn test_strip_prefix() {
211 let toolset = FunctionToolset::new().tool(SearchTool);
212 let prefixed = PrefixedToolset::new(toolset, "web");
213
214 assert_eq!(
215 prefixed.strip_prefix("web_search"),
216 Some("search".to_string())
217 );
218 assert_eq!(prefixed.strip_prefix("local_search"), None);
219 }
220
221 #[test]
222 fn test_custom_separator() {
223 let toolset = FunctionToolset::new().tool(SearchTool);
224 let prefixed = PrefixedToolset::new(toolset, "web").with_separator("::");
225
226 assert_eq!(prefixed.prefixed_name("search"), "web::search");
227 }
228
229 #[tokio::test]
230 async fn test_prefixed_toolset_get_tools() {
231 let toolset = FunctionToolset::new().tool(SearchTool);
232 let prefixed = PrefixedToolset::new(toolset, "web");
233
234 let ctx = RunContext::minimal("test");
235 let tools = prefixed.get_tools(&ctx).await.unwrap();
236
237 assert_eq!(tools.len(), 1);
238 assert!(tools.contains_key("web_search"));
239 assert!(!tools.contains_key("search"));
240
241 let tool = tools.get("web_search").unwrap();
242 assert_eq!(tool.tool_def.name, "web_search");
243 }
244
245 #[tokio::test]
246 async fn test_prefixed_toolset_call_tool() {
247 let toolset = FunctionToolset::new().tool(SearchTool);
248 let prefixed = PrefixedToolset::new(toolset, "web");
249
250 let ctx = RunContext::minimal("test");
251 let tools = prefixed.get_tools(&ctx).await.unwrap();
252 let tool = tools.get("web_search").unwrap();
253
254 let result = prefixed
255 .call_tool(
256 "web_search",
257 serde_json::json!({"query": "rust"}),
258 &ctx,
259 tool,
260 )
261 .await
262 .unwrap();
263
264 assert!(result.as_text().unwrap().contains("rust"));
265 }
266
267 #[tokio::test]
268 async fn test_prefixed_toolset_wrong_prefix() {
269 let toolset = FunctionToolset::new().tool(SearchTool);
270 let prefixed = PrefixedToolset::new(toolset, "web");
271
272 let ctx = RunContext::minimal("test");
273 let fake_tool = ToolsetTool::new(ToolDefinition::new("local_search", "Local search"));
274
275 let result = prefixed
276 .call_tool("local_search", serde_json::json!({}), &ctx, &fake_tool)
277 .await;
278
279 assert!(result.is_err());
280 }
281}