use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use rand::Rng;
use super::ToolHandler;
use crate::chat::{ToolCall, ToolResult};
use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
use crate::intercept::{InterceptorStack, Operation};
use crate::provider::{ToolDefinition, ToolRetryConfig};
pub struct ToolRegistry<Ctx = ()>
where
Ctx: Send + Sync + 'static,
{
pub(crate) handlers: HashMap<String, Arc<dyn ToolHandler<Ctx>>>,
interceptors: InterceptorStack<ToolExec<Ctx>>,
}
impl<Ctx> Default for ToolRegistry<Ctx>
where
Ctx: Send + Sync + 'static,
{
fn default() -> Self {
Self {
handlers: HashMap::new(),
interceptors: InterceptorStack::new(),
}
}
}
impl<Ctx> Clone for ToolRegistry<Ctx>
where
Ctx: Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
handlers: self.handlers.clone(),
interceptors: self.interceptors.clone(),
}
}
}
impl<Ctx> std::fmt::Debug for ToolRegistry<Ctx>
where
Ctx: Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tools", &self.handlers.keys().collect::<Vec<_>>())
.field("interceptors", &self.interceptors.len())
.finish()
}
}
impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, handler: impl ToolHandler<Ctx> + 'static) -> &mut Self {
let name = handler.definition().name.clone();
self.handlers.insert(name, Arc::new(handler));
self
}
pub fn register_shared(&mut self, handler: Arc<dyn ToolHandler<Ctx>>) -> &mut Self {
let name = handler.definition().name.clone();
self.handlers.insert(name, handler);
self
}
pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolHandler<Ctx>>> {
self.handlers.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.handlers.contains_key(name)
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.handlers.values().map(|h| h.definition()).collect()
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
#[must_use]
pub fn without<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
use std::collections::HashSet;
let exclude: HashSet<&str> = names.into_iter().collect();
let mut new = Self {
handlers: HashMap::new(),
interceptors: self.interceptors.clone(),
};
for (name, handler) in &self.handlers {
if !exclude.contains(name.as_str()) {
new.handlers.insert(name.clone(), Arc::clone(handler));
}
}
new
}
#[must_use]
pub fn only<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
use std::collections::HashSet;
let include: HashSet<&str> = names.into_iter().collect();
let mut new = Self {
handlers: HashMap::new(),
interceptors: self.interceptors.clone(),
};
for (name, handler) in &self.handlers {
if include.contains(name.as_str()) {
new.handlers.insert(name.clone(), Arc::clone(handler));
}
}
new
}
#[must_use]
pub fn with_interceptors(mut self, interceptors: InterceptorStack<ToolExec<Ctx>>) -> Self {
self.interceptors = interceptors;
self
}
pub async fn execute(&self, call: &ToolCall, ctx: &Ctx) -> ToolResult {
self.execute_inner(&call.name, &call.id, call.arguments.clone(), ctx)
.await
}
pub(crate) async fn execute_by_name(
&self,
name: &str,
call_id: &str,
arguments: serde_json::Value,
ctx: &Ctx,
) -> ToolResult {
self.execute_inner(name, call_id, arguments, ctx).await
}
async fn execute_inner(
&self,
name: &str,
call_id: &str,
arguments: serde_json::Value,
ctx: &Ctx,
) -> ToolResult {
let Some(handler) = self.handlers.get(name) else {
return ToolResult {
tool_call_id: call_id.to_string(),
content: format!("Unknown tool: {name}"),
is_error: true,
};
};
#[cfg(feature = "schema")]
{
let definition = handler.definition();
if let Err(e) = definition.parameters.validate(&arguments) {
return ToolResult {
tool_call_id: call_id.to_string(),
content: format!("Invalid arguments for tool '{name}': {e}"),
is_error: true,
};
}
}
let request = ToolRequest {
name: name.to_string(),
call_id: call_id.to_string(),
arguments,
};
let operation = ToolHandlerOperation {
handler: handler.clone(),
ctx,
retry_config: handler.definition().retry,
};
let response = self.interceptors.execute(&request, &operation).await;
ToolResult {
tool_call_id: request.call_id,
content: response.content,
is_error: response.is_error,
}
}
pub async fn execute_all(
&self,
calls: &[ToolCall],
ctx: &Ctx,
parallel: bool,
) -> Vec<ToolResult> {
if !parallel || calls.len() <= 1 {
let mut results = Vec::with_capacity(calls.len());
for call in calls {
results.push(self.execute(call, ctx).await);
}
return results;
}
let futures: Vec<_> = calls.iter().map(|call| self.execute(call, ctx)).collect();
futures::future::join_all(futures).await
}
}
fn compute_backoff(config: &ToolRetryConfig, attempt: u32) -> Duration {
#[allow(clippy::cast_possible_wrap)]
let base =
config.initial_backoff.as_secs_f64() * config.backoff_multiplier.powi(attempt as i32);
let capped = base.min(config.max_backoff.as_secs_f64());
let jitter_factor = if config.jitter > 0.0 {
let min_factor = 1.0 - config.jitter;
let mut rng = rand::rng();
rng.random_range(min_factor..=1.0)
} else {
1.0
};
Duration::from_secs_f64(capped * jitter_factor)
}
struct ToolHandlerOperation<'a, Ctx: Send + Sync + 'static> {
handler: Arc<dyn ToolHandler<Ctx>>,
ctx: &'a Ctx,
retry_config: Option<ToolRetryConfig>,
}
impl<Ctx: Send + Sync + 'static> Operation<ToolExec<Ctx>> for ToolHandlerOperation<'_, Ctx> {
fn execute<'b>(
&'b self,
input: &'b ToolRequest,
) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'b>>
where
ToolRequest: Sync,
{
Box::pin(async move {
match &self.retry_config {
Some(config) => execute_with_retry(&self.handler, input, self.ctx, config).await,
None => execute_once(&self.handler, input, self.ctx).await,
}
})
}
}
async fn execute_once<Ctx: Send + Sync + 'static>(
handler: &Arc<dyn ToolHandler<Ctx>>,
request: &ToolRequest,
ctx: &Ctx,
) -> ToolResponse {
match handler.execute(request.arguments.clone(), ctx).await {
Ok(output) => ToolResponse {
content: output.content,
is_error: false,
},
Err(e) => ToolResponse {
content: e.message,
is_error: true,
},
}
}
async fn execute_with_retry<Ctx: Send + Sync + 'static>(
handler: &Arc<dyn ToolHandler<Ctx>>,
request: &ToolRequest,
ctx: &Ctx,
config: &ToolRetryConfig,
) -> ToolResponse {
let mut attempt = 0u32;
loop {
match handler.execute(request.arguments.clone(), ctx).await {
Ok(output) => {
return ToolResponse {
content: output.content,
is_error: false,
};
}
Err(e) => {
let error_msg = e.message;
let should_retry = config
.retry_if
.as_ref()
.is_none_or(|predicate| predicate(&error_msg));
if !should_retry || attempt >= config.max_retries {
return ToolResponse {
content: error_msg,
is_error: true,
};
}
let backoff = compute_backoff(config, attempt);
tokio::time::sleep(backoff).await;
attempt += 1;
}
}
}
}