1pub mod claude_code;
2
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{Arc, RwLock};
7
8use futures_util::future::BoxFuture;
9use serde_json::Value;
10
11use crate::error::{SchemaError, ToolError};
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub enum ToolOutcome {
15 Text(String),
16 Done(String),
17}
18
19type DynDependency = Arc<dyn Any + Send + Sync>;
20type ToolHandler = dyn Fn(Value, &DependencyMap) -> BoxFuture<'static, Result<ToolOutcome, ToolError>>
21 + Send
22 + Sync;
23
24#[derive(Clone, Default, Debug)]
25pub struct DependencyMap {
26 typed: Arc<RwLock<HashMap<TypeId, DynDependency>>>,
27 named: Arc<RwLock<HashMap<String, DynDependency>>>,
28}
29
30impl DependencyMap {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn insert<T>(&self, value: T)
36 where
37 T: Send + Sync + 'static,
38 {
39 let mut typed = self
40 .typed
41 .write()
42 .expect("dependency typed map lock poisoned");
43 typed.insert(TypeId::of::<T>(), Arc::new(value));
44 }
45
46 pub fn get<T>(&self) -> Option<Arc<T>>
47 where
48 T: Send + Sync + 'static,
49 {
50 let typed = self.typed.read().ok()?;
51 let value = typed.get(&TypeId::of::<T>())?.clone();
52 Arc::downcast::<T>(value).ok()
53 }
54
55 pub fn insert_named<T>(&self, key: impl Into<String>, value: T)
56 where
57 T: Send + Sync + 'static,
58 {
59 let mut named = self
60 .named
61 .write()
62 .expect("dependency named map lock poisoned");
63 named.insert(key.into(), Arc::new(value));
64 }
65
66 pub fn get_named<T>(&self, key: &str) -> Option<Arc<T>>
67 where
68 T: Send + Sync + 'static,
69 {
70 let named = self.named.read().ok()?;
71 let value = named.get(key)?.clone();
72 Arc::downcast::<T>(value).ok()
73 }
74
75 pub fn merged_with(&self, overrides: &DependencyMap) -> DependencyMap {
76 let merged = DependencyMap::new();
77
78 {
79 let mut dst_typed = merged
80 .typed
81 .write()
82 .expect("dependency typed map lock poisoned");
83 if let Ok(src_typed) = self.typed.read() {
84 for (key, value) in &*src_typed {
85 dst_typed.insert(*key, value.clone());
86 }
87 }
88 if let Ok(src_typed_override) = overrides.typed.read() {
89 for (key, value) in &*src_typed_override {
90 dst_typed.insert(*key, value.clone());
91 }
92 }
93 }
94
95 {
96 let mut dst_named = merged
97 .named
98 .write()
99 .expect("dependency named map lock poisoned");
100 if let Ok(src_named) = self.named.read() {
101 for (key, value) in &*src_named {
102 dst_named.insert(key.clone(), value.clone());
103 }
104 }
105 if let Ok(src_named_override) = overrides.named.read() {
106 for (key, value) in &*src_named_override {
107 dst_named.insert(key.clone(), value.clone());
108 }
109 }
110 }
111
112 merged
113 }
114}
115
116#[derive(Clone)]
117pub struct ToolSpec {
118 name: String,
119 description: String,
120 json_schema: Value,
121 handler: Arc<ToolHandler>,
122}
123
124impl std::fmt::Debug for ToolSpec {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("ToolSpec")
127 .field("name", &self.name)
128 .field("description", &self.description)
129 .field("json_schema", &self.json_schema)
130 .finish()
131 }
132}
133
134impl ToolSpec {
135 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
136 Self {
137 name: name.into(),
138 description: description.into(),
139 json_schema: serde_json::json!({
140 "type": "object",
141 "properties": {},
142 "required": [],
143 "additionalProperties": true,
144 }),
145 handler: Arc::new(|_args, _deps| {
146 Box::pin(async {
147 Err(ToolError::Execution(
148 "tool handler not configured".to_string(),
149 ))
150 })
151 }),
152 }
153 }
154
155 pub fn with_schema(mut self, schema: Value) -> Result<Self, SchemaError> {
156 validate_schema(&schema)?;
157 self.json_schema = schema;
158 Ok(self)
159 }
160
161 pub fn with_handler<F, Fut>(mut self, handler: F) -> Self
162 where
163 F: Fn(Value, &DependencyMap) -> Fut + Send + Sync + 'static,
164 Fut: Future<Output = Result<ToolOutcome, ToolError>> + Send + 'static,
165 {
166 self.handler = Arc::new(move |args, deps| Box::pin(handler(args, deps)));
167 self
168 }
169
170 pub fn name(&self) -> &str {
171 &self.name
172 }
173
174 pub fn description(&self) -> &str {
175 &self.description
176 }
177
178 pub fn json_schema(&self) -> &Value {
179 &self.json_schema
180 }
181
182 pub async fn execute(
183 &self,
184 args: Value,
185 dependencies: &DependencyMap,
186 ) -> Result<ToolOutcome, ToolError> {
187 validate_arguments(self.name(), &self.json_schema, &args)?;
188 (self.handler)(args, dependencies).await
189 }
190}
191
192fn validate_schema(schema: &Value) -> Result<(), SchemaError> {
193 let schema_obj = schema.as_object().ok_or(SchemaError::SchemaNotObject)?;
194
195 let root_type = schema_obj
196 .get("type")
197 .and_then(Value::as_str)
198 .ok_or(SchemaError::RootTypeMustBeObject)?;
199
200 if root_type != "object" {
201 return Err(SchemaError::RootTypeMustBeObject);
202 }
203
204 if let Some(required) = schema_obj.get("required") {
205 let required_arr = required.as_array().ok_or(SchemaError::InvalidRequired)?;
206 for item in required_arr {
207 if !item.is_string() {
208 return Err(SchemaError::InvalidRequired);
209 }
210 }
211 }
212
213 Ok(())
214}
215
216fn validate_arguments(tool_name: &str, schema: &Value, args: &Value) -> Result<(), ToolError> {
217 let args_obj = args
218 .as_object()
219 .ok_or_else(|| ToolError::InvalidArguments {
220 tool: tool_name.to_string(),
221 message: "arguments must be a JSON object".to_string(),
222 })?;
223
224 let schema_obj = schema
225 .as_object()
226 .ok_or_else(|| ToolError::InvalidArguments {
227 tool: tool_name.to_string(),
228 message: "tool schema must be a JSON object".to_string(),
229 })?;
230
231 if let Some(required) = schema_obj.get("required").and_then(Value::as_array) {
232 for field in required {
233 let Some(field_name) = field.as_str() else {
234 continue;
235 };
236 if !args_obj.contains_key(field_name) {
237 return Err(ToolError::InvalidArguments {
238 tool: tool_name.to_string(),
239 message: format!("missing required field: {field_name}"),
240 });
241 }
242 }
243 }
244
245 let properties = schema_obj
246 .get("properties")
247 .and_then(Value::as_object)
248 .cloned()
249 .unwrap_or_default();
250
251 if schema_obj
252 .get("additionalProperties")
253 .and_then(Value::as_bool)
254 == Some(false)
255 {
256 for key in args_obj.keys() {
257 if !properties.contains_key(key) {
258 return Err(ToolError::InvalidArguments {
259 tool: tool_name.to_string(),
260 message: format!("unknown field: {key}"),
261 });
262 }
263 }
264 }
265
266 for (key, value) in args_obj {
267 if let Some(field_schema) = properties.get(key) {
268 if let Some(type_name) = field_schema.get("type").and_then(Value::as_str) {
269 if !value_matches_type(value, type_name) {
270 return Err(ToolError::InvalidArguments {
271 tool: tool_name.to_string(),
272 message: format!("field '{key}' must be of type {type_name}"),
273 });
274 }
275 }
276 }
277 }
278
279 Ok(())
280}
281
282fn value_matches_type(value: &Value, type_name: &str) -> bool {
283 match type_name {
284 "string" => value.is_string(),
285 "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
286 "number" => value.as_f64().is_some(),
287 "boolean" => value.is_boolean(),
288 "object" => value.is_object(),
289 "array" => value.is_array(),
290 "null" => value.is_null(),
291 _ => true,
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use serde_json::json;
298
299 use super::*;
300
301 #[test]
302 fn schema_validation_rejects_non_object_root() {
303 let result = ToolSpec::new("bad", "bad").with_schema(json!({"type": "string"}));
304 assert!(result.is_err());
305 }
306
307 #[tokio::test]
308 async fn dependency_overrides_win() {
309 let base = DependencyMap::new();
310 base.insert::<u32>(1);
311
312 let overrides = DependencyMap::new();
313 overrides.insert::<u32>(9);
314
315 let merged = base.merged_with(&overrides);
316 assert_eq!(merged.get::<u32>().as_deref(), Some(&9));
317
318 let tool = ToolSpec::new("read", "read dep")
319 .with_schema(json!({
320 "type": "object",
321 "properties": {},
322 "required": [],
323 "additionalProperties": false
324 }))
325 .expect("schema should be valid")
326 .with_handler(|_args, deps| {
327 let value = deps
328 .get::<u32>()
329 .ok_or(ToolError::MissingDependency("u32"))
330 .map(|v| *v)
331 .unwrap_or(0);
332 async move { Ok(ToolOutcome::Text(value.to_string())) }
333 });
334
335 let outcome = tool
336 .execute(json!({}), &merged)
337 .await
338 .expect("tool executes");
339 assert_eq!(outcome, ToolOutcome::Text("9".to_string()));
340 }
341
342 #[tokio::test]
343 async fn argument_validation_reports_missing_required() {
344 let tool = ToolSpec::new("req", "required")
345 .with_schema(json!({
346 "type": "object",
347 "properties": {"value": {"type": "string"}},
348 "required": ["value"],
349 "additionalProperties": false
350 }))
351 .expect("schema valid")
352 .with_handler(|_args, _deps| async move { Ok(ToolOutcome::Text("ok".into())) });
353
354 let err = tool
355 .execute(json!({}), &DependencyMap::new())
356 .await
357 .expect_err("should fail");
358
359 let message = err.to_string();
360 assert!(message.contains("missing required field"));
361 }
362}