Skip to main content

serdes_ai_toolsets/
approval.rs

1//! Approval-required toolset implementation.
2//!
3//! This module provides `ApprovalRequiredToolset`, which requires approval
4//! before executing any tool.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use crate::{AbstractToolset, ToolsetTool};
14
15/// Type alias for approval checker functions.
16pub type ApprovalChecker<Deps> =
17    dyn Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync;
18
19/// Requires approval for tool calls.
20///
21/// When a tool is called, if approval is required, the toolset returns
22/// `ToolError::ApprovalRequired` instead of executing the tool.
23///
24/// # Example
25///
26/// ```ignore
27/// use serdes_ai_toolsets::{ApprovalRequiredToolset, FunctionToolset};
28///
29/// let toolset = FunctionToolset::new().tool(dangerous_tool);
30///
31/// // Require approval for all tools
32/// let approved = ApprovalRequiredToolset::new(toolset);
33///
34/// // Or with a custom checker
35/// let approved = ApprovalRequiredToolset::with_checker(toolset, |ctx, def, args| {
36///     def.name.contains("delete") || def.name.contains("modify")
37/// });
38/// ```
39pub struct ApprovalRequiredToolset<T, Deps = ()> {
40    inner: T,
41    approval_checker: Arc<ApprovalChecker<Deps>>,
42    _phantom: PhantomData<fn() -> Deps>,
43}
44
45impl<T, Deps> ApprovalRequiredToolset<T, Deps>
46where
47    T: AbstractToolset<Deps>,
48{
49    /// Create a toolset that requires approval for ALL tool calls.
50    pub fn new(inner: T) -> Self {
51        Self {
52            inner,
53            approval_checker: Arc::new(|_, _, _| true), // Always require approval
54            _phantom: PhantomData,
55        }
56    }
57
58    /// Create a toolset with a custom approval checker.
59    ///
60    /// The checker returns `true` if approval is required for the given
61    /// tool call, `false` if the call can proceed without approval.
62    pub fn with_checker<F>(inner: T, checker: F) -> Self
63    where
64        F: Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync + 'static,
65    {
66        Self {
67            inner,
68            approval_checker: Arc::new(checker),
69            _phantom: PhantomData,
70        }
71    }
72
73    /// Get the inner toolset.
74    #[must_use]
75    pub fn inner(&self) -> &T {
76        &self.inner
77    }
78}
79
80#[async_trait]
81impl<T, Deps> AbstractToolset<Deps> for ApprovalRequiredToolset<T, Deps>
82where
83    T: AbstractToolset<Deps>,
84    Deps: Send + Sync,
85{
86    fn id(&self) -> Option<&str> {
87        self.inner.id()
88    }
89
90    fn type_name(&self) -> &'static str {
91        "ApprovalRequiredToolset"
92    }
93
94    fn label(&self) -> String {
95        format!("ApprovalRequiredToolset({})", self.inner.label())
96    }
97
98    async fn get_tools(
99        &self,
100        ctx: &RunContext<Deps>,
101    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
102        self.inner.get_tools(ctx).await
103    }
104
105    async fn call_tool(
106        &self,
107        name: &str,
108        args: JsonValue,
109        ctx: &RunContext<Deps>,
110        tool: &ToolsetTool,
111    ) -> Result<ToolReturn, ToolError> {
112        // Check if approval is required
113        if (self.approval_checker)(ctx, &tool.tool_def, &args) {
114            return Err(ToolError::ApprovalRequired {
115                tool_name: name.to_string(),
116                args,
117            });
118        }
119
120        // No approval needed, proceed with the call
121        self.inner.call_tool(name, args, ctx, tool).await
122    }
123
124    async fn enter(&self) -> Result<(), ToolError> {
125        self.inner.enter().await
126    }
127
128    async fn exit(&self) -> Result<(), ToolError> {
129        self.inner.exit().await
130    }
131}
132
133impl<T: std::fmt::Debug, Deps> std::fmt::Debug for ApprovalRequiredToolset<T, Deps> {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        f.debug_struct("ApprovalRequiredToolset")
136            .field("inner", &self.inner)
137            .finish()
138    }
139}
140
141/// Common approval checkers.
142pub mod checkers {
143    use serde_json::Value as JsonValue;
144    use serdes_ai_tools::{RunContext, ToolDefinition};
145
146    /// Always require approval.
147    pub fn always<Deps>(
148    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
149        |_, _, _| true
150    }
151
152    /// Never require approval.
153    pub fn never<Deps>(
154    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
155        |_, _, _| false
156    }
157
158    /// Require approval for tools with names containing any of the given substrings.
159    pub fn name_contains<Deps>(
160        substrings: Vec<String>,
161    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
162        move |_, def, _| substrings.iter().any(|s| def.name.contains(s.as_str()))
163    }
164
165    /// Require approval for tools with names in the given list.
166    pub fn tool_names<Deps>(
167        names: Vec<String>,
168    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
169        move |_, def, _| names.iter().any(|n| n == &def.name)
170    }
171
172    /// Require approval for tools with names matching a prefix.
173    pub fn name_prefix<Deps>(
174        prefix: String,
175    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition, &JsonValue) -> bool + Send + Sync {
176        move |_, def, _| def.name.starts_with(&prefix)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::FunctionToolset;
184    use async_trait::async_trait;
185    use serdes_ai_tools::Tool;
186
187    struct SafeTool;
188
189    #[async_trait]
190    impl Tool<()> for SafeTool {
191        fn definition(&self) -> ToolDefinition {
192            ToolDefinition::new("safe_read", "Safe read operation")
193        }
194
195        async fn call(
196            &self,
197            _ctx: &RunContext<()>,
198            _args: JsonValue,
199        ) -> Result<ToolReturn, ToolError> {
200            Ok(ToolReturn::text("read data"))
201        }
202    }
203
204    struct DangerousTool;
205
206    #[async_trait]
207    impl Tool<()> for DangerousTool {
208        fn definition(&self) -> ToolDefinition {
209            ToolDefinition::new("delete_all", "Delete all data")
210        }
211
212        async fn call(
213            &self,
214            _ctx: &RunContext<()>,
215            _args: JsonValue,
216        ) -> Result<ToolReturn, ToolError> {
217            Ok(ToolReturn::text("deleted"))
218        }
219    }
220
221    #[tokio::test]
222    async fn test_approval_required_all() {
223        let toolset = FunctionToolset::new().tool(SafeTool);
224        let approved = ApprovalRequiredToolset::new(toolset);
225
226        let ctx = RunContext::minimal("test");
227        let tools = approved.get_tools(&ctx).await.unwrap();
228        let tool = tools.get("safe_read").unwrap();
229
230        let result = approved
231            .call_tool("safe_read", serde_json::json!({}), &ctx, tool)
232            .await;
233
234        assert!(matches!(result, Err(ToolError::ApprovalRequired { .. })));
235    }
236
237    #[tokio::test]
238    async fn test_approval_required_selective() {
239        let toolset = FunctionToolset::new().tool(SafeTool).tool(DangerousTool);
240
241        // Only require approval for delete operations
242        let approved =
243            ApprovalRequiredToolset::with_checker(toolset, |_, def, _| def.name.contains("delete"));
244
245        let ctx = RunContext::minimal("test");
246        let tools = approved.get_tools(&ctx).await.unwrap();
247
248        // Safe tool should work
249        let safe_tool = tools.get("safe_read").unwrap();
250        let result = approved
251            .call_tool("safe_read", serde_json::json!({}), &ctx, safe_tool)
252            .await;
253        assert!(result.is_ok());
254
255        // Dangerous tool should require approval
256        let dangerous_tool = tools.get("delete_all").unwrap();
257        let result = approved
258            .call_tool("delete_all", serde_json::json!({}), &ctx, dangerous_tool)
259            .await;
260        assert!(matches!(result, Err(ToolError::ApprovalRequired { .. })));
261    }
262
263    #[tokio::test]
264    async fn test_approval_never() {
265        let toolset = FunctionToolset::new().tool(SafeTool);
266        let approved = ApprovalRequiredToolset::with_checker(toolset, checkers::never());
267
268        let ctx = RunContext::minimal("test");
269        let tools = approved.get_tools(&ctx).await.unwrap();
270        let tool = tools.get("safe_read").unwrap();
271
272        let result = approved
273            .call_tool("safe_read", serde_json::json!({}), &ctx, tool)
274            .await;
275
276        assert!(result.is_ok());
277    }
278
279    #[tokio::test]
280    async fn test_approval_checker_name_contains() {
281        let toolset = FunctionToolset::new().tool(SafeTool).tool(DangerousTool);
282        let approved = ApprovalRequiredToolset::with_checker(
283            toolset,
284            checkers::name_contains(vec!["delete".to_string(), "remove".to_string()]),
285        );
286
287        let ctx = RunContext::minimal("test");
288        let tools = approved.get_tools(&ctx).await.unwrap();
289
290        let dangerous = tools.get("delete_all").unwrap();
291        let result = approved
292            .call_tool("delete_all", serde_json::json!({}), &ctx, dangerous)
293            .await;
294
295        assert!(matches!(result, Err(ToolError::ApprovalRequired { .. })));
296    }
297}