Skip to main content

agent_sdk/
tools.rs

1//! Tool definition and registry.
2//!
3//! Tools allow the LLM to perform actions in the real world. This module provides:
4//!
5//! - [`Tool`] trait - Define custom tools the LLM can call
6//! - [`ToolName`] trait - Marker trait for strongly-typed tool names
7//! - [`PrimitiveToolName`] - Tool names for SDK's built-in tools
8//! - [`DynamicToolName`] - Tool names created at runtime (MCP bridges)
9//! - [`ToolRegistry`] - Collection of available tools
10//! - [`ToolContext`] - Context passed to tool execution
11//! - [`ListenExecuteTool`] - Tools that listen for updates, then execute later
12//!
13//! # Implementing a Tool
14//!
15//! ```ignore
16//! use agent_sdk::{Tool, ToolContext, ToolResult, ToolTier, PrimitiveToolName};
17//!
18//! struct MyTool;
19//!
20//! // No #[async_trait] needed - Rust 1.75+ supports native async traits
21//! impl Tool<MyContext> for MyTool {
22//!     type Name = PrimitiveToolName;
23//!
24//!     fn name(&self) -> PrimitiveToolName { PrimitiveToolName::Read }
25//!     fn display_name(&self) -> &'static str { "My Tool" }
26//!     fn description(&self) -> &'static str { "Does something useful" }
27//!     fn input_schema(&self) -> Value { json!({ "type": "object" }) }
28//!     fn tier(&self) -> ToolTier { ToolTier::Observe }
29//!
30//!     async fn execute(&self, ctx: &ToolContext<MyContext>, input: Value) -> Result<ToolResult> {
31//!         Ok(ToolResult::success("Done!"))
32//!     }
33//! }
34//! ```
35
36use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
37use crate::llm;
38use crate::types::{ToolOutcome, ToolResult, ToolTier};
39use anyhow::Result;
40use async_trait::async_trait;
41use futures::Stream;
42use serde::{Deserialize, Serialize, de::DeserializeOwned};
43use serde_json::Value;
44use std::collections::HashMap;
45use std::future::Future;
46use std::marker::PhantomData;
47use std::pin::Pin;
48use std::sync::Arc;
49use time::OffsetDateTime;
50use tokio::sync::mpsc;
51
52// ============================================================================
53// Tool Name Types
54// ============================================================================
55
56/// Marker trait for tool names.
57///
58/// Tool names must be serializable (for storage/logging) and deserializable
59/// (for parsing from LLM responses). The string representation is derived
60/// from serde serialization.
61///
62/// # Example
63///
64/// ```ignore
65/// #[derive(Serialize, Deserialize)]
66/// #[serde(rename_all = "snake_case")]
67/// pub enum MyToolName {
68///     Read,
69///     Write,
70/// }
71///
72/// impl ToolName for MyToolName {}
73/// ```
74pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
75
76/// Helper to get string representation of a tool name via serde.
77///
78/// # Panics
79///
80/// Panics if the tool name cannot be serialized to a string. This should
81/// never happen with properly implemented `ToolName` types that use
82/// `#[derive(Serialize)]`.
83#[must_use]
84pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
85    serde_json::to_string(name)
86        .expect("ToolName must serialize to string")
87        .trim_matches('"')
88        .to_string()
89}
90
91/// Parse a tool name from string via serde.
92///
93/// # Errors
94/// Returns error if the string doesn't match a valid tool name.
95pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
96    serde_json::from_str(&format!("\"{s}\""))
97}
98
99/// Tool names for SDK's built-in primitive tools.
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101#[serde(rename_all = "snake_case")]
102pub enum PrimitiveToolName {
103    Read,
104    Write,
105    Edit,
106    MultiEdit,
107    Bash,
108    Glob,
109    Grep,
110    NotebookRead,
111    NotebookEdit,
112    TodoRead,
113    TodoWrite,
114    AskUser,
115    LinkFetch,
116    WebSearch,
117}
118
119impl ToolName for PrimitiveToolName {}
120
121/// Dynamic tool name for runtime-created tools (MCP bridges, subagents).
122#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
123#[serde(transparent)]
124pub struct DynamicToolName(String);
125
126impl DynamicToolName {
127    #[must_use]
128    pub fn new(name: impl Into<String>) -> Self {
129        Self(name.into())
130    }
131
132    #[must_use]
133    pub fn as_str(&self) -> &str {
134        &self.0
135    }
136}
137
138impl ToolName for DynamicToolName {}
139
140// ============================================================================
141// Progress Stage Types (for AsyncTool)
142// ============================================================================
143
144/// Marker trait for tool progress stages (type-safe, like [`ToolName`]).
145///
146/// Progress stages are used by async tools to indicate the current phase
147/// of a long-running operation. They must be serializable for event streaming.
148///
149/// # Example
150///
151/// ```ignore
152/// #[derive(Clone, Debug, Serialize, Deserialize)]
153/// #[serde(rename_all = "snake_case")]
154/// pub enum PixTransferStage {
155///     Initiated,
156///     Processing,
157///     SentToBank,
158/// }
159///
160/// impl ProgressStage for PixTransferStage {}
161/// ```
162pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
163
164/// Helper to get string representation of a progress stage via serde.
165///
166/// # Panics
167///
168/// Panics if the stage cannot be serialized to a string. This should
169/// never happen with properly implemented `ProgressStage` types.
170#[must_use]
171pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
172    serde_json::to_string(stage)
173        .expect("ProgressStage must serialize to string")
174        .trim_matches('"')
175        .to_string()
176}
177
178/// Status update from an async tool operation.
179#[derive(Clone, Debug, Serialize)]
180pub enum ToolStatus<S: ProgressStage> {
181    /// Operation is making progress
182    Progress {
183        stage: S,
184        message: String,
185        data: Option<serde_json::Value>,
186    },
187
188    /// Operation completed successfully
189    Completed(ToolResult),
190
191    /// Operation failed
192    Failed(ToolResult),
193}
194
195/// Type-erased status for the agent loop.
196#[derive(Clone, Debug, Serialize, Deserialize)]
197pub enum ErasedToolStatus {
198    /// Operation is making progress
199    Progress {
200        stage: String,
201        message: String,
202        data: Option<serde_json::Value>,
203    },
204    /// Operation completed successfully
205    Completed(ToolResult),
206    /// Operation failed
207    Failed(ToolResult),
208}
209
210/// Update emitted from a `listen()` stream.
211///
212/// This models workflows where a runtime prepares an operation over time, and
213/// execution happens later using an operation identifier and revision.
214#[derive(Clone, Debug, Serialize, Deserialize)]
215pub enum ListenToolUpdate {
216    /// Preparation is still running and should keep listening.
217    Listening {
218        /// Opaque operation identifier used for later execute/cancel calls.
219        operation_id: String,
220        /// Monotonic revision number for optimistic concurrency.
221        revision: u64,
222        /// Human-readable status message.
223        message: String,
224        /// Optional current snapshot for UI rendering.
225        snapshot: Option<serde_json::Value>,
226        /// Optional expiration timestamp (RFC3339).
227        #[serde(with = "time::serde::rfc3339::option")]
228        expires_at: Option<OffsetDateTime>,
229    },
230
231    /// Preparation is complete and execution can be confirmed.
232    Ready {
233        /// Opaque operation identifier used for later execute/cancel calls.
234        operation_id: String,
235        /// Monotonic revision number for optimistic concurrency.
236        revision: u64,
237        /// Human-readable status message.
238        message: String,
239        /// Snapshot shown in confirmation UI.
240        snapshot: serde_json::Value,
241        /// Optional expiration timestamp (RFC3339).
242        #[serde(with = "time::serde::rfc3339::option")]
243        expires_at: Option<OffsetDateTime>,
244    },
245
246    /// Operation is no longer valid.
247    Invalidated {
248        /// Opaque operation identifier.
249        operation_id: String,
250        /// Human-readable reason.
251        message: String,
252        /// Whether caller may recover by starting a new listen operation.
253        recoverable: bool,
254    },
255}
256
257/// Reason for stopping a listen session.
258#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
259pub enum ListenStopReason {
260    /// User explicitly rejected confirmation.
261    UserRejected,
262    /// Agent policy/hook blocked execution before confirmation.
263    Blocked,
264    /// Consumer disconnected while listen stream was active.
265    StreamDisconnected,
266    /// Listen stream ended unexpectedly before terminal state.
267    StreamEnded,
268}
269
270impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
271    fn from(status: ToolStatus<S>) -> Self {
272        match status {
273            ToolStatus::Progress {
274                stage,
275                message,
276                data,
277            } => Self::Progress {
278                stage: stage_to_string(&stage),
279                message,
280                data,
281            },
282            ToolStatus::Completed(r) => Self::Completed(r),
283            ToolStatus::Failed(r) => Self::Failed(r),
284        }
285    }
286}
287
288/// Context passed to tool execution
289pub struct ToolContext<Ctx> {
290    /// Application-specific context (e.g., `user_id`, db connection)
291    pub app: Ctx,
292    /// Tool-specific metadata
293    pub metadata: HashMap<String, Value>,
294    /// Optional channel for tools to emit events (e.g., subagent progress)
295    event_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
296    /// Optional sequence counter for wrapping events in envelopes
297    event_seq: Option<SequenceCounter>,
298}
299
300impl<Ctx> ToolContext<Ctx> {
301    #[must_use]
302    pub fn new(app: Ctx) -> Self {
303        Self {
304            app,
305            metadata: HashMap::new(),
306            event_tx: None,
307            event_seq: None,
308        }
309    }
310
311    #[must_use]
312    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
313        self.metadata.insert(key.into(), value);
314        self
315    }
316
317    /// Set the event channel and sequence counter for tools that need to emit
318    /// events during execution.
319    #[must_use]
320    pub fn with_event_tx(
321        mut self,
322        tx: mpsc::Sender<AgentEventEnvelope>,
323        seq: SequenceCounter,
324    ) -> Self {
325        self.event_tx = Some(tx);
326        self.event_seq = Some(seq);
327        self
328    }
329
330    /// Emit an event through the event channel (if set).
331    ///
332    /// The event is wrapped in an [`AgentEventEnvelope`] with a unique ID,
333    /// sequence number, and timestamp before sending.
334    ///
335    /// This uses `try_send` to avoid blocking and to ensure the future is `Send`.
336    /// The event is silently dropped if the channel is full.
337    pub fn emit_event(&self, event: AgentEvent) {
338        if let Some((tx, seq)) = self.event_tx.as_ref().zip(self.event_seq.as_ref()) {
339            let envelope = AgentEventEnvelope::wrap(event, seq);
340            let _ = tx.try_send(envelope);
341        }
342    }
343
344    /// Get a clone of the event channel sender (if set).
345    ///
346    /// This is useful for tools that spawn subprocesses (like subagents)
347    /// and need to forward events to the parent's event stream.
348    #[must_use]
349    pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEventEnvelope>> {
350        self.event_tx.clone()
351    }
352
353    /// Get a clone of the sequence counter (if set).
354    ///
355    /// This is useful for tools that spawn subprocesses (like subagents)
356    /// and need to assign sequence numbers to events sent to the parent's stream.
357    #[must_use]
358    pub fn event_seq(&self) -> Option<SequenceCounter> {
359        self.event_seq.clone()
360    }
361}
362
363// ============================================================================
364// Tool Trait
365// ============================================================================
366
367/// Definition of a tool that can be called by the agent.
368///
369/// Tools have a strongly-typed `Name` associated type that determines
370/// how the tool name is serialized for LLM communication.
371///
372/// # Native Async Support
373///
374/// This trait uses Rust's native async functions in traits (stabilized in Rust 1.75).
375/// You do NOT need the `async_trait` crate to implement this trait.
376pub trait Tool<Ctx>: Send + Sync {
377    /// The type of name for this tool.
378    type Name: ToolName;
379
380    /// Returns the tool's strongly-typed name.
381    fn name(&self) -> Self::Name;
382
383    /// Human-readable display name for UI (e.g., "Read File" vs "read").
384    ///
385    /// Defaults to empty string. Override for better UX.
386    fn display_name(&self) -> &'static str;
387
388    /// Human-readable description of what the tool does.
389    fn description(&self) -> &'static str;
390
391    /// JSON schema for the tool's input parameters.
392    fn input_schema(&self) -> Value;
393
394    /// Permission tier for this tool.
395    fn tier(&self) -> ToolTier {
396        ToolTier::Observe
397    }
398
399    /// Execute the tool with the given input.
400    ///
401    /// # Errors
402    /// Returns an error if tool execution fails.
403    fn execute(
404        &self,
405        ctx: &ToolContext<Ctx>,
406        input: Value,
407    ) -> impl Future<Output = Result<ToolResult>> + Send;
408}
409
410// ============================================================================
411// AsyncTool Trait
412// ============================================================================
413
414/// A tool that performs long-running async operations.
415///
416/// `AsyncTool`s have two phases:
417/// 1. `execute()` - Start the operation (lightweight, returns quickly)
418/// 2. `check_status()` - Stream progress until completion
419///
420/// The actual work should happen externally (background task, external service)
421/// and persist results to a durable store. The tool is just an orchestrator.
422///
423/// # Example
424///
425/// ```ignore
426/// impl AsyncTool<MyCtx> for ExecutePixTransferTool {
427///     type Name = PixToolName;
428///     type Stage = PixTransferStage;
429///
430///     async fn execute(&self, ctx: &ToolContext<MyCtx>, input: Value) -> Result<ToolOutcome> {
431///         let params = parse_input(&input)?;
432///         let operation_id = ctx.app.pix_service.start_transfer(params).await?;
433///         Ok(ToolOutcome::in_progress(
434///             operation_id,
435///             format!("PIX transfer of {} initiated", params.amount),
436///         ))
437///     }
438///
439///     fn check_status(&self, ctx: &ToolContext<MyCtx>, operation_id: &str)
440///         -> impl Stream<Item = ToolStatus<PixTransferStage>> + Send
441///     {
442///         async_stream::stream! {
443///             loop {
444///                 let status = ctx.app.pix_service.get_status(operation_id).await;
445///                 match status {
446///                     PixStatus::Success { id } => {
447///                         yield ToolStatus::Completed(ToolResult::success(id));
448///                         break;
449///                     }
450///                     _ => yield ToolStatus::Progress { ... };
451///                 }
452///                 tokio::time::sleep(Duration::from_millis(500)).await;
453///             }
454///         }
455///     }
456/// }
457/// ```
458pub trait AsyncTool<Ctx>: Send + Sync {
459    /// The type of name for this tool.
460    type Name: ToolName;
461    /// The type of progress stages for this tool.
462    type Stage: ProgressStage;
463
464    /// Returns the tool's strongly-typed name.
465    fn name(&self) -> Self::Name;
466
467    /// Human-readable display name for UI.
468    fn display_name(&self) -> &'static str;
469
470    /// Human-readable description of what the tool does.
471    fn description(&self) -> &'static str;
472
473    /// JSON schema for the tool's input parameters.
474    fn input_schema(&self) -> Value;
475
476    /// Permission tier for this tool.
477    fn tier(&self) -> ToolTier {
478        ToolTier::Observe
479    }
480
481    /// Execute the tool. Returns immediately with one of:
482    /// - Success/Failed: Operation completed synchronously
483    /// - `InProgress`: Operation started, use `check_status()` to stream updates
484    ///
485    /// # Errors
486    /// Returns an error if tool execution fails.
487    fn execute(
488        &self,
489        ctx: &ToolContext<Ctx>,
490        input: Value,
491    ) -> impl Future<Output = Result<ToolOutcome>> + Send;
492
493    /// Stream status updates for an in-progress operation.
494    /// Must yield until Completed or Failed.
495    fn check_status(
496        &self,
497        ctx: &ToolContext<Ctx>,
498        operation_id: &str,
499    ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
500}
501
502// ============================================================================
503// ListenExecuteTool Trait
504// ============================================================================
505
506/// A tool whose runtime has two phases:
507/// 1. `listen()` - starts preparation and streams updates
508/// 2. `execute()` - performs final execution after confirmation
509///
510/// This abstraction is useful when runtime state can expire or evolve before
511/// execution (quotes, challenge windows, leases, approvals).
512///
513/// Ordering note: the agent loop consumes `listen()` updates before
514/// `AgentHooks::pre_tool_use()` runs. Hooks can therefore block `execute()`, but
515/// any side effects done during `listen()` have already happened.
516pub trait ListenExecuteTool<Ctx>: Send + Sync {
517    /// The type of name for this tool.
518    type Name: ToolName;
519
520    /// Returns the tool's strongly-typed name.
521    fn name(&self) -> Self::Name;
522
523    /// Human-readable display name for UI.
524    fn display_name(&self) -> &'static str;
525
526    /// Human-readable description of what the tool does.
527    fn description(&self) -> &'static str;
528
529    /// JSON schema for the tool's input parameters.
530    fn input_schema(&self) -> Value;
531
532    /// Permission tier for this tool.
533    fn tier(&self) -> ToolTier {
534        ToolTier::Confirm
535    }
536
537    /// Start and stream runtime preparation updates.
538    fn listen(
539        &self,
540        ctx: &ToolContext<Ctx>,
541        input: Value,
542    ) -> impl Stream<Item = ListenToolUpdate> + Send;
543
544    /// Execute using operation ID and optimistic concurrency revision.
545    ///
546    /// # Errors
547    /// Returns an error if execution fails or revision is stale.
548    fn execute(
549        &self,
550        ctx: &ToolContext<Ctx>,
551        operation_id: &str,
552        expected_revision: u64,
553    ) -> impl Future<Output = Result<ToolResult>> + Send;
554
555    /// Stop a listen operation (best effort).
556    ///
557    /// # Errors
558    /// Returns an error if cancellation fails.
559    fn cancel(
560        &self,
561        _ctx: &ToolContext<Ctx>,
562        _operation_id: &str,
563        _reason: ListenStopReason,
564    ) -> impl Future<Output = Result<()>> + Send {
565        async { Ok(()) }
566    }
567}
568
569// ============================================================================
570// Type-Erased Tool (for Registry)
571// ============================================================================
572
573/// Type-erased tool trait for registry storage.
574///
575/// This allows tools with different `Name` associated types to be stored
576/// in the same registry by erasing the type information.
577///
578/// # Example
579///
580/// ```ignore
581/// for tool in registry.all() {
582///     println!("Tool: {} - {}", tool.name_str(), tool.description());
583/// }
584/// ```
585#[async_trait]
586pub trait ErasedTool<Ctx>: Send + Sync {
587    /// Get the tool name as a string.
588    fn name_str(&self) -> &str;
589    /// Get a human-friendly display name for the tool.
590    fn display_name(&self) -> &'static str;
591    /// Get the tool description.
592    fn description(&self) -> &'static str;
593    /// Get the JSON schema for tool inputs.
594    fn input_schema(&self) -> Value;
595    /// Get the tool's permission tier.
596    fn tier(&self) -> ToolTier;
597    /// Execute the tool with the given input.
598    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
599}
600
601/// Wrapper that erases the Name associated type from a Tool.
602struct ToolWrapper<T, Ctx>
603where
604    T: Tool<Ctx>,
605{
606    inner: T,
607    name_cache: String,
608    _marker: PhantomData<Ctx>,
609}
610
611impl<T, Ctx> ToolWrapper<T, Ctx>
612where
613    T: Tool<Ctx>,
614{
615    fn new(tool: T) -> Self {
616        let name_cache = tool_name_to_string(&tool.name());
617        Self {
618            inner: tool,
619            name_cache,
620            _marker: PhantomData,
621        }
622    }
623}
624
625#[async_trait]
626impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
627where
628    T: Tool<Ctx> + 'static,
629    Ctx: Send + Sync + 'static,
630{
631    fn name_str(&self) -> &str {
632        &self.name_cache
633    }
634
635    fn display_name(&self) -> &'static str {
636        self.inner.display_name()
637    }
638
639    fn description(&self) -> &'static str {
640        self.inner.description()
641    }
642
643    fn input_schema(&self) -> Value {
644        self.inner.input_schema()
645    }
646
647    fn tier(&self) -> ToolTier {
648        self.inner.tier()
649    }
650
651    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
652        self.inner.execute(ctx, input).await
653    }
654}
655
656// ============================================================================
657// Type-Erased AsyncTool (for Registry)
658// ============================================================================
659
660/// Type-erased async tool trait for registry storage.
661///
662/// This allows async tools with different `Name` and `Stage` associated types
663/// to be stored in the same registry by erasing the type information.
664#[async_trait]
665pub trait ErasedAsyncTool<Ctx>: Send + Sync {
666    /// Get the tool name as a string.
667    fn name_str(&self) -> &str;
668    /// Get a human-friendly display name for the tool.
669    fn display_name(&self) -> &'static str;
670    /// Get the tool description.
671    fn description(&self) -> &'static str;
672    /// Get the JSON schema for tool inputs.
673    fn input_schema(&self) -> Value;
674    /// Get the tool's permission tier.
675    fn tier(&self) -> ToolTier;
676    /// Execute the tool with the given input.
677    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
678    /// Stream status updates for an in-progress operation (type-erased).
679    fn check_status_stream<'a>(
680        &'a self,
681        ctx: &'a ToolContext<Ctx>,
682        operation_id: &'a str,
683    ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
684}
685
686/// Wrapper that erases the Name and Stage associated types from an [`AsyncTool`].
687struct AsyncToolWrapper<T, Ctx>
688where
689    T: AsyncTool<Ctx>,
690{
691    inner: T,
692    name_cache: String,
693    _marker: PhantomData<Ctx>,
694}
695
696impl<T, Ctx> AsyncToolWrapper<T, Ctx>
697where
698    T: AsyncTool<Ctx>,
699{
700    fn new(tool: T) -> Self {
701        let name_cache = tool_name_to_string(&tool.name());
702        Self {
703            inner: tool,
704            name_cache,
705            _marker: PhantomData,
706        }
707    }
708}
709
710#[async_trait]
711impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
712where
713    T: AsyncTool<Ctx> + 'static,
714    Ctx: Send + Sync + 'static,
715{
716    fn name_str(&self) -> &str {
717        &self.name_cache
718    }
719
720    fn display_name(&self) -> &'static str {
721        self.inner.display_name()
722    }
723
724    fn description(&self) -> &'static str {
725        self.inner.description()
726    }
727
728    fn input_schema(&self) -> Value {
729        self.inner.input_schema()
730    }
731
732    fn tier(&self) -> ToolTier {
733        self.inner.tier()
734    }
735
736    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
737        self.inner.execute(ctx, input).await
738    }
739
740    fn check_status_stream<'a>(
741        &'a self,
742        ctx: &'a ToolContext<Ctx>,
743        operation_id: &'a str,
744    ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
745        use futures::StreamExt;
746        let stream = self.inner.check_status(ctx, operation_id);
747        Box::pin(stream.map(ErasedToolStatus::from))
748    }
749}
750
751// ============================================================================
752// Type-Erased ListenExecuteTool (for Registry)
753// ============================================================================
754
755/// Type-erased listen/execute tool trait for registry storage.
756#[async_trait]
757pub trait ErasedListenTool<Ctx>: Send + Sync {
758    /// Get the tool name as a string.
759    fn name_str(&self) -> &str;
760    /// Get a human-friendly display name for the tool.
761    fn display_name(&self) -> &'static str;
762    /// Get the tool description.
763    fn description(&self) -> &'static str;
764    /// Get the JSON schema for tool inputs.
765    fn input_schema(&self) -> Value;
766    /// Get the tool's permission tier.
767    fn tier(&self) -> ToolTier;
768    /// Start listen stream.
769    fn listen_stream<'a>(
770        &'a self,
771        ctx: &'a ToolContext<Ctx>,
772        input: Value,
773    ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
774    /// Execute using a prepared operation.
775    async fn execute(
776        &self,
777        ctx: &ToolContext<Ctx>,
778        operation_id: &str,
779        expected_revision: u64,
780    ) -> Result<ToolResult>;
781    /// Cancel operation.
782    async fn cancel(
783        &self,
784        ctx: &ToolContext<Ctx>,
785        operation_id: &str,
786        reason: ListenStopReason,
787    ) -> Result<()>;
788}
789
790/// Wrapper that erases the Name associated type from a [`ListenExecuteTool`].
791struct ListenToolWrapper<T, Ctx>
792where
793    T: ListenExecuteTool<Ctx>,
794{
795    inner: T,
796    name_cache: String,
797    _marker: PhantomData<Ctx>,
798}
799
800impl<T, Ctx> ListenToolWrapper<T, Ctx>
801where
802    T: ListenExecuteTool<Ctx>,
803{
804    fn new(tool: T) -> Self {
805        let name_cache = tool_name_to_string(&tool.name());
806        Self {
807            inner: tool,
808            name_cache,
809            _marker: PhantomData,
810        }
811    }
812}
813
814#[async_trait]
815impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
816where
817    T: ListenExecuteTool<Ctx> + 'static,
818    Ctx: Send + Sync + 'static,
819{
820    fn name_str(&self) -> &str {
821        &self.name_cache
822    }
823
824    fn display_name(&self) -> &'static str {
825        self.inner.display_name()
826    }
827
828    fn description(&self) -> &'static str {
829        self.inner.description()
830    }
831
832    fn input_schema(&self) -> Value {
833        self.inner.input_schema()
834    }
835
836    fn tier(&self) -> ToolTier {
837        self.inner.tier()
838    }
839
840    fn listen_stream<'a>(
841        &'a self,
842        ctx: &'a ToolContext<Ctx>,
843        input: Value,
844    ) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
845        let stream = self.inner.listen(ctx, input);
846        Box::pin(stream)
847    }
848
849    async fn execute(
850        &self,
851        ctx: &ToolContext<Ctx>,
852        operation_id: &str,
853        expected_revision: u64,
854    ) -> Result<ToolResult> {
855        self.inner
856            .execute(ctx, operation_id, expected_revision)
857            .await
858    }
859
860    async fn cancel(
861        &self,
862        ctx: &ToolContext<Ctx>,
863        operation_id: &str,
864        reason: ListenStopReason,
865    ) -> Result<()> {
866        self.inner.cancel(ctx, operation_id, reason).await
867    }
868}
869
870/// Registry of available tools.
871///
872/// Tools are stored with their names erased to allow different `Name` types
873/// in the same registry. The registry uses string-based lookup for LLM
874/// compatibility.
875///
876/// Supports both synchronous [`Tool`]s and asynchronous [`AsyncTool`]s.
877pub struct ToolRegistry<Ctx> {
878    tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
879    async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
880    listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
881}
882
883impl<Ctx> Clone for ToolRegistry<Ctx> {
884    fn clone(&self) -> Self {
885        Self {
886            tools: self.tools.clone(),
887            async_tools: self.async_tools.clone(),
888            listen_tools: self.listen_tools.clone(),
889        }
890    }
891}
892
893impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
894    fn default() -> Self {
895        Self::new()
896    }
897}
898
899impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
900    #[must_use]
901    pub fn new() -> Self {
902        Self {
903            tools: HashMap::new(),
904            async_tools: HashMap::new(),
905            listen_tools: HashMap::new(),
906        }
907    }
908
909    /// Register a synchronous tool in the registry.
910    ///
911    /// The tool's name is converted to a string via serde serialization
912    /// and used as the lookup key.
913    pub fn register<T>(&mut self, tool: T) -> &mut Self
914    where
915        T: Tool<Ctx> + 'static,
916    {
917        let wrapper = ToolWrapper::new(tool);
918        let name = wrapper.name_str().to_string();
919        self.tools.insert(name, Arc::new(wrapper));
920        self
921    }
922
923    /// Register an async tool in the registry.
924    ///
925    /// Async tools have two phases: execute (lightweight, starts operation)
926    /// and `check_status` (streams progress until completion).
927    pub fn register_async<T>(&mut self, tool: T) -> &mut Self
928    where
929        T: AsyncTool<Ctx> + 'static,
930    {
931        let wrapper = AsyncToolWrapper::new(tool);
932        let name = wrapper.name_str().to_string();
933        self.async_tools.insert(name, Arc::new(wrapper));
934        self
935    }
936
937    /// Register a listen/execute tool in the registry.
938    ///
939    /// Listen/execute tools start by streaming updates via `listen()`, then run
940    /// final execution with `execute()` once confirmed.
941    pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
942    where
943        T: ListenExecuteTool<Ctx> + 'static,
944    {
945        let wrapper = ListenToolWrapper::new(tool);
946        let name = wrapper.name_str().to_string();
947        self.listen_tools.insert(name, Arc::new(wrapper));
948        self
949    }
950
951    /// Get a synchronous tool by name.
952    #[must_use]
953    pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
954        self.tools.get(name)
955    }
956
957    /// Get an async tool by name.
958    #[must_use]
959    pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
960        self.async_tools.get(name)
961    }
962
963    /// Get a listen/execute tool by name.
964    #[must_use]
965    pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
966        self.listen_tools.get(name)
967    }
968
969    /// Check if a tool name refers to an async tool.
970    #[must_use]
971    pub fn is_async(&self, name: &str) -> bool {
972        self.async_tools.contains_key(name)
973    }
974
975    /// Check if a tool name refers to a listen/execute tool.
976    #[must_use]
977    pub fn is_listen(&self, name: &str) -> bool {
978        self.listen_tools.contains_key(name)
979    }
980
981    /// Get all registered synchronous tools.
982    pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
983        self.tools.values()
984    }
985
986    /// Get all registered async tools.
987    pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
988        self.async_tools.values()
989    }
990
991    /// Get all registered listen/execute tools.
992    pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
993        self.listen_tools.values()
994    }
995
996    /// Get the number of registered tools (sync + async).
997    #[must_use]
998    pub fn len(&self) -> usize {
999        self.tools.len() + self.async_tools.len() + self.listen_tools.len()
1000    }
1001
1002    /// Check if the registry is empty.
1003    #[must_use]
1004    pub fn is_empty(&self) -> bool {
1005        self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
1006    }
1007
1008    /// Filter tools by a predicate.
1009    ///
1010    /// Removes tools for which the predicate returns false.
1011    /// The predicate receives the tool name.
1012    /// Applies to both sync and async tools.
1013    ///
1014    /// # Example
1015    ///
1016    /// ```ignore
1017    /// registry.filter(|name| name != "bash");
1018    /// ```
1019    pub fn filter<F>(&mut self, predicate: F)
1020    where
1021        F: Fn(&str) -> bool,
1022    {
1023        self.tools.retain(|name, _| predicate(name));
1024        self.async_tools.retain(|name, _| predicate(name));
1025        self.listen_tools.retain(|name, _| predicate(name));
1026    }
1027
1028    /// Convert all tools (sync + async) to LLM tool definitions.
1029    #[must_use]
1030    pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
1031        let mut tools: Vec<_> = self
1032            .tools
1033            .values()
1034            .map(|tool| llm::Tool {
1035                name: tool.name_str().to_string(),
1036                description: tool.description().to_string(),
1037                input_schema: tool.input_schema(),
1038            })
1039            .collect();
1040
1041        tools.extend(self.async_tools.values().map(|tool| llm::Tool {
1042            name: tool.name_str().to_string(),
1043            description: tool.description().to_string(),
1044            input_schema: tool.input_schema(),
1045        }));
1046
1047        tools.extend(self.listen_tools.values().map(|tool| llm::Tool {
1048            name: tool.name_str().to_string(),
1049            description: tool.description().to_string(),
1050            input_schema: tool.input_schema(),
1051        }));
1052
1053        tools
1054    }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060
1061    // Test tool name enum for tests
1062    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
1063    #[serde(rename_all = "snake_case")]
1064    enum TestToolName {
1065        MockTool,
1066        AnotherTool,
1067    }
1068
1069    impl ToolName for TestToolName {}
1070
1071    struct MockTool;
1072
1073    impl Tool<()> for MockTool {
1074        type Name = TestToolName;
1075
1076        fn name(&self) -> TestToolName {
1077            TestToolName::MockTool
1078        }
1079
1080        fn display_name(&self) -> &'static str {
1081            "Mock Tool"
1082        }
1083
1084        fn description(&self) -> &'static str {
1085            "A mock tool for testing"
1086        }
1087
1088        fn input_schema(&self) -> Value {
1089            serde_json::json!({
1090                "type": "object",
1091                "properties": {
1092                    "message": { "type": "string" }
1093                }
1094            })
1095        }
1096
1097        async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
1098            let message = input
1099                .get("message")
1100                .and_then(|v| v.as_str())
1101                .unwrap_or("no message");
1102            Ok(ToolResult::success(format!("Received: {message}")))
1103        }
1104    }
1105
1106    #[test]
1107    fn test_tool_name_serialization() {
1108        let name = TestToolName::MockTool;
1109        assert_eq!(tool_name_to_string(&name), "mock_tool");
1110
1111        let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
1112        assert_eq!(parsed, TestToolName::MockTool);
1113    }
1114
1115    #[test]
1116    fn test_dynamic_tool_name() {
1117        let name = DynamicToolName::new("my_mcp_tool");
1118        assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
1119        assert_eq!(name.as_str(), "my_mcp_tool");
1120    }
1121
1122    #[test]
1123    fn test_tool_registry() {
1124        let mut registry = ToolRegistry::new();
1125        registry.register(MockTool);
1126
1127        assert_eq!(registry.len(), 1);
1128        assert!(registry.get("mock_tool").is_some());
1129        assert!(registry.get("nonexistent").is_none());
1130    }
1131
1132    #[test]
1133    fn test_to_llm_tools() {
1134        let mut registry = ToolRegistry::new();
1135        registry.register(MockTool);
1136
1137        let llm_tools = registry.to_llm_tools();
1138        assert_eq!(llm_tools.len(), 1);
1139        assert_eq!(llm_tools[0].name, "mock_tool");
1140    }
1141
1142    struct AnotherTool;
1143
1144    impl Tool<()> for AnotherTool {
1145        type Name = TestToolName;
1146
1147        fn name(&self) -> TestToolName {
1148            TestToolName::AnotherTool
1149        }
1150
1151        fn display_name(&self) -> &'static str {
1152            "Another Tool"
1153        }
1154
1155        fn description(&self) -> &'static str {
1156            "Another tool for testing"
1157        }
1158
1159        fn input_schema(&self) -> Value {
1160            serde_json::json!({ "type": "object" })
1161        }
1162
1163        async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
1164            Ok(ToolResult::success("Done"))
1165        }
1166    }
1167
1168    #[test]
1169    fn test_filter_tools() {
1170        let mut registry = ToolRegistry::new();
1171        registry.register(MockTool);
1172        registry.register(AnotherTool);
1173
1174        assert_eq!(registry.len(), 2);
1175
1176        // Filter out mock_tool
1177        registry.filter(|name| name != "mock_tool");
1178
1179        assert_eq!(registry.len(), 1);
1180        assert!(registry.get("mock_tool").is_none());
1181        assert!(registry.get("another_tool").is_some());
1182    }
1183
1184    #[test]
1185    fn test_filter_tools_keep_all() {
1186        let mut registry = ToolRegistry::new();
1187        registry.register(MockTool);
1188        registry.register(AnotherTool);
1189
1190        registry.filter(|_| true);
1191
1192        assert_eq!(registry.len(), 2);
1193    }
1194
1195    #[test]
1196    fn test_filter_tools_remove_all() {
1197        let mut registry = ToolRegistry::new();
1198        registry.register(MockTool);
1199        registry.register(AnotherTool);
1200
1201        registry.filter(|_| false);
1202
1203        assert!(registry.is_empty());
1204    }
1205
1206    #[test]
1207    fn test_display_name() {
1208        let mut registry = ToolRegistry::new();
1209        registry.register(MockTool);
1210
1211        let tool = registry.get("mock_tool").unwrap();
1212        assert_eq!(tool.display_name(), "Mock Tool");
1213    }
1214
1215    struct ListenMockTool;
1216
1217    impl ListenExecuteTool<()> for ListenMockTool {
1218        type Name = TestToolName;
1219
1220        fn name(&self) -> TestToolName {
1221            TestToolName::MockTool
1222        }
1223
1224        fn display_name(&self) -> &'static str {
1225            "Listen Mock Tool"
1226        }
1227
1228        fn description(&self) -> &'static str {
1229            "A listen/execute mock tool for testing"
1230        }
1231
1232        fn input_schema(&self) -> Value {
1233            serde_json::json!({ "type": "object" })
1234        }
1235
1236        fn listen(
1237            &self,
1238            _ctx: &ToolContext<()>,
1239            _input: Value,
1240        ) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
1241            futures::stream::iter(vec![ListenToolUpdate::Ready {
1242                operation_id: "op_1".to_string(),
1243                revision: 1,
1244                message: "ready".to_string(),
1245                snapshot: serde_json::json!({"ok": true}),
1246                expires_at: None,
1247            }])
1248        }
1249
1250        async fn execute(
1251            &self,
1252            _ctx: &ToolContext<()>,
1253            _operation_id: &str,
1254            _expected_revision: u64,
1255        ) -> Result<ToolResult> {
1256            Ok(ToolResult::success("Executed"))
1257        }
1258    }
1259
1260    #[test]
1261    fn test_listen_tool_registry() {
1262        let mut registry = ToolRegistry::new();
1263        registry.register_listen(ListenMockTool);
1264
1265        assert_eq!(registry.len(), 1);
1266        assert!(registry.get_listen("mock_tool").is_some());
1267        assert!(registry.is_listen("mock_tool"));
1268    }
1269}