Skip to main content

ai_agent/utils/hooks/
can_use_tool.rs

1//! CanUseToolFn type and related types for tool permission checking.
2//!
3//! This module provides the core function type for checking whether a tool
4//! can be used, along with supporting types.
5
6use crate::permission::PermissionDecision;
7use crate::types::ToolDefinition;
8use crate::utils::messages::{AssistantMessage, AssistantMessageContent};
9use serde::{Deserialize, Serialize};
10
11/// Context for tool use, containing information about the current execution context.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(rename_all = "camelCase")]
14pub struct ToolUseContext {
15    /// Session ID
16    pub session_id: String,
17    /// Current working directory
18    pub cwd: Option<String>,
19    /// Whether this is a non-interactive session
20    pub is_non_interactive_session: bool,
21    /// Additional options
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub options: Option<ToolUseContextOptions>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(rename_all = "camelCase")]
28pub struct ToolUseContextOptions {
29    /// Available tools
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub tools: Option<Vec<ToolDefinition>>,
32}
33
34/// Options for permission checking
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct ToolPermissionContext {
38    /// Permission mode
39    pub mode: crate::permission::PermissionMode,
40    /// Whether to wait for automated checks before showing dialog
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub await_automated_checks_before_dialog: Option<bool>,
43}
44
45/// Function type for checking if a tool can be used.
46///
47/// This is the core permission check function that determines whether
48/// a tool can be executed based on the current permission settings.
49///
50/// # Type Parameters
51/// * `Input` - The tool input type (defaults to a map of string to unknown)
52///
53/// # Arguments
54/// * `tool` - The tool definition
55/// * `input` - The input arguments for the tool
56/// * `tool_use_context` - Context about the current tool use
57/// * `assistant_message` - The assistant message that triggered this tool use
58/// * `tool_use_id` - Unique identifier for this tool use
59/// * `force_decision` - Optional forced decision (bypasses normal permission checking)
60///
61/// # Returns
62/// A future that resolves to a permission decision
63pub type CanUseToolFn<Input = std::collections::HashMap<String, serde_json::Value>> = Box<
64    dyn Fn(
65            ToolDefinition,
66            Input,
67            ToolUseContext,
68            AssistantMessage,
69            String,
70            Option<PermissionDecision>,
71        ) -> std::pin::Pin<
72            Box<dyn std::future::Future<Output = PermissionDecision> + Send + 'static>,
73        > + Send
74        + Sync,
75>;
76
77/// Simplified CanUseToolFn that works with JSON values
78pub type CanUseToolFnJson = Box<
79    dyn Fn(
80            ToolDefinition,
81            serde_json::Value,
82            ToolUseContext,
83            AssistantMessage,
84            String,
85            Option<PermissionDecision>,
86        ) -> std::pin::Pin<
87            Box<dyn std::future::Future<Output = PermissionDecision> + Send + 'static>,
88        > + Send
89        + Sync,
90>;
91
92/// Build the can_use_tool function signature for documentation
93pub const CAN_USE_TOOL_FN_SIGNATURE: &str = r#"
94CanUseToolFn<Input> = Fn(
95    tool: ToolDefinition,
96    input: Input,
97    tool_use_context: ToolUseContext,
98    assistant_message: AssistantMessage,
99    tool_use_id: String,
100    force_decision: Option<PermissionDecision>,
101) -> impl Future<Output = PermissionDecision>
102"#;
103
104/// Helper to create a default CanUseToolFn that uses the permission module
105pub fn create_default_can_use_tool_fn(
106    permission_context: ToolPermissionContext,
107) -> CanUseToolFnJson {
108    Box::new(
109        move |tool: ToolDefinition,
110              input: serde_json::Value,
111              _tool_use_context: ToolUseContext,
112              _assistant_message: AssistantMessage,
113              _tool_use_id: String,
114              force_decision: Option<PermissionDecision>| {
115            let ctx =
116                crate::permission::PermissionContext::new().with_mode(permission_context.mode);
117
118            Box::pin(async move {
119                // If force_decision is provided, use it directly
120                if let Some(decision) = force_decision {
121                    return decision;
122                }
123
124                // Check using permission context
125                let result = ctx.check_tool(&tool.name, Some(&input));
126
127                // Convert result to decision
128                match result {
129                    crate::permission::PermissionResult::Allow(allow) => {
130                        PermissionDecision::Allow(crate::permission::PermissionAllowDecision {
131                            behavior: allow.behavior,
132                            updated_input: allow.updated_input,
133                            user_modified: allow.user_modified,
134                            decision_reason: allow.decision_reason,
135                        })
136                    }
137                    crate::permission::PermissionResult::Ask(ask) => {
138                        PermissionDecision::Ask(crate::permission::PermissionAskDecision {
139                            behavior: ask.behavior,
140                            message: ask.message,
141                            updated_input: ask.updated_input,
142                            decision_reason: ask.decision_reason,
143                            blocked_path: ask.blocked_path,
144                        })
145                    }
146                    crate::permission::PermissionResult::Deny(deny) => {
147                        PermissionDecision::Deny(crate::permission::PermissionDenyDecision {
148                            behavior: deny.behavior,
149                            message: deny.message,
150                            decision_reason: deny.decision_reason,
151                        })
152                    }
153                    crate::permission::PermissionResult::Passthrough {
154                        message: _,
155                        decision_reason,
156                    } => {
157                        // Passthrough treated as allow with notification
158                        PermissionDecision::Allow(crate::permission::PermissionAllowDecision {
159                            behavior: crate::permission::PermissionBehavior::Allow,
160                            updated_input: None,
161                            user_modified: None,
162                            decision_reason,
163                        })
164                    }
165                }
166            })
167        },
168    )
169}
170
171/// Create a CanUseToolFn that always allows
172pub fn create_allow_all_can_use_tool_fn() -> CanUseToolFnJson {
173    Box::new(
174        |_tool: ToolDefinition,
175         input: serde_json::Value,
176         _context: ToolUseContext,
177         _message: AssistantMessage,
178         _tool_use_id: String,
179         _force: Option<PermissionDecision>| {
180            Box::pin(async move {
181                PermissionDecision::Allow(crate::permission::PermissionAllowDecision {
182                    behavior: crate::permission::PermissionBehavior::Allow,
183                    updated_input: Some(input),
184                    user_modified: None,
185                    decision_reason: Some(crate::permission::PermissionDecisionReason::Other {
186                        reason: "Allowed by default can_use_tool function".to_string(),
187                    }),
188                })
189            })
190        },
191    )
192}
193
194/// Create a CanUseToolFn that always denies
195pub fn create_deny_all_can_use_tool_fn() -> CanUseToolFnJson {
196    Box::new(
197        |tool: ToolDefinition,
198         _input: serde_json::Value,
199         _context: ToolUseContext,
200         _message: AssistantMessage,
201         _tool_use_id: String,
202         _force: Option<PermissionDecision>| {
203            let tool_name = tool.name.clone();
204            Box::pin(async move {
205                PermissionDecision::Deny(crate::permission::PermissionDenyDecision {
206                    behavior: crate::permission::PermissionBehavior::Deny,
207                    message: format!("Tool '{}' is denied", tool_name),
208                    decision_reason: crate::permission::PermissionDecisionReason::Other {
209                        reason: "Denied by default can_use_tool function".to_string(),
210                    },
211                })
212            })
213        },
214    )
215}
216
217/// Create a minimal AssistantMessage for testing
218#[cfg(test)]
219fn create_test_assistant_message() -> AssistantMessage {
220    AssistantMessage {
221        message: AssistantMessageContent {
222            id: "test-id".to_string(),
223            container: None,
224            model: "test-model".to_string(),
225            role: "assistant".to_string(),
226            stop_reason: None,
227            stop_sequence: None,
228            message_type: "message".to_string(),
229            usage: None,
230            content: vec![],
231            context_management: None,
232        },
233        request_id: None,
234        api_error: None,
235        error: None,
236        error_details: None,
237        is_api_error_message: None,
238        is_virtual: None,
239        is_meta: None,
240        advisor_model: None,
241        uuid: "test-uuid".to_string(),
242        timestamp: "2024-01-01".to_string(),
243        parent_uuid: None,
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_tool_use_context_default() {
253        let ctx = ToolUseContext {
254            session_id: "test".to_string(),
255            cwd: Some("/home".to_string()),
256            is_non_interactive_session: false,
257            options: None,
258        };
259        assert_eq!(ctx.session_id, "test");
260        assert_eq!(ctx.cwd, Some("/home".to_string()));
261    }
262
263    #[test]
264    fn test_tool_permission_context_default() {
265        let ctx = ToolPermissionContext {
266            mode: crate::permission::PermissionMode::Default,
267            await_automated_checks_before_dialog: None,
268        };
269        assert_eq!(ctx.mode, crate::permission::PermissionMode::Default);
270    }
271
272    #[tokio::test]
273    async fn test_create_default_can_use_tool_fn_allow() {
274        let ctx = ToolPermissionContext {
275            mode: crate::permission::PermissionMode::Bypass,
276            await_automated_checks_before_dialog: None,
277        };
278        let fn_ptr = create_default_can_use_tool_fn(ctx);
279
280        let tool = ToolDefinition::new(
281            "Read",
282            "Read files",
283            crate::types::ToolInputSchema::default(),
284        );
285        let input = serde_json::json!({"path": "/test"});
286
287        let result = (fn_ptr)(
288            tool,
289            input,
290            ToolUseContext {
291                session_id: "test".to_string(),
292                cwd: None,
293                is_non_interactive_session: false,
294                options: None,
295            },
296            create_test_assistant_message(),
297            "tool-use-1".to_string(),
298            None,
299        )
300        .await;
301
302        assert!(result.is_allowed());
303    }
304
305    #[tokio::test]
306    async fn test_create_default_can_use_tool_fn_deny() {
307        let ctx = ToolPermissionContext {
308            mode: crate::permission::PermissionMode::DontAsk,
309            await_automated_checks_before_dialog: None,
310        };
311        let fn_ptr = create_default_can_use_tool_fn(ctx);
312
313        let tool = ToolDefinition::new(
314            "Bash",
315            "Run commands",
316            crate::types::ToolInputSchema::default(),
317        );
318        let input = serde_json::json!({"command": "ls"});
319
320        let result = (fn_ptr)(
321            tool,
322            input,
323            ToolUseContext {
324                session_id: "test".to_string(),
325                cwd: None,
326                is_non_interactive_session: false,
327                options: None,
328            },
329            create_test_assistant_message(),
330            "tool-use-1".to_string(),
331            None,
332        )
333        .await;
334
335        assert!(result.is_denied());
336    }
337
338    #[tokio::test]
339    async fn test_create_allow_all_can_use_tool_fn() {
340        let fn_ptr = create_allow_all_can_use_tool_fn();
341
342        let tool = ToolDefinition::new(
343            "Bash",
344            "Run commands",
345            crate::types::ToolInputSchema::default(),
346        );
347        let input = serde_json::json!({"command": "rm -rf /"});
348
349        let result = (fn_ptr)(
350            tool,
351            input,
352            ToolUseContext {
353                session_id: "test".to_string(),
354                cwd: None,
355                is_non_interactive_session: false,
356                options: None,
357            },
358            create_test_assistant_message(),
359            "tool-use-1".to_string(),
360            None,
361        )
362        .await;
363
364        assert!(result.is_allowed());
365    }
366
367    #[tokio::test]
368    async fn test_create_deny_all_can_use_tool_fn() {
369        let fn_ptr = create_deny_all_can_use_tool_fn();
370
371        let tool = ToolDefinition::new(
372            "Read",
373            "Read files",
374            crate::types::ToolInputSchema::default(),
375        );
376        let input = serde_json::json!({"path": "/test"});
377
378        let result = (fn_ptr)(
379            tool,
380            input,
381            ToolUseContext {
382                session_id: "test".to_string(),
383                cwd: None,
384                is_non_interactive_session: false,
385                options: None,
386            },
387            create_test_assistant_message(),
388            "tool-use-1".to_string(),
389            None,
390        )
391        .await;
392
393        assert!(result.is_denied());
394    }
395
396    #[tokio::test]
397    async fn test_force_decision_override() {
398        let ctx = ToolPermissionContext {
399            mode: crate::permission::PermissionMode::Bypass,
400            await_automated_checks_before_dialog: None,
401        };
402        let fn_ptr = create_default_can_use_tool_fn(ctx);
403
404        let tool = ToolDefinition::new(
405            "Bash",
406            "Run commands",
407            crate::types::ToolInputSchema::default(),
408        );
409        let input = serde_json::json!({"command": "ls"});
410
411        // Force a deny decision
412        let force_deny = PermissionDecision::Deny(crate::permission::PermissionDenyDecision {
413            behavior: crate::permission::PermissionBehavior::Deny,
414            message: "Forced deny".to_string(),
415            decision_reason: crate::permission::PermissionDecisionReason::Other {
416                reason: "test".to_string(),
417            },
418        });
419
420        let result = (fn_ptr)(
421            tool,
422            input,
423            ToolUseContext {
424                session_id: "test".to_string(),
425                cwd: None,
426                is_non_interactive_session: false,
427                options: None,
428            },
429            create_test_assistant_message(),
430            "tool-use-1".to_string(),
431            Some(force_deny),
432        )
433        .await;
434
435        assert!(result.is_denied());
436    }
437}