use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use super::context::ToolContext;
use arcp_core::error::ARCPError;
#[async_trait]
pub trait ToolHandler: Send + Sync {
fn name(&self) -> &str;
async fn invoke(
&self,
arguments: serde_json::Value,
ctx: ToolContext,
) -> Result<serde_json::Value, ARCPError>;
}
#[derive(Clone, Default)]
pub struct ToolRegistry {
tools: Arc<HashMap<String, Arc<dyn ToolHandler>>>,
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("names", &self.tools.keys().collect::<Vec<_>>())
.finish()
}
}
impl ToolRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
self.tools.get(name).cloned()
}
#[must_use]
pub fn len(&self) -> usize {
self.tools.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
#[derive(Default)]
pub struct ToolRegistryBuilder {
tools: HashMap<String, Arc<dyn ToolHandler>>,
}
impl std::fmt::Debug for ToolRegistryBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistryBuilder")
.field("names", &self.tools.keys().collect::<Vec<_>>())
.finish()
}
}
impl ToolRegistryBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with(mut self, handler: Arc<dyn ToolHandler>) -> Self {
let name = handler.name().to_owned();
self.tools.insert(name, handler);
self
}
#[must_use]
pub fn build(self) -> ToolRegistry {
ToolRegistry {
tools: Arc::new(self.tools),
}
}
}
#[cfg(test)]
#[allow(
clippy::expect_used,
clippy::unwrap_used,
clippy::panic,
clippy::missing_panics_doc
)]
mod tests {
use tokio_util::sync::CancellationToken;
use super::*;
struct EchoTool;
#[async_trait]
impl ToolHandler for EchoTool {
fn name(&self) -> &'static str {
"echo"
}
async fn invoke(
&self,
arguments: serde_json::Value,
_ctx: ToolContext,
) -> Result<serde_json::Value, ARCPError> {
Ok(arguments)
}
}
#[tokio::test]
async fn registry_round_trips_through_builder() {
let reg = ToolRegistryBuilder::new().with(Arc::new(EchoTool)).build();
assert!(!reg.is_empty());
assert_eq!(reg.len(), 1);
let echo = reg.get("echo").expect("registered");
assert_eq!(echo.name(), "echo");
let (tx, _rx) = tokio::sync::mpsc::channel(1);
let ctx = ToolContext {
cancel: CancellationToken::new(),
job_id: arcp_core::ids::JobId::new(),
session_id: arcp_core::ids::SessionId::new(),
correlation_id: arcp_core::ids::MessageId::new(),
out: tx,
budget: crate::runtime::context::BudgetTracker::new(),
lease: None,
};
let result = echo
.invoke(serde_json::json!({"k": 1}), ctx)
.await
.expect("invoke");
assert_eq!(result, serde_json::json!({"k": 1}));
}
#[test]
fn empty_registry_reports_empty() {
let reg = ToolRegistry::new();
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
assert!(reg.get("missing").is_none());
}
#[test]
fn debug_impls_render_without_panicking() {
let reg = ToolRegistryBuilder::new().with(Arc::new(EchoTool)).build();
let s = format!("{reg:?}");
assert!(s.contains("echo"));
let builder = ToolRegistryBuilder::new().with(Arc::new(EchoTool));
let bs = format!("{builder:?}");
assert!(bs.contains("echo"));
}
}