1#![deny(missing_docs)]
2use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use thiserror::Error;
13
14#[non_exhaustive]
16#[derive(Debug, Error)]
17pub enum ToolError {
18 #[error("tool not found: {0}")]
20 NotFound(String),
21
22 #[error("execution failed: {0}")]
24 ExecutionFailed(String),
25
26 #[error("invalid input: {0}")]
28 InvalidInput(String),
29
30 #[error("{0}")]
32 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
33}
34
35#[non_exhaustive]
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
38pub enum ToolConcurrencyHint {
39 Shared,
41 #[default]
43 Exclusive,
44}
45
46pub trait ToolDynStreaming: Send + Sync + 'static + ToolDyn {
48 fn call_streaming<'a>(
50 &'a self,
51 input: serde_json::Value,
52 on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
53 ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>>;
54}
55pub trait ToolDyn: Send + Sync {
60 fn name(&self) -> &str;
62
63 fn description(&self) -> &str;
65
66 fn input_schema(&self) -> serde_json::Value;
68
69 fn call(
71 &self,
72 input: serde_json::Value,
73 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>;
74
75 fn maybe_streaming(&self) -> Option<&dyn ToolDynStreaming> {
78 None
79 }
80
81 fn concurrency_hint(&self) -> ToolConcurrencyHint {
85 ToolConcurrencyHint::Exclusive
86 }
87}
88
89pub struct AliasedTool {
94 alias: String,
95 inner: Arc<dyn ToolDyn>,
96}
97
98impl AliasedTool {
99 pub fn new(alias: impl Into<String>, inner: Arc<dyn ToolDyn>) -> Self {
101 Self {
102 alias: alias.into(),
103 inner,
104 }
105 }
106
107 pub fn inner(&self) -> &Arc<dyn ToolDyn> {
109 &self.inner
110 }
111}
112
113impl ToolDyn for AliasedTool {
114 fn name(&self) -> &str {
115 &self.alias
116 }
117
118 fn description(&self) -> &str {
119 self.inner.description()
120 }
121
122 fn input_schema(&self) -> serde_json::Value {
123 self.inner.input_schema()
124 }
125
126 fn call(
127 &self,
128 input: serde_json::Value,
129 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>> {
130 self.inner.call(input)
131 }
132
133 fn concurrency_hint(&self) -> ToolConcurrencyHint {
134 self.inner.concurrency_hint()
135 }
136}
137
138#[derive(Clone)]
143pub struct ToolRegistry {
144 tools: HashMap<String, Arc<dyn ToolDyn>>,
145}
146
147impl ToolRegistry {
148 pub fn new() -> Self {
150 Self {
151 tools: HashMap::new(),
152 }
153 }
154
155 pub fn register(&mut self, tool: Arc<dyn ToolDyn>) {
157 self.tools.insert(tool.name().to_string(), tool);
158 }
159
160 pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolDyn>> {
162 self.tools.get(name)
163 }
164
165 pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn ToolDyn>> {
167 self.tools.values()
168 }
169
170 pub fn len(&self) -> usize {
172 self.tools.len()
173 }
174
175 pub fn is_empty(&self) -> bool {
177 self.tools.is_empty()
178 }
179}
180
181impl Default for ToolRegistry {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use serde_json::json;
191
192 fn _assert_send_sync<T: Send + Sync>() {}
193
194 #[test]
195 fn tool_dyn_is_object_safe() {
196 _assert_send_sync::<Arc<dyn ToolDyn>>();
197 }
198
199 #[test]
200 fn tool_error_display() {
201 assert_eq!(
202 ToolError::NotFound("bash".into()).to_string(),
203 "tool not found: bash"
204 );
205 assert_eq!(
206 ToolError::ExecutionFailed("timeout".into()).to_string(),
207 "execution failed: timeout"
208 );
209 assert_eq!(
210 ToolError::InvalidInput("missing field".into()).to_string(),
211 "invalid input: missing field"
212 );
213 }
214
215 struct EchoTool;
216
217 impl ToolDyn for EchoTool {
218 fn name(&self) -> &str {
219 "echo"
220 }
221 fn description(&self) -> &str {
222 "Echoes input back"
223 }
224 fn input_schema(&self) -> serde_json::Value {
225 json!({"type": "object"})
226 }
227 fn call(
228 &self,
229 input: serde_json::Value,
230 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
231 {
232 Box::pin(async move { Ok(json!({"echoed": input})) })
233 }
234 }
235
236 struct FailTool;
237
238 impl ToolDyn for FailTool {
239 fn name(&self) -> &str {
240 "fail"
241 }
242 fn description(&self) -> &str {
243 "Always fails"
244 }
245 fn input_schema(&self) -> serde_json::Value {
246 json!({"type": "object"})
247 }
248 fn call(
249 &self,
250 _input: serde_json::Value,
251 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
252 {
253 Box::pin(async { Err(ToolError::ExecutionFailed("always fails".into())) })
254 }
255 }
256
257 #[test]
258 fn registry_add_and_get() {
259 let mut reg = ToolRegistry::new();
260 assert!(reg.is_empty());
261
262 reg.register(Arc::new(EchoTool));
263 assert_eq!(reg.len(), 1);
264 assert!(reg.get("echo").is_some());
265 assert!(reg.get("nonexistent").is_none());
266 }
267
268 #[test]
269 fn registry_iter() {
270 let mut reg = ToolRegistry::new();
271 reg.register(Arc::new(EchoTool));
272 reg.register(Arc::new(FailTool));
273
274 let names: Vec<&str> = reg.iter().map(|t| t.name()).collect();
275 assert!(names.contains(&"echo"));
276 assert!(names.contains(&"fail"));
277 }
278
279 #[tokio::test]
280 async fn registry_call_tool() {
281 let mut reg = ToolRegistry::new();
282 reg.register(Arc::new(EchoTool));
283
284 let tool = reg.get("echo").unwrap();
285 let result = tool.call(json!({"msg": "hello"})).await.unwrap();
286 assert_eq!(result, json!({"echoed": {"msg": "hello"}}));
287 }
288
289 #[tokio::test]
290 async fn aliased_tool_exposes_alias_name_and_delegates() {
291 let inner: Arc<dyn ToolDyn> = Arc::new(EchoTool);
292 let tool: Arc<dyn ToolDyn> = Arc::new(AliasedTool::new("echo_alias", Arc::clone(&inner)));
293
294 assert_eq!(tool.name(), "echo_alias");
295 assert_eq!(tool.description(), inner.description());
296
297 let result = tool.call(json!({"msg": "hi"})).await.unwrap();
298 assert_eq!(result, json!({"echoed": {"msg": "hi"}}));
299 }
300
301 #[tokio::test]
302 async fn registry_call_failing_tool() {
303 let mut reg = ToolRegistry::new();
304 reg.register(Arc::new(FailTool));
305
306 let tool = reg.get("fail").unwrap();
307 let result = tool.call(json!({})).await;
308 assert!(result.is_err());
309 }
310
311 #[test]
312 fn registry_overwrite() {
313 let mut reg = ToolRegistry::new();
314 reg.register(Arc::new(EchoTool));
315 assert_eq!(reg.len(), 1);
316
317 reg.register(Arc::new(EchoTool));
319 assert_eq!(reg.len(), 1);
320 }
321
322 struct StreamerTool;
323 impl ToolDyn for StreamerTool {
324 fn name(&self) -> &str {
325 "streamer"
326 }
327 fn description(&self) -> &str {
328 "Streams chunks"
329 }
330 fn input_schema(&self) -> serde_json::Value {
331 json!({"type":"object"})
332 }
333 fn call(
334 &self,
335 _input: serde_json::Value,
336 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
337 {
338 Box::pin(async { Ok(serde_json::json!({"status":"done"})) })
339 }
340 fn maybe_streaming(&self) -> Option<&dyn ToolDynStreaming> {
341 Some(self)
342 }
343 }
344 impl ToolDynStreaming for StreamerTool {
345 fn call_streaming<'a>(
346 &'a self,
347 _input: serde_json::Value,
348 on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
349 ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>> {
350 Box::pin(async move {
351 on_chunk("one");
352 on_chunk("two");
353 on_chunk("three");
354 Ok(())
355 })
356 }
357 }
358
359 #[tokio::test]
360 async fn streaming_tool_emits_chunks_and_completes() {
361 use std::sync::{
362 Arc as StdArc, Mutex,
363 atomic::{AtomicUsize, Ordering},
364 };
365 let count = StdArc::new(AtomicUsize::new(0));
366 let seen: StdArc<Mutex<Vec<String>>> = StdArc::new(Mutex::new(vec![]));
367 let c2 = count.clone();
368 let s2 = seen.clone();
369 let tool = StreamerTool;
370 let on_chunk = Box::new(move |c: &str| {
371 c2.fetch_add(1, Ordering::SeqCst);
372 s2.lock().unwrap().push(c.to_string());
373 });
374 let res = tool.call_streaming(serde_json::json!({}), on_chunk).await;
375 assert!(res.is_ok());
376 assert_eq!(count.load(Ordering::SeqCst), 3);
377 let got = seen.lock().unwrap().clone();
378 assert_eq!(got, vec!["one", "two", "three"]);
379 }
380}