1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::errors::ToolError;
9use crate::types::AgentId;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolSchema {
14 pub name: String,
16
17 pub description: String,
19
20 pub parameters: serde_json::Value,
22
23 pub dangerous: bool,
25
26 pub metadata: HashMap<String, serde_json::Value>,
28
29 #[serde(default)]
32 pub required_scopes: Vec<String>,
33}
34
35impl ToolSchema {
36 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
37 Self {
38 name: name.into(),
39 description: description.into(),
40 parameters: serde_json::json!({
41 "type": "object",
42 "properties": {},
43 "required": []
44 }),
45 dangerous: false,
46 metadata: HashMap::new(),
47 required_scopes: Vec::new(),
48 }
49 }
50
51 pub fn with_required_scopes(mut self, scopes: Vec<String>) -> Self {
52 self.required_scopes = scopes;
53 self
54 }
55
56 pub fn with_parameters(mut self, parameters: serde_json::Value) -> Self {
57 self.parameters = parameters;
58 self
59 }
60
61 pub fn with_dangerous(mut self, dangerous: bool) -> Self {
62 self.dangerous = dangerous;
63 self
64 }
65
66 pub fn add_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
67 self.metadata.insert(key.into(), value);
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct ExecutionContext {
75 pub agent_id: AgentId,
77
78 pub data: HashMap<String, serde_json::Value>,
80}
81
82impl ExecutionContext {
83 pub fn new(agent_id: AgentId) -> Self {
84 Self {
85 agent_id,
86 data: HashMap::new(),
87 }
88 }
89
90 pub fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
91 self.data.insert(key.into(), value);
92 self
93 }
94
95 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
96 self.data.get(key)
97 }
98}
99
100#[async_trait]
102pub trait Tool: Send + Sync {
103 fn schema(&self) -> ToolSchema;
105
106 async fn execute(
108 &self,
109 context: &ExecutionContext,
110 arguments: serde_json::Value,
111 ) -> Result<serde_json::Value, ToolError>;
112
113 fn validate(&self, _arguments: &serde_json::Value) -> Result<(), ToolError> {
115 Ok(())
116 }
117}
118
119#[derive(Clone)]
121pub struct ToolRegistry {
122 tools: Arc<HashMap<String, Arc<dyn Tool>>>,
123}
124
125impl ToolRegistry {
126 pub fn new() -> Self {
127 Self {
128 tools: Arc::new(HashMap::new()),
129 }
130 }
131
132 pub fn register(&mut self, tool: Arc<dyn Tool>) {
134 let schema = tool.schema();
135 Arc::get_mut(&mut self.tools)
136 .expect("Cannot register tools after cloning")
137 .insert(schema.name.clone(), tool);
138 }
139
140 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
142 self.tools.get(name).cloned()
143 }
144
145 pub fn list_schemas(&self) -> Vec<ToolSchema> {
147 self.tools.values().map(|tool| tool.schema()).collect()
148 }
149
150 pub fn has(&self, name: &str) -> bool {
152 self.tools.contains_key(name)
153 }
154
155 pub fn len(&self) -> usize {
157 self.tools.len()
158 }
159
160 pub fn is_empty(&self) -> bool {
162 self.tools.is_empty()
163 }
164}
165
166impl Default for ToolRegistry {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_tool_schema_required_scopes_default_empty() {
178 let schema = ToolSchema::new("my_tool", "desc");
179 assert!(schema.required_scopes.is_empty());
180 }
181
182 #[test]
183 fn test_tool_schema_with_required_scopes() {
184 let schema = ToolSchema::new("my_tool", "desc")
185 .with_required_scopes(vec!["fs:read".to_string(), "network:external".to_string()]);
186 assert_eq!(schema.required_scopes.len(), 2);
187 assert!(schema.required_scopes.contains(&"fs:read".to_string()));
188 assert!(schema.required_scopes.contains(&"network:external".to_string()));
189 }
190
191 #[test]
192 fn test_tool_schema_backward_compat_struct_literal() {
193 let schema = ToolSchema {
195 name: "tool".to_string(),
196 description: "desc".to_string(),
197 parameters: serde_json::json!({}),
198 dangerous: false,
199 metadata: HashMap::new(),
200 required_scopes: vec![],
201 };
202 assert!(schema.required_scopes.is_empty());
203 }
204}
205
206#[macro_export]
208macro_rules! define_tool {
209 (
210 $name:ident,
211 schema: $schema:expr,
212 execute: |$ctx:ident, $args:ident| $body:expr
213 ) => {
214 pub struct $name {
215 schema: $crate::tool::ToolSchema,
216 }
217
218 impl $name {
219 pub fn new() -> Self {
220 Self { schema: $schema }
221 }
222 }
223
224 #[async_trait::async_trait]
225 impl $crate::tool::Tool for $name {
226 fn schema(&self) -> $crate::tool::ToolSchema {
227 self.schema.clone()
228 }
229
230 async fn execute(
231 &self,
232 $ctx: &$crate::tool::ExecutionContext,
233 $args: serde_json::Value,
234 ) -> Result<serde_json::Value, $crate::errors::ToolError> {
235 $body.await
236 }
237 }
238 };
239}