agent_sdk_tools/tools.rs
1//! Tool definition and registry.
2//!
3//! Tools allow the LLM to perform actions in the real world. This module provides:
4//!
5//! - [`Tool`] trait - Define custom tools the LLM can call
6//! - [`ToolName`] trait - Marker trait for strongly-typed tool names
7//! - [`PrimitiveToolName`] - Tool names for SDK's built-in tools
8//! - [`DynamicToolName`] - Tool names created at runtime (MCP bridges)
9//! - [`ToolRegistry`] - Collection of available tools
10//! - [`ToolContext`] - Context passed to tool execution
11//! - [`ListenExecuteTool`] - Tools that listen for updates, then execute later
12//!
13//! # Implementing a Tool
14//!
15//! ```
16//! use agent_sdk_tools::tools::{Tool, ToolContext, DynamicToolName};
17//! use agent_sdk_foundation::types::{ToolResult, ToolTier};
18//! use serde_json::{json, Value};
19//! use std::future::Future;
20//!
21//! struct MyTool;
22//!
23//! // No #[async_trait] needed - Rust 1.75+ supports native async traits
24//! impl Tool<()> for MyTool {
25//! type Name = DynamicToolName;
26//!
27//! fn name(&self) -> DynamicToolName { DynamicToolName::new("my_tool") }
28//! // `display_name` defaults to "" — override it for nicer UI.
29//! fn description(&self) -> &'static str { "Does something useful" }
30//! fn input_schema(&self) -> Value { json!({ "type": "object" }) }
31//! fn tier(&self) -> ToolTier { ToolTier::Observe }
32//!
33//! fn execute(
34//! &self,
35//! _ctx: &ToolContext<()>,
36//! _input: Value,
37//! ) -> impl Future<Output = anyhow::Result<ToolResult>> + Send {
38//! async move { Ok(ToolResult::success("Done!")) }
39//! }
40//! }
41//! ```
42
43use crate::authority::{EventAuthority, LocalEventAuthority};
44use crate::seed::{HostDependencies, ToolContextSeed};
45use crate::stores::EventStore;
46use agent_sdk_foundation::events::AgentEvent;
47use agent_sdk_foundation::llm;
48use agent_sdk_foundation::types::{ToolOutcome, ToolResult, ToolTier};
49use anyhow::Result;
50use async_trait::async_trait;
51use futures::Stream;
52use serde::{Deserialize, Serialize, de::DeserializeOwned};
53use serde_json::Value;
54use std::collections::HashMap;
55use std::future::Future;
56use std::marker::PhantomData;
57use std::pin::Pin;
58use std::sync::Arc;
59use time::OffsetDateTime;
60use tokio_util::sync::CancellationToken;
61
62// ============================================================================
63// Tool Name Types
64// ============================================================================
65
66/// Marker trait for tool names.
67///
68/// Tool names must be serializable (for storage/logging) and deserializable
69/// (for parsing from LLM responses). The string representation is derived
70/// from serde serialization.
71///
72/// # Example
73///
74/// ```ignore
75/// #[derive(Serialize, Deserialize)]
76/// #[serde(rename_all = "snake_case")]
77/// pub enum MyToolName {
78/// Read,
79/// Write,
80/// }
81///
82/// impl ToolName for MyToolName {}
83/// ```
84pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
85
86/// Helper to get string representation of a tool name via serde.
87///
88/// Returns `"<unknown_tool>"` if serialization fails (should never happen
89/// with properly implemented `ToolName` types that use `#[derive(Serialize)]`).
90#[must_use]
91pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
92 serde_json::to_string(name)
93 .unwrap_or_else(|_| "\"<unknown_tool>\"".to_string())
94 .trim_matches('"')
95 .to_string()
96}
97
98/// Parse a tool name from string via serde.
99///
100/// The input is encoded as a JSON string with `serde_json::to_string` (not
101/// interpolated with `format!`) so names containing quotes or backslashes —
102/// possible for [`DynamicToolName`]s bridged from remote MCP servers — are
103/// escaped correctly and round-trip with [`tool_name_to_string`].
104///
105/// # Errors
106/// Returns error if the string doesn't match a valid tool name.
107pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
108 let json = serde_json::to_string(s)?;
109 serde_json::from_str(&json)
110}
111
112/// Tool names for SDK's built-in primitive tools.
113#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum PrimitiveToolName {
116 Read,
117 Write,
118 Edit,
119 MultiEdit,
120 Bash,
121 Glob,
122 Grep,
123 NotebookRead,
124 NotebookEdit,
125 TodoRead,
126 TodoWrite,
127 AskUser,
128 LinkFetch,
129 WebSearch,
130}
131
132impl ToolName for PrimitiveToolName {}
133
134/// Dynamic tool name for runtime-created tools (MCP bridges, subagents).
135#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
136#[serde(transparent)]
137pub struct DynamicToolName(String);
138
139impl DynamicToolName {
140 #[must_use]
141 pub fn new(name: impl Into<String>) -> Self {
142 Self(name.into())
143 }
144
145 #[must_use]
146 pub fn as_str(&self) -> &str {
147 &self.0
148 }
149}
150
151impl ToolName for DynamicToolName {}
152
153// ============================================================================
154// Progress Stage Types (for AsyncTool)
155// ============================================================================
156
157/// Marker trait for tool progress stages (type-safe, like [`ToolName`]).
158///
159/// Progress stages are used by async tools to indicate the current phase
160/// of a long-running operation. They must be serializable for event streaming.
161///
162/// # Example
163///
164/// ```ignore
165/// #[derive(Clone, Debug, Serialize, Deserialize)]
166/// #[serde(rename_all = "snake_case")]
167/// pub enum PixTransferStage {
168/// Initiated,
169/// Processing,
170/// SentToBank,
171/// }
172///
173/// impl ProgressStage for PixTransferStage {}
174/// ```
175pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
176
177/// Helper to get string representation of a progress stage via serde.
178///
179/// Returns `"<unknown_stage>"` if serialization fails (should never happen with
180/// properly implemented `ProgressStage` types). This mirrors
181/// [`tool_name_to_string`]'s non-panicking fallback so a failing `Serialize`
182/// impl cannot panic the turn loop on the async-tool progress hot path.
183#[must_use]
184pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
185 serde_json::to_string(stage)
186 .unwrap_or_else(|_| "\"<unknown_stage>\"".to_string())
187 .trim_matches('"')
188 .to_string()
189}
190
191/// Status update from an async tool operation.
192#[derive(Clone, Debug, Serialize)]
193pub enum ToolStatus<S: ProgressStage> {
194 /// Operation is making progress
195 Progress {
196 stage: S,
197 message: String,
198 data: Option<serde_json::Value>,
199 },
200
201 /// Operation completed successfully
202 Completed(ToolResult),
203
204 /// Operation failed
205 Failed(ToolResult),
206}
207
208/// Type-erased status for the agent loop.
209#[derive(Clone, Debug, Serialize, Deserialize)]
210pub enum ErasedToolStatus {
211 /// Operation is making progress
212 Progress {
213 stage: String,
214 message: String,
215 data: Option<serde_json::Value>,
216 },
217 /// Operation completed successfully
218 Completed(ToolResult),
219 /// Operation failed
220 Failed(ToolResult),
221}
222
223/// Update emitted from a `listen()` stream.
224///
225/// This models workflows where a runtime prepares an operation over time, and
226/// execution happens later using an operation identifier and revision.
227#[derive(Clone, Debug, Serialize, Deserialize)]
228pub enum ListenToolUpdate {
229 /// Preparation is still running and should keep listening.
230 Listening {
231 /// Opaque operation identifier used for later execute/cancel calls.
232 operation_id: String,
233 /// Monotonic revision number for optimistic concurrency.
234 revision: u64,
235 /// Human-readable status message.
236 message: String,
237 /// Optional current snapshot for UI rendering.
238 snapshot: Option<serde_json::Value>,
239 /// Optional expiration timestamp (RFC3339).
240 #[serde(with = "time::serde::rfc3339::option")]
241 expires_at: Option<OffsetDateTime>,
242 },
243
244 /// Preparation is complete and execution can be confirmed.
245 Ready {
246 /// Opaque operation identifier used for later execute/cancel calls.
247 operation_id: String,
248 /// Monotonic revision number for optimistic concurrency.
249 revision: u64,
250 /// Human-readable status message.
251 message: String,
252 /// Snapshot shown in confirmation UI.
253 snapshot: serde_json::Value,
254 /// Optional expiration timestamp (RFC3339).
255 #[serde(with = "time::serde::rfc3339::option")]
256 expires_at: Option<OffsetDateTime>,
257 },
258
259 /// Operation is no longer valid.
260 Invalidated {
261 /// Opaque operation identifier.
262 operation_id: String,
263 /// Human-readable reason.
264 message: String,
265 /// Whether caller may recover by starting a new listen operation.
266 recoverable: bool,
267 },
268}
269
270/// Reason for stopping a listen session.
271#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
272pub enum ListenStopReason {
273 /// User explicitly rejected confirmation.
274 UserRejected,
275 /// Agent policy/hook blocked execution before confirmation.
276 Blocked,
277 /// Consumer disconnected while listen stream was active.
278 StreamDisconnected,
279 /// Listen stream ended unexpectedly before terminal state.
280 StreamEnded,
281}
282
283impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
284 fn from(status: ToolStatus<S>) -> Self {
285 match status {
286 ToolStatus::Progress {
287 stage,
288 message,
289 data,
290 } => Self::Progress {
291 stage: stage_to_string(&stage),
292 message,
293 data,
294 },
295 ToolStatus::Completed(r) => Self::Completed(r),
296 ToolStatus::Failed(r) => Self::Failed(r),
297 }
298 }
299}
300
301/// Context passed to tool execution
302#[derive(Clone)]
303pub struct ToolContext<Ctx> {
304 /// Application-specific context (e.g., `user_id`, db connection)
305 pub app: Ctx,
306 /// Tool-specific metadata
307 pub metadata: HashMap<String, Value>,
308 /// Optional event store for tools to emit turn-scoped events.
309 event_store: Option<Arc<dyn EventStore>>,
310 /// Thread associated with the bound event store.
311 event_thread_id: Option<agent_sdk_foundation::types::ThreadId>,
312 /// Turn associated with the bound event store.
313 event_turn: Option<usize>,
314 /// Optional event authority for wrapping events in envelopes
315 event_authority: Option<Arc<dyn EventAuthority>>,
316 /// Optional cancellation token for propagating cancellation to subtasks
317 cancel_token: Option<CancellationToken>,
318 /// Optional semaphore for limiting concurrent subagent threads.
319 subagent_semaphore: Option<Arc<tokio::sync::Semaphore>>,
320 /// Optional per-tool execution timeout enforced at the SDK boundary.
321 ///
322 /// When set, the agent loop races each tool's `execute()` future
323 /// against this duration. A tool that does not finish within the
324 /// budget is stopped at the boundary and reported with a synthetic
325 /// timeout [`ToolResult`] so the `tool_use` / `tool_result` pair stays
326 /// balanced. Tools that hold OS resources (subprocesses, sockets) must
327 /// observe the [cooperative-cancel contract](Tool#cooperative-cancellation)
328 /// so the timeout actually reclaims them.
329 tool_timeout: Option<std::time::Duration>,
330}
331
332impl<Ctx> ToolContext<Ctx> {
333 #[must_use]
334 pub fn new(app: Ctx) -> Self {
335 Self {
336 app,
337 metadata: HashMap::new(),
338 event_store: None,
339 event_thread_id: None,
340 event_turn: None,
341 event_authority: None,
342 cancel_token: None,
343 subagent_semaphore: None,
344 tool_timeout: None,
345 }
346 }
347
348 /// Reconstruct a `ToolContext` from a durable seed and host-provided
349 /// runtime dependencies.
350 ///
351 /// This is the authoritative reconstruction path. Workers should use
352 /// this (or a host's [`crate::seed::ExecutionContextFactory`]) instead
353 /// of chaining builder methods, so that the context shape is
354 /// deterministic and auditable.
355 ///
356 /// The event authority is constructed internally from
357 /// [`ToolContextSeed::sequence_offset`] to guarantee monotonic
358 /// sequencing — callers cannot accidentally supply a misaligned
359 /// authority.
360 #[must_use]
361 pub fn from_seed(seed: &ToolContextSeed, app: Ctx, deps: HostDependencies) -> Self {
362 let authority: Arc<dyn EventAuthority> =
363 Arc::new(LocalEventAuthority::with_offset(seed.sequence_offset));
364 Self {
365 app,
366 metadata: seed.metadata.clone(),
367 event_store: Some(deps.event_store),
368 event_thread_id: Some(seed.thread_id.clone()),
369 event_turn: Some(seed.turn),
370 event_authority: Some(authority),
371 cancel_token: Some(deps.cancel_token),
372 subagent_semaphore: deps.subagent_semaphore,
373 tool_timeout: None,
374 }
375 }
376
377 #[must_use]
378 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
379 self.metadata.insert(key.into(), value);
380 self
381 }
382
383 /// Bind the tool context to the event store for a specific thread/turn.
384 #[must_use]
385 pub fn with_event_store(
386 mut self,
387 store: Arc<dyn EventStore>,
388 thread_id: agent_sdk_foundation::types::ThreadId,
389 turn: usize,
390 authority: Arc<dyn EventAuthority>,
391 ) -> Self {
392 self.event_store = Some(store);
393 self.event_thread_id = Some(thread_id);
394 self.event_turn = Some(turn);
395 self.event_authority = Some(authority);
396 self
397 }
398
399 /// Emit an event through the configured event store (if set).
400 ///
401 /// The event is wrapped in an [`agent_sdk_foundation::AgentEventEnvelope`] with a unique ID,
402 /// sequence number, and timestamp before publishing.
403 ///
404 /// # Errors
405 /// Returns an error if the configured event store cannot persist the event.
406 pub async fn emit_event(&self, event: AgentEvent) -> Result<()>
407 where
408 Ctx: Sync,
409 {
410 let Some((store, authority, thread_id, turn)) = self
411 .event_store
412 .as_ref()
413 .zip(self.event_authority.as_ref())
414 .zip(self.event_thread_id.as_ref())
415 .zip(self.event_turn)
416 .map(|(((store, authority), thread_id), turn)| (store, authority, thread_id, turn))
417 else {
418 // Surface the misconfiguration instead of silently dropping the
419 // event: a tool written for the durable host but run under a
420 // hand-built `ToolContext::new()` would otherwise lose every
421 // emitted event with no trace, undermining the audit trail.
422 let kind = serde_json::to_value(&event)
423 .ok()
424 .and_then(|v| {
425 v.get("type")
426 .and_then(|t| t.as_str().map(ToOwned::to_owned))
427 })
428 .unwrap_or_else(|| "unknown".to_string());
429 log::warn!(
430 "ToolContext::emit_event called on an unbound context; dropping {kind} event \
431 (no event store/authority/thread/turn bound)"
432 );
433 return Ok(());
434 };
435 let envelope = authority.wrap(event);
436 store.append(thread_id, turn, envelope).await
437 }
438
439 /// Get a clone of the event authority (if set).
440 ///
441 /// This is useful for tools that spawn subprocesses (like subagents)
442 /// and need to wrap events with the same sequencing authority as the
443 /// parent's turn log.
444 #[must_use]
445 pub fn event_authority(&self) -> Option<Arc<dyn EventAuthority>> {
446 self.event_authority.clone()
447 }
448
449 /// Set the cancellation token for propagating cancellation to subtasks.
450 #[must_use]
451 pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
452 self.cancel_token = Some(token);
453 self
454 }
455
456 /// Get the cancellation token (if set).
457 ///
458 /// Used by tools that spawn long-running subtasks (like subagents)
459 /// to propagate cancellation from the parent.
460 #[must_use]
461 pub fn cancel_token(&self) -> Option<CancellationToken> {
462 self.cancel_token.clone()
463 }
464
465 /// Set the per-tool execution timeout enforced at the SDK boundary.
466 ///
467 /// The agent loop populates this from `AgentConfig::tool_timeout_ms`;
468 /// callers can also set it directly when constructing a context.
469 #[must_use]
470 pub const fn with_tool_timeout(mut self, timeout: std::time::Duration) -> Self {
471 self.tool_timeout = Some(timeout);
472 self
473 }
474
475 /// Get the per-tool execution timeout (if set).
476 ///
477 /// Read by the agent loop's SDK-boundary execution race; tools do not
478 /// normally need to consult this themselves.
479 #[must_use]
480 pub const fn tool_timeout(&self) -> Option<std::time::Duration> {
481 self.tool_timeout
482 }
483
484 /// Set a shared semaphore for limiting concurrent subagent threads.
485 #[must_use]
486 pub fn with_subagent_semaphore(mut self, semaphore: Arc<tokio::sync::Semaphore>) -> Self {
487 self.subagent_semaphore = Some(semaphore);
488 self
489 }
490
491 /// Get the subagent thread-limiting semaphore (if set).
492 #[must_use]
493 pub fn subagent_semaphore(&self) -> Option<Arc<tokio::sync::Semaphore>> {
494 self.subagent_semaphore.clone()
495 }
496}
497
498// ============================================================================
499// Tool Trait
500// ============================================================================
501
502/// Definition of a tool that can be called by the agent.
503///
504/// Tools have a strongly-typed `Name` associated type that determines
505/// how the tool name is serialized for LLM communication.
506///
507/// # Native Async Support
508///
509/// This trait uses Rust's native async functions in traits (stabilized in Rust 1.75).
510/// You do NOT need the `async_trait` crate to implement this trait.
511///
512/// # Cooperative cancellation
513///
514/// The agent loop races every tool's `execute()` future against the run's
515/// [`ToolContext::cancel_token`] and, when configured, against
516/// [`ToolContext::tool_timeout`]. If either fires the SDK drops the
517/// in-flight `execute()` future and synthesises a balanced `tool_result`
518/// (`"Cancelled by user"` or a timeout message). Dropping a future runs
519/// its destructors but cannot, on its own, reclaim OS resources a tool
520/// has handed to the kernel.
521///
522/// **Subprocess contract:** a tool that spawns a child process MUST make
523/// the process die when its `execute()` future is dropped. The two
524/// supported ways to satisfy this are:
525///
526/// * Build the command with `tokio::process::Command::kill_on_drop(true)`,
527/// so the child is killed when the `Child` handle is dropped together
528/// with the cancelled future (this is what the SDK's MCP stdio transport
529/// does), or
530/// * Observe [`ToolContext::cancel_token`] directly and `kill()` the child
531/// when it fires.
532///
533/// A tool that holds a subprocess open without either of these will leak
534/// the process when cancelled or timed out — the synthesised `tool_result`
535/// keeps the conversation balanced, but the orphaned OS process is the
536/// tool author's bug, not the SDK's.
537pub trait Tool<Ctx>: Send + Sync {
538 /// The type of name for this tool.
539 type Name: ToolName;
540
541 /// Returns the tool's strongly-typed name.
542 fn name(&self) -> Self::Name;
543
544 /// Human-readable display name for UI (e.g., "Read File" vs "read").
545 ///
546 /// Defaults to the empty string. Override for better UX.
547 fn display_name(&self) -> &'static str {
548 ""
549 }
550
551 /// Human-readable description of what the tool does.
552 fn description(&self) -> &'static str;
553
554 /// JSON schema for the tool's input parameters.
555 fn input_schema(&self) -> Value;
556
557 /// Permission tier for this tool.
558 ///
559 /// Defaults to [`ToolTier::Confirm`] (fail-closed): a tool author who
560 /// forgets to declare a tier gets confirmation gating, not silent
561 /// auto-execution. Read-only tools should explicitly opt in to
562 /// [`ToolTier::Observe`].
563 fn tier(&self) -> ToolTier {
564 ToolTier::Confirm
565 }
566
567 /// Execute the tool with the given input.
568 ///
569 /// # Errors
570 /// Returns an error if tool execution fails.
571 fn execute(
572 &self,
573 ctx: &ToolContext<Ctx>,
574 input: Value,
575 ) -> impl Future<Output = Result<ToolResult>> + Send;
576}
577
578// ============================================================================
579// TypedTool Trait (typed input + runtime validation / self-correction)
580// ============================================================================
581
582/// A tool whose model-emitted arguments are validated against a typed,
583/// deserializable [`Input`](TypedTool::Input) **before** [`execute`](TypedTool::execute)
584/// runs.
585///
586/// Today a raw [`serde_json::Value`] is handed straight to [`Tool::execute`],
587/// so a malformed tool call reaches tool code unvalidated. `TypedTool` closes
588/// that gap: you declare a `Serialize` / `Deserialize` argument struct as
589/// [`Input`](TypedTool::Input), and the runtime deserializes the model's args
590/// into it at the dispatch boundary. On a deserialization/validation failure
591/// the runtime synthesises a structured error [`ToolResult`] (carrying the
592/// serde error message) so the model can self-correct on its next turn —
593/// `execute` is **never** called with invalid arguments.
594///
595/// # Relationship to [`Tool`]
596///
597/// `TypedTool` is the typed, opt-in *sugar* layer; [`Tool`] remains the
598/// untyped baseline. A [`TypedTool`] becomes a full [`Tool`] through
599/// [`TypedToolAdapter`] (mirroring how [`SimpleTool`] becomes a [`Tool`] via
600/// [`SimpleToolAdapter`]). Register one with
601/// [`ToolRegistry::register_typed`], which wraps it in the adapter for you;
602/// the adapter performs the deserialize-then-dispatch (or
603/// deserialize-then-synthesise-error) described above.
604///
605/// # Back-compat / migration
606///
607/// Existing [`Tool`] impls (and [`SimpleTool`] / [`DynamicToolName`] tools)
608/// keep compiling and running unchanged — they stay on the `Value`-in
609/// baseline, which is the identity passthrough (a `Value` always
610/// "deserializes" into a `Value`). Migrate a tool to typed args by moving its
611/// `impl Tool<Ctx>` to `impl TypedTool<Ctx>`, setting `type Input = MyArgs`,
612/// and changing `execute`'s signature from `input: Value` to `input: MyArgs`.
613/// The hand-written [`input_schema`](TypedTool::input_schema) JSON stays
614/// user-declared; this trait does **not** auto-derive a schema from `Input`.
615///
616/// # Example
617///
618/// ```
619/// use agent_sdk_tools::tools::{TypedTool, ToolContext};
620/// use agent_sdk_foundation::types::ToolResult;
621/// use serde::{Deserialize, Serialize};
622/// use serde_json::{json, Value};
623/// use std::future::Future;
624///
625/// #[derive(Debug, Serialize, Deserialize)]
626/// struct WeatherArgs {
627/// city: String,
628/// }
629///
630/// struct WeatherTool;
631///
632/// impl TypedTool<()> for WeatherTool {
633/// type Input = WeatherArgs;
634///
635/// fn name(&self) -> &'static str { "get_weather" }
636/// fn description(&self) -> &'static str { "Get current weather for a city" }
637/// fn input_schema(&self) -> Value {
638/// json!({
639/// "type": "object",
640/// "properties": { "city": { "type": "string" } },
641/// "required": ["city"]
642/// })
643/// }
644///
645/// fn execute(
646/// &self,
647/// _ctx: &ToolContext<()>,
648/// input: WeatherArgs,
649/// ) -> impl Future<Output = anyhow::Result<ToolResult>> + Send {
650/// async move { Ok(ToolResult::success(format!("Weather in {}: Sunny", input.city))) }
651/// }
652/// }
653/// ```
654///
655/// Like [`SimpleTool`], a `TypedTool` has a single fixed `&'static str`
656/// [`name`](TypedTool::name) (mapping to [`DynamicToolName`] via
657/// [`TypedToolAdapter`]). Reach for a hand-written [`Tool`] with a
658/// strongly-typed [`ToolName`] when the name must be computed at runtime or
659/// constrained to an enum.
660pub trait TypedTool<Ctx>: Send + Sync {
661 /// The typed input the model's arguments are deserialized into before
662 /// [`execute`](TypedTool::execute) runs.
663 ///
664 /// Must be [`DeserializeOwned`] (to parse model args), [`Serialize`] (so
665 /// the typed value round-trips for logging/storage), and `Send + 'static`
666 /// (to cross the async dispatch boundary).
667 type Input: DeserializeOwned + Serialize + Send + 'static;
668
669 /// The tool's name as sent to (and parsed from) the model.
670 fn name(&self) -> &'static str;
671
672 /// Human-readable display name for UI. Defaults to an empty string.
673 fn display_name(&self) -> &'static str {
674 ""
675 }
676
677 /// Human-readable description of what the tool does.
678 fn description(&self) -> &'static str;
679
680 /// User-declared JSON schema for the tool's input parameters.
681 ///
682 /// This stays hand-written JSON — it is **not** auto-derived from
683 /// [`Input`](TypedTool::Input). Keeping the schema explicit lets the
684 /// declared provider-facing contract diverge from the Rust type when that
685 /// is useful (descriptions, examples, provider-specific keywords).
686 fn input_schema(&self) -> Value;
687
688 /// Permission tier for this tool. Defaults to [`ToolTier::Confirm`]
689 /// (fail-closed); read-only tools should opt in to [`ToolTier::Observe`].
690 fn tier(&self) -> ToolTier {
691 ToolTier::Confirm
692 }
693
694 /// Execute the tool with the already-validated, typed input.
695 ///
696 /// The runtime guarantees `input` deserialized cleanly from the model's
697 /// arguments; a malformed call is turned into a structured error
698 /// [`ToolResult`] before this method is reached, so implementations never
699 /// see invalid arguments.
700 ///
701 /// # Errors
702 /// Returns an error if tool execution fails.
703 fn execute(
704 &self,
705 ctx: &ToolContext<Ctx>,
706 input: Self::Input,
707 ) -> impl Future<Output = Result<ToolResult>> + Send;
708}
709
710/// Synthesise the structured validation-error [`ToolResult`] returned to the
711/// model when its arguments fail to deserialize into a [`TypedTool::Input`].
712///
713/// Factored out (and `pub`) so the exact self-correction wording is
714/// consistent with [`TypedToolAdapter`] and is directly unit-testable. The
715/// error is an *error* [`ToolResult`] (not a thrown `anyhow::Error`): it flows
716/// through the normal balanced `tool_use` / `tool_result` path so history
717/// stays balanced and the model gets a concrete, machine-actionable hint on
718/// its next turn.
719#[must_use]
720pub fn invalid_tool_input_result(tool_name: &str, error: &serde_json::Error) -> ToolResult {
721 ToolResult::error(format!(
722 "Invalid arguments for tool `{tool_name}`: {error}. \
723 The arguments did not match the tool's input schema — \
724 re-read the schema and call the tool again with corrected arguments."
725 ))
726}
727
728/// Deserialize raw model args into a typed `Input`, or synthesise the
729/// structured validation-error result.
730///
731/// Returns `Ok(typed)` for the happy path and `Err(result)` carrying the
732/// balanced error [`ToolResult`] for the self-correction path.
733/// [`TypedToolAdapter`] uses this to ensure [`TypedTool::execute`] is never
734/// reached with invalid arguments.
735///
736/// # Errors
737/// Returns the synthesised error [`ToolResult`] when `raw` does not
738/// deserialize into `Input`.
739pub fn validate_tool_input<Input>(tool_name: &str, raw: Value) -> Result<Input, ToolResult>
740where
741 Input: DeserializeOwned,
742{
743 serde_json::from_value(raw).map_err(|error| invalid_tool_input_result(tool_name, &error))
744}
745
746/// Adapter that turns any [`TypedTool`] into a full [`Tool`].
747///
748/// It gives the wrapped tool `Name = DynamicToolName`, deserializes the
749/// model's `Value` arguments into [`TypedTool::Input`] before dispatching, and
750/// synthesises a structured validation-error [`ToolResult`] when that fails.
751///
752/// You rarely name this type directly — register a [`TypedTool`] with
753/// [`ToolRegistry::register_typed`], which wraps it for you. The adapter
754/// pattern (rather than a blanket `impl Tool for T: TypedTool`) is required
755/// for coherence: a blanket impl would conflict with the existing
756/// [`SimpleToolAdapter`] impl, because the compiler cannot rule out a
757/// downstream `TypedTool` impl for `SimpleToolAdapter`.
758///
759/// This adapter is also where the typed `Input` is threaded through the
760/// erased-tool machinery without leaking the generic into trait objects: the
761/// registry's [`ErasedTool`] wrapper still only ever sees `Value`, while the
762/// concrete `Input` type (and the deserialize) live here, inside the adapter's
763/// concrete `T`.
764pub struct TypedToolAdapter<T> {
765 inner: T,
766}
767
768impl<T> TypedToolAdapter<T> {
769 /// Wrap a [`TypedTool`] so it can be used anywhere a [`Tool`] is expected.
770 pub const fn new(tool: T) -> Self {
771 Self { inner: tool }
772 }
773
774 /// Unwrap the inner [`TypedTool`].
775 pub fn into_inner(self) -> T {
776 self.inner
777 }
778}
779
780impl<Ctx, T> Tool<Ctx> for TypedToolAdapter<T>
781where
782 T: TypedTool<Ctx>,
783 Ctx: Send + Sync,
784{
785 type Name = DynamicToolName;
786
787 fn name(&self) -> DynamicToolName {
788 DynamicToolName::new(TypedTool::name(&self.inner))
789 }
790
791 fn display_name(&self) -> &'static str {
792 TypedTool::display_name(&self.inner)
793 }
794
795 fn description(&self) -> &'static str {
796 TypedTool::description(&self.inner)
797 }
798
799 fn input_schema(&self) -> Value {
800 TypedTool::input_schema(&self.inner)
801 }
802
803 fn tier(&self) -> ToolTier {
804 TypedTool::tier(&self.inner)
805 }
806
807 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
808 match validate_tool_input::<<T as TypedTool<Ctx>>::Input>(
809 TypedTool::name(&self.inner),
810 input,
811 ) {
812 Ok(typed) => TypedTool::execute(&self.inner, ctx, typed).await,
813 // A validation failure is returned as an error `ToolResult`,
814 // never `?`-bailed: it must reach the model as a balanced
815 // `tool_result` for self-correction. `execute` is not called.
816 Err(result) => Ok(result),
817 }
818 }
819}
820
821// ============================================================================
822// ToolLogic Trait (execute-only companion for the derive macros)
823// ============================================================================
824
825/// The `execute`-only half of a tool, used as the target of the
826/// `#[derive(Tool)]` / `#[derive(TypedTool)]` ergonomics macros.
827///
828/// The derives generate everything *except* the behaviour — `name`,
829/// `description`, `input_schema`, `tier` come from `#[tool(...)]` attributes —
830/// and delegate execution to this trait. You implement `ToolLogic` to supply
831/// the one thing a macro cannot: the `execute` body.
832///
833/// It is deliberately a **trait** (not an inherent method): a trait-method
834/// `async fn` that performs no `await` is fine, whereas an inherent one trips
835/// `clippy::unused_async`. Writing the body here keeps trivial, fully
836/// synchronous tools lint-clean without an `#[allow]`.
837///
838/// You rarely name this trait in prose — the derive docs show it in context —
839/// but the shape is:
840///
841/// ```
842/// use agent_sdk_tools::tools::{ToolLogic, ToolContext};
843/// use agent_sdk_foundation::types::ToolResult;
844/// use serde_json::Value;
845///
846/// struct MyTool;
847///
848/// impl ToolLogic<()> for MyTool {
849/// type Input = Value; // typed tools set this to their `Input` struct
850///
851/// async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> anyhow::Result<ToolResult> {
852/// Ok(ToolResult::success(format!("got {input}")))
853/// }
854/// }
855/// ```
856pub trait ToolLogic<Ctx>: Send + Sync {
857 /// The input the tool's `execute` receives. For `#[derive(Tool)]` this is
858 /// [`serde_json::Value`]; for `#[derive(TypedTool)]` it is the typed
859 /// `Input` (validated before `execute` runs).
860 type Input;
861
862 /// The tool's behaviour. Receives the (already-validated, for typed tools)
863 /// input.
864 ///
865 /// # Errors
866 /// Returns an error if tool execution fails.
867 fn execute(
868 &self,
869 ctx: &ToolContext<Ctx>,
870 input: Self::Input,
871 ) -> impl Future<Output = Result<ToolResult>> + Send;
872}
873
874// ============================================================================
875// SimpleTool Trait
876// ============================================================================
877
878/// An ergonomic [`Tool`] whose name is a plain string.
879///
880/// Most custom tools don't need a strongly-typed [`ToolName`] enum — they have
881/// a single, fixed name. `SimpleTool` lets you write a tool by returning a
882/// `&str` from [`name`](SimpleTool::name) instead of defining a `ToolName`
883/// type and an associated [`Tool::Name`].
884///
885/// Any `SimpleTool` is automatically a [`Tool`] (via a blanket impl) with
886/// `Name = DynamicToolName`, so it can be registered and used exactly like a
887/// hand-written `Tool`.
888///
889/// # Example
890///
891/// ```
892/// use agent_sdk_tools::tools::{SimpleTool, ToolContext};
893/// use agent_sdk_foundation::types::ToolResult;
894/// use serde_json::{json, Value};
895/// use std::future::Future;
896///
897/// struct WeatherTool;
898///
899/// impl SimpleTool<()> for WeatherTool {
900/// fn name(&self) -> &'static str { "get_weather" }
901/// fn description(&self) -> &'static str { "Get current weather for a city" }
902/// fn input_schema(&self) -> Value {
903/// json!({ "type": "object", "properties": { "city": { "type": "string" } } })
904/// }
905///
906/// fn execute(
907/// &self,
908/// _ctx: &ToolContext<()>,
909/// input: Value,
910/// ) -> impl Future<Output = anyhow::Result<ToolResult>> + Send {
911/// async move {
912/// let city = input["city"].as_str().unwrap_or("Unknown");
913/// Ok(ToolResult::success(format!("Weather in {city}: Sunny")))
914/// }
915/// }
916/// }
917/// ```
918pub trait SimpleTool<Ctx>: Send + Sync {
919 /// The tool's name as sent to (and parsed from) the LLM.
920 ///
921 /// Returns `&'static str` because a simple tool has one fixed name; reach
922 /// for the full [`Tool`] trait with a [`DynamicToolName`] when the name is
923 /// computed at runtime.
924 fn name(&self) -> &'static str;
925
926 /// Human-readable display name for UI.
927 ///
928 /// Defaults to an empty string; override for a friendlier label.
929 fn display_name(&self) -> &'static str {
930 ""
931 }
932
933 /// Human-readable description of what the tool does.
934 fn description(&self) -> &'static str;
935
936 /// JSON schema for the tool's input parameters.
937 fn input_schema(&self) -> Value;
938
939 /// Permission tier for this tool. Defaults to [`ToolTier::Confirm`]
940 /// (fail-closed); read-only tools should opt in to [`ToolTier::Observe`].
941 fn tier(&self) -> ToolTier {
942 ToolTier::Confirm
943 }
944
945 /// Execute the tool with the given input.
946 ///
947 /// # Errors
948 /// Returns an error if tool execution fails.
949 fn execute(
950 &self,
951 ctx: &ToolContext<Ctx>,
952 input: Value,
953 ) -> impl Future<Output = Result<ToolResult>> + Send;
954}
955
956/// Adapter that turns any [`SimpleTool`] into a full [`Tool`] with
957/// `Name = DynamicToolName`.
958///
959/// You rarely name this type directly — register a [`SimpleTool`] with
960/// [`ToolRegistry::register_simple`], which wraps it for you. Use this adapter
961/// explicitly only when you need a `Tool` value (e.g. to pass to code that is
962/// generic over [`Tool`]).
963pub struct SimpleToolAdapter<T> {
964 inner: T,
965}
966
967impl<T> SimpleToolAdapter<T> {
968 /// Wrap a [`SimpleTool`] so it can be used anywhere a [`Tool`] is expected.
969 pub const fn new(tool: T) -> Self {
970 Self { inner: tool }
971 }
972
973 /// Unwrap the inner [`SimpleTool`].
974 pub fn into_inner(self) -> T {
975 self.inner
976 }
977}
978
979impl<Ctx, T> Tool<Ctx> for SimpleToolAdapter<T>
980where
981 T: SimpleTool<Ctx>,
982{
983 type Name = DynamicToolName;
984
985 fn name(&self) -> DynamicToolName {
986 DynamicToolName::new(SimpleTool::name(&self.inner))
987 }
988
989 fn display_name(&self) -> &'static str {
990 SimpleTool::display_name(&self.inner)
991 }
992
993 fn description(&self) -> &'static str {
994 SimpleTool::description(&self.inner)
995 }
996
997 fn input_schema(&self) -> Value {
998 SimpleTool::input_schema(&self.inner)
999 }
1000
1001 fn tier(&self) -> ToolTier {
1002 SimpleTool::tier(&self.inner)
1003 }
1004
1005 fn execute(
1006 &self,
1007 ctx: &ToolContext<Ctx>,
1008 input: Value,
1009 ) -> impl Future<Output = Result<ToolResult>> + Send {
1010 SimpleTool::execute(&self.inner, ctx, input)
1011 }
1012}
1013
1014// ============================================================================
1015// AsyncTool Trait
1016// ============================================================================
1017
1018/// A tool that performs long-running async operations.
1019///
1020/// `AsyncTool`s have two phases:
1021/// 1. `execute()` - Start the operation (lightweight, returns quickly)
1022/// 2. `check_status()` - Stream progress until completion
1023///
1024/// The actual work should happen externally (background task, external service)
1025/// and persist results to a durable store. The tool is just an orchestrator.
1026///
1027/// # Example
1028///
1029/// ```ignore
1030/// impl AsyncTool<MyCtx> for ExecutePixTransferTool {
1031/// type Name = PixToolName;
1032/// type Stage = PixTransferStage;
1033///
1034/// async fn execute(&self, ctx: &ToolContext<MyCtx>, input: Value) -> Result<ToolOutcome> {
1035/// let params = parse_input(&input)?;
1036/// let operation_id = ctx.app.pix_service.start_transfer(params).await?;
1037/// Ok(ToolOutcome::in_progress(
1038/// operation_id,
1039/// format!("PIX transfer of {} initiated", params.amount),
1040/// ))
1041/// }
1042///
1043/// fn check_status(&self, ctx: &ToolContext<MyCtx>, operation_id: &str)
1044/// -> impl Stream<Item = ToolStatus<PixTransferStage>> + Send
1045/// {
1046/// async_stream::stream! {
1047/// loop {
1048/// let status = ctx.app.pix_service.get_status(operation_id).await;
1049/// match status {
1050/// PixStatus::Success { id } => {
1051/// yield ToolStatus::Completed(ToolResult::success(id));
1052/// break;
1053/// }
1054/// _ => yield ToolStatus::Progress { ... };
1055/// }
1056/// tokio::time::sleep(Duration::from_millis(500)).await;
1057/// }
1058/// }
1059/// }
1060/// }
1061/// ```
1062pub trait AsyncTool<Ctx>: Send + Sync {
1063 /// The type of name for this tool.
1064 type Name: ToolName;
1065 /// The type of progress stages for this tool.
1066 type Stage: ProgressStage;
1067
1068 /// Returns the tool's strongly-typed name.
1069 fn name(&self) -> Self::Name;
1070
1071 /// Human-readable display name for UI. Defaults to the empty string.
1072 fn display_name(&self) -> &'static str {
1073 ""
1074 }
1075
1076 /// Human-readable description of what the tool does.
1077 fn description(&self) -> &'static str;
1078
1079 /// JSON schema for the tool's input parameters.
1080 fn input_schema(&self) -> Value;
1081
1082 /// Permission tier for this tool. Defaults to [`ToolTier::Confirm`]
1083 /// (fail-closed); read-only tools should opt in to [`ToolTier::Observe`].
1084 fn tier(&self) -> ToolTier {
1085 ToolTier::Confirm
1086 }
1087
1088 /// Execute the tool. Returns immediately with one of:
1089 /// - Success/Failed: Operation completed synchronously
1090 /// - `InProgress`: Operation started, use `check_status()` to stream updates
1091 ///
1092 /// # Errors
1093 /// Returns an error if tool execution fails.
1094 fn execute(
1095 &self,
1096 ctx: &ToolContext<Ctx>,
1097 input: Value,
1098 ) -> impl Future<Output = Result<ToolOutcome>> + Send;
1099
1100 /// Stream status updates for an in-progress operation.
1101 /// Must yield until Completed or Failed.
1102 fn check_status(
1103 &self,
1104 ctx: &ToolContext<Ctx>,
1105 operation_id: &str,
1106 ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
1107}
1108
1109// ============================================================================
1110// ListenExecuteTool Trait
1111// ============================================================================
1112
1113/// A tool whose runtime has two phases:
1114/// 1. `listen()` - starts preparation and streams updates
1115/// 2. `execute()` - performs final execution after confirmation
1116///
1117/// This abstraction is useful when runtime state can expire or evolve before
1118/// execution (quotes, challenge windows, leases, approvals).
1119///
1120/// Ordering note: the agent loop consumes `listen()` updates before
1121/// `AgentHooks::pre_tool_use()` runs. Hooks can therefore block `execute()`, but
1122/// any side effects done during `listen()` have already happened.
1123pub trait ListenExecuteTool<Ctx>: Send + Sync {
1124 /// The type of name for this tool.
1125 type Name: ToolName;
1126
1127 /// Returns the tool's strongly-typed name.
1128 fn name(&self) -> Self::Name;
1129
1130 /// Human-readable display name for UI. Defaults to the empty string.
1131 fn display_name(&self) -> &'static str {
1132 ""
1133 }
1134
1135 /// Human-readable description of what the tool does.
1136 fn description(&self) -> &'static str;
1137
1138 /// JSON schema for the tool's input parameters.
1139 fn input_schema(&self) -> Value;
1140
1141 /// Permission tier for this tool.
1142 fn tier(&self) -> ToolTier {
1143 ToolTier::Confirm
1144 }
1145
1146 /// Start and stream runtime preparation updates.
1147 fn listen(
1148 &self,
1149 ctx: &ToolContext<Ctx>,
1150 input: Value,
1151 ) -> impl Stream<Item = ListenToolUpdate> + Send;
1152
1153 /// Execute using operation ID and optimistic concurrency revision.
1154 ///
1155 /// # Errors
1156 /// Returns an error if execution fails or revision is stale.
1157 fn execute(
1158 &self,
1159 ctx: &ToolContext<Ctx>,
1160 operation_id: &str,
1161 expected_revision: u64,
1162 ) -> impl Future<Output = Result<ToolResult>> + Send;
1163
1164 /// Stop a listen operation (best effort).
1165 ///
1166 /// # Errors
1167 /// Returns an error if cancellation fails.
1168 fn cancel(
1169 &self,
1170 _ctx: &ToolContext<Ctx>,
1171 _operation_id: &str,
1172 _reason: ListenStopReason,
1173 ) -> impl Future<Output = Result<()>> + Send {
1174 async { Ok(()) }
1175 }
1176}
1177
1178// ============================================================================
1179// Type-Erased Tool (for Registry)
1180// ============================================================================
1181
1182/// Type-erased tool trait for registry storage.
1183///
1184/// This allows tools with different `Name` associated types to be stored
1185/// in the same registry by erasing the type information.
1186///
1187/// # Example
1188///
1189/// ```ignore
1190/// for tool in registry.all() {
1191/// println!("Tool: {} - {}", tool.name_str(), tool.description());
1192/// }
1193/// ```
1194#[async_trait]
1195pub trait ErasedTool<Ctx>: Send + Sync {
1196 /// Get the tool name as a string.
1197 fn name_str(&self) -> &str;
1198 /// Get a human-friendly display name for the tool.
1199 fn display_name(&self) -> &'static str;
1200 /// Get the tool description.
1201 fn description(&self) -> &'static str;
1202 /// Get the JSON schema for tool inputs.
1203 fn input_schema(&self) -> Value;
1204 /// Get the tool's permission tier.
1205 fn tier(&self) -> ToolTier;
1206 /// Execute the tool with the given input.
1207 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
1208}
1209
1210/// Wrapper that erases the Name associated type from a Tool.
1211struct ToolWrapper<T, Ctx>
1212where
1213 T: Tool<Ctx>,
1214{
1215 inner: T,
1216 name_cache: String,
1217 _marker: PhantomData<Ctx>,
1218}
1219
1220impl<T, Ctx> ToolWrapper<T, Ctx>
1221where
1222 T: Tool<Ctx>,
1223{
1224 fn new(tool: T) -> Self {
1225 let name_cache = tool_name_to_string(&tool.name());
1226 Self {
1227 inner: tool,
1228 name_cache,
1229 _marker: PhantomData,
1230 }
1231 }
1232}
1233
1234#[async_trait]
1235impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
1236where
1237 T: Tool<Ctx> + 'static,
1238 Ctx: Send + Sync + 'static,
1239{
1240 fn name_str(&self) -> &str {
1241 &self.name_cache
1242 }
1243
1244 fn display_name(&self) -> &'static str {
1245 self.inner.display_name()
1246 }
1247
1248 fn description(&self) -> &'static str {
1249 self.inner.description()
1250 }
1251
1252 fn input_schema(&self) -> Value {
1253 self.inner.input_schema()
1254 }
1255
1256 fn tier(&self) -> ToolTier {
1257 self.inner.tier()
1258 }
1259
1260 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
1261 self.inner.execute(ctx, input).await
1262 }
1263}
1264
1265// ============================================================================
1266// Type-Erased AsyncTool (for Registry)
1267// ============================================================================
1268
1269/// Type-erased async tool trait for registry storage.
1270///
1271/// This allows async tools with different `Name` and `Stage` associated types
1272/// to be stored in the same registry by erasing the type information.
1273#[async_trait]
1274pub trait ErasedAsyncTool<Ctx>: Send + Sync {
1275 /// Get the tool name as a string.
1276 fn name_str(&self) -> &str;
1277 /// Get a human-friendly display name for the tool.
1278 fn display_name(&self) -> &'static str;
1279 /// Get the tool description.
1280 fn description(&self) -> &'static str;
1281 /// Get the JSON schema for tool inputs.
1282 fn input_schema(&self) -> Value;
1283 /// Get the tool's permission tier.
1284 fn tier(&self) -> ToolTier;
1285 /// Execute the tool with the given input.
1286 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
1287 /// Stream status updates for an in-progress operation (type-erased).
1288 fn check_status_stream<'a>(
1289 &'a self,
1290 ctx: &'a ToolContext<Ctx>,
1291 operation_id: &'a str,
1292 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
1293}
1294
1295/// Wrapper that erases the Name and Stage associated types from an [`AsyncTool`].
1296struct AsyncToolWrapper<T, Ctx>
1297where
1298 T: AsyncTool<Ctx>,
1299{
1300 inner: T,
1301 name_cache: String,
1302 _marker: PhantomData<Ctx>,
1303}
1304
1305impl<T, Ctx> AsyncToolWrapper<T, Ctx>
1306where
1307 T: AsyncTool<Ctx>,
1308{
1309 fn new(tool: T) -> Self {
1310 let name_cache = tool_name_to_string(&tool.name());
1311 Self {
1312 inner: tool,
1313 name_cache,
1314 _marker: PhantomData,
1315 }
1316 }
1317}
1318
1319#[async_trait]
1320impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
1321where
1322 T: AsyncTool<Ctx> + 'static,
1323 Ctx: Send + Sync + 'static,
1324{
1325 fn name_str(&self) -> &str {
1326 &self.name_cache
1327 }
1328
1329 fn display_name(&self) -> &'static str {
1330 self.inner.display_name()
1331 }
1332
1333 fn description(&self) -> &'static str {
1334 self.inner.description()
1335 }
1336
1337 fn input_schema(&self) -> Value {
1338 self.inner.input_schema()
1339 }
1340
1341 fn tier(&self) -> ToolTier {
1342 self.inner.tier()
1343 }
1344
1345 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
1346 self.inner.execute(ctx, input).await
1347 }
1348
1349 fn check_status_stream<'a>(
1350 &'a self,
1351 ctx: &'a ToolContext<Ctx>,
1352 operation_id: &'a str,
1353 ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
1354 use futures::StreamExt;
1355 let stream = self.inner.check_status(ctx, operation_id);
1356 Box::pin(stream.map(ErasedToolStatus::from))
1357 }
1358}
1359
1360// ============================================================================
1361// Type-Erased ListenExecuteTool (for Registry)
1362// ============================================================================
1363
1364/// Type-erased listen/execute tool trait for registry storage.
1365#[async_trait]
1366pub trait ErasedListenTool<Ctx>: Send + Sync {
1367 /// Get the tool name as a string.
1368 fn name_str(&self) -> &str;
1369 /// Get a human-friendly display name for the tool.
1370 fn display_name(&self) -> &'static str;
1371 /// Get the tool description.
1372 fn description(&self) -> &'static str;
1373 /// Get the JSON schema for tool inputs.
1374 fn input_schema(&self) -> Value;
1375 /// Get the tool's permission tier.
1376 fn tier(&self) -> ToolTier;
1377 /// Start listen stream.
1378 fn listen_stream<'a>(
1379 &'a self,
1380 ctx: &'a ToolContext<Ctx>,
1381 input: Value,
1382 ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
1383 /// Execute using a prepared operation.
1384 async fn execute(
1385 &self,
1386 ctx: &ToolContext<Ctx>,
1387 operation_id: &str,
1388 expected_revision: u64,
1389 ) -> Result<ToolResult>;
1390 /// Cancel operation.
1391 async fn cancel(
1392 &self,
1393 ctx: &ToolContext<Ctx>,
1394 operation_id: &str,
1395 reason: ListenStopReason,
1396 ) -> Result<()>;
1397}
1398
1399/// Wrapper that erases the Name associated type from a [`ListenExecuteTool`].
1400struct ListenToolWrapper<T, Ctx>
1401where
1402 T: ListenExecuteTool<Ctx>,
1403{
1404 inner: T,
1405 name_cache: String,
1406 _marker: PhantomData<Ctx>,
1407}
1408
1409impl<T, Ctx> ListenToolWrapper<T, Ctx>
1410where
1411 T: ListenExecuteTool<Ctx>,
1412{
1413 fn new(tool: T) -> Self {
1414 let name_cache = tool_name_to_string(&tool.name());
1415 Self {
1416 inner: tool,
1417 name_cache,
1418 _marker: PhantomData,
1419 }
1420 }
1421}
1422
1423#[async_trait]
1424impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
1425where
1426 T: ListenExecuteTool<Ctx> + 'static,
1427 Ctx: Send + Sync + 'static,
1428{
1429 fn name_str(&self) -> &str {
1430 &self.name_cache
1431 }
1432
1433 fn display_name(&self) -> &'static str {
1434 self.inner.display_name()
1435 }
1436
1437 fn description(&self) -> &'static str {
1438 self.inner.description()
1439 }
1440
1441 fn input_schema(&self) -> Value {
1442 self.inner.input_schema()
1443 }
1444
1445 fn tier(&self) -> ToolTier {
1446 self.inner.tier()
1447 }
1448
1449 fn listen_stream<'a>(
1450 &'a self,
1451 ctx: &'a ToolContext<Ctx>,
1452 input: Value,
1453 ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
1454 let stream = self.inner.listen(ctx, input);
1455 Box::pin(stream)
1456 }
1457
1458 async fn execute(
1459 &self,
1460 ctx: &ToolContext<Ctx>,
1461 operation_id: &str,
1462 expected_revision: u64,
1463 ) -> Result<ToolResult> {
1464 self.inner
1465 .execute(ctx, operation_id, expected_revision)
1466 .await
1467 }
1468
1469 async fn cancel(
1470 &self,
1471 ctx: &ToolContext<Ctx>,
1472 operation_id: &str,
1473 reason: ListenStopReason,
1474 ) -> Result<()> {
1475 self.inner.cancel(ctx, operation_id, reason).await
1476 }
1477}
1478
1479/// Registry of available tools.
1480///
1481/// Tools are stored with their names erased to allow different `Name` types
1482/// in the same registry. The registry uses string-based lookup for LLM
1483/// compatibility.
1484///
1485/// Supports both synchronous [`Tool`]s and asynchronous [`AsyncTool`]s.
1486pub struct ToolRegistry<Ctx> {
1487 tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
1488 async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
1489 listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
1490}
1491
1492impl<Ctx> Clone for ToolRegistry<Ctx> {
1493 fn clone(&self) -> Self {
1494 Self {
1495 tools: self.tools.clone(),
1496 async_tools: self.async_tools.clone(),
1497 listen_tools: self.listen_tools.clone(),
1498 }
1499 }
1500}
1501
1502impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
1503 fn default() -> Self {
1504 Self::new()
1505 }
1506}
1507
1508impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
1509 #[must_use]
1510 pub fn new() -> Self {
1511 Self {
1512 tools: HashMap::new(),
1513 async_tools: HashMap::new(),
1514 listen_tools: HashMap::new(),
1515 }
1516 }
1517
1518 /// Evict any existing registration for `name` across **all three** maps so
1519 /// a name lives in exactly one map, then warn about the replacement.
1520 ///
1521 /// Without this, re-registering a name silently replaced the tool, and the
1522 /// same name registered as both (say) a sync and a listen tool coexisted —
1523 /// [`len`](ToolRegistry::len) double-counted it and
1524 /// [`to_llm_tools`](ToolRegistry::to_llm_tools) emitted two definitions with
1525 /// identical names (which providers reject). A remote MCP server could also
1526 /// silently shadow a vetted built-in (`read`, `bash`). We keep the
1527 /// non-breaking last-registration-wins behavior but make it loud; callers
1528 /// that need fail-closed semantics should use the `try_register*` variants.
1529 fn evict_existing(&mut self, name: &str, new_kind: &str) {
1530 // Evict from all three maps (each is side-effecting and must run); a
1531 // name lives in at most one map, so the listen > async > sync ordering
1532 // only disambiguates the pathological double-registration case.
1533 let previous_kind = [
1534 (self.listen_tools.remove(name).is_some(), "listen"),
1535 (self.async_tools.remove(name).is_some(), "async"),
1536 (self.tools.remove(name).is_some(), "sync"),
1537 ]
1538 .into_iter()
1539 .find_map(|(removed, kind)| removed.then_some(kind));
1540 if let Some(previous_kind) = previous_kind {
1541 log::warn!(
1542 "tool registry: name {name:?} already registered as a {previous_kind} tool; \
1543 replacing it with a {new_kind} tool (last registration wins)"
1544 );
1545 }
1546 }
1547
1548 /// Error if `name` is already registered in any of the three maps.
1549 fn ensure_unique(&self, name: &str) -> Result<()> {
1550 anyhow::ensure!(
1551 !self.tools.contains_key(name)
1552 && !self.async_tools.contains_key(name)
1553 && !self.listen_tools.contains_key(name),
1554 "tool {name:?} is already registered",
1555 );
1556 Ok(())
1557 }
1558
1559 /// Register a synchronous tool in the registry.
1560 ///
1561 /// The tool's name is converted to a string via serde serialization
1562 /// and used as the lookup key. If the name is already registered (in any
1563 /// map), the previous tool is evicted and a warning is logged; use
1564 /// [`try_register`](ToolRegistry::try_register) for fail-closed semantics.
1565 pub fn register<T>(&mut self, tool: T) -> &mut Self
1566 where
1567 T: Tool<Ctx> + 'static,
1568 {
1569 let wrapper = ToolWrapper::new(tool);
1570 let name = wrapper.name_str().to_string();
1571 self.evict_existing(&name, "sync");
1572 self.tools.insert(name, Arc::new(wrapper));
1573 self
1574 }
1575
1576 /// Register a synchronous tool, returning an error on name collision.
1577 ///
1578 /// Unlike [`register`](ToolRegistry::register), this never silently
1579 /// replaces an existing tool — it checks all three maps and fails if the
1580 /// name is taken. Useful for registering untrusted (e.g. MCP-supplied)
1581 /// tools without letting them squat over vetted built-ins.
1582 ///
1583 /// # Errors
1584 /// Returns an error if a tool with the same name is already registered.
1585 pub fn try_register<T>(&mut self, tool: T) -> Result<&mut Self>
1586 where
1587 T: Tool<Ctx> + 'static,
1588 {
1589 let wrapper = ToolWrapper::new(tool);
1590 let name = wrapper.name_str().to_string();
1591 self.ensure_unique(&name)?;
1592 self.tools.insert(name, Arc::new(wrapper));
1593 Ok(self)
1594 }
1595
1596 /// Register a [`SimpleTool`] — a tool whose name is a plain `&str` and
1597 /// which needs no [`ToolName`] type.
1598 ///
1599 /// The tool is wrapped in a [`SimpleToolAdapter`] (giving it
1600 /// `Name = DynamicToolName`) and registered like any other [`Tool`].
1601 /// This is the lowest-ceremony way to add a first custom tool.
1602 pub fn register_simple<T>(&mut self, tool: T) -> &mut Self
1603 where
1604 T: SimpleTool<Ctx> + 'static,
1605 {
1606 self.register(SimpleToolAdapter::new(tool))
1607 }
1608
1609 /// Register a [`TypedTool`] — a tool whose model-emitted arguments are
1610 /// deserialized into a typed [`TypedTool::Input`] and validated **before**
1611 /// `execute` runs.
1612 ///
1613 /// The tool is wrapped in a [`TypedToolAdapter`] (giving it
1614 /// `Name = DynamicToolName`) and registered like any other [`Tool`]. A
1615 /// malformed tool call is turned into a structured validation-error
1616 /// [`ToolResult`] at the dispatch boundary so the model can self-correct;
1617 /// `execute` is never reached with invalid arguments.
1618 pub fn register_typed<T>(&mut self, tool: T) -> &mut Self
1619 where
1620 T: TypedTool<Ctx> + 'static,
1621 {
1622 self.register(TypedToolAdapter::new(tool))
1623 }
1624
1625 /// Register an async tool in the registry.
1626 ///
1627 /// Async tools have two phases: execute (lightweight, starts operation)
1628 /// and `check_status` (streams progress until completion).
1629 pub fn register_async<T>(&mut self, tool: T) -> &mut Self
1630 where
1631 T: AsyncTool<Ctx> + 'static,
1632 {
1633 let wrapper = AsyncToolWrapper::new(tool);
1634 let name = wrapper.name_str().to_string();
1635 self.evict_existing(&name, "async");
1636 self.async_tools.insert(name, Arc::new(wrapper));
1637 self
1638 }
1639
1640 /// Register an async tool, returning an error on name collision.
1641 ///
1642 /// The fail-closed counterpart to [`register_async`](ToolRegistry::register_async).
1643 ///
1644 /// # Errors
1645 /// Returns an error if a tool with the same name is already registered.
1646 pub fn try_register_async<T>(&mut self, tool: T) -> Result<&mut Self>
1647 where
1648 T: AsyncTool<Ctx> + 'static,
1649 {
1650 let wrapper = AsyncToolWrapper::new(tool);
1651 let name = wrapper.name_str().to_string();
1652 self.ensure_unique(&name)?;
1653 self.async_tools.insert(name, Arc::new(wrapper));
1654 Ok(self)
1655 }
1656
1657 /// Register a listen/execute tool in the registry.
1658 ///
1659 /// Listen/execute tools start by streaming updates via `listen()`, then run
1660 /// final execution with `execute()` once confirmed.
1661 pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
1662 where
1663 T: ListenExecuteTool<Ctx> + 'static,
1664 {
1665 let wrapper = ListenToolWrapper::new(tool);
1666 let name = wrapper.name_str().to_string();
1667 self.evict_existing(&name, "listen");
1668 self.listen_tools.insert(name, Arc::new(wrapper));
1669 self
1670 }
1671
1672 /// Register a listen/execute tool, returning an error on name collision.
1673 ///
1674 /// The fail-closed counterpart to [`register_listen`](ToolRegistry::register_listen).
1675 ///
1676 /// # Errors
1677 /// Returns an error if a tool with the same name is already registered.
1678 pub fn try_register_listen<T>(&mut self, tool: T) -> Result<&mut Self>
1679 where
1680 T: ListenExecuteTool<Ctx> + 'static,
1681 {
1682 let wrapper = ListenToolWrapper::new(tool);
1683 let name = wrapper.name_str().to_string();
1684 self.ensure_unique(&name)?;
1685 self.listen_tools.insert(name, Arc::new(wrapper));
1686 Ok(self)
1687 }
1688
1689 /// Get a synchronous tool by name.
1690 #[must_use]
1691 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
1692 self.tools.get(name)
1693 }
1694
1695 /// Get an async tool by name.
1696 #[must_use]
1697 pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
1698 self.async_tools.get(name)
1699 }
1700
1701 /// Get a listen/execute tool by name.
1702 #[must_use]
1703 pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
1704 self.listen_tools.get(name)
1705 }
1706
1707 /// Check if a tool name refers to an async tool.
1708 #[must_use]
1709 pub fn is_async(&self, name: &str) -> bool {
1710 self.async_tools.contains_key(name)
1711 }
1712
1713 /// Check if a tool name refers to a listen/execute tool.
1714 #[must_use]
1715 pub fn is_listen(&self, name: &str) -> bool {
1716 self.listen_tools.contains_key(name)
1717 }
1718
1719 /// Get all registered synchronous tools.
1720 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
1721 self.tools.values()
1722 }
1723
1724 /// Get all registered async tools.
1725 pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
1726 self.async_tools.values()
1727 }
1728
1729 /// Get all registered listen/execute tools.
1730 pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
1731 self.listen_tools.values()
1732 }
1733
1734 /// Get the number of registered tools (sync + async).
1735 #[must_use]
1736 pub fn len(&self) -> usize {
1737 self.tools.len() + self.async_tools.len() + self.listen_tools.len()
1738 }
1739
1740 /// Check if the registry is empty.
1741 #[must_use]
1742 pub fn is_empty(&self) -> bool {
1743 self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
1744 }
1745
1746 /// Filter tools by a predicate.
1747 ///
1748 /// Removes tools for which the predicate returns false.
1749 /// The predicate receives the tool name.
1750 /// Applies to both sync and async tools.
1751 ///
1752 /// # Example
1753 ///
1754 /// ```ignore
1755 /// registry.filter(|name| name != "bash");
1756 /// ```
1757 pub fn filter<F>(&mut self, predicate: F)
1758 where
1759 F: Fn(&str) -> bool,
1760 {
1761 self.tools.retain(|name, _| predicate(name));
1762 self.async_tools.retain(|name, _| predicate(name));
1763 self.listen_tools.retain(|name, _| predicate(name));
1764 }
1765
1766 /// Convert all tools (sync + async + listen) to LLM tool
1767 /// definitions. The output is sorted by tool name so the order
1768 /// is deterministic across builds and across calls.
1769 ///
1770 /// Determinism matters for **prompt caching**. Anthropic's
1771 /// `cache_control: ephemeral` keys on the byte content of the
1772 /// system + tool list. Anything that perturbs the order of the
1773 /// tool list invalidates the cache. The three backing maps are
1774 /// `HashMap`s, whose `values()` order is randomized (DoS-safe
1775 /// `RandomState` by default), so two consecutive turns with the
1776 /// same registered tool set were producing different orderings
1777 /// and silently zeroing the cache hit rate.
1778 ///
1779 /// Sorting by name is the cheapest fix that holds across
1780 /// insertion order, internal map type changes, and concurrent
1781 /// builds. The tool count is small (tens, not thousands) so the
1782 /// sort cost is negligible compared to a single LLM call.
1783 #[must_use]
1784 pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
1785 /// Build the LLM tool descriptor from the accessors every erased tool
1786 /// trait shares. Extracted so the five-field `llm::Tool` literal —
1787 /// whose byte content is prompt-cache load-bearing — exists in exactly
1788 /// one place across the sync / async / listen iterators.
1789 fn descriptor(
1790 name: &str,
1791 display_name: &str,
1792 description: &str,
1793 input_schema: Value,
1794 tier: ToolTier,
1795 ) -> llm::Tool {
1796 llm::Tool {
1797 name: name.to_string(),
1798 description: description.to_string(),
1799 input_schema,
1800 display_name: display_name.to_string(),
1801 tier,
1802 }
1803 }
1804
1805 let mut tools: Vec<_> = self
1806 .tools
1807 .values()
1808 .map(|tool| {
1809 descriptor(
1810 tool.name_str(),
1811 tool.display_name(),
1812 tool.description(),
1813 tool.input_schema(),
1814 tool.tier(),
1815 )
1816 })
1817 .collect();
1818
1819 tools.extend(self.async_tools.values().map(|tool| {
1820 descriptor(
1821 tool.name_str(),
1822 tool.display_name(),
1823 tool.description(),
1824 tool.input_schema(),
1825 tool.tier(),
1826 )
1827 }));
1828
1829 tools.extend(self.listen_tools.values().map(|tool| {
1830 descriptor(
1831 tool.name_str(),
1832 tool.display_name(),
1833 tool.description(),
1834 tool.input_schema(),
1835 tool.tier(),
1836 )
1837 }));
1838
1839 tools.sort_by(|a, b| a.name.cmp(&b.name));
1840 tools
1841 }
1842}
1843
1844#[cfg(test)]
1845mod tests {
1846 use super::*;
1847 use anyhow::Context;
1848
1849 // Test tool name enum for tests
1850 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
1851 #[serde(rename_all = "snake_case")]
1852 enum TestToolName {
1853 MockTool,
1854 AnotherTool,
1855 }
1856
1857 impl ToolName for TestToolName {}
1858
1859 struct MockTool;
1860
1861 impl Tool<()> for MockTool {
1862 type Name = TestToolName;
1863
1864 fn name(&self) -> TestToolName {
1865 TestToolName::MockTool
1866 }
1867
1868 fn display_name(&self) -> &'static str {
1869 "Mock Tool"
1870 }
1871
1872 fn description(&self) -> &'static str {
1873 "A mock tool for testing"
1874 }
1875
1876 fn input_schema(&self) -> Value {
1877 serde_json::json!({
1878 "type": "object",
1879 "properties": {
1880 "message": { "type": "string" }
1881 }
1882 })
1883 }
1884
1885 async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
1886 let message = input
1887 .get("message")
1888 .and_then(|v| v.as_str())
1889 .unwrap_or("no message");
1890 Ok(ToolResult::success(format!("Received: {message}")))
1891 }
1892 }
1893
1894 #[test]
1895 fn test_tool_name_serialization() {
1896 let name = TestToolName::MockTool;
1897 assert_eq!(tool_name_to_string(&name), "mock_tool");
1898
1899 let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
1900 assert_eq!(parsed, TestToolName::MockTool);
1901 }
1902
1903 #[test]
1904 fn test_dynamic_tool_name() {
1905 let name = DynamicToolName::new("my_mcp_tool");
1906 assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
1907 assert_eq!(name.as_str(), "my_mcp_tool");
1908 }
1909
1910 #[test]
1911 fn test_tool_registry() {
1912 let mut registry = ToolRegistry::new();
1913 registry.register(MockTool);
1914
1915 assert_eq!(registry.len(), 1);
1916 assert!(registry.get("mock_tool").is_some());
1917 assert!(registry.get("nonexistent").is_none());
1918 }
1919
1920 #[test]
1921 fn test_to_llm_tools() {
1922 let mut registry = ToolRegistry::new();
1923 registry.register(MockTool);
1924
1925 let llm_tools = registry.to_llm_tools();
1926 assert_eq!(llm_tools.len(), 1);
1927 assert_eq!(llm_tools[0].name, "mock_tool");
1928 }
1929
1930 #[test]
1931 fn to_llm_tools_returns_alphabetical_order() {
1932 let mut registry = ToolRegistry::new();
1933 // Register in non-alphabetical order so the assertion would
1934 // fail if we ever returned insertion order again.
1935 registry.register(MockTool); // "mock_tool"
1936 registry.register(AnotherTool); // "another_tool"
1937
1938 let names: Vec<String> = registry
1939 .to_llm_tools()
1940 .into_iter()
1941 .map(|t| t.name)
1942 .collect();
1943 assert_eq!(names, vec!["another_tool", "mock_tool"]);
1944 }
1945
1946 #[test]
1947 fn to_llm_tools_is_deterministic_across_calls() {
1948 // Regression: prompt caching depends on byte-stable tool list
1949 // ordering. The `HashMap` behind the registry randomizes its
1950 // `values()` order, so without an explicit sort two consecutive
1951 // builds with the same registered set could ship different
1952 // tool orderings to the LLM and silently invalidate the cache.
1953 let mut registry = ToolRegistry::new();
1954 registry.register(MockTool);
1955 registry.register(AnotherTool);
1956
1957 let first: Vec<String> = registry
1958 .to_llm_tools()
1959 .into_iter()
1960 .map(|t| t.name)
1961 .collect();
1962
1963 for _ in 0..32 {
1964 let next: Vec<String> = registry
1965 .to_llm_tools()
1966 .into_iter()
1967 .map(|t| t.name)
1968 .collect();
1969 assert_eq!(next, first, "tool ordering must be stable across calls");
1970 }
1971 }
1972
1973 struct AnotherTool;
1974
1975 impl Tool<()> for AnotherTool {
1976 type Name = TestToolName;
1977
1978 fn name(&self) -> TestToolName {
1979 TestToolName::AnotherTool
1980 }
1981
1982 fn display_name(&self) -> &'static str {
1983 "Another Tool"
1984 }
1985
1986 fn description(&self) -> &'static str {
1987 "Another tool for testing"
1988 }
1989
1990 fn input_schema(&self) -> Value {
1991 serde_json::json!({ "type": "object" })
1992 }
1993
1994 async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
1995 Ok(ToolResult::success("Done"))
1996 }
1997 }
1998
1999 #[test]
2000 fn test_filter_tools() {
2001 let mut registry = ToolRegistry::new();
2002 registry.register(MockTool);
2003 registry.register(AnotherTool);
2004
2005 assert_eq!(registry.len(), 2);
2006
2007 // Filter out mock_tool
2008 registry.filter(|name| name != "mock_tool");
2009
2010 assert_eq!(registry.len(), 1);
2011 assert!(registry.get("mock_tool").is_none());
2012 assert!(registry.get("another_tool").is_some());
2013 }
2014
2015 #[test]
2016 fn test_filter_tools_keep_all() {
2017 let mut registry = ToolRegistry::new();
2018 registry.register(MockTool);
2019 registry.register(AnotherTool);
2020
2021 registry.filter(|_| true);
2022
2023 assert_eq!(registry.len(), 2);
2024 }
2025
2026 #[test]
2027 fn test_filter_tools_remove_all() {
2028 let mut registry = ToolRegistry::new();
2029 registry.register(MockTool);
2030 registry.register(AnotherTool);
2031
2032 registry.filter(|_| false);
2033
2034 assert!(registry.is_empty());
2035 }
2036
2037 #[test]
2038 fn test_display_name() {
2039 let mut registry = ToolRegistry::new();
2040 registry.register(MockTool);
2041
2042 let tool = registry.get("mock_tool").unwrap();
2043 assert_eq!(tool.display_name(), "Mock Tool");
2044 }
2045
2046 struct ListenMockTool;
2047
2048 impl ListenExecuteTool<()> for ListenMockTool {
2049 type Name = TestToolName;
2050
2051 fn name(&self) -> TestToolName {
2052 TestToolName::MockTool
2053 }
2054
2055 fn display_name(&self) -> &'static str {
2056 "Listen Mock Tool"
2057 }
2058
2059 fn description(&self) -> &'static str {
2060 "A listen/execute mock tool for testing"
2061 }
2062
2063 fn input_schema(&self) -> Value {
2064 serde_json::json!({ "type": "object" })
2065 }
2066
2067 fn listen(
2068 &self,
2069 _ctx: &ToolContext<()>,
2070 _input: Value,
2071 ) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
2072 futures::stream::iter(vec![ListenToolUpdate::Ready {
2073 operation_id: "op_1".to_string(),
2074 revision: 1,
2075 message: "ready".to_string(),
2076 snapshot: serde_json::json!({"ok": true}),
2077 expires_at: None,
2078 }])
2079 }
2080
2081 async fn execute(
2082 &self,
2083 _ctx: &ToolContext<()>,
2084 _operation_id: &str,
2085 _expected_revision: u64,
2086 ) -> Result<ToolResult> {
2087 Ok(ToolResult::success("Executed"))
2088 }
2089 }
2090
2091 #[test]
2092 fn test_listen_tool_registry() {
2093 let mut registry = ToolRegistry::new();
2094 registry.register_listen(ListenMockTool);
2095
2096 assert_eq!(registry.len(), 1);
2097 assert!(registry.get_listen("mock_tool").is_some());
2098 assert!(registry.is_listen("mock_tool"));
2099 }
2100
2101 // ── TypedTool: typed input + validation / self-correction ───────────
2102
2103 use std::sync::atomic::{AtomicBool, Ordering};
2104
2105 #[derive(Debug, Serialize, Deserialize)]
2106 struct GreetArgs {
2107 name: String,
2108 // Required so a missing/typo'd field is a hard validation error.
2109 greeting: String,
2110 }
2111
2112 /// A typed tool that records whether `execute` was reached, so tests can
2113 /// assert the validation boundary never calls `execute` with bad args.
2114 struct GreetTool {
2115 executed: Arc<AtomicBool>,
2116 }
2117
2118 impl TypedTool<()> for GreetTool {
2119 type Input = GreetArgs;
2120
2121 fn name(&self) -> &'static str {
2122 "greet"
2123 }
2124
2125 fn description(&self) -> &'static str {
2126 "Greet someone by name"
2127 }
2128
2129 fn input_schema(&self) -> Value {
2130 serde_json::json!({
2131 "type": "object",
2132 "properties": {
2133 "name": { "type": "string" },
2134 "greeting": { "type": "string" }
2135 },
2136 "required": ["name", "greeting"]
2137 })
2138 }
2139
2140 async fn execute(&self, _ctx: &ToolContext<()>, input: GreetArgs) -> Result<ToolResult> {
2141 self.executed.store(true, Ordering::SeqCst);
2142 Ok(ToolResult::success(format!(
2143 "{}, {}!",
2144 input.greeting, input.name
2145 )))
2146 }
2147 }
2148
2149 #[tokio::test]
2150 async fn typed_tool_happy_path_receives_typed_input() -> Result<()> {
2151 let executed = Arc::new(AtomicBool::new(false));
2152 let adapter = TypedToolAdapter::new(GreetTool {
2153 executed: executed.clone(),
2154 });
2155 let ctx = ToolContext::new(());
2156
2157 let result = Tool::execute(
2158 &adapter,
2159 &ctx,
2160 serde_json::json!({ "name": "Ada", "greeting": "Hello" }),
2161 )
2162 .await?;
2163
2164 assert!(executed.load(Ordering::SeqCst), "execute must be called");
2165 assert!(result.success);
2166 assert_eq!(result.output, "Hello, Ada!");
2167 Ok(())
2168 }
2169
2170 #[tokio::test]
2171 async fn typed_tool_invalid_args_self_correct_without_executing() -> Result<()> {
2172 let executed = Arc::new(AtomicBool::new(false));
2173 let adapter = TypedToolAdapter::new(GreetTool {
2174 executed: executed.clone(),
2175 });
2176 let ctx = ToolContext::new(());
2177
2178 // `greeting` is missing — must not deserialize into `GreetArgs`.
2179 let result = Tool::execute(&adapter, &ctx, serde_json::json!({ "name": "Ada" })).await?;
2180
2181 assert!(
2182 !executed.load(Ordering::SeqCst),
2183 "execute must NOT be called with invalid arguments"
2184 );
2185 assert!(!result.success, "validation failure is an error result");
2186 assert!(
2187 result.output.contains("Invalid arguments for tool `greet`"),
2188 "error must identify the tool: {}",
2189 result.output
2190 );
2191 assert!(
2192 result.output.contains("greeting"),
2193 "error must surface the serde message naming the bad field: {}",
2194 result.output
2195 );
2196 Ok(())
2197 }
2198
2199 #[tokio::test]
2200 async fn typed_tool_wrong_type_self_corrects() -> Result<()> {
2201 let executed = Arc::new(AtomicBool::new(false));
2202 let adapter = TypedToolAdapter::new(GreetTool {
2203 executed: executed.clone(),
2204 });
2205 let ctx = ToolContext::new(());
2206
2207 // `name` is a number, not a string.
2208 let result = Tool::execute(
2209 &adapter,
2210 &ctx,
2211 serde_json::json!({ "name": 42, "greeting": "Hi" }),
2212 )
2213 .await?;
2214
2215 assert!(!executed.load(Ordering::SeqCst));
2216 assert!(!result.success);
2217 Ok(())
2218 }
2219
2220 /// Back-compat: a `TypedTool` whose `Input = Value` is the identity
2221 /// passthrough — any JSON deserializes, mirroring today's untyped tools.
2222 struct ValueTypedTool;
2223
2224 impl TypedTool<()> for ValueTypedTool {
2225 type Input = Value;
2226
2227 fn name(&self) -> &'static str {
2228 "value_typed"
2229 }
2230
2231 fn description(&self) -> &'static str {
2232 "Accepts any JSON, like an untyped tool"
2233 }
2234
2235 fn input_schema(&self) -> Value {
2236 serde_json::json!({ "type": "object" })
2237 }
2238
2239 async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
2240 Ok(ToolResult::success(input.to_string()))
2241 }
2242 }
2243
2244 #[tokio::test]
2245 async fn typed_tool_value_input_is_identity_passthrough() -> Result<()> {
2246 let adapter = TypedToolAdapter::new(ValueTypedTool);
2247 let ctx = ToolContext::new(());
2248
2249 // Arbitrary shape — Value always "deserializes".
2250 let result = Tool::execute(
2251 &adapter,
2252 &ctx,
2253 serde_json::json!({ "anything": [1, 2, 3], "nested": { "ok": true } }),
2254 )
2255 .await?;
2256
2257 assert!(result.success);
2258 Ok(())
2259 }
2260
2261 #[test]
2262 fn register_typed_exposes_tool_via_registry() -> Result<()> {
2263 let mut registry = ToolRegistry::new();
2264 registry.register_typed(GreetTool {
2265 executed: Arc::new(AtomicBool::new(false)),
2266 });
2267
2268 assert_eq!(registry.len(), 1);
2269 let tool = registry.get("greet").context("typed tool registered")?;
2270 // The user-declared schema flows through unchanged.
2271 assert_eq!(tool.input_schema()["required"][0], "name");
2272 Ok(())
2273 }
2274
2275 #[test]
2276 fn invalid_tool_input_result_is_balanced_error() -> Result<()> {
2277 let Err(err) = serde_json::from_str::<GreetArgs>("{}") else {
2278 anyhow::bail!("empty object must fail to deserialize GreetArgs");
2279 };
2280 let result = invalid_tool_input_result("greet", &err);
2281
2282 assert!(!result.success);
2283 assert!(result.output.contains("greet"));
2284 assert!(result.output.contains("call the tool again"));
2285 Ok(())
2286 }
2287
2288 // ── Fail-closed tier + display_name defaults (findings 8 & 19) ───────
2289
2290 /// A tool that overrides neither `tier()` nor `display_name()`, exercising
2291 /// the trait defaults.
2292 struct DefaultsTool;
2293
2294 impl Tool<()> for DefaultsTool {
2295 type Name = DynamicToolName;
2296
2297 fn name(&self) -> DynamicToolName {
2298 DynamicToolName::new("defaults")
2299 }
2300
2301 fn description(&self) -> &'static str {
2302 "uses trait defaults"
2303 }
2304
2305 fn input_schema(&self) -> Value {
2306 serde_json::json!({ "type": "object" })
2307 }
2308
2309 async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
2310 Ok(ToolResult::success("ok"))
2311 }
2312 }
2313
2314 #[test]
2315 fn tool_trait_defaults_are_fail_closed() {
2316 let tool = DefaultsTool;
2317 // display_name defaults to "" (finding 19: the doc-claimed default now
2318 // actually exists).
2319 assert_eq!(Tool::display_name(&tool), "");
2320 // tier defaults to Confirm so a side-effecting tool whose author forgot
2321 // to declare a tier is gated, not auto-run (finding 8).
2322 assert_eq!(Tool::tier(&tool), ToolTier::Confirm);
2323 }
2324
2325 // ── Registry name-collision handling (findings 9 & 10) ───────────────
2326
2327 #[test]
2328 fn re_registering_same_name_replaces_without_duplicates() {
2329 let mut registry = ToolRegistry::new();
2330 registry.register(MockTool);
2331 registry.register(MockTool); // same name "mock_tool"
2332
2333 assert_eq!(registry.len(), 1, "re-register must replace, not add");
2334 let names: Vec<String> = registry
2335 .to_llm_tools()
2336 .into_iter()
2337 .map(|t| t.name)
2338 .collect();
2339 assert_eq!(names, vec!["mock_tool"]);
2340 }
2341
2342 #[test]
2343 fn cross_kind_name_collision_keeps_single_entry() {
2344 let mut registry = ToolRegistry::new();
2345 registry.register(MockTool); // sync "mock_tool"
2346 registry.register_listen(ListenMockTool); // listen "mock_tool"
2347
2348 // The listen registration evicts the sync one — a name lives in exactly
2349 // one map, so `len()` and `to_llm_tools()` never double-count it.
2350 assert_eq!(registry.len(), 1);
2351 assert!(registry.is_listen("mock_tool"));
2352 assert!(
2353 registry.get("mock_tool").is_none(),
2354 "the shadowed sync tool must be evicted"
2355 );
2356 let names: Vec<String> = registry
2357 .to_llm_tools()
2358 .into_iter()
2359 .map(|t| t.name)
2360 .collect();
2361 assert_eq!(names, vec!["mock_tool"], "no duplicate LLM definitions");
2362 }
2363
2364 #[test]
2365 fn try_register_rejects_name_collision() {
2366 let mut registry = ToolRegistry::new();
2367 registry.register(MockTool); // "mock_tool"
2368
2369 assert!(
2370 registry.try_register(MockTool).is_err(),
2371 "duplicate sync name must be rejected"
2372 );
2373 assert!(
2374 registry.try_register_listen(ListenMockTool).is_err(),
2375 "cross-map duplicate (squatting) must be rejected"
2376 );
2377 assert_eq!(
2378 registry.len(),
2379 1,
2380 "rejected registrations must not be stored"
2381 );
2382 }
2383
2384 // ── Non-panicking serde helpers (findings 16 & 17) ───────────────────
2385
2386 #[derive(Clone)]
2387 struct FailingStage;
2388
2389 impl Serialize for FailingStage {
2390 fn serialize<S>(&self, _serializer: S) -> core::result::Result<S::Ok, S::Error>
2391 where
2392 S: serde::Serializer,
2393 {
2394 Err(serde::ser::Error::custom("intentionally unserializable"))
2395 }
2396 }
2397
2398 impl<'de> Deserialize<'de> for FailingStage {
2399 fn deserialize<D>(_deserializer: D) -> core::result::Result<Self, D::Error>
2400 where
2401 D: serde::Deserializer<'de>,
2402 {
2403 Ok(Self)
2404 }
2405 }
2406
2407 impl ProgressStage for FailingStage {}
2408
2409 #[test]
2410 fn stage_to_string_falls_back_instead_of_panicking() {
2411 // A ProgressStage whose Serialize impl fails must not panic the turn
2412 // loop on the async-tool progress hot path.
2413 assert_eq!(stage_to_string(&FailingStage), "<unknown_stage>");
2414 }
2415
2416 #[test]
2417 fn tool_name_from_str_round_trips_special_characters() -> Result<()> {
2418 // Names with quotes/backslashes (possible from remote MCP servers) must
2419 // be JSON-escaped, not interpolated raw, so parsing succeeds.
2420 let name: DynamicToolName =
2421 tool_name_from_str("weird\"name\\with-escapes").context("must parse escaped name")?;
2422 assert_eq!(name.as_str(), "weird\"name\\with-escapes");
2423 Ok(())
2424 }
2425
2426 // ── emit_event surfaces unbound misuse instead of silently dropping ──
2427
2428 #[tokio::test]
2429 async fn emit_event_persists_when_bound_and_is_noop_when_unbound() -> Result<()> {
2430 use crate::stores::InMemoryEventStore;
2431 use agent_sdk_foundation::types::ThreadId;
2432
2433 let store = Arc::new(InMemoryEventStore::new());
2434 let thread_id = ThreadId::new();
2435 let authority: Arc<dyn EventAuthority> = Arc::new(LocalEventAuthority::new());
2436
2437 let bound =
2438 ToolContext::new(()).with_event_store(store.clone(), thread_id.clone(), 1, authority);
2439 bound.emit_event(AgentEvent::text("m1", "hi")).await?;
2440 assert_eq!(
2441 store.event_count(&thread_id).await?,
2442 1,
2443 "a bound context persists the event"
2444 );
2445
2446 // An unbound context is a no-op (the fix also logs a warning) — it must
2447 // not silently append elsewhere or error.
2448 let unbound = ToolContext::new(());
2449 unbound.emit_event(AgentEvent::text("m2", "lost")).await?;
2450 assert_eq!(
2451 store.event_count(&thread_id).await?,
2452 1,
2453 "an unbound context changes nothing"
2454 );
2455 Ok(())
2456 }
2457}