Skip to main content

agent_sdk_rs/tools/
mod.rs

1pub mod claude_code;
2
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{Arc, RwLock};
7
8use futures_util::future::BoxFuture;
9use serde_json::Value;
10
11use crate::error::{SchemaError, ToolError};
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub enum ToolOutcome {
15    Text(String),
16    Done(String),
17}
18
19type DynDependency = Arc<dyn Any + Send + Sync>;
20type ToolHandler = dyn Fn(Value, &DependencyMap) -> BoxFuture<'static, Result<ToolOutcome, ToolError>>
21    + Send
22    + Sync;
23
24#[derive(Clone, Default, Debug)]
25pub struct DependencyMap {
26    typed: Arc<RwLock<HashMap<TypeId, DynDependency>>>,
27    named: Arc<RwLock<HashMap<String, DynDependency>>>,
28}
29
30impl DependencyMap {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn insert<T>(&self, value: T)
36    where
37        T: Send + Sync + 'static,
38    {
39        let mut typed = self
40            .typed
41            .write()
42            .expect("dependency typed map lock poisoned");
43        typed.insert(TypeId::of::<T>(), Arc::new(value));
44    }
45
46    pub fn get<T>(&self) -> Option<Arc<T>>
47    where
48        T: Send + Sync + 'static,
49    {
50        let typed = self.typed.read().ok()?;
51        let value = typed.get(&TypeId::of::<T>())?.clone();
52        Arc::downcast::<T>(value).ok()
53    }
54
55    pub fn insert_named<T>(&self, key: impl Into<String>, value: T)
56    where
57        T: Send + Sync + 'static,
58    {
59        let mut named = self
60            .named
61            .write()
62            .expect("dependency named map lock poisoned");
63        named.insert(key.into(), Arc::new(value));
64    }
65
66    pub fn get_named<T>(&self, key: &str) -> Option<Arc<T>>
67    where
68        T: Send + Sync + 'static,
69    {
70        let named = self.named.read().ok()?;
71        let value = named.get(key)?.clone();
72        Arc::downcast::<T>(value).ok()
73    }
74
75    pub fn merged_with(&self, overrides: &DependencyMap) -> DependencyMap {
76        let merged = DependencyMap::new();
77
78        {
79            let mut dst_typed = merged
80                .typed
81                .write()
82                .expect("dependency typed map lock poisoned");
83            if let Ok(src_typed) = self.typed.read() {
84                for (key, value) in &*src_typed {
85                    dst_typed.insert(*key, value.clone());
86                }
87            }
88            if let Ok(src_typed_override) = overrides.typed.read() {
89                for (key, value) in &*src_typed_override {
90                    dst_typed.insert(*key, value.clone());
91                }
92            }
93        }
94
95        {
96            let mut dst_named = merged
97                .named
98                .write()
99                .expect("dependency named map lock poisoned");
100            if let Ok(src_named) = self.named.read() {
101                for (key, value) in &*src_named {
102                    dst_named.insert(key.clone(), value.clone());
103                }
104            }
105            if let Ok(src_named_override) = overrides.named.read() {
106                for (key, value) in &*src_named_override {
107                    dst_named.insert(key.clone(), value.clone());
108                }
109            }
110        }
111
112        merged
113    }
114}
115
116#[derive(Clone)]
117pub struct ToolSpec {
118    name: String,
119    description: String,
120    json_schema: Value,
121    handler: Arc<ToolHandler>,
122}
123
124impl std::fmt::Debug for ToolSpec {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("ToolSpec")
127            .field("name", &self.name)
128            .field("description", &self.description)
129            .field("json_schema", &self.json_schema)
130            .finish()
131    }
132}
133
134impl ToolSpec {
135    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
136        Self {
137            name: name.into(),
138            description: description.into(),
139            json_schema: serde_json::json!({
140                "type": "object",
141                "properties": {},
142                "required": [],
143                "additionalProperties": true,
144            }),
145            handler: Arc::new(|_args, _deps| {
146                Box::pin(async {
147                    Err(ToolError::Execution(
148                        "tool handler not configured".to_string(),
149                    ))
150                })
151            }),
152        }
153    }
154
155    pub fn with_schema(mut self, schema: Value) -> Result<Self, SchemaError> {
156        validate_schema(&schema)?;
157        self.json_schema = schema;
158        Ok(self)
159    }
160
161    pub fn with_handler<F, Fut>(mut self, handler: F) -> Self
162    where
163        F: Fn(Value, &DependencyMap) -> Fut + Send + Sync + 'static,
164        Fut: Future<Output = Result<ToolOutcome, ToolError>> + Send + 'static,
165    {
166        self.handler = Arc::new(move |args, deps| Box::pin(handler(args, deps)));
167        self
168    }
169
170    pub fn name(&self) -> &str {
171        &self.name
172    }
173
174    pub fn description(&self) -> &str {
175        &self.description
176    }
177
178    pub fn json_schema(&self) -> &Value {
179        &self.json_schema
180    }
181
182    pub async fn execute(
183        &self,
184        args: Value,
185        dependencies: &DependencyMap,
186    ) -> Result<ToolOutcome, ToolError> {
187        validate_arguments(self.name(), &self.json_schema, &args)?;
188        (self.handler)(args, dependencies).await
189    }
190}
191
192fn validate_schema(schema: &Value) -> Result<(), SchemaError> {
193    let schema_obj = schema.as_object().ok_or(SchemaError::SchemaNotObject)?;
194
195    let root_type = schema_obj
196        .get("type")
197        .and_then(Value::as_str)
198        .ok_or(SchemaError::RootTypeMustBeObject)?;
199
200    if root_type != "object" {
201        return Err(SchemaError::RootTypeMustBeObject);
202    }
203
204    if let Some(required) = schema_obj.get("required") {
205        let required_arr = required.as_array().ok_or(SchemaError::InvalidRequired)?;
206        for item in required_arr {
207            if !item.is_string() {
208                return Err(SchemaError::InvalidRequired);
209            }
210        }
211    }
212
213    Ok(())
214}
215
216fn validate_arguments(tool_name: &str, schema: &Value, args: &Value) -> Result<(), ToolError> {
217    let args_obj = args
218        .as_object()
219        .ok_or_else(|| ToolError::InvalidArguments {
220            tool: tool_name.to_string(),
221            message: "arguments must be a JSON object".to_string(),
222        })?;
223
224    let schema_obj = schema
225        .as_object()
226        .ok_or_else(|| ToolError::InvalidArguments {
227            tool: tool_name.to_string(),
228            message: "tool schema must be a JSON object".to_string(),
229        })?;
230
231    if let Some(required) = schema_obj.get("required").and_then(Value::as_array) {
232        for field in required {
233            let Some(field_name) = field.as_str() else {
234                continue;
235            };
236            if !args_obj.contains_key(field_name) {
237                return Err(ToolError::InvalidArguments {
238                    tool: tool_name.to_string(),
239                    message: format!("missing required field: {field_name}"),
240                });
241            }
242        }
243    }
244
245    let properties = schema_obj
246        .get("properties")
247        .and_then(Value::as_object)
248        .cloned()
249        .unwrap_or_default();
250
251    if schema_obj
252        .get("additionalProperties")
253        .and_then(Value::as_bool)
254        == Some(false)
255    {
256        for key in args_obj.keys() {
257            if !properties.contains_key(key) {
258                return Err(ToolError::InvalidArguments {
259                    tool: tool_name.to_string(),
260                    message: format!("unknown field: {key}"),
261                });
262            }
263        }
264    }
265
266    for (key, value) in args_obj {
267        if let Some(field_schema) = properties.get(key) {
268            if let Some(type_name) = field_schema.get("type").and_then(Value::as_str) {
269                if !value_matches_type(value, type_name) {
270                    return Err(ToolError::InvalidArguments {
271                        tool: tool_name.to_string(),
272                        message: format!("field '{key}' must be of type {type_name}"),
273                    });
274                }
275            }
276        }
277    }
278
279    Ok(())
280}
281
282fn value_matches_type(value: &Value, type_name: &str) -> bool {
283    match type_name {
284        "string" => value.is_string(),
285        "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
286        "number" => value.as_f64().is_some(),
287        "boolean" => value.is_boolean(),
288        "object" => value.is_object(),
289        "array" => value.is_array(),
290        "null" => value.is_null(),
291        _ => true,
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use serde_json::json;
298
299    use super::*;
300
301    #[test]
302    fn schema_validation_rejects_non_object_root() {
303        let result = ToolSpec::new("bad", "bad").with_schema(json!({"type": "string"}));
304        assert!(result.is_err());
305    }
306
307    #[tokio::test]
308    async fn dependency_overrides_win() {
309        let base = DependencyMap::new();
310        base.insert::<u32>(1);
311
312        let overrides = DependencyMap::new();
313        overrides.insert::<u32>(9);
314
315        let merged = base.merged_with(&overrides);
316        assert_eq!(merged.get::<u32>().as_deref(), Some(&9));
317
318        let tool = ToolSpec::new("read", "read dep")
319            .with_schema(json!({
320                "type": "object",
321                "properties": {},
322                "required": [],
323                "additionalProperties": false
324            }))
325            .expect("schema should be valid")
326            .with_handler(|_args, deps| {
327                let value = deps
328                    .get::<u32>()
329                    .ok_or(ToolError::MissingDependency("u32"))
330                    .map(|v| *v)
331                    .unwrap_or(0);
332                async move { Ok(ToolOutcome::Text(value.to_string())) }
333            });
334
335        let outcome = tool
336            .execute(json!({}), &merged)
337            .await
338            .expect("tool executes");
339        assert_eq!(outcome, ToolOutcome::Text("9".to_string()));
340    }
341
342    #[tokio::test]
343    async fn argument_validation_reports_missing_required() {
344        let tool = ToolSpec::new("req", "required")
345            .with_schema(json!({
346                "type": "object",
347                "properties": {"value": {"type": "string"}},
348                "required": ["value"],
349                "additionalProperties": false
350            }))
351            .expect("schema valid")
352            .with_handler(|_args, _deps| async move { Ok(ToolOutcome::Text("ok".into())) });
353
354        let err = tool
355            .execute(json!({}), &DependencyMap::new())
356            .await
357            .expect_err("should fail");
358
359        let message = err.to_string();
360        assert!(message.contains("missing required field"));
361    }
362}