1use crate::policy::Capability;
8use serde::{Deserialize, Serialize};
9use std::collections::BTreeMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolCall {
17 pub call_id: String,
18 pub tool_name: String,
19 pub input: serde_json::Value,
20 #[serde(default)]
21 pub requested_capabilities: Vec<Capability>,
22}
23
24impl ToolCall {
25 pub fn new(
26 tool_name: impl Into<String>,
27 input: serde_json::Value,
28 requested_capabilities: Vec<Capability>,
29 ) -> Self {
30 Self {
31 call_id: uuid::Uuid::new_v4().to_string(),
32 tool_name: tool_name.into(),
33 input,
34 requested_capabilities,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(tag = "status", rename_all = "snake_case")]
45pub enum ToolOutcome {
46 Success { output: serde_json::Value },
47 Failure { error: String },
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
57pub struct ToolAnnotations {
58 #[serde(default)]
60 pub read_only: bool,
61 #[serde(default)]
63 pub destructive: bool,
64 #[serde(default)]
66 pub idempotent: bool,
67 #[serde(default)]
69 pub open_world: bool,
70 #[serde(default)]
72 pub requires_confirmation: bool,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub struct ToolDefinition {
83 pub name: String,
85 pub description: String,
87 pub input_schema: serde_json::Value,
89
90 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub title: Option<String>,
94 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub output_schema: Option<serde_json::Value>,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
99 pub annotations: Option<ToolAnnotations>,
100
101 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub category: Option<String>,
105 #[serde(default, skip_serializing_if = "Vec::is_empty")]
107 pub tags: Vec<String>,
108 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub timeout_secs: Option<u32>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119#[serde(tag = "type", rename_all = "snake_case")]
120pub enum ToolContent {
121 Text { text: String },
122 Image { data: String, mime_type: String },
123 Json { value: serde_json::Value },
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136pub struct ToolResult {
137 pub call_id: String,
138 pub tool_name: String,
139 #[serde(default)]
141 pub output: serde_json::Value,
142 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub content: Option<Vec<ToolContent>>,
145 #[serde(default)]
147 pub is_error: bool,
148}
149
150impl ToolResult {
151 pub fn text(call_id: impl Into<String>, tool_name: impl Into<String>, text: &str) -> Self {
153 Self {
154 call_id: call_id.into(),
155 tool_name: tool_name.into(),
156 output: serde_json::Value::String(text.to_string()),
157 content: Some(vec![ToolContent::Text {
158 text: text.to_string(),
159 }]),
160 is_error: false,
161 }
162 }
163
164 pub fn json(
166 call_id: impl Into<String>,
167 tool_name: impl Into<String>,
168 value: serde_json::Value,
169 ) -> Self {
170 Self {
171 call_id: call_id.into(),
172 tool_name: tool_name.into(),
173 output: value.clone(),
174 content: Some(vec![ToolContent::Json { value }]),
175 is_error: false,
176 }
177 }
178
179 pub fn error(call_id: impl Into<String>, tool_name: impl Into<String>, message: &str) -> Self {
181 Self {
182 call_id: call_id.into(),
183 tool_name: tool_name.into(),
184 output: serde_json::json!({ "error": message }),
185 content: Some(vec![ToolContent::Text {
186 text: message.to_string(),
187 }]),
188 is_error: true,
189 }
190 }
191}
192
193impl From<&ToolResult> for ToolOutcome {
195 fn from(result: &ToolResult) -> Self {
196 if result.is_error {
197 ToolOutcome::Failure {
198 error: match &result.output {
199 serde_json::Value::String(s) => s.clone(),
200 other => other.to_string(),
201 },
202 }
203 } else {
204 ToolOutcome::Success {
205 output: result.output.clone(),
206 }
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
218pub struct ToolContext {
219 pub run_id: String,
220 pub session_id: String,
221 pub iteration: u32,
222}
223
224#[derive(Debug, thiserror::Error)]
228pub enum ToolError {
229 #[error("tool not found: {tool_name}")]
230 NotFound { tool_name: String },
231
232 #[error("[{tool_name}] execution failed: {message}")]
233 ExecutionFailed { tool_name: String, message: String },
234
235 #[error("invalid input: {message}")]
236 InvalidInput { message: String },
237
238 #[error("[{tool_name}] timed out after {timeout_secs}s")]
239 Timeout {
240 tool_name: String,
241 timeout_secs: u32,
242 },
243
244 #[error("workspace policy violation: {message}")]
245 PolicyViolation { message: String },
246
247 #[error("{0}")]
248 Other(String),
249}
250
251pub trait Tool: Send + Sync {
263 fn definition(&self) -> ToolDefinition;
265
266 fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, ToolError>;
268}
269
270#[derive(Clone, Default)]
274pub struct ToolRegistry {
275 tools: BTreeMap<String, Arc<dyn Tool>>,
276}
277
278impl ToolRegistry {
279 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
281 self.tools
282 .insert(tool.definition().name.clone(), Arc::new(tool));
283 }
284
285 pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
287 self.tools.insert(tool.definition().name.clone(), tool);
288 }
289
290 pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
292 self.tools.get(tool_name).cloned()
293 }
294
295 pub fn definitions(&self) -> Vec<ToolDefinition> {
297 self.tools.values().map(|tool| tool.definition()).collect()
298 }
299
300 pub fn len(&self) -> usize {
302 self.tools.len()
303 }
304
305 pub fn is_empty(&self) -> bool {
307 self.tools.is_empty()
308 }
309
310 pub fn names(&self) -> Vec<String> {
312 self.tools.keys().cloned().collect()
313 }
314}
315
316impl std::fmt::Debug for ToolRegistry {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 f.debug_struct("ToolRegistry")
319 .field("tools", &self.tools.keys().collect::<Vec<_>>())
320 .finish()
321 }
322}
323
324#[cfg(test)]
327mod tests {
328 use super::*;
329 use serde_json::json;
330
331 #[test]
334 fn tool_call_new() {
335 let tc = ToolCall::new("read_file", json!({"path": "/tmp"}), vec![]);
336 assert_eq!(tc.tool_name, "read_file");
337 assert!(!tc.call_id.is_empty());
338 }
339
340 #[test]
341 fn tool_outcome_serde_roundtrip() {
342 let success = ToolOutcome::Success {
343 output: json!({"data": 42}),
344 };
345 let json_str = serde_json::to_string(&success).unwrap();
346 assert!(json_str.contains("\"status\":\"success\""));
347 let back: ToolOutcome = serde_json::from_str(&json_str).unwrap();
348 assert!(matches!(back, ToolOutcome::Success { .. }));
349
350 let failure = ToolOutcome::Failure {
351 error: "not found".into(),
352 };
353 let json_str = serde_json::to_string(&failure).unwrap();
354 assert!(json_str.contains("\"status\":\"failure\""));
355 }
356
357 #[test]
360 fn annotations_default_all_false() {
361 let ann = ToolAnnotations::default();
362 assert!(!ann.read_only);
363 assert!(!ann.destructive);
364 assert!(!ann.idempotent);
365 assert!(!ann.open_world);
366 assert!(!ann.requires_confirmation);
367 }
368
369 #[test]
370 fn annotations_serde_roundtrip() {
371 let ann = ToolAnnotations {
372 read_only: true,
373 destructive: false,
374 idempotent: true,
375 open_world: false,
376 requires_confirmation: true,
377 };
378 let json_str = serde_json::to_string(&ann).unwrap();
379 let back: ToolAnnotations = serde_json::from_str(&json_str).unwrap();
380 assert_eq!(ann, back);
381 }
382
383 #[test]
384 fn annotations_missing_fields_default_false() {
385 let json_str = r#"{"read_only": true}"#;
386 let ann: ToolAnnotations = serde_json::from_str(json_str).unwrap();
387 assert!(ann.read_only);
388 assert!(!ann.destructive);
389 }
390
391 #[test]
394 fn tool_definition_minimal() {
395 let def = ToolDefinition {
396 name: "test_tool".into(),
397 description: "A test tool".into(),
398 input_schema: json!({"type": "object"}),
399 title: None,
400 output_schema: None,
401 annotations: None,
402 category: None,
403 tags: vec![],
404 timeout_secs: None,
405 };
406 let json_str = serde_json::to_string(&def).unwrap();
407 assert!(!json_str.contains("title"));
409 assert!(!json_str.contains("tags"));
410 let back: ToolDefinition = serde_json::from_str(&json_str).unwrap();
411 assert_eq!(def, back);
412 }
413
414 #[test]
415 fn tool_definition_full() {
416 let def = ToolDefinition {
417 name: "read_file".into(),
418 description: "Read a file from the workspace".into(),
419 input_schema: json!({
420 "type": "object",
421 "properties": { "path": { "type": "string" } },
422 "required": ["path"]
423 }),
424 title: Some("Read File".into()),
425 output_schema: Some(json!({"type": "string"})),
426 annotations: Some(ToolAnnotations {
427 read_only: true,
428 idempotent: true,
429 ..Default::default()
430 }),
431 category: Some("filesystem".into()),
432 tags: vec!["fs".into(), "read".into()],
433 timeout_secs: Some(30),
434 };
435 let json_str = serde_json::to_string(&def).unwrap();
436 let back: ToolDefinition = serde_json::from_str(&json_str).unwrap();
437 assert_eq!(def, back);
438 assert!(json_str.contains("\"category\":\"filesystem\""));
439 }
440
441 #[test]
444 fn tool_content_text_serde() {
445 let content = ToolContent::Text {
446 text: "hello".into(),
447 };
448 let json_str = serde_json::to_string(&content).unwrap();
449 assert!(json_str.contains("\"type\":\"text\""));
450 let back: ToolContent = serde_json::from_str(&json_str).unwrap();
451 assert_eq!(content, back);
452 }
453
454 #[test]
455 fn tool_content_json_serde() {
456 let content = ToolContent::Json {
457 value: json!({"key": "value"}),
458 };
459 let json_str = serde_json::to_string(&content).unwrap();
460 assert!(json_str.contains("\"type\":\"json\""));
461 let back: ToolContent = serde_json::from_str(&json_str).unwrap();
462 assert_eq!(content, back);
463 }
464
465 #[test]
466 fn tool_content_image_serde() {
467 let content = ToolContent::Image {
468 data: "base64data".into(),
469 mime_type: "image/png".into(),
470 };
471 let json_str = serde_json::to_string(&content).unwrap();
472 let back: ToolContent = serde_json::from_str(&json_str).unwrap();
473 assert_eq!(content, back);
474 }
475
476 #[test]
479 fn tool_result_text_helper() {
480 let result = ToolResult::text("call-1", "echo", "hello world");
481 assert_eq!(result.call_id, "call-1");
482 assert_eq!(result.tool_name, "echo");
483 assert!(!result.is_error);
484 assert!(result.content.is_some());
485 }
486
487 #[test]
488 fn tool_result_json_helper() {
489 let result = ToolResult::json("call-2", "search", json!({"matches": 5}));
490 assert!(!result.is_error);
491 assert_eq!(result.output, json!({"matches": 5}));
492 }
493
494 #[test]
495 fn tool_result_error_helper() {
496 let result = ToolResult::error("call-3", "bash", "permission denied");
497 assert!(result.is_error);
498 assert_eq!(result.output["error"], "permission denied");
499 }
500
501 #[test]
502 fn tool_result_serde_roundtrip() {
503 let result = ToolResult {
504 call_id: "c1".into(),
505 tool_name: "test".into(),
506 output: json!({"ok": true}),
507 content: Some(vec![ToolContent::Text {
508 text: "success".into(),
509 }]),
510 is_error: false,
511 };
512 let json_str = serde_json::to_string(&result).unwrap();
513 let back: ToolResult = serde_json::from_str(&json_str).unwrap();
514 assert_eq!(result, back);
515 }
516
517 #[test]
520 fn tool_result_to_outcome_success() {
521 let result = ToolResult::json("c1", "test", json!({"data": 42}));
522 let outcome: ToolOutcome = ToolOutcome::from(&result);
523 assert!(matches!(outcome, ToolOutcome::Success { .. }));
524 }
525
526 #[test]
527 fn tool_result_to_outcome_failure() {
528 let result = ToolResult::error("c1", "test", "oops");
529 let outcome: ToolOutcome = ToolOutcome::from(&result);
530 match outcome {
531 ToolOutcome::Failure { error } => assert!(error.contains("oops")),
532 _ => panic!("expected failure"),
533 }
534 }
535
536 struct EchoTool;
539
540 impl Tool for EchoTool {
541 fn definition(&self) -> ToolDefinition {
542 ToolDefinition {
543 name: "echo".into(),
544 description: "Echoes the input value".into(),
545 input_schema: json!({
546 "type": "object",
547 "properties": { "value": { "type": "string" } },
548 "required": ["value"]
549 }),
550 title: None,
551 output_schema: None,
552 annotations: Some(ToolAnnotations {
553 read_only: true,
554 idempotent: true,
555 ..Default::default()
556 }),
557 category: Some("test".into()),
558 tags: vec![],
559 timeout_secs: Some(10),
560 }
561 }
562
563 fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
564 let value = call.input.get("value").cloned().unwrap_or(json!(null));
565 Ok(ToolResult::json(&call.call_id, &call.tool_name, value))
566 }
567 }
568
569 struct FailTool;
570
571 impl Tool for FailTool {
572 fn definition(&self) -> ToolDefinition {
573 ToolDefinition {
574 name: "fail".into(),
575 description: "Always fails".into(),
576 input_schema: json!({"type": "object"}),
577 title: None,
578 output_schema: None,
579 annotations: None,
580 category: None,
581 tags: vec![],
582 timeout_secs: None,
583 }
584 }
585
586 fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
587 Err(ToolError::ExecutionFailed {
588 tool_name: call.tool_name.clone(),
589 message: "always fails".into(),
590 })
591 }
592 }
593
594 fn test_context() -> ToolContext {
595 ToolContext {
596 run_id: "run-1".into(),
597 session_id: "sess-1".into(),
598 iteration: 1,
599 }
600 }
601
602 #[test]
603 fn tool_trait_execute_success() {
604 let tool = EchoTool;
605 let call = ToolCall::new("echo", json!({"value": "hello"}), vec![]);
606 let result = tool.execute(&call, &test_context()).unwrap();
607 assert!(!result.is_error);
608 assert_eq!(result.output, json!("hello"));
609 }
610
611 #[test]
612 fn tool_trait_execute_error() {
613 let tool = FailTool;
614 let call = ToolCall::new("fail", json!({}), vec![]);
615 let err = tool.execute(&call, &test_context()).unwrap_err();
616 assert!(matches!(err, ToolError::ExecutionFailed { .. }));
617 assert!(err.to_string().contains("always fails"));
618 }
619
620 #[test]
621 fn registry_register_and_get() {
622 let mut reg = ToolRegistry::default();
623 assert!(reg.is_empty());
624
625 reg.register(EchoTool);
626 assert_eq!(reg.len(), 1);
627 assert!(!reg.is_empty());
628
629 let tool = reg.get("echo").expect("should find echo");
630 let def = tool.definition();
631 assert_eq!(def.name, "echo");
632 }
633
634 #[test]
635 fn registry_get_missing() {
636 let reg = ToolRegistry::default();
637 assert!(reg.get("nonexistent").is_none());
638 }
639
640 #[test]
641 fn registry_definitions() {
642 let mut reg = ToolRegistry::default();
643 reg.register(EchoTool);
644 reg.register(FailTool);
645
646 let defs = reg.definitions();
647 assert_eq!(defs.len(), 2);
648 let names: Vec<_> = defs.iter().map(|d| d.name.as_str()).collect();
649 assert!(names.contains(&"echo"));
650 assert!(names.contains(&"fail"));
651 }
652
653 #[test]
654 fn registry_names() {
655 let mut reg = ToolRegistry::default();
656 reg.register(EchoTool);
657 reg.register(FailTool);
658
659 let names = reg.names();
660 assert_eq!(names.len(), 2);
661 assert!(names.contains(&"echo".to_string()));
662 assert!(names.contains(&"fail".to_string()));
663 }
664
665 #[test]
666 fn registry_register_replaces_existing() {
667 let mut reg = ToolRegistry::default();
668 reg.register(EchoTool);
669 reg.register(EchoTool); assert_eq!(reg.len(), 1);
671 }
672
673 #[test]
674 fn registry_debug_format() {
675 let mut reg = ToolRegistry::default();
676 reg.register(EchoTool);
677 let debug = format!("{:?}", reg);
678 assert!(debug.contains("echo"));
679 }
680
681 #[test]
684 fn tool_error_display() {
685 let err = ToolError::NotFound {
686 tool_name: "ghost".into(),
687 };
688 assert_eq!(err.to_string(), "tool not found: ghost");
689
690 let err = ToolError::Timeout {
691 tool_name: "slow".into(),
692 timeout_secs: 30,
693 };
694 assert_eq!(err.to_string(), "[slow] timed out after 30s");
695 }
696}