Skip to main content

aios_protocol/
tool.rs

1//! Tool types: calls, outcomes, definitions, results, and the canonical Tool trait.
2//!
3//! This module provides the shared vocabulary for tool execution across all
4//! Agent OS projects. Tool implementations (in Praxis or other runtimes)
5//! implement the [`Tool`] trait defined here.
6
7use crate::policy::Capability;
8use serde::{Deserialize, Serialize};
9use std::collections::BTreeMap;
10use std::sync::Arc;
11
12// ── Existing types (stable) ───────────────────────────────────────────
13
14/// A tool invocation request with capabilities.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolCall {
17    pub call_id: String,
18    pub tool_name: String,
19    pub input: serde_json::Value,
20    #[serde(default)]
21    pub requested_capabilities: Vec<Capability>,
22}
23
24impl ToolCall {
25    pub fn new(
26        tool_name: impl Into<String>,
27        input: serde_json::Value,
28        requested_capabilities: Vec<Capability>,
29    ) -> Self {
30        Self {
31            call_id: uuid::Uuid::new_v4().to_string(),
32            tool_name: tool_name.into(),
33            input,
34            requested_capabilities,
35        }
36    }
37}
38
39/// Tool execution outcome (kernel-level, simplified).
40///
41/// Used at the kernel boundary ([`ToolExecutionReport`](crate::ports::ToolExecutionReport)).
42/// For richer tool results with typed content, see [`ToolResult`].
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(tag = "status", rename_all = "snake_case")]
45pub enum ToolOutcome {
46    Success { output: serde_json::Value },
47    Failure { error: String },
48}
49
50// ── MCP-compatible behavioral annotations ─────────────────────────────
51
52/// Behavioral annotations for tools (MCP-compatible).
53///
54/// These hints inform the runtime about a tool's side effects,
55/// enabling policy enforcement and user confirmation flows.
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
57pub struct ToolAnnotations {
58    /// Tool does not modify its environment.
59    #[serde(default)]
60    pub read_only: bool,
61    /// Tool may perform destructive updates.
62    #[serde(default)]
63    pub destructive: bool,
64    /// Repeated calls with same args produce same result.
65    #[serde(default)]
66    pub idempotent: bool,
67    /// Tool interacts with external entities (network, APIs).
68    #[serde(default)]
69    pub open_world: bool,
70    /// Tool requires user confirmation before execution.
71    #[serde(default)]
72    pub requires_confirmation: bool,
73}
74
75// ── Tool definition ───────────────────────────────────────────────────
76
77/// Complete description of a tool's interface and behavior.
78///
79/// This is the canonical tool definition used across all Agent OS projects.
80/// It is MCP-aligned with additional fields for categorization and timeouts.
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub struct ToolDefinition {
83    /// Unique tool name (e.g. "read_file", "bash").
84    pub name: String,
85    /// Human-readable description of what the tool does.
86    pub description: String,
87    /// JSON Schema describing the tool's input parameters.
88    pub input_schema: serde_json::Value,
89
90    // ── MCP-aligned fields (all optional, backward-compatible) ──
91    /// Human-readable display name (MCP: title).
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub title: Option<String>,
94    /// JSON Schema for structured output (MCP: outputSchema).
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub output_schema: Option<serde_json::Value>,
97    /// Behavioral hints (MCP: annotations).
98    #[serde(default, skip_serializing_if = "Option::is_none")]
99    pub annotations: Option<ToolAnnotations>,
100
101    // ── Agent OS extensions ──
102    /// Tool category for grouping ("filesystem", "code", "shell", "mcp").
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub category: Option<String>,
105    /// Tags for filtering and matching.
106    #[serde(default, skip_serializing_if = "Vec::is_empty")]
107    pub tags: Vec<String>,
108    /// Maximum execution timeout in seconds.
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub timeout_secs: Option<u32>,
111}
112
113// ── Typed content blocks ──────────────────────────────────────────────
114
115/// Typed content block in a tool result (MCP-compatible).
116///
117/// Tools can return structured content alongside the legacy JSON `output` field.
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119#[serde(tag = "type", rename_all = "snake_case")]
120pub enum ToolContent {
121    Text { text: String },
122    Image { data: String, mime_type: String },
123    Json { value: serde_json::Value },
124}
125
126// ── Rich tool result ──────────────────────────────────────────────────
127
128/// Rich tool execution result with typed content.
129///
130/// This is the canonical result type returned by [`Tool::execute`].
131/// It includes both a legacy JSON `output` and optional MCP-style
132/// typed content blocks for richer responses.
133///
134/// For the simplified kernel-level outcome, see [`ToolOutcome`].
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136pub struct ToolResult {
137    pub call_id: String,
138    pub tool_name: String,
139    /// Legacy JSON output (always present for backward compatibility).
140    #[serde(default)]
141    pub output: serde_json::Value,
142    /// MCP-style typed content blocks (optional, alongside output).
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub content: Option<Vec<ToolContent>>,
145    /// Whether this result represents an error (MCP: isError).
146    #[serde(default)]
147    pub is_error: bool,
148}
149
150impl ToolResult {
151    /// Create a successful text result.
152    pub fn text(call_id: impl Into<String>, tool_name: impl Into<String>, text: &str) -> Self {
153        Self {
154            call_id: call_id.into(),
155            tool_name: tool_name.into(),
156            output: serde_json::Value::String(text.to_string()),
157            content: Some(vec![ToolContent::Text {
158                text: text.to_string(),
159            }]),
160            is_error: false,
161        }
162    }
163
164    /// Create a successful JSON result.
165    pub fn json(
166        call_id: impl Into<String>,
167        tool_name: impl Into<String>,
168        value: serde_json::Value,
169    ) -> Self {
170        Self {
171            call_id: call_id.into(),
172            tool_name: tool_name.into(),
173            output: value.clone(),
174            content: Some(vec![ToolContent::Json { value }]),
175            is_error: false,
176        }
177    }
178
179    /// Create an error result.
180    pub fn error(call_id: impl Into<String>, tool_name: impl Into<String>, message: &str) -> Self {
181        Self {
182            call_id: call_id.into(),
183            tool_name: tool_name.into(),
184            output: serde_json::json!({ "error": message }),
185            content: Some(vec![ToolContent::Text {
186                text: message.to_string(),
187            }]),
188            is_error: true,
189        }
190    }
191}
192
193/// Convert a rich `ToolResult` to a simplified `ToolOutcome` for kernel boundaries.
194impl From<&ToolResult> for ToolOutcome {
195    fn from(result: &ToolResult) -> Self {
196        if result.is_error {
197            ToolOutcome::Failure {
198                error: match &result.output {
199                    serde_json::Value::String(s) => s.clone(),
200                    other => other.to_string(),
201                },
202            }
203        } else {
204            ToolOutcome::Success {
205                output: result.output.clone(),
206            }
207        }
208    }
209}
210
211// ── Tool execution context ────────────────────────────────────────────
212
213/// Context provided to a tool during execution.
214///
215/// Contains the identifiers for the current run, session, and iteration
216/// so tools can correlate their actions with the agent loop.
217#[derive(Debug, Clone)]
218pub struct ToolContext {
219    pub run_id: String,
220    pub session_id: String,
221    pub iteration: u32,
222}
223
224// ── Tool errors ───────────────────────────────────────────────────────
225
226/// Errors that can occur during tool execution.
227#[derive(Debug, thiserror::Error)]
228pub enum ToolError {
229    #[error("tool not found: {tool_name}")]
230    NotFound { tool_name: String },
231
232    #[error("[{tool_name}] execution failed: {message}")]
233    ExecutionFailed { tool_name: String, message: String },
234
235    #[error("invalid input: {message}")]
236    InvalidInput { message: String },
237
238    #[error("[{tool_name}] timed out after {timeout_secs}s")]
239    Timeout {
240        tool_name: String,
241        timeout_secs: u32,
242    },
243
244    #[error("workspace policy violation: {message}")]
245    PolicyViolation { message: String },
246
247    #[error("{0}")]
248    Other(String),
249}
250
251// ── Canonical Tool trait ──────────────────────────────────────────────
252
253/// The canonical tool interface for the Agent OS.
254///
255/// All tool implementations (filesystem, shell, MCP bridges, skills)
256/// implement this trait. The trait is synchronous — runtimes wrap
257/// execution in `spawn_blocking` when needed.
258///
259/// # Object Safety
260///
261/// This trait is dyn-compatible (`Arc<dyn Tool>`) for use in registries.
262pub trait Tool: Send + Sync {
263    /// Returns the tool's definition (name, schema, annotations).
264    fn definition(&self) -> ToolDefinition;
265
266    /// Execute the tool with the given call and context.
267    fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, ToolError>;
268}
269
270// ── Tool registry ─────────────────────────────────────────────────────
271
272/// A registry of named tools, used by the orchestrator to dispatch tool calls.
273#[derive(Clone, Default)]
274pub struct ToolRegistry {
275    tools: BTreeMap<String, Arc<dyn Tool>>,
276}
277
278impl ToolRegistry {
279    /// Register a tool. Replaces any existing tool with the same name.
280    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
281        self.tools
282            .insert(tool.definition().name.clone(), Arc::new(tool));
283    }
284
285    /// Register a pre-wrapped `Arc<dyn Tool>`.
286    pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
287        self.tools.insert(tool.definition().name.clone(), tool);
288    }
289
290    /// Look up a tool by name.
291    pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
292        self.tools.get(tool_name).cloned()
293    }
294
295    /// Return definitions for all registered tools.
296    pub fn definitions(&self) -> Vec<ToolDefinition> {
297        self.tools.values().map(|tool| tool.definition()).collect()
298    }
299
300    /// Return the number of registered tools.
301    pub fn len(&self) -> usize {
302        self.tools.len()
303    }
304
305    /// Returns `true` if no tools are registered.
306    pub fn is_empty(&self) -> bool {
307        self.tools.is_empty()
308    }
309
310    /// Return all registered tool names.
311    pub fn names(&self) -> Vec<String> {
312        self.tools.keys().cloned().collect()
313    }
314}
315
316impl std::fmt::Debug for ToolRegistry {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        f.debug_struct("ToolRegistry")
319            .field("tools", &self.tools.keys().collect::<Vec<_>>())
320            .finish()
321    }
322}
323
324// ── Tests ─────────────────────────────────────────────────────────────
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use serde_json::json;
330
331    // ── Existing tests ──
332
333    #[test]
334    fn tool_call_new() {
335        let tc = ToolCall::new("read_file", json!({"path": "/tmp"}), vec![]);
336        assert_eq!(tc.tool_name, "read_file");
337        assert!(!tc.call_id.is_empty());
338    }
339
340    #[test]
341    fn tool_outcome_serde_roundtrip() {
342        let success = ToolOutcome::Success {
343            output: json!({"data": 42}),
344        };
345        let json_str = serde_json::to_string(&success).unwrap();
346        assert!(json_str.contains("\"status\":\"success\""));
347        let back: ToolOutcome = serde_json::from_str(&json_str).unwrap();
348        assert!(matches!(back, ToolOutcome::Success { .. }));
349
350        let failure = ToolOutcome::Failure {
351            error: "not found".into(),
352        };
353        let json_str = serde_json::to_string(&failure).unwrap();
354        assert!(json_str.contains("\"status\":\"failure\""));
355    }
356
357    // ── ToolAnnotations tests ──
358
359    #[test]
360    fn annotations_default_all_false() {
361        let ann = ToolAnnotations::default();
362        assert!(!ann.read_only);
363        assert!(!ann.destructive);
364        assert!(!ann.idempotent);
365        assert!(!ann.open_world);
366        assert!(!ann.requires_confirmation);
367    }
368
369    #[test]
370    fn annotations_serde_roundtrip() {
371        let ann = ToolAnnotations {
372            read_only: true,
373            destructive: false,
374            idempotent: true,
375            open_world: false,
376            requires_confirmation: true,
377        };
378        let json_str = serde_json::to_string(&ann).unwrap();
379        let back: ToolAnnotations = serde_json::from_str(&json_str).unwrap();
380        assert_eq!(ann, back);
381    }
382
383    #[test]
384    fn annotations_missing_fields_default_false() {
385        let json_str = r#"{"read_only": true}"#;
386        let ann: ToolAnnotations = serde_json::from_str(json_str).unwrap();
387        assert!(ann.read_only);
388        assert!(!ann.destructive);
389    }
390
391    // ── ToolDefinition tests ──
392
393    #[test]
394    fn tool_definition_minimal() {
395        let def = ToolDefinition {
396            name: "test_tool".into(),
397            description: "A test tool".into(),
398            input_schema: json!({"type": "object"}),
399            title: None,
400            output_schema: None,
401            annotations: None,
402            category: None,
403            tags: vec![],
404            timeout_secs: None,
405        };
406        let json_str = serde_json::to_string(&def).unwrap();
407        // Optional fields should be omitted
408        assert!(!json_str.contains("title"));
409        assert!(!json_str.contains("tags"));
410        let back: ToolDefinition = serde_json::from_str(&json_str).unwrap();
411        assert_eq!(def, back);
412    }
413
414    #[test]
415    fn tool_definition_full() {
416        let def = ToolDefinition {
417            name: "read_file".into(),
418            description: "Read a file from the workspace".into(),
419            input_schema: json!({
420                "type": "object",
421                "properties": { "path": { "type": "string" } },
422                "required": ["path"]
423            }),
424            title: Some("Read File".into()),
425            output_schema: Some(json!({"type": "string"})),
426            annotations: Some(ToolAnnotations {
427                read_only: true,
428                idempotent: true,
429                ..Default::default()
430            }),
431            category: Some("filesystem".into()),
432            tags: vec!["fs".into(), "read".into()],
433            timeout_secs: Some(30),
434        };
435        let json_str = serde_json::to_string(&def).unwrap();
436        let back: ToolDefinition = serde_json::from_str(&json_str).unwrap();
437        assert_eq!(def, back);
438        assert!(json_str.contains("\"category\":\"filesystem\""));
439    }
440
441    // ── ToolContent tests ──
442
443    #[test]
444    fn tool_content_text_serde() {
445        let content = ToolContent::Text {
446            text: "hello".into(),
447        };
448        let json_str = serde_json::to_string(&content).unwrap();
449        assert!(json_str.contains("\"type\":\"text\""));
450        let back: ToolContent = serde_json::from_str(&json_str).unwrap();
451        assert_eq!(content, back);
452    }
453
454    #[test]
455    fn tool_content_json_serde() {
456        let content = ToolContent::Json {
457            value: json!({"key": "value"}),
458        };
459        let json_str = serde_json::to_string(&content).unwrap();
460        assert!(json_str.contains("\"type\":\"json\""));
461        let back: ToolContent = serde_json::from_str(&json_str).unwrap();
462        assert_eq!(content, back);
463    }
464
465    #[test]
466    fn tool_content_image_serde() {
467        let content = ToolContent::Image {
468            data: "base64data".into(),
469            mime_type: "image/png".into(),
470        };
471        let json_str = serde_json::to_string(&content).unwrap();
472        let back: ToolContent = serde_json::from_str(&json_str).unwrap();
473        assert_eq!(content, back);
474    }
475
476    // ── ToolResult tests ──
477
478    #[test]
479    fn tool_result_text_helper() {
480        let result = ToolResult::text("call-1", "echo", "hello world");
481        assert_eq!(result.call_id, "call-1");
482        assert_eq!(result.tool_name, "echo");
483        assert!(!result.is_error);
484        assert!(result.content.is_some());
485    }
486
487    #[test]
488    fn tool_result_json_helper() {
489        let result = ToolResult::json("call-2", "search", json!({"matches": 5}));
490        assert!(!result.is_error);
491        assert_eq!(result.output, json!({"matches": 5}));
492    }
493
494    #[test]
495    fn tool_result_error_helper() {
496        let result = ToolResult::error("call-3", "bash", "permission denied");
497        assert!(result.is_error);
498        assert_eq!(result.output["error"], "permission denied");
499    }
500
501    #[test]
502    fn tool_result_serde_roundtrip() {
503        let result = ToolResult {
504            call_id: "c1".into(),
505            tool_name: "test".into(),
506            output: json!({"ok": true}),
507            content: Some(vec![ToolContent::Text {
508                text: "success".into(),
509            }]),
510            is_error: false,
511        };
512        let json_str = serde_json::to_string(&result).unwrap();
513        let back: ToolResult = serde_json::from_str(&json_str).unwrap();
514        assert_eq!(result, back);
515    }
516
517    // ── ToolResult → ToolOutcome conversion ──
518
519    #[test]
520    fn tool_result_to_outcome_success() {
521        let result = ToolResult::json("c1", "test", json!({"data": 42}));
522        let outcome: ToolOutcome = ToolOutcome::from(&result);
523        assert!(matches!(outcome, ToolOutcome::Success { .. }));
524    }
525
526    #[test]
527    fn tool_result_to_outcome_failure() {
528        let result = ToolResult::error("c1", "test", "oops");
529        let outcome: ToolOutcome = ToolOutcome::from(&result);
530        match outcome {
531            ToolOutcome::Failure { error } => assert!(error.contains("oops")),
532            _ => panic!("expected failure"),
533        }
534    }
535
536    // ── Tool trait + Registry tests ──
537
538    struct EchoTool;
539
540    impl Tool for EchoTool {
541        fn definition(&self) -> ToolDefinition {
542            ToolDefinition {
543                name: "echo".into(),
544                description: "Echoes the input value".into(),
545                input_schema: json!({
546                    "type": "object",
547                    "properties": { "value": { "type": "string" } },
548                    "required": ["value"]
549                }),
550                title: None,
551                output_schema: None,
552                annotations: Some(ToolAnnotations {
553                    read_only: true,
554                    idempotent: true,
555                    ..Default::default()
556                }),
557                category: Some("test".into()),
558                tags: vec![],
559                timeout_secs: Some(10),
560            }
561        }
562
563        fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
564            let value = call.input.get("value").cloned().unwrap_or(json!(null));
565            Ok(ToolResult::json(&call.call_id, &call.tool_name, value))
566        }
567    }
568
569    struct FailTool;
570
571    impl Tool for FailTool {
572        fn definition(&self) -> ToolDefinition {
573            ToolDefinition {
574                name: "fail".into(),
575                description: "Always fails".into(),
576                input_schema: json!({"type": "object"}),
577                title: None,
578                output_schema: None,
579                annotations: None,
580                category: None,
581                tags: vec![],
582                timeout_secs: None,
583            }
584        }
585
586        fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
587            Err(ToolError::ExecutionFailed {
588                tool_name: call.tool_name.clone(),
589                message: "always fails".into(),
590            })
591        }
592    }
593
594    fn test_context() -> ToolContext {
595        ToolContext {
596            run_id: "run-1".into(),
597            session_id: "sess-1".into(),
598            iteration: 1,
599        }
600    }
601
602    #[test]
603    fn tool_trait_execute_success() {
604        let tool = EchoTool;
605        let call = ToolCall::new("echo", json!({"value": "hello"}), vec![]);
606        let result = tool.execute(&call, &test_context()).unwrap();
607        assert!(!result.is_error);
608        assert_eq!(result.output, json!("hello"));
609    }
610
611    #[test]
612    fn tool_trait_execute_error() {
613        let tool = FailTool;
614        let call = ToolCall::new("fail", json!({}), vec![]);
615        let err = tool.execute(&call, &test_context()).unwrap_err();
616        assert!(matches!(err, ToolError::ExecutionFailed { .. }));
617        assert!(err.to_string().contains("always fails"));
618    }
619
620    #[test]
621    fn registry_register_and_get() {
622        let mut reg = ToolRegistry::default();
623        assert!(reg.is_empty());
624
625        reg.register(EchoTool);
626        assert_eq!(reg.len(), 1);
627        assert!(!reg.is_empty());
628
629        let tool = reg.get("echo").expect("should find echo");
630        let def = tool.definition();
631        assert_eq!(def.name, "echo");
632    }
633
634    #[test]
635    fn registry_get_missing() {
636        let reg = ToolRegistry::default();
637        assert!(reg.get("nonexistent").is_none());
638    }
639
640    #[test]
641    fn registry_definitions() {
642        let mut reg = ToolRegistry::default();
643        reg.register(EchoTool);
644        reg.register(FailTool);
645
646        let defs = reg.definitions();
647        assert_eq!(defs.len(), 2);
648        let names: Vec<_> = defs.iter().map(|d| d.name.as_str()).collect();
649        assert!(names.contains(&"echo"));
650        assert!(names.contains(&"fail"));
651    }
652
653    #[test]
654    fn registry_names() {
655        let mut reg = ToolRegistry::default();
656        reg.register(EchoTool);
657        reg.register(FailTool);
658
659        let names = reg.names();
660        assert_eq!(names.len(), 2);
661        assert!(names.contains(&"echo".to_string()));
662        assert!(names.contains(&"fail".to_string()));
663    }
664
665    #[test]
666    fn registry_register_replaces_existing() {
667        let mut reg = ToolRegistry::default();
668        reg.register(EchoTool);
669        reg.register(EchoTool); // same name, should replace
670        assert_eq!(reg.len(), 1);
671    }
672
673    #[test]
674    fn registry_debug_format() {
675        let mut reg = ToolRegistry::default();
676        reg.register(EchoTool);
677        let debug = format!("{:?}", reg);
678        assert!(debug.contains("echo"));
679    }
680
681    // ── ToolError tests ──
682
683    #[test]
684    fn tool_error_display() {
685        let err = ToolError::NotFound {
686            tool_name: "ghost".into(),
687        };
688        assert_eq!(err.to_string(), "tool not found: ghost");
689
690        let err = ToolError::Timeout {
691            tool_name: "slow".into(),
692            timeout_secs: 30,
693        };
694        assert_eq!(err.to_string(), "[slow] timed out after 30s");
695    }
696}