Skip to main content

aster/tools/
base.rs

1//! Tool Base Trait and Types
2//!
3//! This module defines the core `Tool` trait that all tools must implement.
4//! It provides a unified interface for tool execution with:
5//! - Name and description for identification
6//! - JSON Schema for input validation
7//! - Async execution with context
8//! - Permission checking
9//! - Configurable options
10//!
11//! Requirements: 1.1, 1.2
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15
16use super::context::{ToolContext, ToolDefinition, ToolOptions, ToolResult};
17use super::error::ToolError;
18
19/// Permission check behavior
20///
21/// Determines how the tool execution should proceed after permission check.
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum PermissionBehavior {
24    /// Allow execution to proceed
25    Allow,
26    /// Deny execution with a reason
27    Deny,
28    /// Ask user for confirmation before proceeding
29    Ask,
30}
31
32/// Result of a permission check
33///
34/// Contains the behavior decision and optional additional information.
35/// Requirements: 1.2
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct PermissionCheckResult {
38    /// The permission behavior (Allow/Deny/Ask)
39    pub behavior: PermissionBehavior,
40    /// Optional message explaining the decision
41    pub message: Option<String>,
42    /// Optional updated parameters (e.g., sanitized inputs)
43    pub updated_params: Option<serde_json::Value>,
44}
45
46impl PermissionCheckResult {
47    /// Create an Allow result
48    pub fn allow() -> Self {
49        Self {
50            behavior: PermissionBehavior::Allow,
51            message: None,
52            updated_params: None,
53        }
54    }
55
56    /// Create a Deny result with a reason
57    pub fn deny(reason: impl Into<String>) -> Self {
58        Self {
59            behavior: PermissionBehavior::Deny,
60            message: Some(reason.into()),
61            updated_params: None,
62        }
63    }
64
65    /// Create an Ask result with a message for the user
66    pub fn ask(message: impl Into<String>) -> Self {
67        Self {
68            behavior: PermissionBehavior::Ask,
69            message: Some(message.into()),
70            updated_params: None,
71        }
72    }
73
74    /// Set updated parameters
75    pub fn with_updated_params(mut self, params: serde_json::Value) -> Self {
76        self.updated_params = Some(params);
77        self
78    }
79
80    /// Check if permission is allowed
81    pub fn is_allowed(&self) -> bool {
82        self.behavior == PermissionBehavior::Allow
83    }
84
85    /// Check if permission is denied
86    pub fn is_denied(&self) -> bool {
87        self.behavior == PermissionBehavior::Deny
88    }
89
90    /// Check if user confirmation is required
91    pub fn requires_confirmation(&self) -> bool {
92        self.behavior == PermissionBehavior::Ask
93    }
94}
95
96impl Default for PermissionCheckResult {
97    fn default() -> Self {
98        Self::allow()
99    }
100}
101
102/// Tool trait - the core interface for all tools
103///
104/// All tools in the system must implement this trait. It provides:
105/// - Identification (name, description)
106/// - Input schema for validation
107/// - Async execution
108/// - Permission checking
109/// - Configuration options
110///
111/// Requirements: 1.1, 1.2
112#[async_trait]
113pub trait Tool: Send + Sync {
114    /// Returns the unique name of the tool
115    ///
116    /// This name is used for registration and lookup in the tool registry.
117    fn name(&self) -> &str;
118
119    /// Returns a human-readable description of the tool
120    ///
121    /// This description is provided to the LLM to help it understand
122    /// when and how to use the tool.
123    fn description(&self) -> &str;
124
125    /// Returns a dynamically generated description of the tool
126    ///
127    /// Override this method when the tool description needs to include
128    /// dynamic content (e.g., available skills, current state).
129    /// Default implementation returns None, falling back to `description()`.
130    fn dynamic_description(&self) -> Option<String> {
131        None
132    }
133
134    /// Returns the JSON Schema for the tool's input parameters
135    ///
136    /// This schema is used for:
137    /// - Input validation before execution
138    /// - Providing parameter information to the LLM
139    fn input_schema(&self) -> serde_json::Value;
140
141    /// Execute the tool with the given parameters and context
142    ///
143    /// This is the main entry point for tool execution.
144    ///
145    /// # Arguments
146    /// * `params` - The input parameters as a JSON value
147    /// * `context` - The execution context containing environment info
148    ///
149    /// # Returns
150    /// * `Ok(ToolResult)` - The execution result
151    /// * `Err(ToolError)` - If execution fails
152    async fn execute(
153        &self,
154        params: serde_json::Value,
155        context: &ToolContext,
156    ) -> Result<ToolResult, ToolError>;
157
158    /// Check permissions before executing the tool
159    ///
160    /// This method is called before `execute` to determine if the tool
161    /// should be allowed to run with the given parameters.
162    ///
163    /// Default implementation allows all executions.
164    ///
165    /// # Arguments
166    /// * `params` - The input parameters to check
167    /// * `context` - The execution context
168    ///
169    /// # Returns
170    /// A `PermissionCheckResult` indicating whether to allow, deny, or ask
171    async fn check_permissions(
172        &self,
173        _params: &serde_json::Value,
174        _context: &ToolContext,
175    ) -> PermissionCheckResult {
176        PermissionCheckResult::allow()
177    }
178
179    /// Get the tool definition for LLM consumption
180    ///
181    /// Returns a `ToolDefinition` containing the name, description,
182    /// and input schema in a format suitable for LLM tool calling.
183    ///
184    /// Default implementation constructs from name(), dynamic_description() or description(),
185    /// and input_schema(). Prefers dynamic_description() if available.
186    fn get_definition(&self) -> ToolDefinition {
187        let description = self
188            .dynamic_description()
189            .unwrap_or_else(|| self.description().to_string());
190        ToolDefinition {
191            name: self.name().to_string(),
192            description,
193            input_schema: self.input_schema(),
194        }
195    }
196
197    /// Get the tool's configuration options
198    ///
199    /// Returns the `ToolOptions` for this tool, including retry settings,
200    /// timeout configuration, etc.
201    ///
202    /// Default implementation returns default options.
203    fn options(&self) -> ToolOptions {
204        ToolOptions::default()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use std::path::PathBuf;
212
213    /// A simple test tool for unit testing
214    struct TestTool {
215        name: String,
216        should_fail: bool,
217    }
218
219    impl TestTool {
220        fn new(name: &str) -> Self {
221            Self {
222                name: name.to_string(),
223                should_fail: false,
224            }
225        }
226
227        fn failing(name: &str) -> Self {
228            Self {
229                name: name.to_string(),
230                should_fail: true,
231            }
232        }
233    }
234
235    #[async_trait]
236    impl Tool for TestTool {
237        fn name(&self) -> &str {
238            &self.name
239        }
240
241        fn description(&self) -> &str {
242            "A test tool for unit testing"
243        }
244
245        fn input_schema(&self) -> serde_json::Value {
246            serde_json::json!({
247                "type": "object",
248                "properties": {
249                    "input": { "type": "string" }
250                },
251                "required": ["input"]
252            })
253        }
254
255        async fn execute(
256            &self,
257            params: serde_json::Value,
258            _context: &ToolContext,
259        ) -> Result<ToolResult, ToolError> {
260            if self.should_fail {
261                return Err(ToolError::execution_failed("Test failure"));
262            }
263
264            let input = params
265                .get("input")
266                .and_then(|v| v.as_str())
267                .unwrap_or("default");
268
269            Ok(ToolResult::success(format!("Processed: {}", input)))
270        }
271    }
272
273    #[test]
274    fn test_permission_check_result_allow() {
275        let result = PermissionCheckResult::allow();
276        assert!(result.is_allowed());
277        assert!(!result.is_denied());
278        assert!(!result.requires_confirmation());
279        assert!(result.message.is_none());
280        assert!(result.updated_params.is_none());
281    }
282
283    #[test]
284    fn test_permission_check_result_deny() {
285        let result = PermissionCheckResult::deny("Access denied");
286        assert!(!result.is_allowed());
287        assert!(result.is_denied());
288        assert!(!result.requires_confirmation());
289        assert_eq!(result.message, Some("Access denied".to_string()));
290    }
291
292    #[test]
293    fn test_permission_check_result_ask() {
294        let result = PermissionCheckResult::ask("Do you want to proceed?");
295        assert!(!result.is_allowed());
296        assert!(!result.is_denied());
297        assert!(result.requires_confirmation());
298        assert_eq!(result.message, Some("Do you want to proceed?".to_string()));
299    }
300
301    #[test]
302    fn test_permission_check_result_with_updated_params() {
303        let params = serde_json::json!({"sanitized": true});
304        let result = PermissionCheckResult::allow().with_updated_params(params.clone());
305        assert!(result.is_allowed());
306        assert_eq!(result.updated_params, Some(params));
307    }
308
309    #[test]
310    fn test_permission_check_result_default() {
311        let result = PermissionCheckResult::default();
312        assert!(result.is_allowed());
313    }
314
315    #[tokio::test]
316    async fn test_tool_trait_basic() {
317        let tool = TestTool::new("test_tool");
318
319        assert_eq!(tool.name(), "test_tool");
320        assert_eq!(tool.description(), "A test tool for unit testing");
321
322        let schema = tool.input_schema();
323        assert_eq!(schema["type"], "object");
324        assert!(schema["properties"]["input"].is_object());
325    }
326
327    #[tokio::test]
328    async fn test_tool_execute_success() {
329        let tool = TestTool::new("test_tool");
330        let context = ToolContext::new(PathBuf::from("/tmp"));
331        let params = serde_json::json!({"input": "hello"});
332
333        let result = tool.execute(params, &context).await.unwrap();
334        assert!(result.is_success());
335        assert_eq!(result.output, Some("Processed: hello".to_string()));
336    }
337
338    #[tokio::test]
339    async fn test_tool_execute_failure() {
340        let tool = TestTool::failing("failing_tool");
341        let context = ToolContext::new(PathBuf::from("/tmp"));
342        let params = serde_json::json!({"input": "hello"});
343
344        let result = tool.execute(params, &context).await;
345        assert!(result.is_err());
346        assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
347    }
348
349    #[tokio::test]
350    async fn test_tool_default_check_permissions() {
351        let tool = TestTool::new("test_tool");
352        let context = ToolContext::new(PathBuf::from("/tmp"));
353        let params = serde_json::json!({"input": "hello"});
354
355        let result = tool.check_permissions(&params, &context).await;
356        assert!(result.is_allowed());
357    }
358
359    #[test]
360    fn test_tool_get_definition() {
361        let tool = TestTool::new("test_tool");
362        let def = tool.get_definition();
363
364        assert_eq!(def.name, "test_tool");
365        assert_eq!(def.description, "A test tool for unit testing");
366        assert_eq!(def.input_schema["type"], "object");
367    }
368
369    #[test]
370    fn test_tool_default_options() {
371        let tool = TestTool::new("test_tool");
372        let opts = tool.options();
373
374        assert_eq!(opts.max_retries, 3);
375        assert!(opts.enable_dynamic_timeout);
376    }
377
378    #[test]
379    fn test_permission_behavior_equality() {
380        assert_eq!(PermissionBehavior::Allow, PermissionBehavior::Allow);
381        assert_eq!(PermissionBehavior::Deny, PermissionBehavior::Deny);
382        assert_eq!(PermissionBehavior::Ask, PermissionBehavior::Ask);
383        assert_ne!(PermissionBehavior::Allow, PermissionBehavior::Deny);
384    }
385
386    #[test]
387    fn test_permission_check_result_serialization() {
388        let result = PermissionCheckResult::deny("test reason");
389        let json = serde_json::to_string(&result).unwrap();
390        let deserialized: PermissionCheckResult = serde_json::from_str(&json).unwrap();
391
392        assert_eq!(result.behavior, deserialized.behavior);
393        assert_eq!(result.message, deserialized.message);
394    }
395}