1use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use std::time::Duration;
15use tokio_util::sync::CancellationToken;
16
17#[derive(Debug, Clone)]
22pub struct ToolContext {
23 pub working_directory: PathBuf,
25
26 pub session_id: String,
28
29 pub user: Option<String>,
31
32 pub environment: HashMap<String, String>,
34
35 pub cancellation_token: Option<CancellationToken>,
37}
38
39impl Default for ToolContext {
40 fn default() -> Self {
41 Self {
42 working_directory: std::env::current_dir().unwrap_or_default(),
43 session_id: String::new(),
44 user: None,
45 environment: HashMap::new(),
46 cancellation_token: None,
47 }
48 }
49}
50
51impl ToolContext {
52 pub fn new(working_directory: PathBuf) -> Self {
54 Self {
55 working_directory,
56 ..Default::default()
57 }
58 }
59
60 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
62 self.session_id = session_id.into();
63 self
64 }
65
66 pub fn with_user(mut self, user: impl Into<String>) -> Self {
68 self.user = Some(user.into());
69 self
70 }
71
72 pub fn with_environment(mut self, environment: HashMap<String, String>) -> Self {
74 self.environment = environment;
75 self
76 }
77
78 pub fn with_env_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
80 self.environment.insert(key.into(), value.into());
81 self
82 }
83
84 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
86 self.cancellation_token = Some(token);
87 self
88 }
89
90 pub fn is_cancelled(&self) -> bool {
92 self.cancellation_token
93 .as_ref()
94 .is_some_and(|t| t.is_cancelled())
95 }
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ToolOptions {
104 pub max_retries: u32,
106
107 #[serde(with = "duration_serde")]
109 pub base_timeout: Duration,
110
111 pub enable_dynamic_timeout: bool,
113
114 pub retryable_errors: Vec<String>,
116}
117
118impl Default for ToolOptions {
119 fn default() -> Self {
120 Self {
121 max_retries: 3,
122 base_timeout: Duration::from_secs(30),
123 enable_dynamic_timeout: true,
124 retryable_errors: vec![
125 "timeout".to_string(),
126 "connection refused".to_string(),
127 "temporary failure".to_string(),
128 ],
129 }
130 }
131}
132
133impl ToolOptions {
134 pub fn new() -> Self {
136 Self::default()
137 }
138
139 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
141 self.max_retries = max_retries;
142 self
143 }
144
145 pub fn with_base_timeout(mut self, timeout: Duration) -> Self {
147 self.base_timeout = timeout;
148 self
149 }
150
151 pub fn with_dynamic_timeout(mut self, enabled: bool) -> Self {
153 self.enable_dynamic_timeout = enabled;
154 self
155 }
156
157 pub fn with_retryable_errors(mut self, errors: Vec<String>) -> Self {
159 self.retryable_errors = errors;
160 self
161 }
162
163 pub fn is_error_retryable(&self, error_msg: &str) -> bool {
165 let error_lower = error_msg.to_lowercase();
166 self.retryable_errors
167 .iter()
168 .any(|pattern| error_lower.contains(&pattern.to_lowercase()))
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ToolDefinition {
177 pub name: String,
179
180 pub description: String,
182
183 pub input_schema: serde_json::Value,
185}
186
187impl ToolDefinition {
188 pub fn new(
190 name: impl Into<String>,
191 description: impl Into<String>,
192 input_schema: serde_json::Value,
193 ) -> Self {
194 Self {
195 name: name.into(),
196 description: description.into(),
197 input_schema,
198 }
199 }
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ToolResult {
208 pub success: bool,
210
211 pub output: Option<String>,
213
214 pub error: Option<String>,
216
217 pub metadata: HashMap<String, serde_json::Value>,
219}
220
221impl Default for ToolResult {
222 fn default() -> Self {
223 Self {
224 success: true,
225 output: None,
226 error: None,
227 metadata: HashMap::new(),
228 }
229 }
230}
231
232impl ToolResult {
233 pub fn success(output: impl Into<String>) -> Self {
235 Self {
236 success: true,
237 output: Some(output.into()),
238 error: None,
239 metadata: HashMap::new(),
240 }
241 }
242
243 pub fn success_empty() -> Self {
245 Self {
246 success: true,
247 output: None,
248 error: None,
249 metadata: HashMap::new(),
250 }
251 }
252
253 pub fn error(error: impl Into<String>) -> Self {
255 Self {
256 success: false,
257 output: None,
258 error: Some(error.into()),
259 metadata: HashMap::new(),
260 }
261 }
262
263 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
265 self.metadata.insert(key.into(), value);
266 self
267 }
268
269 pub fn with_metadata_map(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
271 self.metadata.extend(metadata);
272 self
273 }
274
275 pub fn is_success(&self) -> bool {
277 self.success
278 }
279
280 pub fn is_error(&self) -> bool {
282 !self.success
283 }
284
285 pub fn message(&self) -> Option<&str> {
287 if self.success {
288 self.output.as_deref()
289 } else {
290 self.error.as_deref()
291 }
292 }
293
294 pub fn content(&self) -> &str {
296 self.message().unwrap_or("")
297 }
298
299 pub fn with_content(mut self, content: impl Into<String>) -> Self {
301 let content = content.into();
302 if self.success {
303 self.output = Some(content);
304 } else {
305 self.error = Some(content);
306 }
307 self
308 }
309}
310
311mod duration_serde {
313 use serde::{Deserialize, Deserializer, Serialize, Serializer};
314 use std::time::Duration;
315
316 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
317 where
318 S: Serializer,
319 {
320 duration.as_secs().serialize(serializer)
321 }
322
323 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
324 where
325 D: Deserializer<'de>,
326 {
327 let secs = u64::deserialize(deserializer)?;
328 Ok(Duration::from_secs(secs))
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_tool_context_default() {
338 let ctx = ToolContext::default();
339 assert!(ctx.session_id.is_empty());
340 assert!(ctx.user.is_none());
341 assert!(ctx.environment.is_empty());
342 assert!(ctx.cancellation_token.is_none());
343 }
344
345 #[test]
346 fn test_tool_context_builder() {
347 let ctx = ToolContext::new(PathBuf::from("/tmp"))
348 .with_session_id("session-123")
349 .with_user("test-user")
350 .with_env_var("HOME", "/home/test");
351
352 assert_eq!(ctx.working_directory, PathBuf::from("/tmp"));
353 assert_eq!(ctx.session_id, "session-123");
354 assert_eq!(ctx.user, Some("test-user".to_string()));
355 assert_eq!(ctx.environment.get("HOME"), Some(&"/home/test".to_string()));
356 }
357
358 #[test]
359 fn test_tool_context_cancellation() {
360 let token = CancellationToken::new();
361 let ctx = ToolContext::default().with_cancellation_token(token.clone());
362
363 assert!(!ctx.is_cancelled());
364 token.cancel();
365 assert!(ctx.is_cancelled());
366 }
367
368 #[test]
369 fn test_tool_options_default() {
370 let opts = ToolOptions::default();
371 assert_eq!(opts.max_retries, 3);
372 assert_eq!(opts.base_timeout, Duration::from_secs(30));
373 assert!(opts.enable_dynamic_timeout);
374 assert!(!opts.retryable_errors.is_empty());
375 }
376
377 #[test]
378 fn test_tool_options_builder() {
379 let opts = ToolOptions::new()
380 .with_max_retries(5)
381 .with_base_timeout(Duration::from_secs(60))
382 .with_dynamic_timeout(false);
383
384 assert_eq!(opts.max_retries, 5);
385 assert_eq!(opts.base_timeout, Duration::from_secs(60));
386 assert!(!opts.enable_dynamic_timeout);
387 }
388
389 #[test]
390 fn test_tool_options_is_error_retryable() {
391 let opts = ToolOptions::default();
392 assert!(opts.is_error_retryable("Connection timeout occurred"));
393 assert!(opts.is_error_retryable("TIMEOUT"));
394 assert!(opts.is_error_retryable("connection refused by server"));
395 assert!(!opts.is_error_retryable("permission denied"));
396 assert!(!opts.is_error_retryable("file not found"));
397 }
398
399 #[test]
400 fn test_tool_definition() {
401 let schema = serde_json::json!({
402 "type": "object",
403 "properties": {
404 "command": { "type": "string" }
405 },
406 "required": ["command"]
407 });
408
409 let def = ToolDefinition::new("bash", "Execute shell commands", schema.clone());
410
411 assert_eq!(def.name, "bash");
412 assert_eq!(def.description, "Execute shell commands");
413 assert_eq!(def.input_schema, schema);
414 }
415
416 #[test]
417 fn test_tool_result_success() {
418 let result = ToolResult::success("Hello, World!");
419 assert!(result.is_success());
420 assert!(!result.is_error());
421 assert_eq!(result.output, Some("Hello, World!".to_string()));
422 assert!(result.error.is_none());
423 assert_eq!(result.message(), Some("Hello, World!"));
424 }
425
426 #[test]
427 fn test_tool_result_success_empty() {
428 let result = ToolResult::success_empty();
429 assert!(result.is_success());
430 assert!(result.output.is_none());
431 assert!(result.error.is_none());
432 }
433
434 #[test]
435 fn test_tool_result_error() {
436 let result = ToolResult::error("Something went wrong");
437 assert!(!result.is_success());
438 assert!(result.is_error());
439 assert!(result.output.is_none());
440 assert_eq!(result.error, Some("Something went wrong".to_string()));
441 assert_eq!(result.message(), Some("Something went wrong"));
442 }
443
444 #[test]
445 fn test_tool_result_with_metadata() {
446 let result = ToolResult::success("output")
447 .with_metadata("duration_ms", serde_json::json!(100))
448 .with_metadata("exit_code", serde_json::json!(0));
449
450 assert_eq!(
451 result.metadata.get("duration_ms"),
452 Some(&serde_json::json!(100))
453 );
454 assert_eq!(
455 result.metadata.get("exit_code"),
456 Some(&serde_json::json!(0))
457 );
458 }
459
460 #[test]
461 fn test_tool_options_serialization() {
462 let opts = ToolOptions::default();
463 let json = serde_json::to_string(&opts).unwrap();
464 let deserialized: ToolOptions = serde_json::from_str(&json).unwrap();
465
466 assert_eq!(opts.max_retries, deserialized.max_retries);
467 assert_eq!(opts.base_timeout, deserialized.base_timeout);
468 assert_eq!(
469 opts.enable_dynamic_timeout,
470 deserialized.enable_dynamic_timeout
471 );
472 }
473
474 #[test]
475 fn test_tool_result_serialization() {
476 let result =
477 ToolResult::success("test output").with_metadata("key", serde_json::json!("value"));
478
479 let json = serde_json::to_string(&result).unwrap();
480 let deserialized: ToolResult = serde_json::from_str(&json).unwrap();
481
482 assert_eq!(result.success, deserialized.success);
483 assert_eq!(result.output, deserialized.output);
484 assert_eq!(result.metadata, deserialized.metadata);
485 }
486}