1use crate::permission::PermissionDecision;
7use crate::types::ToolDefinition;
8use crate::utils::messages::{AssistantMessage, AssistantMessageContent};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(rename_all = "camelCase")]
14pub struct ToolUseContext {
15 pub session_id: String,
17 pub cwd: Option<String>,
19 pub is_non_interactive_session: bool,
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub options: Option<ToolUseContextOptions>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(rename_all = "camelCase")]
28pub struct ToolUseContextOptions {
29 #[serde(default, skip_serializing_if = "Option::is_none")]
31 pub tools: Option<Vec<ToolDefinition>>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct ToolPermissionContext {
38 pub mode: crate::permission::PermissionMode,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub await_automated_checks_before_dialog: Option<bool>,
43}
44
45pub type CanUseToolFn<Input = std::collections::HashMap<String, serde_json::Value>> = Box<
64 dyn Fn(
65 ToolDefinition,
66 Input,
67 ToolUseContext,
68 AssistantMessage,
69 String,
70 Option<PermissionDecision>,
71 ) -> std::pin::Pin<
72 Box<dyn std::future::Future<Output = PermissionDecision> + Send + 'static>,
73 > + Send
74 + Sync,
75>;
76
77pub type CanUseToolFnJson = Box<
79 dyn Fn(
80 ToolDefinition,
81 serde_json::Value,
82 ToolUseContext,
83 AssistantMessage,
84 String,
85 Option<PermissionDecision>,
86 ) -> std::pin::Pin<
87 Box<dyn std::future::Future<Output = PermissionDecision> + Send + 'static>,
88 > + Send
89 + Sync,
90>;
91
92pub const CAN_USE_TOOL_FN_SIGNATURE: &str = r#"
94CanUseToolFn<Input> = Fn(
95 tool: ToolDefinition,
96 input: Input,
97 tool_use_context: ToolUseContext,
98 assistant_message: AssistantMessage,
99 tool_use_id: String,
100 force_decision: Option<PermissionDecision>,
101) -> impl Future<Output = PermissionDecision>
102"#;
103
104pub fn create_default_can_use_tool_fn(
106 permission_context: ToolPermissionContext,
107) -> CanUseToolFnJson {
108 Box::new(
109 move |tool: ToolDefinition,
110 input: serde_json::Value,
111 _tool_use_context: ToolUseContext,
112 _assistant_message: AssistantMessage,
113 _tool_use_id: String,
114 force_decision: Option<PermissionDecision>| {
115 let ctx =
116 crate::permission::PermissionContext::new().with_mode(permission_context.mode);
117
118 Box::pin(async move {
119 if let Some(decision) = force_decision {
121 return decision;
122 }
123
124 let result = ctx.check_tool(&tool.name, Some(&input));
126
127 match result {
129 crate::permission::PermissionResult::Allow(allow) => {
130 PermissionDecision::Allow(crate::permission::PermissionAllowDecision {
131 behavior: allow.behavior,
132 updated_input: allow.updated_input,
133 user_modified: allow.user_modified,
134 decision_reason: allow.decision_reason,
135 })
136 }
137 crate::permission::PermissionResult::Ask(ask) => {
138 PermissionDecision::Ask(crate::permission::PermissionAskDecision {
139 behavior: ask.behavior,
140 message: ask.message,
141 updated_input: ask.updated_input,
142 decision_reason: ask.decision_reason,
143 blocked_path: ask.blocked_path,
144 })
145 }
146 crate::permission::PermissionResult::Deny(deny) => {
147 PermissionDecision::Deny(crate::permission::PermissionDenyDecision {
148 behavior: deny.behavior,
149 message: deny.message,
150 decision_reason: deny.decision_reason,
151 })
152 }
153 crate::permission::PermissionResult::Passthrough {
154 message: _,
155 decision_reason,
156 } => {
157 PermissionDecision::Allow(crate::permission::PermissionAllowDecision {
159 behavior: crate::permission::PermissionBehavior::Allow,
160 updated_input: None,
161 user_modified: None,
162 decision_reason,
163 })
164 }
165 }
166 })
167 },
168 )
169}
170
171pub fn create_allow_all_can_use_tool_fn() -> CanUseToolFnJson {
173 Box::new(
174 |_tool: ToolDefinition,
175 input: serde_json::Value,
176 _context: ToolUseContext,
177 _message: AssistantMessage,
178 _tool_use_id: String,
179 _force: Option<PermissionDecision>| {
180 Box::pin(async move {
181 PermissionDecision::Allow(crate::permission::PermissionAllowDecision {
182 behavior: crate::permission::PermissionBehavior::Allow,
183 updated_input: Some(input),
184 user_modified: None,
185 decision_reason: Some(crate::permission::PermissionDecisionReason::Other {
186 reason: "Allowed by default can_use_tool function".to_string(),
187 }),
188 })
189 })
190 },
191 )
192}
193
194pub fn create_deny_all_can_use_tool_fn() -> CanUseToolFnJson {
196 Box::new(
197 |tool: ToolDefinition,
198 _input: serde_json::Value,
199 _context: ToolUseContext,
200 _message: AssistantMessage,
201 _tool_use_id: String,
202 _force: Option<PermissionDecision>| {
203 let tool_name = tool.name.clone();
204 Box::pin(async move {
205 PermissionDecision::Deny(crate::permission::PermissionDenyDecision {
206 behavior: crate::permission::PermissionBehavior::Deny,
207 message: format!("Tool '{}' is denied", tool_name),
208 decision_reason: crate::permission::PermissionDecisionReason::Other {
209 reason: "Denied by default can_use_tool function".to_string(),
210 },
211 })
212 })
213 },
214 )
215}
216
217#[cfg(test)]
219fn create_test_assistant_message() -> AssistantMessage {
220 AssistantMessage {
221 message: AssistantMessageContent {
222 id: "test-id".to_string(),
223 container: None,
224 model: "test-model".to_string(),
225 role: "assistant".to_string(),
226 stop_reason: None,
227 stop_sequence: None,
228 message_type: "message".to_string(),
229 usage: None,
230 content: vec![],
231 context_management: None,
232 },
233 request_id: None,
234 api_error: None,
235 error: None,
236 error_details: None,
237 is_api_error_message: None,
238 is_virtual: None,
239 is_meta: None,
240 advisor_model: None,
241 uuid: "test-uuid".to_string(),
242 timestamp: "2024-01-01".to_string(),
243 parent_uuid: None,
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_tool_use_context_default() {
253 let ctx = ToolUseContext {
254 session_id: "test".to_string(),
255 cwd: Some("/home".to_string()),
256 is_non_interactive_session: false,
257 options: None,
258 };
259 assert_eq!(ctx.session_id, "test");
260 assert_eq!(ctx.cwd, Some("/home".to_string()));
261 }
262
263 #[test]
264 fn test_tool_permission_context_default() {
265 let ctx = ToolPermissionContext {
266 mode: crate::permission::PermissionMode::Default,
267 await_automated_checks_before_dialog: None,
268 };
269 assert_eq!(ctx.mode, crate::permission::PermissionMode::Default);
270 }
271
272 #[tokio::test]
273 async fn test_create_default_can_use_tool_fn_allow() {
274 let ctx = ToolPermissionContext {
275 mode: crate::permission::PermissionMode::Bypass,
276 await_automated_checks_before_dialog: None,
277 };
278 let fn_ptr = create_default_can_use_tool_fn(ctx);
279
280 let tool = ToolDefinition::new(
281 "Read",
282 "Read files",
283 crate::types::ToolInputSchema::default(),
284 );
285 let input = serde_json::json!({"path": "/test"});
286
287 let result = (fn_ptr)(
288 tool,
289 input,
290 ToolUseContext {
291 session_id: "test".to_string(),
292 cwd: None,
293 is_non_interactive_session: false,
294 options: None,
295 },
296 create_test_assistant_message(),
297 "tool-use-1".to_string(),
298 None,
299 )
300 .await;
301
302 assert!(result.is_allowed());
303 }
304
305 #[tokio::test]
306 async fn test_create_default_can_use_tool_fn_deny() {
307 let ctx = ToolPermissionContext {
308 mode: crate::permission::PermissionMode::DontAsk,
309 await_automated_checks_before_dialog: None,
310 };
311 let fn_ptr = create_default_can_use_tool_fn(ctx);
312
313 let tool = ToolDefinition::new(
314 "Bash",
315 "Run commands",
316 crate::types::ToolInputSchema::default(),
317 );
318 let input = serde_json::json!({"command": "ls"});
319
320 let result = (fn_ptr)(
321 tool,
322 input,
323 ToolUseContext {
324 session_id: "test".to_string(),
325 cwd: None,
326 is_non_interactive_session: false,
327 options: None,
328 },
329 create_test_assistant_message(),
330 "tool-use-1".to_string(),
331 None,
332 )
333 .await;
334
335 assert!(result.is_denied());
336 }
337
338 #[tokio::test]
339 async fn test_create_allow_all_can_use_tool_fn() {
340 let fn_ptr = create_allow_all_can_use_tool_fn();
341
342 let tool = ToolDefinition::new(
343 "Bash",
344 "Run commands",
345 crate::types::ToolInputSchema::default(),
346 );
347 let input = serde_json::json!({"command": "rm -rf /"});
348
349 let result = (fn_ptr)(
350 tool,
351 input,
352 ToolUseContext {
353 session_id: "test".to_string(),
354 cwd: None,
355 is_non_interactive_session: false,
356 options: None,
357 },
358 create_test_assistant_message(),
359 "tool-use-1".to_string(),
360 None,
361 )
362 .await;
363
364 assert!(result.is_allowed());
365 }
366
367 #[tokio::test]
368 async fn test_create_deny_all_can_use_tool_fn() {
369 let fn_ptr = create_deny_all_can_use_tool_fn();
370
371 let tool = ToolDefinition::new(
372 "Read",
373 "Read files",
374 crate::types::ToolInputSchema::default(),
375 );
376 let input = serde_json::json!({"path": "/test"});
377
378 let result = (fn_ptr)(
379 tool,
380 input,
381 ToolUseContext {
382 session_id: "test".to_string(),
383 cwd: None,
384 is_non_interactive_session: false,
385 options: None,
386 },
387 create_test_assistant_message(),
388 "tool-use-1".to_string(),
389 None,
390 )
391 .await;
392
393 assert!(result.is_denied());
394 }
395
396 #[tokio::test]
397 async fn test_force_decision_override() {
398 let ctx = ToolPermissionContext {
399 mode: crate::permission::PermissionMode::Bypass,
400 await_automated_checks_before_dialog: None,
401 };
402 let fn_ptr = create_default_can_use_tool_fn(ctx);
403
404 let tool = ToolDefinition::new(
405 "Bash",
406 "Run commands",
407 crate::types::ToolInputSchema::default(),
408 );
409 let input = serde_json::json!({"command": "ls"});
410
411 let force_deny = PermissionDecision::Deny(crate::permission::PermissionDenyDecision {
413 behavior: crate::permission::PermissionBehavior::Deny,
414 message: "Forced deny".to_string(),
415 decision_reason: crate::permission::PermissionDecisionReason::Other {
416 reason: "test".to_string(),
417 },
418 });
419
420 let result = (fn_ptr)(
421 tool,
422 input,
423 ToolUseContext {
424 session_id: "test".to_string(),
425 cwd: None,
426 is_non_interactive_session: false,
427 options: None,
428 },
429 create_test_assistant_message(),
430 "tool-use-1".to_string(),
431 Some(force_deny),
432 )
433 .await;
434
435 assert!(result.is_denied());
436 }
437}