1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14pub struct PreparedToolset<T, F, Deps = ()> {
37 inner: T,
38 prepare_fn: F,
39 _phantom: PhantomData<fn() -> Deps>,
40}
41
42impl<T, F, Deps> PreparedToolset<T, F, Deps>
43where
44 T: AbstractToolset<Deps>,
45 F: Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync,
46{
47 pub fn new(inner: T, prepare_fn: F) -> Self {
49 Self {
50 inner,
51 prepare_fn,
52 _phantom: PhantomData,
53 }
54 }
55
56 #[must_use]
58 pub fn inner(&self) -> &T {
59 &self.inner
60 }
61}
62
63#[async_trait]
64impl<T, F, Deps> AbstractToolset<Deps> for PreparedToolset<T, F, Deps>
65where
66 T: AbstractToolset<Deps>,
67 F: Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync,
68 Deps: Send + Sync,
69{
70 fn id(&self) -> Option<&str> {
71 self.inner.id()
72 }
73
74 fn type_name(&self) -> &'static str {
75 "PreparedToolset"
76 }
77
78 fn label(&self) -> String {
79 format!("PreparedToolset({})", self.inner.label())
80 }
81
82 async fn get_tools(
83 &self,
84 ctx: &RunContext<Deps>,
85 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
86 let inner_tools = self.inner.get_tools(ctx).await?;
87
88 let defs: Vec<ToolDefinition> = inner_tools.values().map(|t| t.tool_def.clone()).collect();
90
91 let prepared_defs = match (self.prepare_fn)(ctx, defs) {
93 Some(defs) => defs,
94 None => return Ok(HashMap::new()), };
96
97 let prepared_names: std::collections::HashSet<_> =
99 prepared_defs.iter().map(|d| d.name.clone()).collect();
100
101 let def_map: HashMap<String, ToolDefinition> = prepared_defs
103 .into_iter()
104 .map(|d| (d.name.clone(), d))
105 .collect();
106
107 Ok(inner_tools
108 .into_iter()
109 .filter(|(name, _)| prepared_names.contains(name))
110 .map(|(name, mut tool)| {
111 if let Some(prepared_def) = def_map.get(&name) {
113 tool.tool_def = prepared_def.clone();
114 }
115 (name, tool)
116 })
117 .collect())
118 }
119
120 async fn call_tool(
121 &self,
122 name: &str,
123 args: JsonValue,
124 ctx: &RunContext<Deps>,
125 tool: &ToolsetTool,
126 ) -> Result<ToolReturn, ToolError> {
127 self.inner.call_tool(name, args, ctx, tool).await
128 }
129
130 async fn enter(&self) -> Result<(), ToolError> {
131 self.inner.enter().await
132 }
133
134 async fn exit(&self) -> Result<(), ToolError> {
135 self.inner.exit().await
136 }
137}
138
139impl<T: std::fmt::Debug, F, Deps> std::fmt::Debug for PreparedToolset<T, F, Deps> {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("PreparedToolset")
142 .field("inner", &self.inner)
143 .finish()
144 }
145}
146
147pub mod preparers {
149 use serdes_ai_tools::{RunContext, ToolDefinition};
150
151 pub fn add_description_suffix<Deps>(
153 suffix: &str,
154 ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync + '_
155 {
156 move |_, defs| {
157 Some(
158 defs.into_iter()
159 .map(|mut d| {
160 d.description = format!("{} {}", d.description, suffix);
161 d
162 })
163 .collect(),
164 )
165 }
166 }
167
168 pub fn filter<Deps, F>(
170 pred: F,
171 ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync
172 where
173 F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
174 {
175 move |ctx, defs| Some(defs.into_iter().filter(|d| pred(ctx, d)).collect())
176 }
177
178 pub fn sort_by_name<Deps>(
180 ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync
181 {
182 |_, mut defs| {
183 defs.sort_by(|a, b| a.name.cmp(&b.name));
184 Some(defs)
185 }
186 }
187
188 pub fn limit<Deps>(
190 max: usize,
191 ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync
192 {
193 move |_, defs| Some(defs.into_iter().take(max).collect())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::FunctionToolset;
201 use async_trait::async_trait;
202 use serdes_ai_tools::Tool;
203
204 struct ToolA;
205
206 #[async_trait]
207 impl Tool<()> for ToolA {
208 fn definition(&self) -> ToolDefinition {
209 ToolDefinition::new("tool_a", "Tool A")
210 }
211
212 async fn call(
213 &self,
214 _ctx: &RunContext<()>,
215 _args: JsonValue,
216 ) -> Result<ToolReturn, ToolError> {
217 Ok(ToolReturn::text("A"))
218 }
219 }
220
221 struct AdminTool;
222
223 #[async_trait]
224 impl Tool<()> for AdminTool {
225 fn definition(&self) -> ToolDefinition {
226 ToolDefinition::new("admin_delete", "Admin delete")
227 }
228
229 async fn call(
230 &self,
231 _ctx: &RunContext<()>,
232 _args: JsonValue,
233 ) -> Result<ToolReturn, ToolError> {
234 Ok(ToolReturn::text("Deleted"))
235 }
236 }
237
238 #[tokio::test]
239 async fn test_prepared_toolset_filter() {
240 let toolset = FunctionToolset::new().tool(ToolA).tool(AdminTool);
241
242 let prepared = PreparedToolset::new(toolset, |_, defs| {
244 Some(
245 defs.into_iter()
246 .filter(|d| !d.name.starts_with("admin_"))
247 .collect(),
248 )
249 });
250
251 let ctx = RunContext::minimal("test");
252 let tools = prepared.get_tools(&ctx).await.unwrap();
253
254 assert_eq!(tools.len(), 1);
255 assert!(tools.contains_key("tool_a"));
256 assert!(!tools.contains_key("admin_delete"));
257 }
258
259 #[tokio::test]
260 async fn test_prepared_toolset_modify_description() {
261 let toolset = FunctionToolset::new().tool(ToolA);
262
263 let prepared = PreparedToolset::new(toolset, |_, defs| {
264 Some(
265 defs.into_iter()
266 .map(|mut d| {
267 d.description = format!("[MODIFIED] {}", d.description);
268 d
269 })
270 .collect(),
271 )
272 });
273
274 let ctx = RunContext::minimal("test");
275 let tools = prepared.get_tools(&ctx).await.unwrap();
276
277 let tool = tools.get("tool_a").unwrap();
278 assert!(tool.tool_def.description.starts_with("[MODIFIED]"));
279 }
280
281 #[tokio::test]
282 async fn test_prepared_toolset_returns_none() {
283 let toolset = FunctionToolset::new().tool(ToolA);
284
285 let prepared = PreparedToolset::new(toolset, |_, _| None);
287
288 let ctx = RunContext::minimal("test");
289 let tools = prepared.get_tools(&ctx).await.unwrap();
290
291 assert!(tools.is_empty());
292 }
293
294 #[tokio::test]
295 async fn test_prepared_toolset_call_still_works() {
296 let toolset = FunctionToolset::new().tool(ToolA);
297
298 let prepared = PreparedToolset::new(toolset, |_, defs| Some(defs));
299
300 let ctx = RunContext::minimal("test");
301 let tools = prepared.get_tools(&ctx).await.unwrap();
302 let tool = tools.get("tool_a").unwrap();
303
304 let result = prepared
305 .call_tool("tool_a", serde_json::json!({}), &ctx, tool)
306 .await
307 .unwrap();
308
309 assert_eq!(result.as_text(), Some("A"));
310 }
311}