Skip to main content

atd_runtime/
registry.rs

1//! `Tool` trait + `Registry` — the contract third-party implementers see.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use atd_protocol::{ToolDefinition, ToolSummary};
9
10use crate::context::CallContext;
11use crate::error::ToolCallError;
12
13/// Boxed future returned by [`Tool::call`].
14pub type CallFuture<'a> =
15    Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolCallError>> + Send + 'a>>;
16
17/// Boxed future returned by [`Tool::call_paginated`]. SP-pagination-v1 §4.4.
18pub type PaginatedCallFuture<'a> =
19    Pin<Box<dyn Future<Output = Result<PaginatedResult, ToolCallError>> + Send + 'a>>;
20
21/// Result of a paginated tool call. Tools that don't paginate return
22/// `next_cursor: None`; the default `Tool::call_paginated` impl wraps
23/// `Tool::call` to produce this shape automatically.
24///
25/// SP-pagination-v1 §4.4.
26#[derive(Debug)]
27pub struct PaginatedResult {
28    /// The page body. Same shape the tool would return from `call`.
29    pub value: serde_json::Value,
30    /// Server-opaque continuation handle. `Some(_)` if more pages exist;
31    /// `None` on terminal pages. Tools using ATD's reference HMAC-signed
32    /// cursors call `ctx.cursor_issuer().issue(payload)` to produce this.
33    pub next_cursor: Option<String>,
34}
35
36/// A tool. One `impl Tool for MyTool` per tool; registered once at startup.
37/// Tools MUST NOT panic; they return `Err(ToolCallError)` instead.
38///
39/// `call` returns a boxed future so the trait is dyn-compatible without
40/// requiring the `async_trait` macro.
41pub trait Tool: Send + Sync {
42    /// Stable borrow of the tool's definition. Registry calls this once at
43    /// registration time (for summaries/schema lookup) — implementers
44    /// typically store a single `ToolDefinition` in the struct.
45    fn definition(&self) -> &ToolDefinition;
46
47    /// Invoke the tool. Args are the deserialized JSON from the wire.
48    fn call<'a>(&'a self, args: serde_json::Value, ctx: &'a CallContext) -> CallFuture<'a>;
49
50    /// SP-pagination-v1 §4.4 — whether this tool's `call_paginated` impl
51    /// produces meaningful pages (cursors emitted on first page, cursors
52    /// consumed on continuations). Default `false`: dispatch routes the
53    /// tool through `Binding::call` unchanged (preserving CLI / future
54    /// MCP / REST binding semantics). Tools that override `call_paginated`
55    /// must also override this to return `true`.
56    ///
57    /// v1 constraint: paginated tools execute through native (in-process)
58    /// semantics, bypassing the `Binding` abstraction — pagination state
59    /// doesn't survive subprocess boundaries. CLI-backed tools that need
60    /// pagination would require an out-of-band stateful protocol, deferred
61    /// to a future SP.
62    fn supports_pagination(&self) -> bool {
63        false
64    }
65
66    /// SP-pagination-v1 §4.4 — paginated variant. Default impl wraps `call`
67    /// and returns `next_cursor: None`, so existing tools work unchanged.
68    /// Tools that want to paginate override this method AND `supports_pagination`.
69    ///
70    /// `cursor`:
71    /// - `None` on the initial `Request::RunTool` — produce page 1.
72    /// - `Some(s)` on `Request::RunToolContinue` — the cursor's payload
73    ///   has already been HMAC-verified by dispatch; the tool decodes the
74    ///   `opaque_state` (or its own scheme) to resume.
75    fn call_paginated<'a>(
76        &'a self,
77        args: serde_json::Value,
78        ctx: &'a CallContext,
79        _cursor: Option<&'a str>,
80    ) -> PaginatedCallFuture<'a> {
81        let fut = self.call(args, ctx);
82        Box::pin(async move {
83            let value = fut.await?;
84            Ok(PaginatedResult {
85                value,
86                next_cursor: None,
87            })
88        })
89    }
90}
91
92/// One registered tool plus the binding dispatch uses to execute it.
93/// SP-12 Task 4: `Binding` sits between dispatch and the `Tool` impl so
94/// the same tool can be served via different execution strategies
95/// (in-process, CLI subprocess, future MCP / REST / AppFunction).
96///
97/// SP-operability-v1 C2: `semaphore` enforces `max_concurrent` at
98/// dispatch time via `try_acquire_owned` — saturation returns 1002
99/// before the tool runs. Sized from
100/// `tool.definition().resources.max_concurrent`: a positive value gives
101/// that many permits, `0` maps to `Semaphore::MAX_PERMITS` (effectively
102/// unlimited). `#[non_exhaustive]` is load-bearing — downstream code
103/// goes through `Registry::register*` so new fields can be added here
104/// without breaking external struct-literal callers.
105#[derive(Clone)]
106#[non_exhaustive]
107pub struct RegisteredTool {
108    pub tool: Arc<dyn Tool>,
109    pub binding: Arc<dyn crate::binding::Binding>,
110    pub semaphore: Arc<tokio::sync::Semaphore>,
111}
112
113impl RegisteredTool {
114    pub fn definition(&self) -> &ToolDefinition {
115        self.tool.definition()
116    }
117}
118
119pub struct Registry {
120    tools: HashMap<String, RegisteredTool>,
121}
122
123impl Registry {
124    pub fn new() -> Self {
125        Self {
126            tools: HashMap::new(),
127        }
128    }
129
130    /// Register a tool with the default `NativeBinding` — dispatch will call
131    /// the tool's `Tool::call` directly. Panics on duplicate tool_id:
132    /// startup misconfiguration should fail loud, not at request time.
133    pub fn register(&mut self, tool: Arc<dyn Tool>) {
134        let binding: Arc<dyn crate::binding::Binding> =
135            Arc::new(crate::binding::NativeBinding::new(tool.clone()));
136        self.register_with_binding(tool, binding);
137    }
138
139    /// Register a tool paired with an explicit binding. Use this for tools
140    /// whose execution strategy differs from "run the `Tool::call` future"
141    /// (e.g. `CliBinding` for subprocess-backed tools).
142    pub fn register_with_binding(
143        &mut self,
144        tool: Arc<dyn Tool>,
145        binding: Arc<dyn crate::binding::Binding>,
146    ) {
147        let id = tool.definition().id.clone();
148        if self.tools.contains_key(&id) {
149            panic!("duplicate tool registration: {id}");
150        }
151        // SP-operability-v1 C2: pre-build the per-tool semaphore here
152        // (not lazily at dispatch) so the permit count is an invariant of
153        // registration, and so `RegisteredTool` stays cheaply `Clone`.
154        // `max_concurrent == 0` is treated as "unlimited" — MAX_PERMITS
155        // is the tokio-sanctioned sentinel for this.
156        let max = tool.definition().resources.max_concurrent;
157        let permits = if max == 0 {
158            tokio::sync::Semaphore::MAX_PERMITS
159        } else {
160            max as usize
161        };
162        let semaphore = Arc::new(tokio::sync::Semaphore::new(permits));
163        self.tools.insert(
164            id,
165            RegisteredTool {
166                tool,
167                binding,
168                semaphore,
169            },
170        );
171    }
172
173    pub fn get(&self, tool_id: &str) -> Option<&RegisteredTool> {
174        self.tools.get(tool_id)
175    }
176
177    pub fn summaries(&self) -> Vec<ToolSummary> {
178        self.tools
179            .values()
180            .map(|r| ToolSummary::from(r.tool.definition()))
181            .collect()
182    }
183
184    pub fn count(&self) -> usize {
185        self.tools.len()
186    }
187}
188
189impl Default for Registry {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use atd_protocol::{
199        BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
200        ToolTrust, ToolVisibility, TrustLevel,
201    };
202
203    struct StubTool {
204        def: ToolDefinition,
205    }
206
207    impl StubTool {
208        fn new(id: &str) -> Self {
209            Self {
210                def: ToolDefinition {
211                    id: id.into(),
212                    name: id.into(),
213                    description: "stub".into(),
214                    version: "0.0.0".into(),
215                    capability: ToolCapability {
216                        domain: "stub".into(),
217                        actions: vec![],
218                        tags: vec![],
219                        intent_examples: vec![],
220                    },
221                    input_schema: serde_json::json!({}),
222                    output_schema: serde_json::json!({}),
223                    bindings: vec![ToolBinding {
224                        protocol: BindingProtocol::Cli,
225                        config: serde_json::json!({}),
226                    }],
227                    safety: ToolSafety {
228                        level: SafetyLevel::Read,
229                        dry_run: false,
230                        side_effects: vec![],
231                        data_sensitivity: None,
232                    },
233                    resources: ToolResources {
234                        timeout_ms: 1000,
235                        max_concurrent: 1,
236                        rate_limit_per_min: None,
237                        estimated_tokens: None,
238                    },
239                    trust: ToolTrust {
240                        publisher: "test".into(),
241                        trust_level: TrustLevel::L0Unverified,
242                        signature: None,
243                    },
244                    visibility: ToolVisibility::Read,
245                    required_capabilities: vec![],
246                    tier: None,
247                    errors: vec![],
248                },
249            }
250        }
251    }
252
253    impl Tool for StubTool {
254        fn definition(&self) -> &ToolDefinition {
255            &self.def
256        }
257        fn call<'a>(&'a self, _args: serde_json::Value, _ctx: &'a CallContext) -> CallFuture<'a> {
258            Box::pin(async move { Ok(serde_json::json!({})) })
259        }
260    }
261
262    #[test]
263    fn register_and_get_returns_the_tool() {
264        let mut r = Registry::new();
265        r.register(Arc::new(StubTool::new("test:a")));
266        assert!(r.get("test:a").is_some());
267        assert!(r.get("test:missing").is_none());
268    }
269
270    #[test]
271    fn summaries_projects_registered_tools() {
272        let mut r = Registry::new();
273        r.register(Arc::new(StubTool::new("test:a")));
274        r.register(Arc::new(StubTool::new("test:b")));
275        let sums = r.summaries();
276        assert_eq!(sums.len(), 2);
277        let ids: std::collections::HashSet<_> = sums.iter().map(|s| s.id.clone()).collect();
278        assert!(ids.contains("test:a"));
279        assert!(ids.contains("test:b"));
280    }
281
282    #[test]
283    #[should_panic(expected = "duplicate tool registration: test:a")]
284    fn duplicate_registration_panics() {
285        let mut r = Registry::new();
286        r.register(Arc::new(StubTool::new("test:a")));
287        r.register(Arc::new(StubTool::new("test:a")));
288    }
289
290    #[test]
291    fn empty_registry_reports_zero() {
292        let r = Registry::new();
293        assert_eq!(r.count(), 0);
294        assert!(r.summaries().is_empty());
295    }
296
297    /// SP-operability-v1 C2. Verify that `register_with_binding` sizes
298    /// the per-tool semaphore from `resources.max_concurrent` — including
299    /// the `0 → MAX_PERMITS` unlimited-sentinel rule.
300    #[test]
301    fn semaphore_permits_match_max_concurrent() {
302        fn mk_tool(id: &str, max_concurrent: u32) -> Arc<dyn Tool> {
303            Arc::new(StubTool {
304                def: ToolDefinition {
305                    id: id.into(),
306                    name: id.into(),
307                    description: "t".into(),
308                    version: "0".into(),
309                    capability: ToolCapability {
310                        domain: "d".into(),
311                        actions: vec![],
312                        tags: vec![],
313                        intent_examples: vec![],
314                    },
315                    input_schema: serde_json::json!({}),
316                    output_schema: serde_json::json!({}),
317                    bindings: vec![ToolBinding {
318                        protocol: BindingProtocol::Cli,
319                        config: serde_json::json!({}),
320                    }],
321                    safety: ToolSafety {
322                        level: SafetyLevel::Read,
323                        dry_run: false,
324                        side_effects: vec![],
325                        data_sensitivity: None,
326                    },
327                    resources: ToolResources {
328                        timeout_ms: 100,
329                        max_concurrent,
330                        rate_limit_per_min: None,
331                        estimated_tokens: None,
332                    },
333                    trust: ToolTrust {
334                        publisher: "p".into(),
335                        trust_level: TrustLevel::L0Unverified,
336                        signature: None,
337                    },
338                    visibility: ToolVisibility::Read,
339                    required_capabilities: vec![],
340                    tier: None,
341                    errors: vec![],
342                },
343            })
344        }
345
346        let mut reg = Registry::new();
347        reg.register(mk_tool("stub:a", 5));
348        reg.register(mk_tool("stub:b", 0));
349
350        let a = reg.get("stub:a").unwrap();
351        assert_eq!(a.semaphore.available_permits(), 5);
352
353        let b = reg.get("stub:b").unwrap();
354        assert_eq!(
355            b.semaphore.available_permits(),
356            tokio::sync::Semaphore::MAX_PERMITS,
357            "max_concurrent=0 should map to MAX_PERMITS"
358        );
359    }
360}