Skip to main content

agent_sdk/
tools.rs

1//! Tool definition and registry.
2//!
3//! Tools allow the LLM to perform actions in the real world. This module provides:
4//!
5//! - [`Tool`] trait - Define custom tools the LLM can call
6//! - [`ToolName`] trait - Marker trait for strongly-typed tool names
7//! - [`PrimitiveToolName`] - Tool names for SDK's built-in tools
8//! - [`DynamicToolName`] - Tool names created at runtime (MCP bridges)
9//! - [`ToolRegistry`] - Collection of available tools
10//! - [`ToolContext`] - Context passed to tool execution
11//! - [`ListenExecuteTool`] - Tools that listen for updates, then execute later
12//!
13//! # Implementing a Tool
14//!
15//! ```ignore
16//! use agent_sdk::{Tool, ToolContext, ToolResult, ToolTier, PrimitiveToolName};
17//!
18//! struct MyTool;
19//!
20//! // No #[async_trait] needed - Rust 1.75+ supports native async traits
21//! impl Tool<MyContext> for MyTool {
22//!     type Name = PrimitiveToolName;
23//!
24//!     fn name(&self) -> PrimitiveToolName { PrimitiveToolName::Read }
25//!     fn display_name(&self) -> &'static str { "My Tool" }
26//!     fn description(&self) -> &'static str { "Does something useful" }
27//!     fn input_schema(&self) -> Value { json!({ "type": "object" }) }
28//!     fn tier(&self) -> ToolTier { ToolTier::Observe }
29//!
30//!     async fn execute(&self, ctx: &ToolContext<MyContext>, input: Value) -> Result<ToolResult> {
31//!         Ok(ToolResult::success("Done!"))
32//!     }
33//! }
34//! ```
35
36use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
37use crate::llm;
38use crate::types::{ToolOutcome, ToolResult, ToolTier};
39use anyhow::Result;
40use async_trait::async_trait;
41use futures::Stream;
42use serde::{Deserialize, Serialize, de::DeserializeOwned};
43use serde_json::Value;
44use std::collections::HashMap;
45use std::future::Future;
46use std::marker::PhantomData;
47use std::pin::Pin;
48use std::sync::Arc;
49use time::OffsetDateTime;
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52
53// ============================================================================
54// Tool Name Types
55// ============================================================================
56
57/// Marker trait for tool names.
58///
59/// Tool names must be serializable (for storage/logging) and deserializable
60/// (for parsing from LLM responses). The string representation is derived
61/// from serde serialization.
62///
63/// # Example
64///
65/// ```ignore
66/// #[derive(Serialize, Deserialize)]
67/// #[serde(rename_all = "snake_case")]
68/// pub enum MyToolName {
69///     Read,
70///     Write,
71/// }
72///
73/// impl ToolName for MyToolName {}
74/// ```
75pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
76
77/// Helper to get string representation of a tool name via serde.
78///
79/// Returns `"<unknown_tool>"` if serialization fails (should never happen
80/// with properly implemented `ToolName` types that use `#[derive(Serialize)]`).
81#[must_use]
82pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
83    serde_json::to_string(name)
84        .unwrap_or_else(|_| "\"<unknown_tool>\"".to_string())
85        .trim_matches('"')
86        .to_string()
87}
88
89/// Parse a tool name from string via serde.
90///
91/// # Errors
92/// Returns error if the string doesn't match a valid tool name.
93pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
94    serde_json::from_str(&format!("\"{s}\""))
95}
96
97/// Tool names for SDK's built-in primitive tools.
98#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum PrimitiveToolName {
101    Read,
102    Write,
103    Edit,
104    MultiEdit,
105    Bash,
106    Glob,
107    Grep,
108    NotebookRead,
109    NotebookEdit,
110    TodoRead,
111    TodoWrite,
112    AskUser,
113    LinkFetch,
114    WebSearch,
115}
116
117impl ToolName for PrimitiveToolName {}
118
119/// Dynamic tool name for runtime-created tools (MCP bridges, subagents).
120#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
121#[serde(transparent)]
122pub struct DynamicToolName(String);
123
124impl DynamicToolName {
125    #[must_use]
126    pub fn new(name: impl Into<String>) -> Self {
127        Self(name.into())
128    }
129
130    #[must_use]
131    pub fn as_str(&self) -> &str {
132        &self.0
133    }
134}
135
136impl ToolName for DynamicToolName {}
137
138// ============================================================================
139// Progress Stage Types (for AsyncTool)
140// ============================================================================
141
142/// Marker trait for tool progress stages (type-safe, like [`ToolName`]).
143///
144/// Progress stages are used by async tools to indicate the current phase
145/// of a long-running operation. They must be serializable for event streaming.
146///
147/// # Example
148///
149/// ```ignore
150/// #[derive(Clone, Debug, Serialize, Deserialize)]
151/// #[serde(rename_all = "snake_case")]
152/// pub enum PixTransferStage {
153///     Initiated,
154///     Processing,
155///     SentToBank,
156/// }
157///
158/// impl ProgressStage for PixTransferStage {}
159/// ```
160pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
161
162/// Helper to get string representation of a progress stage via serde.
163///
164/// # Panics
165///
166/// Panics if the stage cannot be serialized to a string. This should
167/// never happen with properly implemented `ProgressStage` types.
168#[must_use]
169pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
170    serde_json::to_string(stage)
171        .expect("ProgressStage must serialize to string")
172        .trim_matches('"')
173        .to_string()
174}
175
176/// Status update from an async tool operation.
177#[derive(Clone, Debug, Serialize)]
178pub enum ToolStatus<S: ProgressStage> {
179    /// Operation is making progress
180    Progress {
181        stage: S,
182        message: String,
183        data: Option<serde_json::Value>,
184    },
185
186    /// Operation completed successfully
187    Completed(ToolResult),
188
189    /// Operation failed
190    Failed(ToolResult),
191}
192
193/// Type-erased status for the agent loop.
194#[derive(Clone, Debug, Serialize, Deserialize)]
195pub enum ErasedToolStatus {
196    /// Operation is making progress
197    Progress {
198        stage: String,
199        message: String,
200        data: Option<serde_json::Value>,
201    },
202    /// Operation completed successfully
203    Completed(ToolResult),
204    /// Operation failed
205    Failed(ToolResult),
206}
207
208/// Update emitted from a `listen()` stream.
209///
210/// This models workflows where a runtime prepares an operation over time, and
211/// execution happens later using an operation identifier and revision.
212#[derive(Clone, Debug, Serialize, Deserialize)]
213pub enum ListenToolUpdate {
214    /// Preparation is still running and should keep listening.
215    Listening {
216        /// Opaque operation identifier used for later execute/cancel calls.
217        operation_id: String,
218        /// Monotonic revision number for optimistic concurrency.
219        revision: u64,
220        /// Human-readable status message.
221        message: String,
222        /// Optional current snapshot for UI rendering.
223        snapshot: Option<serde_json::Value>,
224        /// Optional expiration timestamp (RFC3339).
225        #[serde(with = "time::serde::rfc3339::option")]
226        expires_at: Option<OffsetDateTime>,
227    },
228
229    /// Preparation is complete and execution can be confirmed.
230    Ready {
231        /// Opaque operation identifier used for later execute/cancel calls.
232        operation_id: String,
233        /// Monotonic revision number for optimistic concurrency.
234        revision: u64,
235        /// Human-readable status message.
236        message: String,
237        /// Snapshot shown in confirmation UI.
238        snapshot: serde_json::Value,
239        /// Optional expiration timestamp (RFC3339).
240        #[serde(with = "time::serde::rfc3339::option")]
241        expires_at: Option<OffsetDateTime>,
242    },
243
244    /// Operation is no longer valid.
245    Invalidated {
246        /// Opaque operation identifier.
247        operation_id: String,
248        /// Human-readable reason.
249        message: String,
250        /// Whether caller may recover by starting a new listen operation.
251        recoverable: bool,
252    },
253}
254
255/// Reason for stopping a listen session.
256#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
257pub enum ListenStopReason {
258    /// User explicitly rejected confirmation.
259    UserRejected,
260    /// Agent policy/hook blocked execution before confirmation.
261    Blocked,
262    /// Consumer disconnected while listen stream was active.
263    StreamDisconnected,
264    /// Listen stream ended unexpectedly before terminal state.
265    StreamEnded,
266}
267
268impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
269    fn from(status: ToolStatus<S>) -> Self {
270        match status {
271            ToolStatus::Progress {
272                stage,
273                message,
274                data,
275            } => Self::Progress {
276                stage: stage_to_string(&stage),
277                message,
278                data,
279            },
280            ToolStatus::Completed(r) => Self::Completed(r),
281            ToolStatus::Failed(r) => Self::Failed(r),
282        }
283    }
284}
285
286/// Context passed to tool execution
287pub struct ToolContext<Ctx> {
288    /// Application-specific context (e.g., `user_id`, db connection)
289    pub app: Ctx,
290    /// Tool-specific metadata
291    pub metadata: HashMap<String, Value>,
292    /// Optional channel for tools to emit events (e.g., subagent progress)
293    event_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
294    /// Optional sequence counter for wrapping events in envelopes
295    event_seq: Option<SequenceCounter>,
296    /// Optional cancellation token for propagating cancellation to subtasks
297    cancel_token: Option<CancellationToken>,
298}
299
300impl<Ctx> ToolContext<Ctx> {
301    #[must_use]
302    pub fn new(app: Ctx) -> Self {
303        Self {
304            app,
305            metadata: HashMap::new(),
306            event_tx: None,
307            event_seq: None,
308            cancel_token: None,
309        }
310    }
311
312    #[must_use]
313    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
314        self.metadata.insert(key.into(), value);
315        self
316    }
317
318    /// Set the event channel and sequence counter for tools that need to emit
319    /// events during execution.
320    #[must_use]
321    pub fn with_event_tx(
322        mut self,
323        tx: mpsc::Sender<AgentEventEnvelope>,
324        seq: SequenceCounter,
325    ) -> Self {
326        self.event_tx = Some(tx);
327        self.event_seq = Some(seq);
328        self
329    }
330
331    /// Emit an event through the event channel (if set).
332    ///
333    /// The event is wrapped in an [`AgentEventEnvelope`] with a unique ID,
334    /// sequence number, and timestamp before sending.
335    ///
336    /// This uses `try_send` to avoid blocking and to ensure the future is `Send`.
337    /// The event is silently dropped if the channel is full.
338    pub fn emit_event(&self, event: AgentEvent) {
339        if let Some((tx, seq)) = self.event_tx.as_ref().zip(self.event_seq.as_ref()) {
340            let envelope = AgentEventEnvelope::wrap(event, seq);
341            let _ = tx.try_send(envelope);
342        }
343    }
344
345    /// Get a clone of the event channel sender (if set).
346    ///
347    /// This is useful for tools that spawn subprocesses (like subagents)
348    /// and need to forward events to the parent's event stream.
349    #[must_use]
350    pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEventEnvelope>> {
351        self.event_tx.clone()
352    }
353
354    /// Get a clone of the sequence counter (if set).
355    ///
356    /// This is useful for tools that spawn subprocesses (like subagents)
357    /// and need to assign sequence numbers to events sent to the parent's stream.
358    #[must_use]
359    pub fn event_seq(&self) -> Option<SequenceCounter> {
360        self.event_seq.clone()
361    }
362
363    /// Set the cancellation token for propagating cancellation to subtasks.
364    #[must_use]
365    pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
366        self.cancel_token = Some(token);
367        self
368    }
369
370    /// Get the cancellation token (if set).
371    ///
372    /// Used by tools that spawn long-running subtasks (like subagents)
373    /// to propagate cancellation from the parent.
374    #[must_use]
375    pub fn cancel_token(&self) -> Option<CancellationToken> {
376        self.cancel_token.clone()
377    }
378}
379
380// ============================================================================
381// Tool Trait
382// ============================================================================
383
384/// Definition of a tool that can be called by the agent.
385///
386/// Tools have a strongly-typed `Name` associated type that determines
387/// how the tool name is serialized for LLM communication.
388///
389/// # Native Async Support
390///
391/// This trait uses Rust's native async functions in traits (stabilized in Rust 1.75).
392/// You do NOT need the `async_trait` crate to implement this trait.
393pub trait Tool<Ctx>: Send + Sync {
394    /// The type of name for this tool.
395    type Name: ToolName;
396
397    /// Returns the tool's strongly-typed name.
398    fn name(&self) -> Self::Name;
399
400    /// Human-readable display name for UI (e.g., "Read File" vs "read").
401    ///
402    /// Defaults to empty string. Override for better UX.
403    fn display_name(&self) -> &'static str;
404
405    /// Human-readable description of what the tool does.
406    fn description(&self) -> &'static str;
407
408    /// JSON schema for the tool's input parameters.
409    fn input_schema(&self) -> Value;
410
411    /// Permission tier for this tool.
412    fn tier(&self) -> ToolTier {
413        ToolTier::Observe
414    }
415
416    /// Execute the tool with the given input.
417    ///
418    /// # Errors
419    /// Returns an error if tool execution fails.
420    fn execute(
421        &self,
422        ctx: &ToolContext<Ctx>,
423        input: Value,
424    ) -> impl Future<Output = Result<ToolResult>> + Send;
425}
426
427// ============================================================================
428// AsyncTool Trait
429// ============================================================================
430
431/// A tool that performs long-running async operations.
432///
433/// `AsyncTool`s have two phases:
434/// 1. `execute()` - Start the operation (lightweight, returns quickly)
435/// 2. `check_status()` - Stream progress until completion
436///
437/// The actual work should happen externally (background task, external service)
438/// and persist results to a durable store. The tool is just an orchestrator.
439///
440/// # Example
441///
442/// ```ignore
443/// impl AsyncTool<MyCtx> for ExecutePixTransferTool {
444///     type Name = PixToolName;
445///     type Stage = PixTransferStage;
446///
447///     async fn execute(&self, ctx: &ToolContext<MyCtx>, input: Value) -> Result<ToolOutcome> {
448///         let params = parse_input(&input)?;
449///         let operation_id = ctx.app.pix_service.start_transfer(params).await?;
450///         Ok(ToolOutcome::in_progress(
451///             operation_id,
452///             format!("PIX transfer of {} initiated", params.amount),
453///         ))
454///     }
455///
456///     fn check_status(&self, ctx: &ToolContext<MyCtx>, operation_id: &str)
457///         -> impl Stream<Item = ToolStatus<PixTransferStage>> + Send
458///     {
459///         async_stream::stream! {
460///             loop {
461///                 let status = ctx.app.pix_service.get_status(operation_id).await;
462///                 match status {
463///                     PixStatus::Success { id } => {
464///                         yield ToolStatus::Completed(ToolResult::success(id));
465///                         break;
466///                     }
467///                     _ => yield ToolStatus::Progress { ... };
468///                 }
469///                 tokio::time::sleep(Duration::from_millis(500)).await;
470///             }
471///         }
472///     }
473/// }
474/// ```
475pub trait AsyncTool<Ctx>: Send + Sync {
476    /// The type of name for this tool.
477    type Name: ToolName;
478    /// The type of progress stages for this tool.
479    type Stage: ProgressStage;
480
481    /// Returns the tool's strongly-typed name.
482    fn name(&self) -> Self::Name;
483
484    /// Human-readable display name for UI.
485    fn display_name(&self) -> &'static str;
486
487    /// Human-readable description of what the tool does.
488    fn description(&self) -> &'static str;
489
490    /// JSON schema for the tool's input parameters.
491    fn input_schema(&self) -> Value;
492
493    /// Permission tier for this tool.
494    fn tier(&self) -> ToolTier {
495        ToolTier::Observe
496    }
497
498    /// Execute the tool. Returns immediately with one of:
499    /// - Success/Failed: Operation completed synchronously
500    /// - `InProgress`: Operation started, use `check_status()` to stream updates
501    ///
502    /// # Errors
503    /// Returns an error if tool execution fails.
504    fn execute(
505        &self,
506        ctx: &ToolContext<Ctx>,
507        input: Value,
508    ) -> impl Future<Output = Result<ToolOutcome>> + Send;
509
510    /// Stream status updates for an in-progress operation.
511    /// Must yield until Completed or Failed.
512    fn check_status(
513        &self,
514        ctx: &ToolContext<Ctx>,
515        operation_id: &str,
516    ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
517}
518
519// ============================================================================
520// ListenExecuteTool Trait
521// ============================================================================
522
523/// A tool whose runtime has two phases:
524/// 1. `listen()` - starts preparation and streams updates
525/// 2. `execute()` - performs final execution after confirmation
526///
527/// This abstraction is useful when runtime state can expire or evolve before
528/// execution (quotes, challenge windows, leases, approvals).
529///
530/// Ordering note: the agent loop consumes `listen()` updates before
531/// `AgentHooks::pre_tool_use()` runs. Hooks can therefore block `execute()`, but
532/// any side effects done during `listen()` have already happened.
533pub trait ListenExecuteTool<Ctx>: Send + Sync {
534    /// The type of name for this tool.
535    type Name: ToolName;
536
537    /// Returns the tool's strongly-typed name.
538    fn name(&self) -> Self::Name;
539
540    /// Human-readable display name for UI.
541    fn display_name(&self) -> &'static str;
542
543    /// Human-readable description of what the tool does.
544    fn description(&self) -> &'static str;
545
546    /// JSON schema for the tool's input parameters.
547    fn input_schema(&self) -> Value;
548
549    /// Permission tier for this tool.
550    fn tier(&self) -> ToolTier {
551        ToolTier::Confirm
552    }
553
554    /// Start and stream runtime preparation updates.
555    fn listen(
556        &self,
557        ctx: &ToolContext<Ctx>,
558        input: Value,
559    ) -> impl Stream<Item = ListenToolUpdate> + Send;
560
561    /// Execute using operation ID and optimistic concurrency revision.
562    ///
563    /// # Errors
564    /// Returns an error if execution fails or revision is stale.
565    fn execute(
566        &self,
567        ctx: &ToolContext<Ctx>,
568        operation_id: &str,
569        expected_revision: u64,
570    ) -> impl Future<Output = Result<ToolResult>> + Send;
571
572    /// Stop a listen operation (best effort).
573    ///
574    /// # Errors
575    /// Returns an error if cancellation fails.
576    fn cancel(
577        &self,
578        _ctx: &ToolContext<Ctx>,
579        _operation_id: &str,
580        _reason: ListenStopReason,
581    ) -> impl Future<Output = Result<()>> + Send {
582        async { Ok(()) }
583    }
584}
585
586// ============================================================================
587// Type-Erased Tool (for Registry)
588// ============================================================================
589
590/// Type-erased tool trait for registry storage.
591///
592/// This allows tools with different `Name` associated types to be stored
593/// in the same registry by erasing the type information.
594///
595/// # Example
596///
597/// ```ignore
598/// for tool in registry.all() {
599///     println!("Tool: {} - {}", tool.name_str(), tool.description());
600/// }
601/// ```
602#[async_trait]
603pub trait ErasedTool<Ctx>: Send + Sync {
604    /// Get the tool name as a string.
605    fn name_str(&self) -> &str;
606    /// Get a human-friendly display name for the tool.
607    fn display_name(&self) -> &'static str;
608    /// Get the tool description.
609    fn description(&self) -> &'static str;
610    /// Get the JSON schema for tool inputs.
611    fn input_schema(&self) -> Value;
612    /// Get the tool's permission tier.
613    fn tier(&self) -> ToolTier;
614    /// Execute the tool with the given input.
615    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
616}
617
618/// Wrapper that erases the Name associated type from a Tool.
619struct ToolWrapper<T, Ctx>
620where
621    T: Tool<Ctx>,
622{
623    inner: T,
624    name_cache: String,
625    _marker: PhantomData<Ctx>,
626}
627
628impl<T, Ctx> ToolWrapper<T, Ctx>
629where
630    T: Tool<Ctx>,
631{
632    fn new(tool: T) -> Self {
633        let name_cache = tool_name_to_string(&tool.name());
634        Self {
635            inner: tool,
636            name_cache,
637            _marker: PhantomData,
638        }
639    }
640}
641
642#[async_trait]
643impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
644where
645    T: Tool<Ctx> + 'static,
646    Ctx: Send + Sync + 'static,
647{
648    fn name_str(&self) -> &str {
649        &self.name_cache
650    }
651
652    fn display_name(&self) -> &'static str {
653        self.inner.display_name()
654    }
655
656    fn description(&self) -> &'static str {
657        self.inner.description()
658    }
659
660    fn input_schema(&self) -> Value {
661        self.inner.input_schema()
662    }
663
664    fn tier(&self) -> ToolTier {
665        self.inner.tier()
666    }
667
668    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
669        self.inner.execute(ctx, input).await
670    }
671}
672
673// ============================================================================
674// Type-Erased AsyncTool (for Registry)
675// ============================================================================
676
677/// Type-erased async tool trait for registry storage.
678///
679/// This allows async tools with different `Name` and `Stage` associated types
680/// to be stored in the same registry by erasing the type information.
681#[async_trait]
682pub trait ErasedAsyncTool<Ctx>: Send + Sync {
683    /// Get the tool name as a string.
684    fn name_str(&self) -> &str;
685    /// Get a human-friendly display name for the tool.
686    fn display_name(&self) -> &'static str;
687    /// Get the tool description.
688    fn description(&self) -> &'static str;
689    /// Get the JSON schema for tool inputs.
690    fn input_schema(&self) -> Value;
691    /// Get the tool's permission tier.
692    fn tier(&self) -> ToolTier;
693    /// Execute the tool with the given input.
694    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
695    /// Stream status updates for an in-progress operation (type-erased).
696    fn check_status_stream<'a>(
697        &'a self,
698        ctx: &'a ToolContext<Ctx>,
699        operation_id: &'a str,
700    ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
701}
702
703/// Wrapper that erases the Name and Stage associated types from an [`AsyncTool`].
704struct AsyncToolWrapper<T, Ctx>
705where
706    T: AsyncTool<Ctx>,
707{
708    inner: T,
709    name_cache: String,
710    _marker: PhantomData<Ctx>,
711}
712
713impl<T, Ctx> AsyncToolWrapper<T, Ctx>
714where
715    T: AsyncTool<Ctx>,
716{
717    fn new(tool: T) -> Self {
718        let name_cache = tool_name_to_string(&tool.name());
719        Self {
720            inner: tool,
721            name_cache,
722            _marker: PhantomData,
723        }
724    }
725}
726
727#[async_trait]
728impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
729where
730    T: AsyncTool<Ctx> + 'static,
731    Ctx: Send + Sync + 'static,
732{
733    fn name_str(&self) -> &str {
734        &self.name_cache
735    }
736
737    fn display_name(&self) -> &'static str {
738        self.inner.display_name()
739    }
740
741    fn description(&self) -> &'static str {
742        self.inner.description()
743    }
744
745    fn input_schema(&self) -> Value {
746        self.inner.input_schema()
747    }
748
749    fn tier(&self) -> ToolTier {
750        self.inner.tier()
751    }
752
753    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
754        self.inner.execute(ctx, input).await
755    }
756
757    fn check_status_stream<'a>(
758        &'a self,
759        ctx: &'a ToolContext<Ctx>,
760        operation_id: &'a str,
761    ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
762        use futures::StreamExt;
763        let stream = self.inner.check_status(ctx, operation_id);
764        Box::pin(stream.map(ErasedToolStatus::from))
765    }
766}
767
768// ============================================================================
769// Type-Erased ListenExecuteTool (for Registry)
770// ============================================================================
771
772/// Type-erased listen/execute tool trait for registry storage.
773#[async_trait]
774pub trait ErasedListenTool<Ctx>: Send + Sync {
775    /// Get the tool name as a string.
776    fn name_str(&self) -> &str;
777    /// Get a human-friendly display name for the tool.
778    fn display_name(&self) -> &'static str;
779    /// Get the tool description.
780    fn description(&self) -> &'static str;
781    /// Get the JSON schema for tool inputs.
782    fn input_schema(&self) -> Value;
783    /// Get the tool's permission tier.
784    fn tier(&self) -> ToolTier;
785    /// Start listen stream.
786    fn listen_stream<'a>(
787        &'a self,
788        ctx: &'a ToolContext<Ctx>,
789        input: Value,
790    ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
791    /// Execute using a prepared operation.
792    async fn execute(
793        &self,
794        ctx: &ToolContext<Ctx>,
795        operation_id: &str,
796        expected_revision: u64,
797    ) -> Result<ToolResult>;
798    /// Cancel operation.
799    async fn cancel(
800        &self,
801        ctx: &ToolContext<Ctx>,
802        operation_id: &str,
803        reason: ListenStopReason,
804    ) -> Result<()>;
805}
806
807/// Wrapper that erases the Name associated type from a [`ListenExecuteTool`].
808struct ListenToolWrapper<T, Ctx>
809where
810    T: ListenExecuteTool<Ctx>,
811{
812    inner: T,
813    name_cache: String,
814    _marker: PhantomData<Ctx>,
815}
816
817impl<T, Ctx> ListenToolWrapper<T, Ctx>
818where
819    T: ListenExecuteTool<Ctx>,
820{
821    fn new(tool: T) -> Self {
822        let name_cache = tool_name_to_string(&tool.name());
823        Self {
824            inner: tool,
825            name_cache,
826            _marker: PhantomData,
827        }
828    }
829}
830
831#[async_trait]
832impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
833where
834    T: ListenExecuteTool<Ctx> + 'static,
835    Ctx: Send + Sync + 'static,
836{
837    fn name_str(&self) -> &str {
838        &self.name_cache
839    }
840
841    fn display_name(&self) -> &'static str {
842        self.inner.display_name()
843    }
844
845    fn description(&self) -> &'static str {
846        self.inner.description()
847    }
848
849    fn input_schema(&self) -> Value {
850        self.inner.input_schema()
851    }
852
853    fn tier(&self) -> ToolTier {
854        self.inner.tier()
855    }
856
857    fn listen_stream<'a>(
858        &'a self,
859        ctx: &'a ToolContext<Ctx>,
860        input: Value,
861    ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
862        let stream = self.inner.listen(ctx, input);
863        Box::pin(stream)
864    }
865
866    async fn execute(
867        &self,
868        ctx: &ToolContext<Ctx>,
869        operation_id: &str,
870        expected_revision: u64,
871    ) -> Result<ToolResult> {
872        self.inner
873            .execute(ctx, operation_id, expected_revision)
874            .await
875    }
876
877    async fn cancel(
878        &self,
879        ctx: &ToolContext<Ctx>,
880        operation_id: &str,
881        reason: ListenStopReason,
882    ) -> Result<()> {
883        self.inner.cancel(ctx, operation_id, reason).await
884    }
885}
886
887/// Registry of available tools.
888///
889/// Tools are stored with their names erased to allow different `Name` types
890/// in the same registry. The registry uses string-based lookup for LLM
891/// compatibility.
892///
893/// Supports both synchronous [`Tool`]s and asynchronous [`AsyncTool`]s.
894pub struct ToolRegistry<Ctx> {
895    tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
896    async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
897    listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
898}
899
900impl<Ctx> Clone for ToolRegistry<Ctx> {
901    fn clone(&self) -> Self {
902        Self {
903            tools: self.tools.clone(),
904            async_tools: self.async_tools.clone(),
905            listen_tools: self.listen_tools.clone(),
906        }
907    }
908}
909
910impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
911    fn default() -> Self {
912        Self::new()
913    }
914}
915
916impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
917    #[must_use]
918    pub fn new() -> Self {
919        Self {
920            tools: HashMap::new(),
921            async_tools: HashMap::new(),
922            listen_tools: HashMap::new(),
923        }
924    }
925
926    /// Register a synchronous tool in the registry.
927    ///
928    /// The tool's name is converted to a string via serde serialization
929    /// and used as the lookup key.
930    pub fn register<T>(&mut self, tool: T) -> &mut Self
931    where
932        T: Tool<Ctx> + 'static,
933    {
934        let wrapper = ToolWrapper::new(tool);
935        let name = wrapper.name_str().to_string();
936        self.tools.insert(name, Arc::new(wrapper));
937        self
938    }
939
940    /// Register an async tool in the registry.
941    ///
942    /// Async tools have two phases: execute (lightweight, starts operation)
943    /// and `check_status` (streams progress until completion).
944    pub fn register_async<T>(&mut self, tool: T) -> &mut Self
945    where
946        T: AsyncTool<Ctx> + 'static,
947    {
948        let wrapper = AsyncToolWrapper::new(tool);
949        let name = wrapper.name_str().to_string();
950        self.async_tools.insert(name, Arc::new(wrapper));
951        self
952    }
953
954    /// Register a listen/execute tool in the registry.
955    ///
956    /// Listen/execute tools start by streaming updates via `listen()`, then run
957    /// final execution with `execute()` once confirmed.
958    pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
959    where
960        T: ListenExecuteTool<Ctx> + 'static,
961    {
962        let wrapper = ListenToolWrapper::new(tool);
963        let name = wrapper.name_str().to_string();
964        self.listen_tools.insert(name, Arc::new(wrapper));
965        self
966    }
967
968    /// Get a synchronous tool by name.
969    #[must_use]
970    pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
971        self.tools.get(name)
972    }
973
974    /// Get an async tool by name.
975    #[must_use]
976    pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
977        self.async_tools.get(name)
978    }
979
980    /// Get a listen/execute tool by name.
981    #[must_use]
982    pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
983        self.listen_tools.get(name)
984    }
985
986    /// Check if a tool name refers to an async tool.
987    #[must_use]
988    pub fn is_async(&self, name: &str) -> bool {
989        self.async_tools.contains_key(name)
990    }
991
992    /// Check if a tool name refers to a listen/execute tool.
993    #[must_use]
994    pub fn is_listen(&self, name: &str) -> bool {
995        self.listen_tools.contains_key(name)
996    }
997
998    /// Get all registered synchronous tools.
999    pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
1000        self.tools.values()
1001    }
1002
1003    /// Get all registered async tools.
1004    pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
1005        self.async_tools.values()
1006    }
1007
1008    /// Get all registered listen/execute tools.
1009    pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
1010        self.listen_tools.values()
1011    }
1012
1013    /// Get the number of registered tools (sync + async).
1014    #[must_use]
1015    pub fn len(&self) -> usize {
1016        self.tools.len() + self.async_tools.len() + self.listen_tools.len()
1017    }
1018
1019    /// Check if the registry is empty.
1020    #[must_use]
1021    pub fn is_empty(&self) -> bool {
1022        self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
1023    }
1024
1025    /// Filter tools by a predicate.
1026    ///
1027    /// Removes tools for which the predicate returns false.
1028    /// The predicate receives the tool name.
1029    /// Applies to both sync and async tools.
1030    ///
1031    /// # Example
1032    ///
1033    /// ```ignore
1034    /// registry.filter(|name| name != "bash");
1035    /// ```
1036    pub fn filter<F>(&mut self, predicate: F)
1037    where
1038        F: Fn(&str) -> bool,
1039    {
1040        self.tools.retain(|name, _| predicate(name));
1041        self.async_tools.retain(|name, _| predicate(name));
1042        self.listen_tools.retain(|name, _| predicate(name));
1043    }
1044
1045    /// Convert all tools (sync + async) to LLM tool definitions.
1046    #[must_use]
1047    pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
1048        let mut tools: Vec<_> = self
1049            .tools
1050            .values()
1051            .map(|tool| llm::Tool {
1052                name: tool.name_str().to_string(),
1053                description: tool.description().to_string(),
1054                input_schema: tool.input_schema(),
1055            })
1056            .collect();
1057
1058        tools.extend(self.async_tools.values().map(|tool| llm::Tool {
1059            name: tool.name_str().to_string(),
1060            description: tool.description().to_string(),
1061            input_schema: tool.input_schema(),
1062        }));
1063
1064        tools.extend(self.listen_tools.values().map(|tool| llm::Tool {
1065            name: tool.name_str().to_string(),
1066            description: tool.description().to_string(),
1067            input_schema: tool.input_schema(),
1068        }));
1069
1070        tools
1071    }
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077
1078    // Test tool name enum for tests
1079    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
1080    #[serde(rename_all = "snake_case")]
1081    enum TestToolName {
1082        MockTool,
1083        AnotherTool,
1084    }
1085
1086    impl ToolName for TestToolName {}
1087
1088    struct MockTool;
1089
1090    impl Tool<()> for MockTool {
1091        type Name = TestToolName;
1092
1093        fn name(&self) -> TestToolName {
1094            TestToolName::MockTool
1095        }
1096
1097        fn display_name(&self) -> &'static str {
1098            "Mock Tool"
1099        }
1100
1101        fn description(&self) -> &'static str {
1102            "A mock tool for testing"
1103        }
1104
1105        fn input_schema(&self) -> Value {
1106            serde_json::json!({
1107                "type": "object",
1108                "properties": {
1109                    "message": { "type": "string" }
1110                }
1111            })
1112        }
1113
1114        async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
1115            let message = input
1116                .get("message")
1117                .and_then(|v| v.as_str())
1118                .unwrap_or("no message");
1119            Ok(ToolResult::success(format!("Received: {message}")))
1120        }
1121    }
1122
1123    #[test]
1124    fn test_tool_name_serialization() {
1125        let name = TestToolName::MockTool;
1126        assert_eq!(tool_name_to_string(&name), "mock_tool");
1127
1128        let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
1129        assert_eq!(parsed, TestToolName::MockTool);
1130    }
1131
1132    #[test]
1133    fn test_dynamic_tool_name() {
1134        let name = DynamicToolName::new("my_mcp_tool");
1135        assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
1136        assert_eq!(name.as_str(), "my_mcp_tool");
1137    }
1138
1139    #[test]
1140    fn test_tool_registry() {
1141        let mut registry = ToolRegistry::new();
1142        registry.register(MockTool);
1143
1144        assert_eq!(registry.len(), 1);
1145        assert!(registry.get("mock_tool").is_some());
1146        assert!(registry.get("nonexistent").is_none());
1147    }
1148
1149    #[test]
1150    fn test_to_llm_tools() {
1151        let mut registry = ToolRegistry::new();
1152        registry.register(MockTool);
1153
1154        let llm_tools = registry.to_llm_tools();
1155        assert_eq!(llm_tools.len(), 1);
1156        assert_eq!(llm_tools[0].name, "mock_tool");
1157    }
1158
1159    struct AnotherTool;
1160
1161    impl Tool<()> for AnotherTool {
1162        type Name = TestToolName;
1163
1164        fn name(&self) -> TestToolName {
1165            TestToolName::AnotherTool
1166        }
1167
1168        fn display_name(&self) -> &'static str {
1169            "Another Tool"
1170        }
1171
1172        fn description(&self) -> &'static str {
1173            "Another tool for testing"
1174        }
1175
1176        fn input_schema(&self) -> Value {
1177            serde_json::json!({ "type": "object" })
1178        }
1179
1180        async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
1181            Ok(ToolResult::success("Done"))
1182        }
1183    }
1184
1185    #[test]
1186    fn test_filter_tools() {
1187        let mut registry = ToolRegistry::new();
1188        registry.register(MockTool);
1189        registry.register(AnotherTool);
1190
1191        assert_eq!(registry.len(), 2);
1192
1193        // Filter out mock_tool
1194        registry.filter(|name| name != "mock_tool");
1195
1196        assert_eq!(registry.len(), 1);
1197        assert!(registry.get("mock_tool").is_none());
1198        assert!(registry.get("another_tool").is_some());
1199    }
1200
1201    #[test]
1202    fn test_filter_tools_keep_all() {
1203        let mut registry = ToolRegistry::new();
1204        registry.register(MockTool);
1205        registry.register(AnotherTool);
1206
1207        registry.filter(|_| true);
1208
1209        assert_eq!(registry.len(), 2);
1210    }
1211
1212    #[test]
1213    fn test_filter_tools_remove_all() {
1214        let mut registry = ToolRegistry::new();
1215        registry.register(MockTool);
1216        registry.register(AnotherTool);
1217
1218        registry.filter(|_| false);
1219
1220        assert!(registry.is_empty());
1221    }
1222
1223    #[test]
1224    fn test_display_name() {
1225        let mut registry = ToolRegistry::new();
1226        registry.register(MockTool);
1227
1228        let tool = registry.get("mock_tool").unwrap();
1229        assert_eq!(tool.display_name(), "Mock Tool");
1230    }
1231
1232    struct ListenMockTool;
1233
1234    impl ListenExecuteTool<()> for ListenMockTool {
1235        type Name = TestToolName;
1236
1237        fn name(&self) -> TestToolName {
1238            TestToolName::MockTool
1239        }
1240
1241        fn display_name(&self) -> &'static str {
1242            "Listen Mock Tool"
1243        }
1244
1245        fn description(&self) -> &'static str {
1246            "A listen/execute mock tool for testing"
1247        }
1248
1249        fn input_schema(&self) -> Value {
1250            serde_json::json!({ "type": "object" })
1251        }
1252
1253        fn listen(
1254            &self,
1255            _ctx: &ToolContext<()>,
1256            _input: Value,
1257        ) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
1258            futures::stream::iter(vec![ListenToolUpdate::Ready {
1259                operation_id: "op_1".to_string(),
1260                revision: 1,
1261                message: "ready".to_string(),
1262                snapshot: serde_json::json!({"ok": true}),
1263                expires_at: None,
1264            }])
1265        }
1266
1267        async fn execute(
1268            &self,
1269            _ctx: &ToolContext<()>,
1270            _operation_id: &str,
1271            _expected_revision: u64,
1272        ) -> Result<ToolResult> {
1273            Ok(ToolResult::success("Executed"))
1274        }
1275    }
1276
1277    #[test]
1278    fn test_listen_tool_registry() {
1279        let mut registry = ToolRegistry::new();
1280        registry.register_listen(ListenMockTool);
1281
1282        assert_eq!(registry.len(), 1);
1283        assert!(registry.get_listen("mock_tool").is_some());
1284        assert!(registry.is_listen("mock_tool"));
1285    }
1286}