1use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15
16use super::context::{ToolContext, ToolDefinition, ToolOptions, ToolResult};
17use super::error::ToolError;
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum PermissionBehavior {
24 Allow,
26 Deny,
28 Ask,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct PermissionCheckResult {
38 pub behavior: PermissionBehavior,
40 pub message: Option<String>,
42 pub updated_params: Option<serde_json::Value>,
44}
45
46impl PermissionCheckResult {
47 pub fn allow() -> Self {
49 Self {
50 behavior: PermissionBehavior::Allow,
51 message: None,
52 updated_params: None,
53 }
54 }
55
56 pub fn deny(reason: impl Into<String>) -> Self {
58 Self {
59 behavior: PermissionBehavior::Deny,
60 message: Some(reason.into()),
61 updated_params: None,
62 }
63 }
64
65 pub fn ask(message: impl Into<String>) -> Self {
67 Self {
68 behavior: PermissionBehavior::Ask,
69 message: Some(message.into()),
70 updated_params: None,
71 }
72 }
73
74 pub fn with_updated_params(mut self, params: serde_json::Value) -> Self {
76 self.updated_params = Some(params);
77 self
78 }
79
80 pub fn is_allowed(&self) -> bool {
82 self.behavior == PermissionBehavior::Allow
83 }
84
85 pub fn is_denied(&self) -> bool {
87 self.behavior == PermissionBehavior::Deny
88 }
89
90 pub fn requires_confirmation(&self) -> bool {
92 self.behavior == PermissionBehavior::Ask
93 }
94}
95
96impl Default for PermissionCheckResult {
97 fn default() -> Self {
98 Self::allow()
99 }
100}
101
102#[async_trait]
113pub trait Tool: Send + Sync {
114 fn name(&self) -> &str;
118
119 fn description(&self) -> &str;
124
125 fn dynamic_description(&self) -> Option<String> {
131 None
132 }
133
134 fn input_schema(&self) -> serde_json::Value;
140
141 async fn execute(
153 &self,
154 params: serde_json::Value,
155 context: &ToolContext,
156 ) -> Result<ToolResult, ToolError>;
157
158 async fn check_permissions(
172 &self,
173 _params: &serde_json::Value,
174 _context: &ToolContext,
175 ) -> PermissionCheckResult {
176 PermissionCheckResult::allow()
177 }
178
179 fn get_definition(&self) -> ToolDefinition {
187 let description = self
188 .dynamic_description()
189 .unwrap_or_else(|| self.description().to_string());
190 ToolDefinition {
191 name: self.name().to_string(),
192 description,
193 input_schema: self.input_schema(),
194 }
195 }
196
197 fn options(&self) -> ToolOptions {
204 ToolOptions::default()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::path::PathBuf;
212
213 struct TestTool {
215 name: String,
216 should_fail: bool,
217 }
218
219 impl TestTool {
220 fn new(name: &str) -> Self {
221 Self {
222 name: name.to_string(),
223 should_fail: false,
224 }
225 }
226
227 fn failing(name: &str) -> Self {
228 Self {
229 name: name.to_string(),
230 should_fail: true,
231 }
232 }
233 }
234
235 #[async_trait]
236 impl Tool for TestTool {
237 fn name(&self) -> &str {
238 &self.name
239 }
240
241 fn description(&self) -> &str {
242 "A test tool for unit testing"
243 }
244
245 fn input_schema(&self) -> serde_json::Value {
246 serde_json::json!({
247 "type": "object",
248 "properties": {
249 "input": { "type": "string" }
250 },
251 "required": ["input"]
252 })
253 }
254
255 async fn execute(
256 &self,
257 params: serde_json::Value,
258 _context: &ToolContext,
259 ) -> Result<ToolResult, ToolError> {
260 if self.should_fail {
261 return Err(ToolError::execution_failed("Test failure"));
262 }
263
264 let input = params
265 .get("input")
266 .and_then(|v| v.as_str())
267 .unwrap_or("default");
268
269 Ok(ToolResult::success(format!("Processed: {}", input)))
270 }
271 }
272
273 #[test]
274 fn test_permission_check_result_allow() {
275 let result = PermissionCheckResult::allow();
276 assert!(result.is_allowed());
277 assert!(!result.is_denied());
278 assert!(!result.requires_confirmation());
279 assert!(result.message.is_none());
280 assert!(result.updated_params.is_none());
281 }
282
283 #[test]
284 fn test_permission_check_result_deny() {
285 let result = PermissionCheckResult::deny("Access denied");
286 assert!(!result.is_allowed());
287 assert!(result.is_denied());
288 assert!(!result.requires_confirmation());
289 assert_eq!(result.message, Some("Access denied".to_string()));
290 }
291
292 #[test]
293 fn test_permission_check_result_ask() {
294 let result = PermissionCheckResult::ask("Do you want to proceed?");
295 assert!(!result.is_allowed());
296 assert!(!result.is_denied());
297 assert!(result.requires_confirmation());
298 assert_eq!(result.message, Some("Do you want to proceed?".to_string()));
299 }
300
301 #[test]
302 fn test_permission_check_result_with_updated_params() {
303 let params = serde_json::json!({"sanitized": true});
304 let result = PermissionCheckResult::allow().with_updated_params(params.clone());
305 assert!(result.is_allowed());
306 assert_eq!(result.updated_params, Some(params));
307 }
308
309 #[test]
310 fn test_permission_check_result_default() {
311 let result = PermissionCheckResult::default();
312 assert!(result.is_allowed());
313 }
314
315 #[tokio::test]
316 async fn test_tool_trait_basic() {
317 let tool = TestTool::new("test_tool");
318
319 assert_eq!(tool.name(), "test_tool");
320 assert_eq!(tool.description(), "A test tool for unit testing");
321
322 let schema = tool.input_schema();
323 assert_eq!(schema["type"], "object");
324 assert!(schema["properties"]["input"].is_object());
325 }
326
327 #[tokio::test]
328 async fn test_tool_execute_success() {
329 let tool = TestTool::new("test_tool");
330 let context = ToolContext::new(PathBuf::from("/tmp"));
331 let params = serde_json::json!({"input": "hello"});
332
333 let result = tool.execute(params, &context).await.unwrap();
334 assert!(result.is_success());
335 assert_eq!(result.output, Some("Processed: hello".to_string()));
336 }
337
338 #[tokio::test]
339 async fn test_tool_execute_failure() {
340 let tool = TestTool::failing("failing_tool");
341 let context = ToolContext::new(PathBuf::from("/tmp"));
342 let params = serde_json::json!({"input": "hello"});
343
344 let result = tool.execute(params, &context).await;
345 assert!(result.is_err());
346 assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
347 }
348
349 #[tokio::test]
350 async fn test_tool_default_check_permissions() {
351 let tool = TestTool::new("test_tool");
352 let context = ToolContext::new(PathBuf::from("/tmp"));
353 let params = serde_json::json!({"input": "hello"});
354
355 let result = tool.check_permissions(¶ms, &context).await;
356 assert!(result.is_allowed());
357 }
358
359 #[test]
360 fn test_tool_get_definition() {
361 let tool = TestTool::new("test_tool");
362 let def = tool.get_definition();
363
364 assert_eq!(def.name, "test_tool");
365 assert_eq!(def.description, "A test tool for unit testing");
366 assert_eq!(def.input_schema["type"], "object");
367 }
368
369 #[test]
370 fn test_tool_default_options() {
371 let tool = TestTool::new("test_tool");
372 let opts = tool.options();
373
374 assert_eq!(opts.max_retries, 3);
375 assert!(opts.enable_dynamic_timeout);
376 }
377
378 #[test]
379 fn test_permission_behavior_equality() {
380 assert_eq!(PermissionBehavior::Allow, PermissionBehavior::Allow);
381 assert_eq!(PermissionBehavior::Deny, PermissionBehavior::Deny);
382 assert_eq!(PermissionBehavior::Ask, PermissionBehavior::Ask);
383 assert_ne!(PermissionBehavior::Allow, PermissionBehavior::Deny);
384 }
385
386 #[test]
387 fn test_permission_check_result_serialization() {
388 let result = PermissionCheckResult::deny("test reason");
389 let json = serde_json::to_string(&result).unwrap();
390 let deserialized: PermissionCheckResult = serde_json::from_str(&json).unwrap();
391
392 assert_eq!(result.behavior, deserialized.behavior);
393 assert_eq!(result.message, deserialized.message);
394 }
395}