1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use atd_protocol::{ToolDefinition, ToolSummary};
9
10use crate::context::CallContext;
11use crate::error::ToolCallError;
12
13pub type CallFuture<'a> =
15 Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolCallError>> + Send + 'a>>;
16
17pub type PaginatedCallFuture<'a> =
19 Pin<Box<dyn Future<Output = Result<PaginatedResult, ToolCallError>> + Send + 'a>>;
20
21#[derive(Debug)]
27pub struct PaginatedResult {
28 pub value: serde_json::Value,
30 pub next_cursor: Option<String>,
34}
35
36pub trait Tool: Send + Sync {
42 fn definition(&self) -> &ToolDefinition;
46
47 fn call<'a>(&'a self, args: serde_json::Value, ctx: &'a CallContext) -> CallFuture<'a>;
49
50 fn supports_pagination(&self) -> bool {
63 false
64 }
65
66 fn call_paginated<'a>(
76 &'a self,
77 args: serde_json::Value,
78 ctx: &'a CallContext,
79 _cursor: Option<&'a str>,
80 ) -> PaginatedCallFuture<'a> {
81 let fut = self.call(args, ctx);
82 Box::pin(async move {
83 let value = fut.await?;
84 Ok(PaginatedResult {
85 value,
86 next_cursor: None,
87 })
88 })
89 }
90}
91
92#[derive(Clone)]
106#[non_exhaustive]
107pub struct RegisteredTool {
108 pub tool: Arc<dyn Tool>,
109 pub binding: Arc<dyn crate::binding::Binding>,
110 pub semaphore: Arc<tokio::sync::Semaphore>,
111}
112
113impl RegisteredTool {
114 pub fn definition(&self) -> &ToolDefinition {
115 self.tool.definition()
116 }
117}
118
119pub struct Registry {
120 tools: HashMap<String, RegisteredTool>,
121}
122
123impl Registry {
124 pub fn new() -> Self {
125 Self {
126 tools: HashMap::new(),
127 }
128 }
129
130 pub fn register(&mut self, tool: Arc<dyn Tool>) {
134 let binding: Arc<dyn crate::binding::Binding> =
135 Arc::new(crate::binding::NativeBinding::new(tool.clone()));
136 self.register_with_binding(tool, binding);
137 }
138
139 pub fn register_with_binding(
143 &mut self,
144 tool: Arc<dyn Tool>,
145 binding: Arc<dyn crate::binding::Binding>,
146 ) {
147 let id = tool.definition().id.clone();
148 if self.tools.contains_key(&id) {
149 panic!("duplicate tool registration: {id}");
150 }
151 let max = tool.definition().resources.max_concurrent;
157 let permits = if max == 0 {
158 tokio::sync::Semaphore::MAX_PERMITS
159 } else {
160 max as usize
161 };
162 let semaphore = Arc::new(tokio::sync::Semaphore::new(permits));
163 self.tools.insert(
164 id,
165 RegisteredTool {
166 tool,
167 binding,
168 semaphore,
169 },
170 );
171 }
172
173 pub fn get(&self, tool_id: &str) -> Option<&RegisteredTool> {
174 self.tools.get(tool_id)
175 }
176
177 pub fn summaries(&self) -> Vec<ToolSummary> {
178 self.tools
179 .values()
180 .map(|r| ToolSummary::from(r.tool.definition()))
181 .collect()
182 }
183
184 pub fn count(&self) -> usize {
185 self.tools.len()
186 }
187}
188
189impl Default for Registry {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use atd_protocol::{
199 BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
200 ToolTrust, ToolVisibility, TrustLevel,
201 };
202
203 struct StubTool {
204 def: ToolDefinition,
205 }
206
207 impl StubTool {
208 fn new(id: &str) -> Self {
209 Self {
210 def: ToolDefinition {
211 id: id.into(),
212 name: id.into(),
213 description: "stub".into(),
214 version: "0.0.0".into(),
215 capability: ToolCapability {
216 domain: "stub".into(),
217 actions: vec![],
218 tags: vec![],
219 intent_examples: vec![],
220 },
221 input_schema: serde_json::json!({}),
222 output_schema: serde_json::json!({}),
223 bindings: vec![ToolBinding {
224 protocol: BindingProtocol::Cli,
225 config: serde_json::json!({}),
226 }],
227 safety: ToolSafety {
228 level: SafetyLevel::Read,
229 dry_run: false,
230 side_effects: vec![],
231 data_sensitivity: None,
232 },
233 resources: ToolResources {
234 timeout_ms: 1000,
235 max_concurrent: 1,
236 rate_limit_per_min: None,
237 estimated_tokens: None,
238 },
239 trust: ToolTrust {
240 publisher: "test".into(),
241 trust_level: TrustLevel::L0Unverified,
242 signature: None,
243 },
244 visibility: ToolVisibility::Read,
245 required_capabilities: vec![],
246 tier: None,
247 errors: vec![],
248 },
249 }
250 }
251 }
252
253 impl Tool for StubTool {
254 fn definition(&self) -> &ToolDefinition {
255 &self.def
256 }
257 fn call<'a>(&'a self, _args: serde_json::Value, _ctx: &'a CallContext) -> CallFuture<'a> {
258 Box::pin(async move { Ok(serde_json::json!({})) })
259 }
260 }
261
262 #[test]
263 fn register_and_get_returns_the_tool() {
264 let mut r = Registry::new();
265 r.register(Arc::new(StubTool::new("test:a")));
266 assert!(r.get("test:a").is_some());
267 assert!(r.get("test:missing").is_none());
268 }
269
270 #[test]
271 fn summaries_projects_registered_tools() {
272 let mut r = Registry::new();
273 r.register(Arc::new(StubTool::new("test:a")));
274 r.register(Arc::new(StubTool::new("test:b")));
275 let sums = r.summaries();
276 assert_eq!(sums.len(), 2);
277 let ids: std::collections::HashSet<_> = sums.iter().map(|s| s.id.clone()).collect();
278 assert!(ids.contains("test:a"));
279 assert!(ids.contains("test:b"));
280 }
281
282 #[test]
283 #[should_panic(expected = "duplicate tool registration: test:a")]
284 fn duplicate_registration_panics() {
285 let mut r = Registry::new();
286 r.register(Arc::new(StubTool::new("test:a")));
287 r.register(Arc::new(StubTool::new("test:a")));
288 }
289
290 #[test]
291 fn empty_registry_reports_zero() {
292 let r = Registry::new();
293 assert_eq!(r.count(), 0);
294 assert!(r.summaries().is_empty());
295 }
296
297 #[test]
301 fn semaphore_permits_match_max_concurrent() {
302 fn mk_tool(id: &str, max_concurrent: u32) -> Arc<dyn Tool> {
303 Arc::new(StubTool {
304 def: ToolDefinition {
305 id: id.into(),
306 name: id.into(),
307 description: "t".into(),
308 version: "0".into(),
309 capability: ToolCapability {
310 domain: "d".into(),
311 actions: vec![],
312 tags: vec![],
313 intent_examples: vec![],
314 },
315 input_schema: serde_json::json!({}),
316 output_schema: serde_json::json!({}),
317 bindings: vec![ToolBinding {
318 protocol: BindingProtocol::Cli,
319 config: serde_json::json!({}),
320 }],
321 safety: ToolSafety {
322 level: SafetyLevel::Read,
323 dry_run: false,
324 side_effects: vec![],
325 data_sensitivity: None,
326 },
327 resources: ToolResources {
328 timeout_ms: 100,
329 max_concurrent,
330 rate_limit_per_min: None,
331 estimated_tokens: None,
332 },
333 trust: ToolTrust {
334 publisher: "p".into(),
335 trust_level: TrustLevel::L0Unverified,
336 signature: None,
337 },
338 visibility: ToolVisibility::Read,
339 required_capabilities: vec![],
340 tier: None,
341 errors: vec![],
342 },
343 })
344 }
345
346 let mut reg = Registry::new();
347 reg.register(mk_tool("stub:a", 5));
348 reg.register(mk_tool("stub:b", 0));
349
350 let a = reg.get("stub:a").unwrap();
351 assert_eq!(a.semaphore.available_permits(), 5);
352
353 let b = reg.get("stub:b").unwrap();
354 assert_eq!(
355 b.semaphore.available_permits(),
356 tokio::sync::Semaphore::MAX_PERMITS,
357 "max_concurrent=0 should map to MAX_PERMITS"
358 );
359 }
360}