Skip to main content

aster/tools/
context.rs

1//! Tool Context and Configuration Types
2//!
3//! This module defines the core types for tool execution context and configuration:
4//! - `ToolContext`: Execution environment information
5//! - `ToolOptions`: Tool configuration options
6//! - `ToolDefinition`: Tool definition for LLM consumption
7//! - `ToolResult`: Tool execution result
8//!
9//! Requirements: 1.3, 1.4
10
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use std::time::Duration;
15use tokio_util::sync::CancellationToken;
16
17/// Tool execution context
18///
19/// Contains environment information available during tool execution.
20/// This is passed to every tool's execute method.
21#[derive(Debug, Clone)]
22pub struct ToolContext {
23    /// Current working directory for the tool execution
24    pub working_directory: PathBuf,
25
26    /// Session identifier for tracking
27    pub session_id: String,
28
29    /// Optional user identifier
30    pub user: Option<String>,
31
32    /// Environment variables available to the tool
33    pub environment: HashMap<String, String>,
34
35    /// Cancellation token for cooperative cancellation
36    pub cancellation_token: Option<CancellationToken>,
37}
38
39impl Default for ToolContext {
40    fn default() -> Self {
41        Self {
42            working_directory: std::env::current_dir().unwrap_or_default(),
43            session_id: String::new(),
44            user: None,
45            environment: HashMap::new(),
46            cancellation_token: None,
47        }
48    }
49}
50
51impl ToolContext {
52    /// Create a new ToolContext with the given working directory
53    pub fn new(working_directory: PathBuf) -> Self {
54        Self {
55            working_directory,
56            ..Default::default()
57        }
58    }
59
60    /// Set the session ID
61    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
62        self.session_id = session_id.into();
63        self
64    }
65
66    /// Set the user
67    pub fn with_user(mut self, user: impl Into<String>) -> Self {
68        self.user = Some(user.into());
69        self
70    }
71
72    /// Set environment variables
73    pub fn with_environment(mut self, environment: HashMap<String, String>) -> Self {
74        self.environment = environment;
75        self
76    }
77
78    /// Add a single environment variable
79    pub fn with_env_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
80        self.environment.insert(key.into(), value.into());
81        self
82    }
83
84    /// Set the cancellation token
85    pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
86        self.cancellation_token = Some(token);
87        self
88    }
89
90    /// Check if cancellation has been requested
91    pub fn is_cancelled(&self) -> bool {
92        self.cancellation_token
93            .as_ref()
94            .is_some_and(|t| t.is_cancelled())
95    }
96}
97
98/// Tool configuration options
99///
100/// Configurable options for tool execution behavior.
101/// Requirements: 1.3
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ToolOptions {
104    /// Maximum number of retry attempts for transient failures
105    pub max_retries: u32,
106
107    /// Base timeout duration for tool execution
108    #[serde(with = "duration_serde")]
109    pub base_timeout: Duration,
110
111    /// Whether to enable dynamic timeout adjustment
112    pub enable_dynamic_timeout: bool,
113
114    /// List of error patterns that are considered retryable
115    pub retryable_errors: Vec<String>,
116}
117
118impl Default for ToolOptions {
119    fn default() -> Self {
120        Self {
121            max_retries: 3,
122            base_timeout: Duration::from_secs(30),
123            enable_dynamic_timeout: true,
124            retryable_errors: vec![
125                "timeout".to_string(),
126                "connection refused".to_string(),
127                "temporary failure".to_string(),
128            ],
129        }
130    }
131}
132
133impl ToolOptions {
134    /// Create new ToolOptions with default values
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    /// Set maximum retries
140    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
141        self.max_retries = max_retries;
142        self
143    }
144
145    /// Set base timeout
146    pub fn with_base_timeout(mut self, timeout: Duration) -> Self {
147        self.base_timeout = timeout;
148        self
149    }
150
151    /// Enable or disable dynamic timeout
152    pub fn with_dynamic_timeout(mut self, enabled: bool) -> Self {
153        self.enable_dynamic_timeout = enabled;
154        self
155    }
156
157    /// Set retryable error patterns
158    pub fn with_retryable_errors(mut self, errors: Vec<String>) -> Self {
159        self.retryable_errors = errors;
160        self
161    }
162
163    /// Check if an error message matches any retryable pattern
164    pub fn is_error_retryable(&self, error_msg: &str) -> bool {
165        let error_lower = error_msg.to_lowercase();
166        self.retryable_errors
167            .iter()
168            .any(|pattern| error_lower.contains(&pattern.to_lowercase()))
169    }
170}
171
172/// Tool definition for LLM consumption
173///
174/// Contains the information needed by an LLM to understand and use a tool.
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ToolDefinition {
177    /// Tool name (unique identifier)
178    pub name: String,
179
180    /// Human-readable description of what the tool does
181    pub description: String,
182
183    /// JSON Schema for the tool's input parameters
184    pub input_schema: serde_json::Value,
185}
186
187impl ToolDefinition {
188    /// Create a new ToolDefinition
189    pub fn new(
190        name: impl Into<String>,
191        description: impl Into<String>,
192        input_schema: serde_json::Value,
193    ) -> Self {
194        Self {
195            name: name.into(),
196            description: description.into(),
197            input_schema,
198        }
199    }
200}
201
202/// Tool execution result
203///
204/// Contains the outcome of a tool execution.
205/// Requirements: 1.4
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ToolResult {
208    /// Whether the execution was successful
209    pub success: bool,
210
211    /// Output content (if successful)
212    pub output: Option<String>,
213
214    /// Error message (if failed)
215    pub error: Option<String>,
216
217    /// Additional metadata about the execution
218    pub metadata: HashMap<String, serde_json::Value>,
219}
220
221impl Default for ToolResult {
222    fn default() -> Self {
223        Self {
224            success: true,
225            output: None,
226            error: None,
227            metadata: HashMap::new(),
228        }
229    }
230}
231
232impl ToolResult {
233    /// Create a successful result with output
234    pub fn success(output: impl Into<String>) -> Self {
235        Self {
236            success: true,
237            output: Some(output.into()),
238            error: None,
239            metadata: HashMap::new(),
240        }
241    }
242
243    /// Create a successful result without output
244    pub fn success_empty() -> Self {
245        Self {
246            success: true,
247            output: None,
248            error: None,
249            metadata: HashMap::new(),
250        }
251    }
252
253    /// Create a failed result with error message
254    pub fn error(error: impl Into<String>) -> Self {
255        Self {
256            success: false,
257            output: None,
258            error: Some(error.into()),
259            metadata: HashMap::new(),
260        }
261    }
262
263    /// Add metadata to the result
264    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
265        self.metadata.insert(key.into(), value);
266        self
267    }
268
269    /// Add multiple metadata entries
270    pub fn with_metadata_map(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
271        self.metadata.extend(metadata);
272        self
273    }
274
275    /// Check if the result indicates success
276    pub fn is_success(&self) -> bool {
277        self.success
278    }
279
280    /// Check if the result indicates failure
281    pub fn is_error(&self) -> bool {
282        !self.success
283    }
284
285    /// Get the output or error message
286    pub fn message(&self) -> Option<&str> {
287        if self.success {
288            self.output.as_deref()
289        } else {
290            self.error.as_deref()
291        }
292    }
293
294    /// Get the content (output or error message)
295    pub fn content(&self) -> &str {
296        self.message().unwrap_or("")
297    }
298
299    /// Create a new result with updated content
300    pub fn with_content(mut self, content: impl Into<String>) -> Self {
301        let content = content.into();
302        if self.success {
303            self.output = Some(content);
304        } else {
305            self.error = Some(content);
306        }
307        self
308    }
309}
310
311/// Serde helper for Duration serialization
312mod duration_serde {
313    use serde::{Deserialize, Deserializer, Serialize, Serializer};
314    use std::time::Duration;
315
316    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
317    where
318        S: Serializer,
319    {
320        duration.as_secs().serialize(serializer)
321    }
322
323    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
324    where
325        D: Deserializer<'de>,
326    {
327        let secs = u64::deserialize(deserializer)?;
328        Ok(Duration::from_secs(secs))
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_tool_context_default() {
338        let ctx = ToolContext::default();
339        assert!(ctx.session_id.is_empty());
340        assert!(ctx.user.is_none());
341        assert!(ctx.environment.is_empty());
342        assert!(ctx.cancellation_token.is_none());
343    }
344
345    #[test]
346    fn test_tool_context_builder() {
347        let ctx = ToolContext::new(PathBuf::from("/tmp"))
348            .with_session_id("session-123")
349            .with_user("test-user")
350            .with_env_var("HOME", "/home/test");
351
352        assert_eq!(ctx.working_directory, PathBuf::from("/tmp"));
353        assert_eq!(ctx.session_id, "session-123");
354        assert_eq!(ctx.user, Some("test-user".to_string()));
355        assert_eq!(ctx.environment.get("HOME"), Some(&"/home/test".to_string()));
356    }
357
358    #[test]
359    fn test_tool_context_cancellation() {
360        let token = CancellationToken::new();
361        let ctx = ToolContext::default().with_cancellation_token(token.clone());
362
363        assert!(!ctx.is_cancelled());
364        token.cancel();
365        assert!(ctx.is_cancelled());
366    }
367
368    #[test]
369    fn test_tool_options_default() {
370        let opts = ToolOptions::default();
371        assert_eq!(opts.max_retries, 3);
372        assert_eq!(opts.base_timeout, Duration::from_secs(30));
373        assert!(opts.enable_dynamic_timeout);
374        assert!(!opts.retryable_errors.is_empty());
375    }
376
377    #[test]
378    fn test_tool_options_builder() {
379        let opts = ToolOptions::new()
380            .with_max_retries(5)
381            .with_base_timeout(Duration::from_secs(60))
382            .with_dynamic_timeout(false);
383
384        assert_eq!(opts.max_retries, 5);
385        assert_eq!(opts.base_timeout, Duration::from_secs(60));
386        assert!(!opts.enable_dynamic_timeout);
387    }
388
389    #[test]
390    fn test_tool_options_is_error_retryable() {
391        let opts = ToolOptions::default();
392        assert!(opts.is_error_retryable("Connection timeout occurred"));
393        assert!(opts.is_error_retryable("TIMEOUT"));
394        assert!(opts.is_error_retryable("connection refused by server"));
395        assert!(!opts.is_error_retryable("permission denied"));
396        assert!(!opts.is_error_retryable("file not found"));
397    }
398
399    #[test]
400    fn test_tool_definition() {
401        let schema = serde_json::json!({
402            "type": "object",
403            "properties": {
404                "command": { "type": "string" }
405            },
406            "required": ["command"]
407        });
408
409        let def = ToolDefinition::new("bash", "Execute shell commands", schema.clone());
410
411        assert_eq!(def.name, "bash");
412        assert_eq!(def.description, "Execute shell commands");
413        assert_eq!(def.input_schema, schema);
414    }
415
416    #[test]
417    fn test_tool_result_success() {
418        let result = ToolResult::success("Hello, World!");
419        assert!(result.is_success());
420        assert!(!result.is_error());
421        assert_eq!(result.output, Some("Hello, World!".to_string()));
422        assert!(result.error.is_none());
423        assert_eq!(result.message(), Some("Hello, World!"));
424    }
425
426    #[test]
427    fn test_tool_result_success_empty() {
428        let result = ToolResult::success_empty();
429        assert!(result.is_success());
430        assert!(result.output.is_none());
431        assert!(result.error.is_none());
432    }
433
434    #[test]
435    fn test_tool_result_error() {
436        let result = ToolResult::error("Something went wrong");
437        assert!(!result.is_success());
438        assert!(result.is_error());
439        assert!(result.output.is_none());
440        assert_eq!(result.error, Some("Something went wrong".to_string()));
441        assert_eq!(result.message(), Some("Something went wrong"));
442    }
443
444    #[test]
445    fn test_tool_result_with_metadata() {
446        let result = ToolResult::success("output")
447            .with_metadata("duration_ms", serde_json::json!(100))
448            .with_metadata("exit_code", serde_json::json!(0));
449
450        assert_eq!(
451            result.metadata.get("duration_ms"),
452            Some(&serde_json::json!(100))
453        );
454        assert_eq!(
455            result.metadata.get("exit_code"),
456            Some(&serde_json::json!(0))
457        );
458    }
459
460    #[test]
461    fn test_tool_options_serialization() {
462        let opts = ToolOptions::default();
463        let json = serde_json::to_string(&opts).unwrap();
464        let deserialized: ToolOptions = serde_json::from_str(&json).unwrap();
465
466        assert_eq!(opts.max_retries, deserialized.max_retries);
467        assert_eq!(opts.base_timeout, deserialized.base_timeout);
468        assert_eq!(
469            opts.enable_dynamic_timeout,
470            deserialized.enable_dynamic_timeout
471        );
472    }
473
474    #[test]
475    fn test_tool_result_serialization() {
476        let result =
477            ToolResult::success("test output").with_metadata("key", serde_json::json!("value"));
478
479        let json = serde_json::to_string(&result).unwrap();
480        let deserialized: ToolResult = serde_json::from_str(&json).unwrap();
481
482        assert_eq!(result.success, deserialized.success);
483        assert_eq!(result.output, deserialized.output);
484        assert_eq!(result.metadata, deserialized.metadata);
485    }
486}