1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Duration;
8
9use rand::Rng;
10
11use super::ToolHandler;
12use crate::chat::{ToolCall, ToolResult};
13use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
14use crate::intercept::{InterceptorStack, Operation};
15use crate::provider::{ToolDefinition, ToolRetryConfig};
16
17pub struct ToolRegistry<Ctx = ()>
47where
48 Ctx: Send + Sync + 'static,
49{
50 pub(crate) handlers: HashMap<String, Arc<dyn ToolHandler<Ctx>>>,
51 interceptors: InterceptorStack<ToolExec<Ctx>>,
52}
53
54impl<Ctx> Default for ToolRegistry<Ctx>
55where
56 Ctx: Send + Sync + 'static,
57{
58 fn default() -> Self {
59 Self {
60 handlers: HashMap::new(),
61 interceptors: InterceptorStack::new(),
62 }
63 }
64}
65
66impl<Ctx> Clone for ToolRegistry<Ctx>
67where
68 Ctx: Send + Sync + 'static,
69{
70 fn clone(&self) -> Self {
75 Self {
76 handlers: self.handlers.clone(),
77 interceptors: self.interceptors.clone(),
78 }
79 }
80}
81
82impl<Ctx> std::fmt::Debug for ToolRegistry<Ctx>
83where
84 Ctx: Send + Sync + 'static,
85{
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("ToolRegistry")
88 .field("tools", &self.handlers.keys().collect::<Vec<_>>())
89 .field("interceptors", &self.interceptors.len())
90 .finish()
91 }
92}
93
94impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn register(&mut self, handler: impl ToolHandler<Ctx> + 'static) -> &mut Self {
104 let name = handler.definition().name.clone();
105 self.handlers.insert(name, Arc::new(handler));
106 self
107 }
108
109 pub fn register_shared(&mut self, handler: Arc<dyn ToolHandler<Ctx>>) -> &mut Self {
111 let name = handler.definition().name.clone();
112 self.handlers.insert(name, handler);
113 self
114 }
115
116 pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolHandler<Ctx>>> {
118 self.handlers.get(name)
119 }
120
121 pub fn contains(&self, name: &str) -> bool {
123 self.handlers.contains_key(name)
124 }
125
126 pub fn definitions(&self) -> Vec<ToolDefinition> {
131 self.handlers.values().map(|h| h.definition()).collect()
132 }
133
134 pub fn len(&self) -> usize {
136 self.handlers.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.handlers.is_empty()
142 }
143
144 #[must_use]
161 pub fn without<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
162 use std::collections::HashSet;
163 let exclude: HashSet<&str> = names.into_iter().collect();
164 let mut new = Self {
165 handlers: HashMap::new(),
166 interceptors: self.interceptors.clone(),
167 };
168 for (name, handler) in &self.handlers {
169 if !exclude.contains(name.as_str()) {
170 new.handlers.insert(name.clone(), Arc::clone(handler));
171 }
172 }
173 new
174 }
175
176 #[must_use]
192 pub fn only<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
193 use std::collections::HashSet;
194 let include: HashSet<&str> = names.into_iter().collect();
195 let mut new = Self {
196 handlers: HashMap::new(),
197 interceptors: self.interceptors.clone(),
198 };
199 for (name, handler) in &self.handlers {
200 if include.contains(name.as_str()) {
201 new.handlers.insert(name.clone(), Arc::clone(handler));
202 }
203 }
204 new
205 }
206
207 #[must_use]
232 pub fn with_interceptors(mut self, interceptors: InterceptorStack<ToolExec<Ctx>>) -> Self {
233 self.interceptors = interceptors;
234 self
235 }
236
237 pub async fn execute(&self, call: &ToolCall, ctx: &Ctx) -> ToolResult {
249 let Some(handler) = self.handlers.get(&call.name) else {
250 return ToolResult {
251 tool_call_id: call.id.clone(),
252 content: format!("Unknown tool: {}", call.name),
253 is_error: true,
254 };
255 };
256
257 #[cfg(feature = "schema")]
259 {
260 let definition = handler.definition();
261 if let Err(e) = definition.parameters.validate(&call.arguments) {
262 return ToolResult {
263 tool_call_id: call.id.clone(),
264 content: format!("Invalid arguments for tool '{}': {e}", call.name),
265 is_error: true,
266 };
267 }
268 }
269
270 let request = ToolRequest {
272 name: call.name.clone(),
273 call_id: call.id.clone(),
274 arguments: call.arguments.clone(),
275 };
276
277 let operation = ToolHandlerOperation {
279 handler: handler.clone(),
280 ctx,
281 retry_config: handler.definition().retry,
282 };
283
284 let response = self.interceptors.execute(&request, &operation).await;
286
287 ToolResult {
288 tool_call_id: call.id.clone(),
289 content: response.content,
290 is_error: response.is_error,
291 }
292 }
293
294 pub(crate) async fn execute_by_name(
300 &self,
301 name: &str,
302 call_id: &str,
303 arguments: serde_json::Value,
304 ctx: &Ctx,
305 ) -> ToolResult {
306 let Some(handler) = self.handlers.get(name) else {
307 return ToolResult {
308 tool_call_id: call_id.to_string(),
309 content: format!("Unknown tool: {name}"),
310 is_error: true,
311 };
312 };
313
314 #[cfg(feature = "schema")]
316 {
317 let definition = handler.definition();
318 if let Err(e) = definition.parameters.validate(&arguments) {
319 return ToolResult {
320 tool_call_id: call_id.to_string(),
321 content: format!("Invalid arguments for tool '{name}': {e}"),
322 is_error: true,
323 };
324 }
325 }
326
327 let request = ToolRequest {
329 name: name.to_string(),
330 call_id: call_id.to_string(),
331 arguments,
332 };
333
334 let operation = ToolHandlerOperation {
335 handler: handler.clone(),
336 ctx,
337 retry_config: handler.definition().retry,
338 };
339
340 let response = self.interceptors.execute(&request, &operation).await;
341
342 ToolResult {
343 tool_call_id: request.call_id,
344 content: response.content,
345 is_error: response.is_error,
346 }
347 }
348
349 pub async fn execute_all(
354 &self,
355 calls: &[ToolCall],
356 ctx: &Ctx,
357 parallel: bool,
358 ) -> Vec<ToolResult> {
359 if !parallel || calls.len() <= 1 {
360 let mut results = Vec::with_capacity(calls.len());
361 for call in calls {
362 results.push(self.execute(call, ctx).await);
363 }
364 return results;
365 }
366
367 let futures: Vec<_> = calls.iter().map(|call| self.execute(call, ctx)).collect();
369 futures::future::join_all(futures).await
370 }
371}
372
373fn compute_backoff(config: &ToolRetryConfig, attempt: u32) -> Duration {
377 #[allow(clippy::cast_possible_wrap)]
380 let base =
381 config.initial_backoff.as_secs_f64() * config.backoff_multiplier.powi(attempt as i32);
382 let capped = base.min(config.max_backoff.as_secs_f64());
383
384 let jitter_factor = if config.jitter > 0.0 {
386 let min_factor = 1.0 - config.jitter;
387 let mut rng = rand::rng();
388 rng.random_range(min_factor..=1.0)
389 } else {
390 1.0
391 };
392
393 Duration::from_secs_f64(capped * jitter_factor)
394}
395
396struct ToolHandlerOperation<'a, Ctx: Send + Sync + 'static> {
401 handler: Arc<dyn ToolHandler<Ctx>>,
402 ctx: &'a Ctx,
403 retry_config: Option<ToolRetryConfig>,
404}
405
406impl<Ctx: Send + Sync + 'static> Operation<ToolExec<Ctx>> for ToolHandlerOperation<'_, Ctx> {
407 fn execute<'b>(
408 &'b self,
409 input: &'b ToolRequest,
410 ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'b>>
411 where
412 ToolRequest: Sync,
413 {
414 Box::pin(async move {
415 match &self.retry_config {
416 Some(config) => execute_with_retry(&self.handler, input, self.ctx, config).await,
417 None => execute_once(&self.handler, input, self.ctx).await,
418 }
419 })
420 }
421}
422
423async fn execute_once<Ctx: Send + Sync + 'static>(
425 handler: &Arc<dyn ToolHandler<Ctx>>,
426 request: &ToolRequest,
427 ctx: &Ctx,
428) -> ToolResponse {
429 match handler.execute(request.arguments.clone(), ctx).await {
430 Ok(output) => ToolResponse {
431 content: output.content,
432 is_error: false,
433 },
434 Err(e) => ToolResponse {
435 content: e.message,
436 is_error: true,
437 },
438 }
439}
440
441async fn execute_with_retry<Ctx: Send + Sync + 'static>(
443 handler: &Arc<dyn ToolHandler<Ctx>>,
444 request: &ToolRequest,
445 ctx: &Ctx,
446 config: &ToolRetryConfig,
447) -> ToolResponse {
448 let mut attempt = 0u32;
449
450 loop {
451 match handler.execute(request.arguments.clone(), ctx).await {
452 Ok(output) => {
453 return ToolResponse {
454 content: output.content,
455 is_error: false,
456 };
457 }
458 Err(e) => {
459 let error_msg = e.message;
460
461 let should_retry = config
463 .retry_if
464 .as_ref()
465 .is_none_or(|predicate| predicate(&error_msg));
466
467 if !should_retry || attempt >= config.max_retries {
468 return ToolResponse {
469 content: error_msg,
470 is_error: true,
471 };
472 }
473
474 let backoff = compute_backoff(config, attempt);
476 tokio::time::sleep(backoff).await;
477
478 attempt += 1;
479 }
480 }
481 }
482}