use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use atd_protocol::{ToolDefinition, ToolSummary};
use crate::context::CallContext;
use crate::error::ToolCallError;
pub type CallFuture<'a> =
Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolCallError>> + Send + 'a>>;
pub type PaginatedCallFuture<'a> =
Pin<Box<dyn Future<Output = Result<PaginatedResult, ToolCallError>> + Send + 'a>>;
#[derive(Debug)]
pub struct PaginatedResult {
pub value: serde_json::Value,
pub next_cursor: Option<String>,
}
pub trait Tool: Send + Sync {
fn definition(&self) -> &ToolDefinition;
fn call<'a>(&'a self, args: serde_json::Value, ctx: &'a CallContext) -> CallFuture<'a>;
fn supports_pagination(&self) -> bool {
false
}
fn call_paginated<'a>(
&'a self,
args: serde_json::Value,
ctx: &'a CallContext,
_cursor: Option<&'a str>,
) -> PaginatedCallFuture<'a> {
let fut = self.call(args, ctx);
Box::pin(async move {
let value = fut.await?;
Ok(PaginatedResult {
value,
next_cursor: None,
})
})
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct RegisteredTool {
pub tool: Arc<dyn Tool>,
pub binding: Arc<dyn crate::binding::Binding>,
pub semaphore: Arc<tokio::sync::Semaphore>,
}
impl RegisteredTool {
pub fn definition(&self) -> &ToolDefinition {
self.tool.definition()
}
}
pub struct Registry {
tools: HashMap<String, RegisteredTool>,
}
impl Registry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Arc<dyn Tool>) {
let binding: Arc<dyn crate::binding::Binding> =
Arc::new(crate::binding::NativeBinding::new(tool.clone()));
self.register_with_binding(tool, binding);
}
pub fn register_with_binding(
&mut self,
tool: Arc<dyn Tool>,
binding: Arc<dyn crate::binding::Binding>,
) {
let id = tool.definition().id.clone();
if self.tools.contains_key(&id) {
panic!("duplicate tool registration: {id}");
}
let max = tool.definition().resources.max_concurrent;
let permits = if max == 0 {
tokio::sync::Semaphore::MAX_PERMITS
} else {
max as usize
};
let semaphore = Arc::new(tokio::sync::Semaphore::new(permits));
self.tools.insert(
id,
RegisteredTool {
tool,
binding,
semaphore,
},
);
}
pub fn get(&self, tool_id: &str) -> Option<&RegisteredTool> {
self.tools.get(tool_id)
}
pub fn summaries(&self) -> Vec<ToolSummary> {
self.tools
.values()
.map(|r| ToolSummary::from(r.tool.definition()))
.collect()
}
pub fn count(&self) -> usize {
self.tools.len()
}
}
impl Default for Registry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use atd_protocol::{
BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
ToolTrust, ToolVisibility, TrustLevel,
};
struct StubTool {
def: ToolDefinition,
}
impl StubTool {
fn new(id: &str) -> Self {
Self {
def: ToolDefinition {
id: id.into(),
name: id.into(),
description: "stub".into(),
version: "0.0.0".into(),
capability: ToolCapability {
domain: "stub".into(),
actions: vec![],
tags: vec![],
intent_examples: vec![],
},
input_schema: serde_json::json!({}),
output_schema: serde_json::json!({}),
bindings: vec![ToolBinding {
protocol: BindingProtocol::Cli,
config: serde_json::json!({}),
}],
safety: ToolSafety {
level: SafetyLevel::Read,
dry_run: false,
side_effects: vec![],
data_sensitivity: None,
},
resources: ToolResources {
timeout_ms: 1000,
max_concurrent: 1,
rate_limit_per_min: None,
estimated_tokens: None,
},
trust: ToolTrust {
publisher: "test".into(),
trust_level: TrustLevel::L0Unverified,
signature: None,
},
visibility: ToolVisibility::Read,
required_capabilities: vec![],
tier: None,
errors: vec![],
},
}
}
}
impl Tool for StubTool {
fn definition(&self) -> &ToolDefinition {
&self.def
}
fn call<'a>(&'a self, _args: serde_json::Value, _ctx: &'a CallContext) -> CallFuture<'a> {
Box::pin(async move { Ok(serde_json::json!({})) })
}
}
#[test]
fn register_and_get_returns_the_tool() {
let mut r = Registry::new();
r.register(Arc::new(StubTool::new("test:a")));
assert!(r.get("test:a").is_some());
assert!(r.get("test:missing").is_none());
}
#[test]
fn summaries_projects_registered_tools() {
let mut r = Registry::new();
r.register(Arc::new(StubTool::new("test:a")));
r.register(Arc::new(StubTool::new("test:b")));
let sums = r.summaries();
assert_eq!(sums.len(), 2);
let ids: std::collections::HashSet<_> = sums.iter().map(|s| s.id.clone()).collect();
assert!(ids.contains("test:a"));
assert!(ids.contains("test:b"));
}
#[test]
#[should_panic(expected = "duplicate tool registration: test:a")]
fn duplicate_registration_panics() {
let mut r = Registry::new();
r.register(Arc::new(StubTool::new("test:a")));
r.register(Arc::new(StubTool::new("test:a")));
}
#[test]
fn empty_registry_reports_zero() {
let r = Registry::new();
assert_eq!(r.count(), 0);
assert!(r.summaries().is_empty());
}
#[test]
fn semaphore_permits_match_max_concurrent() {
fn mk_tool(id: &str, max_concurrent: u32) -> Arc<dyn Tool> {
Arc::new(StubTool {
def: ToolDefinition {
id: id.into(),
name: id.into(),
description: "t".into(),
version: "0".into(),
capability: ToolCapability {
domain: "d".into(),
actions: vec![],
tags: vec![],
intent_examples: vec![],
},
input_schema: serde_json::json!({}),
output_schema: serde_json::json!({}),
bindings: vec![ToolBinding {
protocol: BindingProtocol::Cli,
config: serde_json::json!({}),
}],
safety: ToolSafety {
level: SafetyLevel::Read,
dry_run: false,
side_effects: vec![],
data_sensitivity: None,
},
resources: ToolResources {
timeout_ms: 100,
max_concurrent,
rate_limit_per_min: None,
estimated_tokens: None,
},
trust: ToolTrust {
publisher: "p".into(),
trust_level: TrustLevel::L0Unverified,
signature: None,
},
visibility: ToolVisibility::Read,
required_capabilities: vec![],
tier: None,
errors: vec![],
},
})
}
let mut reg = Registry::new();
reg.register(mk_tool("stub:a", 5));
reg.register(mk_tool("stub:b", 0));
let a = reg.get("stub:a").unwrap();
assert_eq!(a.semaphore.available_permits(), 5);
let b = reg.get("stub:b").unwrap();
assert_eq!(
b.semaphore.available_permits(),
tokio::sync::Semaphore::MAX_PERMITS,
"max_concurrent=0 should map to MAX_PERMITS"
);
}
}