Skip to main content

agent_sdk/
tools.rs

1//! Tool definition and registry.
2//!
3//! Tools allow the LLM to perform actions in the real world. This module provides:
4//!
5//! - [`Tool`] trait - Define custom tools the LLM can call
6//! - [`ToolName`] trait - Marker trait for strongly-typed tool names
7//! - [`PrimitiveToolName`] - Tool names for SDK's built-in tools
8//! - [`DynamicToolName`] - Tool names created at runtime (MCP bridges)
9//! - [`ToolRegistry`] - Collection of available tools
10//! - [`ToolContext`] - Context passed to tool execution
11//!
12//! # Implementing a Tool
13//!
14//! ```ignore
15//! use agent_sdk::{Tool, ToolContext, ToolResult, ToolTier, PrimitiveToolName};
16//!
17//! struct MyTool;
18//!
19//! // No #[async_trait] needed - Rust 1.75+ supports native async traits
20//! impl Tool<MyContext> for MyTool {
21//!     type Name = PrimitiveToolName;
22//!
23//!     fn name(&self) -> PrimitiveToolName { PrimitiveToolName::Read }
24//!     fn display_name(&self) -> &'static str { "My Tool" }
25//!     fn description(&self) -> &'static str { "Does something useful" }
26//!     fn input_schema(&self) -> Value { json!({ "type": "object" }) }
27//!     fn tier(&self) -> ToolTier { ToolTier::Observe }
28//!
29//!     async fn execute(&self, ctx: &ToolContext<MyContext>, input: Value) -> Result<ToolResult> {
30//!         Ok(ToolResult::success("Done!"))
31//!     }
32//! }
33//! ```
34
35use crate::events::AgentEvent;
36use crate::llm;
37use crate::types::{ToolOutcome, ToolResult, ToolTier};
38use anyhow::Result;
39use async_trait::async_trait;
40use futures::Stream;
41use serde::{Deserialize, Serialize, de::DeserializeOwned};
42use serde_json::Value;
43use std::collections::HashMap;
44use std::future::Future;
45use std::marker::PhantomData;
46use std::pin::Pin;
47use std::sync::Arc;
48use tokio::sync::mpsc;
49
50// ============================================================================
51// Tool Name Types
52// ============================================================================
53
54/// Marker trait for tool names.
55///
56/// Tool names must be serializable (for storage/logging) and deserializable
57/// (for parsing from LLM responses). The string representation is derived
58/// from serde serialization.
59///
60/// # Example
61///
62/// ```ignore
63/// #[derive(Serialize, Deserialize)]
64/// #[serde(rename_all = "snake_case")]
65/// pub enum MyToolName {
66///     Read,
67///     Write,
68/// }
69///
70/// impl ToolName for MyToolName {}
71/// ```
72pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
73
74/// Helper to get string representation of a tool name via serde.
75///
76/// # Panics
77///
78/// Panics if the tool name cannot be serialized to a string. This should
79/// never happen with properly implemented `ToolName` types that use
80/// `#[derive(Serialize)]`.
81#[must_use]
82pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
83    serde_json::to_string(name)
84        .expect("ToolName must serialize to string")
85        .trim_matches('"')
86        .to_string()
87}
88
89/// Parse a tool name from string via serde.
90///
91/// # Errors
92/// Returns error if the string doesn't match a valid tool name.
93pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
94    serde_json::from_str(&format!("\"{s}\""))
95}
96
97/// Tool names for SDK's built-in primitive tools.
98#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum PrimitiveToolName {
101    Read,
102    Write,
103    Edit,
104    MultiEdit,
105    Bash,
106    Glob,
107    Grep,
108    NotebookRead,
109    NotebookEdit,
110    TodoRead,
111    TodoWrite,
112    AskUser,
113    LinkFetch,
114    WebSearch,
115}
116
117impl ToolName for PrimitiveToolName {}
118
119/// Dynamic tool name for runtime-created tools (MCP bridges, subagents).
120#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
121#[serde(transparent)]
122pub struct DynamicToolName(String);
123
124impl DynamicToolName {
125    #[must_use]
126    pub fn new(name: impl Into<String>) -> Self {
127        Self(name.into())
128    }
129
130    #[must_use]
131    pub fn as_str(&self) -> &str {
132        &self.0
133    }
134}
135
136impl ToolName for DynamicToolName {}
137
138// ============================================================================
139// Progress Stage Types (for AsyncTool)
140// ============================================================================
141
142/// Marker trait for tool progress stages (type-safe, like [`ToolName`]).
143///
144/// Progress stages are used by async tools to indicate the current phase
145/// of a long-running operation. They must be serializable for event streaming.
146///
147/// # Example
148///
149/// ```ignore
150/// #[derive(Clone, Debug, Serialize, Deserialize)]
151/// #[serde(rename_all = "snake_case")]
152/// pub enum PixTransferStage {
153///     Initiated,
154///     Processing,
155///     SentToBank,
156/// }
157///
158/// impl ProgressStage for PixTransferStage {}
159/// ```
160pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
161
162/// Helper to get string representation of a progress stage via serde.
163///
164/// # Panics
165///
166/// Panics if the stage cannot be serialized to a string. This should
167/// never happen with properly implemented `ProgressStage` types.
168#[must_use]
169pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
170    serde_json::to_string(stage)
171        .expect("ProgressStage must serialize to string")
172        .trim_matches('"')
173        .to_string()
174}
175
176/// Status update from an async tool operation.
177#[derive(Clone, Debug, Serialize)]
178pub enum ToolStatus<S: ProgressStage> {
179    /// Operation is making progress
180    Progress {
181        stage: S,
182        message: String,
183        data: Option<serde_json::Value>,
184    },
185
186    /// Operation completed successfully
187    Completed(ToolResult),
188
189    /// Operation failed
190    Failed(ToolResult),
191}
192
193/// Type-erased status for the agent loop.
194#[derive(Clone, Debug, Serialize, Deserialize)]
195pub enum ErasedToolStatus {
196    /// Operation is making progress
197    Progress {
198        stage: String,
199        message: String,
200        data: Option<serde_json::Value>,
201    },
202    /// Operation completed successfully
203    Completed(ToolResult),
204    /// Operation failed
205    Failed(ToolResult),
206}
207
208impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
209    fn from(status: ToolStatus<S>) -> Self {
210        match status {
211            ToolStatus::Progress {
212                stage,
213                message,
214                data,
215            } => Self::Progress {
216                stage: stage_to_string(&stage),
217                message,
218                data,
219            },
220            ToolStatus::Completed(r) => Self::Completed(r),
221            ToolStatus::Failed(r) => Self::Failed(r),
222        }
223    }
224}
225
226/// Context passed to tool execution
227pub struct ToolContext<Ctx> {
228    /// Application-specific context (e.g., `user_id`, db connection)
229    pub app: Ctx,
230    /// Tool-specific metadata
231    pub metadata: HashMap<String, Value>,
232    /// Optional channel for tools to emit events (e.g., subagent progress)
233    event_tx: Option<mpsc::Sender<AgentEvent>>,
234}
235
236impl<Ctx> ToolContext<Ctx> {
237    #[must_use]
238    pub fn new(app: Ctx) -> Self {
239        Self {
240            app,
241            metadata: HashMap::new(),
242            event_tx: None,
243        }
244    }
245
246    #[must_use]
247    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
248        self.metadata.insert(key.into(), value);
249        self
250    }
251
252    /// Set the event channel for tools that need to emit events during execution.
253    #[must_use]
254    pub fn with_event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
255        self.event_tx = Some(tx);
256        self
257    }
258
259    /// Emit an event through the event channel (if set).
260    ///
261    /// This uses `try_send` to avoid blocking and to ensure the future is `Send`.
262    /// The event is silently dropped if the channel is full.
263    pub fn emit_event(&self, event: AgentEvent) {
264        if let Some(tx) = &self.event_tx {
265            let _ = tx.try_send(event);
266        }
267    }
268
269    /// Get a clone of the event channel sender (if set).
270    ///
271    /// This is useful for tools that spawn subprocesses (like subagents)
272    /// and need to forward events to the parent's event stream.
273    #[must_use]
274    pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEvent>> {
275        self.event_tx.clone()
276    }
277}
278
279// ============================================================================
280// Tool Trait
281// ============================================================================
282
283/// Definition of a tool that can be called by the agent.
284///
285/// Tools have a strongly-typed `Name` associated type that determines
286/// how the tool name is serialized for LLM communication.
287///
288/// # Native Async Support
289///
290/// This trait uses Rust's native async functions in traits (stabilized in Rust 1.75).
291/// You do NOT need the `async_trait` crate to implement this trait.
292pub trait Tool<Ctx>: Send + Sync {
293    /// The type of name for this tool.
294    type Name: ToolName;
295
296    /// Returns the tool's strongly-typed name.
297    fn name(&self) -> Self::Name;
298
299    /// Human-readable display name for UI (e.g., "Read File" vs "read").
300    ///
301    /// Defaults to empty string. Override for better UX.
302    fn display_name(&self) -> &'static str;
303
304    /// Human-readable description of what the tool does.
305    fn description(&self) -> &'static str;
306
307    /// JSON schema for the tool's input parameters.
308    fn input_schema(&self) -> Value;
309
310    /// Permission tier for this tool.
311    fn tier(&self) -> ToolTier {
312        ToolTier::Observe
313    }
314
315    /// Execute the tool with the given input.
316    ///
317    /// # Errors
318    /// Returns an error if tool execution fails.
319    fn execute(
320        &self,
321        ctx: &ToolContext<Ctx>,
322        input: Value,
323    ) -> impl Future<Output = Result<ToolResult>> + Send;
324}
325
326// ============================================================================
327// AsyncTool Trait
328// ============================================================================
329
330/// A tool that performs long-running async operations.
331///
332/// `AsyncTool`s have two phases:
333/// 1. `execute()` - Start the operation (lightweight, returns quickly)
334/// 2. `check_status()` - Stream progress until completion
335///
336/// The actual work should happen externally (background task, external service)
337/// and persist results to a durable store. The tool is just an orchestrator.
338///
339/// # Example
340///
341/// ```ignore
342/// impl AsyncTool<MyCtx> for ExecutePixTransferTool {
343///     type Name = PixToolName;
344///     type Stage = PixTransferStage;
345///
346///     async fn execute(&self, ctx: &ToolContext<MyCtx>, input: Value) -> Result<ToolOutcome> {
347///         let params = parse_input(&input)?;
348///         let operation_id = ctx.app.pix_service.start_transfer(params).await?;
349///         Ok(ToolOutcome::in_progress(
350///             operation_id,
351///             format!("PIX transfer of {} initiated", params.amount),
352///         ))
353///     }
354///
355///     fn check_status(&self, ctx: &ToolContext<MyCtx>, operation_id: &str)
356///         -> impl Stream<Item = ToolStatus<PixTransferStage>> + Send
357///     {
358///         async_stream::stream! {
359///             loop {
360///                 let status = ctx.app.pix_service.get_status(operation_id).await;
361///                 match status {
362///                     PixStatus::Success { id } => {
363///                         yield ToolStatus::Completed(ToolResult::success(id));
364///                         break;
365///                     }
366///                     _ => yield ToolStatus::Progress { ... };
367///                 }
368///                 tokio::time::sleep(Duration::from_millis(500)).await;
369///             }
370///         }
371///     }
372/// }
373/// ```
374pub trait AsyncTool<Ctx>: Send + Sync {
375    /// The type of name for this tool.
376    type Name: ToolName;
377    /// The type of progress stages for this tool.
378    type Stage: ProgressStage;
379
380    /// Returns the tool's strongly-typed name.
381    fn name(&self) -> Self::Name;
382
383    /// Human-readable display name for UI.
384    fn display_name(&self) -> &'static str;
385
386    /// Human-readable description of what the tool does.
387    fn description(&self) -> &'static str;
388
389    /// JSON schema for the tool's input parameters.
390    fn input_schema(&self) -> Value;
391
392    /// Permission tier for this tool.
393    fn tier(&self) -> ToolTier {
394        ToolTier::Observe
395    }
396
397    /// Execute the tool. Returns immediately with one of:
398    /// - Success/Failed: Operation completed synchronously
399    /// - `InProgress`: Operation started, use `check_status()` to stream updates
400    ///
401    /// # Errors
402    /// Returns an error if tool execution fails.
403    fn execute(
404        &self,
405        ctx: &ToolContext<Ctx>,
406        input: Value,
407    ) -> impl Future<Output = Result<ToolOutcome>> + Send;
408
409    /// Stream status updates for an in-progress operation.
410    /// Must yield until Completed or Failed.
411    fn check_status(
412        &self,
413        ctx: &ToolContext<Ctx>,
414        operation_id: &str,
415    ) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
416}
417
418// ============================================================================
419// Type-Erased Tool (for Registry)
420// ============================================================================
421
422/// Type-erased tool trait for registry storage.
423///
424/// This allows tools with different `Name` associated types to be stored
425/// in the same registry by erasing the type information.
426///
427/// # Example
428///
429/// ```ignore
430/// for tool in registry.all() {
431///     println!("Tool: {} - {}", tool.name_str(), tool.description());
432/// }
433/// ```
434#[async_trait]
435pub trait ErasedTool<Ctx>: Send + Sync {
436    /// Get the tool name as a string.
437    fn name_str(&self) -> &str;
438    /// Get a human-friendly display name for the tool.
439    fn display_name(&self) -> &'static str;
440    /// Get the tool description.
441    fn description(&self) -> &'static str;
442    /// Get the JSON schema for tool inputs.
443    fn input_schema(&self) -> Value;
444    /// Get the tool's permission tier.
445    fn tier(&self) -> ToolTier;
446    /// Execute the tool with the given input.
447    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
448}
449
450/// Wrapper that erases the Name associated type from a Tool.
451struct ToolWrapper<T, Ctx>
452where
453    T: Tool<Ctx>,
454{
455    inner: T,
456    name_cache: String,
457    _marker: PhantomData<Ctx>,
458}
459
460impl<T, Ctx> ToolWrapper<T, Ctx>
461where
462    T: Tool<Ctx>,
463{
464    fn new(tool: T) -> Self {
465        let name_cache = tool_name_to_string(&tool.name());
466        Self {
467            inner: tool,
468            name_cache,
469            _marker: PhantomData,
470        }
471    }
472}
473
474#[async_trait]
475impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
476where
477    T: Tool<Ctx> + 'static,
478    Ctx: Send + Sync + 'static,
479{
480    fn name_str(&self) -> &str {
481        &self.name_cache
482    }
483
484    fn display_name(&self) -> &'static str {
485        self.inner.display_name()
486    }
487
488    fn description(&self) -> &'static str {
489        self.inner.description()
490    }
491
492    fn input_schema(&self) -> Value {
493        self.inner.input_schema()
494    }
495
496    fn tier(&self) -> ToolTier {
497        self.inner.tier()
498    }
499
500    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
501        self.inner.execute(ctx, input).await
502    }
503}
504
505// ============================================================================
506// Type-Erased AsyncTool (for Registry)
507// ============================================================================
508
509/// Type-erased async tool trait for registry storage.
510///
511/// This allows async tools with different `Name` and `Stage` associated types
512/// to be stored in the same registry by erasing the type information.
513#[async_trait]
514pub trait ErasedAsyncTool<Ctx>: Send + Sync {
515    /// Get the tool name as a string.
516    fn name_str(&self) -> &str;
517    /// Get a human-friendly display name for the tool.
518    fn display_name(&self) -> &'static str;
519    /// Get the tool description.
520    fn description(&self) -> &'static str;
521    /// Get the JSON schema for tool inputs.
522    fn input_schema(&self) -> Value;
523    /// Get the tool's permission tier.
524    fn tier(&self) -> ToolTier;
525    /// Execute the tool with the given input.
526    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
527    /// Stream status updates for an in-progress operation (type-erased).
528    fn check_status_stream<'a>(
529        &'a self,
530        ctx: &'a ToolContext<Ctx>,
531        operation_id: &'a str,
532    ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
533}
534
535/// Wrapper that erases the Name and Stage associated types from an [`AsyncTool`].
536struct AsyncToolWrapper<T, Ctx>
537where
538    T: AsyncTool<Ctx>,
539{
540    inner: T,
541    name_cache: String,
542    _marker: PhantomData<Ctx>,
543}
544
545impl<T, Ctx> AsyncToolWrapper<T, Ctx>
546where
547    T: AsyncTool<Ctx>,
548{
549    fn new(tool: T) -> Self {
550        let name_cache = tool_name_to_string(&tool.name());
551        Self {
552            inner: tool,
553            name_cache,
554            _marker: PhantomData,
555        }
556    }
557}
558
559#[async_trait]
560impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
561where
562    T: AsyncTool<Ctx> + 'static,
563    Ctx: Send + Sync + 'static,
564{
565    fn name_str(&self) -> &str {
566        &self.name_cache
567    }
568
569    fn display_name(&self) -> &'static str {
570        self.inner.display_name()
571    }
572
573    fn description(&self) -> &'static str {
574        self.inner.description()
575    }
576
577    fn input_schema(&self) -> Value {
578        self.inner.input_schema()
579    }
580
581    fn tier(&self) -> ToolTier {
582        self.inner.tier()
583    }
584
585    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
586        self.inner.execute(ctx, input).await
587    }
588
589    fn check_status_stream<'a>(
590        &'a self,
591        ctx: &'a ToolContext<Ctx>,
592        operation_id: &'a str,
593    ) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
594        use futures::StreamExt;
595        let stream = self.inner.check_status(ctx, operation_id);
596        Box::pin(stream.map(ErasedToolStatus::from))
597    }
598}
599
600// ============================================================================
601// Tool Registry
602// ============================================================================
603
604/// Registry of available tools.
605///
606/// Tools are stored with their names erased to allow different `Name` types
607/// in the same registry. The registry uses string-based lookup for LLM
608/// compatibility.
609///
610/// Supports both synchronous [`Tool`]s and asynchronous [`AsyncTool`]s.
611pub struct ToolRegistry<Ctx> {
612    tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
613    async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
614}
615
616impl<Ctx> Clone for ToolRegistry<Ctx> {
617    fn clone(&self) -> Self {
618        Self {
619            tools: self.tools.clone(),
620            async_tools: self.async_tools.clone(),
621        }
622    }
623}
624
625impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
626    fn default() -> Self {
627        Self::new()
628    }
629}
630
631impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
632    #[must_use]
633    pub fn new() -> Self {
634        Self {
635            tools: HashMap::new(),
636            async_tools: HashMap::new(),
637        }
638    }
639
640    /// Register a synchronous tool in the registry.
641    ///
642    /// The tool's name is converted to a string via serde serialization
643    /// and used as the lookup key.
644    pub fn register<T>(&mut self, tool: T) -> &mut Self
645    where
646        T: Tool<Ctx> + 'static,
647    {
648        let wrapper = ToolWrapper::new(tool);
649        let name = wrapper.name_str().to_string();
650        self.tools.insert(name, Arc::new(wrapper));
651        self
652    }
653
654    /// Register an async tool in the registry.
655    ///
656    /// Async tools have two phases: execute (lightweight, starts operation)
657    /// and `check_status` (streams progress until completion).
658    pub fn register_async<T>(&mut self, tool: T) -> &mut Self
659    where
660        T: AsyncTool<Ctx> + 'static,
661    {
662        let wrapper = AsyncToolWrapper::new(tool);
663        let name = wrapper.name_str().to_string();
664        self.async_tools.insert(name, Arc::new(wrapper));
665        self
666    }
667
668    /// Get a synchronous tool by name.
669    #[must_use]
670    pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
671        self.tools.get(name)
672    }
673
674    /// Get an async tool by name.
675    #[must_use]
676    pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
677        self.async_tools.get(name)
678    }
679
680    /// Check if a tool name refers to an async tool.
681    #[must_use]
682    pub fn is_async(&self, name: &str) -> bool {
683        self.async_tools.contains_key(name)
684    }
685
686    /// Get all registered synchronous tools.
687    pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
688        self.tools.values()
689    }
690
691    /// Get all registered async tools.
692    pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
693        self.async_tools.values()
694    }
695
696    /// Get the number of registered tools (sync + async).
697    #[must_use]
698    pub fn len(&self) -> usize {
699        self.tools.len() + self.async_tools.len()
700    }
701
702    /// Check if the registry is empty.
703    #[must_use]
704    pub fn is_empty(&self) -> bool {
705        self.tools.is_empty() && self.async_tools.is_empty()
706    }
707
708    /// Filter tools by a predicate.
709    ///
710    /// Removes tools for which the predicate returns false.
711    /// The predicate receives the tool name.
712    /// Applies to both sync and async tools.
713    ///
714    /// # Example
715    ///
716    /// ```ignore
717    /// registry.filter(|name| name != "bash");
718    /// ```
719    pub fn filter<F>(&mut self, predicate: F)
720    where
721        F: Fn(&str) -> bool,
722    {
723        self.tools.retain(|name, _| predicate(name));
724        self.async_tools.retain(|name, _| predicate(name));
725    }
726
727    /// Convert all tools (sync + async) to LLM tool definitions.
728    #[must_use]
729    pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
730        let mut tools: Vec<_> = self
731            .tools
732            .values()
733            .map(|tool| llm::Tool {
734                name: tool.name_str().to_string(),
735                description: tool.description().to_string(),
736                input_schema: tool.input_schema(),
737            })
738            .collect();
739
740        tools.extend(self.async_tools.values().map(|tool| llm::Tool {
741            name: tool.name_str().to_string(),
742            description: tool.description().to_string(),
743            input_schema: tool.input_schema(),
744        }));
745
746        tools
747    }
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    // Test tool name enum for tests
755    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
756    #[serde(rename_all = "snake_case")]
757    enum TestToolName {
758        MockTool,
759        AnotherTool,
760    }
761
762    impl ToolName for TestToolName {}
763
764    struct MockTool;
765
766    impl Tool<()> for MockTool {
767        type Name = TestToolName;
768
769        fn name(&self) -> TestToolName {
770            TestToolName::MockTool
771        }
772
773        fn display_name(&self) -> &'static str {
774            "Mock Tool"
775        }
776
777        fn description(&self) -> &'static str {
778            "A mock tool for testing"
779        }
780
781        fn input_schema(&self) -> Value {
782            serde_json::json!({
783                "type": "object",
784                "properties": {
785                    "message": { "type": "string" }
786                }
787            })
788        }
789
790        async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
791            let message = input
792                .get("message")
793                .and_then(|v| v.as_str())
794                .unwrap_or("no message");
795            Ok(ToolResult::success(format!("Received: {message}")))
796        }
797    }
798
799    #[test]
800    fn test_tool_name_serialization() {
801        let name = TestToolName::MockTool;
802        assert_eq!(tool_name_to_string(&name), "mock_tool");
803
804        let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
805        assert_eq!(parsed, TestToolName::MockTool);
806    }
807
808    #[test]
809    fn test_dynamic_tool_name() {
810        let name = DynamicToolName::new("my_mcp_tool");
811        assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
812        assert_eq!(name.as_str(), "my_mcp_tool");
813    }
814
815    #[test]
816    fn test_tool_registry() {
817        let mut registry = ToolRegistry::new();
818        registry.register(MockTool);
819
820        assert_eq!(registry.len(), 1);
821        assert!(registry.get("mock_tool").is_some());
822        assert!(registry.get("nonexistent").is_none());
823    }
824
825    #[test]
826    fn test_to_llm_tools() {
827        let mut registry = ToolRegistry::new();
828        registry.register(MockTool);
829
830        let llm_tools = registry.to_llm_tools();
831        assert_eq!(llm_tools.len(), 1);
832        assert_eq!(llm_tools[0].name, "mock_tool");
833    }
834
835    struct AnotherTool;
836
837    impl Tool<()> for AnotherTool {
838        type Name = TestToolName;
839
840        fn name(&self) -> TestToolName {
841            TestToolName::AnotherTool
842        }
843
844        fn display_name(&self) -> &'static str {
845            "Another Tool"
846        }
847
848        fn description(&self) -> &'static str {
849            "Another tool for testing"
850        }
851
852        fn input_schema(&self) -> Value {
853            serde_json::json!({ "type": "object" })
854        }
855
856        async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
857            Ok(ToolResult::success("Done"))
858        }
859    }
860
861    #[test]
862    fn test_filter_tools() {
863        let mut registry = ToolRegistry::new();
864        registry.register(MockTool);
865        registry.register(AnotherTool);
866
867        assert_eq!(registry.len(), 2);
868
869        // Filter out mock_tool
870        registry.filter(|name| name != "mock_tool");
871
872        assert_eq!(registry.len(), 1);
873        assert!(registry.get("mock_tool").is_none());
874        assert!(registry.get("another_tool").is_some());
875    }
876
877    #[test]
878    fn test_filter_tools_keep_all() {
879        let mut registry = ToolRegistry::new();
880        registry.register(MockTool);
881        registry.register(AnotherTool);
882
883        registry.filter(|_| true);
884
885        assert_eq!(registry.len(), 2);
886    }
887
888    #[test]
889    fn test_filter_tools_remove_all() {
890        let mut registry = ToolRegistry::new();
891        registry.register(MockTool);
892        registry.register(AnotherTool);
893
894        registry.filter(|_| false);
895
896        assert!(registry.is_empty());
897    }
898
899    #[test]
900    fn test_display_name() {
901        let mut registry = ToolRegistry::new();
902        registry.register(MockTool);
903
904        let tool = registry.get("mock_tool").unwrap();
905        assert_eq!(tool.display_name(), "Mock Tool");
906    }
907}