Skip to main content

cortexai_core/
tool.rs

1//! Tool/function definitions for agent capabilities
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::errors::ToolError;
9use crate::types::AgentId;
10
11/// Tool/function schema for LLM function calling
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolSchema {
14    /// Tool name
15    pub name: String,
16
17    /// Tool description
18    pub description: String,
19
20    /// Input parameters schema (JSON Schema)
21    pub parameters: serde_json::Value,
22
23    /// Whether tool is dangerous and requires confirmation
24    pub dangerous: bool,
25
26    /// Tool metadata
27    pub metadata: HashMap<String, serde_json::Value>,
28
29    /// Scopes required to execute this tool (e.g. "fs:read", "network:external")
30    /// Empty means no scope restrictions — tool is always allowed.
31    #[serde(default)]
32    pub required_scopes: Vec<String>,
33}
34
35impl ToolSchema {
36    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
37        Self {
38            name: name.into(),
39            description: description.into(),
40            parameters: serde_json::json!({
41                "type": "object",
42                "properties": {},
43                "required": []
44            }),
45            dangerous: false,
46            metadata: HashMap::new(),
47            required_scopes: Vec::new(),
48        }
49    }
50
51    pub fn with_required_scopes(mut self, scopes: Vec<String>) -> Self {
52        self.required_scopes = scopes;
53        self
54    }
55
56    pub fn with_parameters(mut self, parameters: serde_json::Value) -> Self {
57        self.parameters = parameters;
58        self
59    }
60
61    pub fn with_dangerous(mut self, dangerous: bool) -> Self {
62        self.dangerous = dangerous;
63        self
64    }
65
66    pub fn add_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
67        self.metadata.insert(key.into(), value);
68        self
69    }
70}
71
72/// Execution context for tool calls
73#[derive(Debug, Clone)]
74pub struct ExecutionContext {
75    /// Agent making the tool call
76    pub agent_id: AgentId,
77
78    /// Additional context data
79    pub data: HashMap<String, serde_json::Value>,
80}
81
82impl ExecutionContext {
83    pub fn new(agent_id: AgentId) -> Self {
84        Self {
85            agent_id,
86            data: HashMap::new(),
87        }
88    }
89
90    pub fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
91        self.data.insert(key.into(), value);
92        self
93    }
94
95    pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
96        self.data.get(key)
97    }
98}
99
100/// Tool/function trait
101#[async_trait]
102pub trait Tool: Send + Sync {
103    /// Get tool schema
104    fn schema(&self) -> ToolSchema;
105
106    /// Execute the tool
107    async fn execute(
108        &self,
109        context: &ExecutionContext,
110        arguments: serde_json::Value,
111    ) -> Result<serde_json::Value, ToolError>;
112
113    /// Validate arguments before execution
114    fn validate(&self, _arguments: &serde_json::Value) -> Result<(), ToolError> {
115        Ok(())
116    }
117}
118
119/// Tool registry
120#[derive(Clone)]
121pub struct ToolRegistry {
122    tools: Arc<HashMap<String, Arc<dyn Tool>>>,
123}
124
125impl ToolRegistry {
126    pub fn new() -> Self {
127        Self {
128            tools: Arc::new(HashMap::new()),
129        }
130    }
131
132    /// Register a tool
133    pub fn register(&mut self, tool: Arc<dyn Tool>) {
134        let schema = tool.schema();
135        Arc::get_mut(&mut self.tools)
136            .expect("Cannot register tools after cloning")
137            .insert(schema.name.clone(), tool);
138    }
139
140    /// Get a tool by name
141    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
142        self.tools.get(name).cloned()
143    }
144
145    /// List all tool schemas
146    pub fn list_schemas(&self) -> Vec<ToolSchema> {
147        self.tools.values().map(|tool| tool.schema()).collect()
148    }
149
150    /// Check if tool exists
151    pub fn has(&self, name: &str) -> bool {
152        self.tools.contains_key(name)
153    }
154
155    /// Get tool count
156    pub fn len(&self) -> usize {
157        self.tools.len()
158    }
159
160    /// Check if registry is empty
161    pub fn is_empty(&self) -> bool {
162        self.tools.is_empty()
163    }
164}
165
166impl Default for ToolRegistry {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_tool_schema_required_scopes_default_empty() {
178        let schema = ToolSchema::new("my_tool", "desc");
179        assert!(schema.required_scopes.is_empty());
180    }
181
182    #[test]
183    fn test_tool_schema_with_required_scopes() {
184        let schema = ToolSchema::new("my_tool", "desc")
185            .with_required_scopes(vec!["fs:read".to_string(), "network:external".to_string()]);
186        assert_eq!(schema.required_scopes.len(), 2);
187        assert!(schema.required_scopes.contains(&"fs:read".to_string()));
188        assert!(schema.required_scopes.contains(&"network:external".to_string()));
189    }
190
191    #[test]
192    fn test_tool_schema_backward_compat_struct_literal() {
193        // Verify existing struct-literal construction still compiles with default scopes
194        let schema = ToolSchema {
195            name: "tool".to_string(),
196            description: "desc".to_string(),
197            parameters: serde_json::json!({}),
198            dangerous: false,
199            metadata: HashMap::new(),
200            required_scopes: vec![],
201        };
202        assert!(schema.required_scopes.is_empty());
203    }
204}
205
206/// Macro to easily define tools
207#[macro_export]
208macro_rules! define_tool {
209    (
210        $name:ident,
211        schema: $schema:expr,
212        execute: |$ctx:ident, $args:ident| $body:expr
213    ) => {
214        pub struct $name {
215            schema: $crate::tool::ToolSchema,
216        }
217
218        impl $name {
219            pub fn new() -> Self {
220                Self { schema: $schema }
221            }
222        }
223
224        #[async_trait::async_trait]
225        impl $crate::tool::Tool for $name {
226            fn schema(&self) -> $crate::tool::ToolSchema {
227                self.schema.clone()
228            }
229
230            async fn execute(
231                &self,
232                $ctx: &$crate::tool::ExecutionContext,
233                $args: serde_json::Value,
234            ) -> Result<serde_json::Value, $crate::errors::ToolError> {
235                $body.await
236            }
237        }
238    };
239}