1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use crate::{AbstractToolset, ToolsetTool};
14
15pub type ApprovalChecker<Deps> =
17 dyn Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync;
18
19pub struct ApprovalRequiredToolset<T, Deps = ()> {
40 inner: T,
41 approval_checker: Arc<ApprovalChecker<Deps>>,
42 _phantom: PhantomData<fn() -> Deps>,
43}
44
45impl<T, Deps> ApprovalRequiredToolset<T, Deps>
46where
47 T: AbstractToolset<Deps>,
48{
49 pub fn new(inner: T) -> Self {
51 Self {
52 inner,
53 approval_checker: Arc::new(|_, _, _| true), _phantom: PhantomData,
55 }
56 }
57
58 pub fn with_checker<F>(inner: T, checker: F) -> Self
63 where
64 F: Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync + 'static,
65 {
66 Self {
67 inner,
68 approval_checker: Arc::new(checker),
69 _phantom: PhantomData,
70 }
71 }
72
73 #[must_use]
75 pub fn inner(&self) -> &T {
76 &self.inner
77 }
78}
79
80#[async_trait]
81impl<T, Deps> AbstractToolset<Deps> for ApprovalRequiredToolset<T, Deps>
82where
83 T: AbstractToolset<Deps>,
84 Deps: Send + Sync,
85{
86 fn id(&self) -> Option<&str> {
87 self.inner.id()
88 }
89
90 fn type_name(&self) -> &'static str {
91 "ApprovalRequiredToolset"
92 }
93
94 fn label(&self) -> String {
95 format!("ApprovalRequiredToolset({})", self.inner.label())
96 }
97
98 async fn get_tools(
99 &self,
100 ctx: &RunContext<Deps>,
101 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
102 self.inner.get_tools(ctx).await
103 }
104
105 async fn call_tool(
106 &self,
107 name: &str,
108 args: JsonValue,
109 ctx: &RunContext<Deps>,
110 tool: &ToolsetTool,
111 ) -> Result<ToolReturn, ToolError> {
112 if (self.approval_checker)(ctx, &tool.tool_def, &args) {
114 return Err(ToolError::ApprovalRequired {
115 tool_name: name.to_string(),
116 args,
117 });
118 }
119
120 self.inner.call_tool(name, args, ctx, tool).await
122 }
123
124 async fn enter(&self) -> Result<(), ToolError> {
125 self.inner.enter().await
126 }
127
128 async fn exit(&self) -> Result<(), ToolError> {
129 self.inner.exit().await
130 }
131}
132
133impl<T: std::fmt::Debug, Deps> std::fmt::Debug for ApprovalRequiredToolset<T, Deps> {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("ApprovalRequiredToolset")
136 .field("inner", &self.inner)
137 .finish()
138 }
139}
140
141pub mod checkers {
143 use serde_json::Value as JsonValue;
144 use serdes_ai_tools::{RunContext, ToolDefinition};
145
146 pub fn always<Deps>(
148 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
149 |_, _, _| true
150 }
151
152 pub fn never<Deps>(
154 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
155 |_, _, _| false
156 }
157
158 pub fn name_contains<Deps>(
160 substrings: Vec<String>,
161 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
162 move |_, def, _| substrings.iter().any(|s| def.name.contains(s.as_str()))
163 }
164
165 pub fn tool_names<Deps>(
167 names: Vec<String>,
168 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
169 move |_, def, _| names.iter().any(|n| n == &def.name)
170 }
171
172 pub fn name_prefix<Deps>(
174 prefix: String,
175 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
176 move |_, def, _| def.name.starts_with(&prefix)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::FunctionToolset;
184 use async_trait::async_trait;
185 use serdes_ai_tools::Tool;
186
187 struct SafeTool;
188
189 #[async_trait]
190 impl Tool<()> for SafeTool {
191 fn definition(&self) -> ToolDefinition {
192 ToolDefinition::new("safe_read", "Safe read operation")
193 }
194
195 async fn call(
196 &self,
197 _ctx: &RunContext<()>,
198 _args: JsonValue,
199 ) -> Result<ToolReturn, ToolError> {
200 Ok(ToolReturn::text("read data"))
201 }
202 }
203
204 struct DangerousTool;
205
206 #[async_trait]
207 impl Tool<()> for DangerousTool {
208 fn definition(&self) -> ToolDefinition {
209 ToolDefinition::new("delete_all", "Delete all data")
210 }
211
212 async fn call(
213 &self,
214 _ctx: &RunContext<()>,
215 _args: JsonValue,
216 ) -> Result<ToolReturn, ToolError> {
217 Ok(ToolReturn::text("deleted"))
218 }
219 }
220
221 #[tokio::test]
222 async fn test_approval_required_all() {
223 let toolset = FunctionToolset::new().tool(SafeTool);
224 let approved = ApprovalRequiredToolset::new(toolset);
225
226 let ctx = RunContext::minimal("test");
227 let tools = approved.get_tools(&ctx).await.unwrap();
228 let tool = tools.get("safe_read").unwrap();
229
230 let result = approved
231 .call_tool("safe_read", serde_json::json!({}), &ctx, tool)
232 .await;
233
234 assert!(matches!(result, Err(ToolError::ApprovalRequired { .. })));
235 }
236
237 #[tokio::test]
238 async fn test_approval_required_selective() {
239 let toolset = FunctionToolset::new().tool(SafeTool).tool(DangerousTool);
240
241 let approved =
243 ApprovalRequiredToolset::with_checker(toolset, |_, def, _| def.name.contains("delete"));
244
245 let ctx = RunContext::minimal("test");
246 let tools = approved.get_tools(&ctx).await.unwrap();
247
248 let safe_tool = tools.get("safe_read").unwrap();
250 let result = approved
251 .call_tool("safe_read", serde_json::json!({}), &ctx, safe_tool)
252 .await;
253 assert!(result.is_ok());
254
255 let dangerous_tool = tools.get("delete_all").unwrap();
257 let result = approved
258 .call_tool("delete_all", serde_json::json!({}), &ctx, dangerous_tool)
259 .await;
260 assert!(matches!(result, Err(ToolError::ApprovalRequired { .. })));
261 }
262
263 #[tokio::test]
264 async fn test_approval_never() {
265 let toolset = FunctionToolset::new().tool(SafeTool);
266 let approved = ApprovalRequiredToolset::with_checker(toolset, checkers::never());
267
268 let ctx = RunContext::minimal("test");
269 let tools = approved.get_tools(&ctx).await.unwrap();
270 let tool = tools.get("safe_read").unwrap();
271
272 let result = approved
273 .call_tool("safe_read", serde_json::json!({}), &ctx, tool)
274 .await;
275
276 assert!(result.is_ok());
277 }
278
279 #[tokio::test]
280 async fn test_approval_checker_name_contains() {
281 let toolset = FunctionToolset::new().tool(SafeTool).tool(DangerousTool);
282 let approved = ApprovalRequiredToolset::with_checker(
283 toolset,
284 checkers::name_contains(vec!["delete".to_string(), "remove".to_string()]),
285 );
286
287 let ctx = RunContext::minimal("test");
288 let tools = approved.get_tools(&ctx).await.unwrap();
289
290 let dangerous = tools.get("delete_all").unwrap();
291 let result = approved
292 .call_tool("delete_all", serde_json::json!({}), &ctx, dangerous)
293 .await;
294
295 assert!(matches!(result, Err(ToolError::ApprovalRequired { .. })));
296 }
297}