Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use schemars::JsonSchema;
14use serde::Serialize;
15use serde::de::DeserializeOwned;
16use serde_json::Value;
17
18use crate::context::RequestContext;
19use crate::error::{Error, Result};
20use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
21
22/// Validate a tool name according to MCP spec.
23///
24/// Tool names must be:
25/// - 1-128 characters long
26/// - Contain only alphanumeric characters, underscores, hyphens, and dots
27///
28/// Returns `Ok(())` if valid, `Err` with description if invalid.
29pub fn validate_tool_name(name: &str) -> Result<()> {
30    if name.is_empty() {
31        return Err(Error::tool("Tool name cannot be empty"));
32    }
33    if name.len() > 128 {
34        return Err(Error::tool(format!(
35            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
36            name,
37            name.len()
38        )));
39    }
40    if let Some(invalid_char) = name
41        .chars()
42        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
43    {
44        return Err(Error::tool(format!(
45            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
46            name, invalid_char
47        )));
48    }
49    Ok(())
50}
51
52/// A boxed future for tool handlers
53pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
54
55/// Tool handler trait - the core abstraction for tool execution
56pub trait ToolHandler: Send + Sync {
57    /// Execute the tool with the given arguments
58    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
59
60    /// Execute the tool with request context for progress/cancellation support
61    ///
62    /// The default implementation ignores the context and calls `call`.
63    /// Override this to receive progress/cancellation context.
64    fn call_with_context(
65        &self,
66        _ctx: RequestContext,
67        args: Value,
68    ) -> BoxFuture<'_, Result<CallToolResult>> {
69        self.call(args)
70    }
71
72    /// Returns true if this handler uses context (for optimization)
73    fn uses_context(&self) -> bool {
74        false
75    }
76
77    /// Get the tool's input schema
78    fn input_schema(&self) -> Value;
79}
80
81/// A complete tool definition with handler
82pub struct Tool {
83    pub name: String,
84    pub title: Option<String>,
85    pub description: Option<String>,
86    pub output_schema: Option<Value>,
87    pub icons: Option<Vec<ToolIcon>>,
88    pub annotations: Option<ToolAnnotations>,
89    handler: Arc<dyn ToolHandler>,
90}
91
92impl std::fmt::Debug for Tool {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("Tool")
95            .field("name", &self.name)
96            .field("title", &self.title)
97            .field("description", &self.description)
98            .field("output_schema", &self.output_schema)
99            .field("icons", &self.icons)
100            .field("annotations", &self.annotations)
101            .finish_non_exhaustive()
102    }
103}
104
105impl Tool {
106    /// Create a new tool builder
107    pub fn builder(name: impl Into<String>) -> ToolBuilder {
108        ToolBuilder::new(name)
109    }
110
111    /// Get the tool definition for tools/list
112    pub fn definition(&self) -> ToolDefinition {
113        ToolDefinition {
114            name: self.name.clone(),
115            title: self.title.clone(),
116            description: self.description.clone(),
117            input_schema: self.handler.input_schema(),
118            output_schema: self.output_schema.clone(),
119            icons: self.icons.clone(),
120            annotations: self.annotations.clone(),
121        }
122    }
123
124    /// Call the tool without context
125    pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
126        self.handler.call(args)
127    }
128
129    /// Call the tool with request context
130    ///
131    /// Use this when you have a RequestContext available for progress/cancellation.
132    pub fn call_with_context(
133        &self,
134        ctx: RequestContext,
135        args: Value,
136    ) -> BoxFuture<'_, Result<CallToolResult>> {
137        self.handler.call_with_context(ctx, args)
138    }
139
140    /// Returns true if this tool uses context
141    pub fn uses_context(&self) -> bool {
142        self.handler.uses_context()
143    }
144}
145
146// =============================================================================
147// Builder API
148// =============================================================================
149
150/// Builder for creating tools with a fluent API
151///
152/// # Example
153///
154/// ```rust
155/// use tower_mcp::{ToolBuilder, CallToolResult};
156/// use schemars::JsonSchema;
157/// use serde::Deserialize;
158///
159/// #[derive(Debug, Deserialize, JsonSchema)]
160/// struct GreetInput {
161///     name: String,
162/// }
163///
164/// let tool = ToolBuilder::new("greet")
165///     .description("Greet someone by name")
166///     .handler(|input: GreetInput| async move {
167///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
168///     })
169///     .build()
170///     .expect("valid tool name");
171///
172/// assert_eq!(tool.name, "greet");
173/// ```
174pub struct ToolBuilder {
175    name: String,
176    title: Option<String>,
177    description: Option<String>,
178    output_schema: Option<Value>,
179    icons: Option<Vec<ToolIcon>>,
180    annotations: Option<ToolAnnotations>,
181}
182
183impl ToolBuilder {
184    pub fn new(name: impl Into<String>) -> Self {
185        Self {
186            name: name.into(),
187            title: None,
188            description: None,
189            output_schema: None,
190            icons: None,
191            annotations: None,
192        }
193    }
194
195    /// Set a human-readable title for the tool
196    pub fn title(mut self, title: impl Into<String>) -> Self {
197        self.title = Some(title.into());
198        self
199    }
200
201    /// Set the output schema (JSON Schema for structured output)
202    pub fn output_schema(mut self, schema: Value) -> Self {
203        self.output_schema = Some(schema);
204        self
205    }
206
207    /// Add an icon for the tool
208    pub fn icon(mut self, src: impl Into<String>) -> Self {
209        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
210            src: src.into(),
211            mime_type: None,
212            sizes: None,
213        });
214        self
215    }
216
217    /// Add an icon with metadata
218    pub fn icon_with_meta(
219        mut self,
220        src: impl Into<String>,
221        mime_type: Option<String>,
222        sizes: Option<Vec<String>>,
223    ) -> Self {
224        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
225            src: src.into(),
226            mime_type,
227            sizes,
228        });
229        self
230    }
231
232    /// Set the tool description
233    pub fn description(mut self, description: impl Into<String>) -> Self {
234        self.description = Some(description.into());
235        self
236    }
237
238    /// Mark the tool as read-only (does not modify state)
239    pub fn read_only(mut self) -> Self {
240        self.annotations
241            .get_or_insert_with(ToolAnnotations::default)
242            .read_only_hint = true;
243        self
244    }
245
246    /// Mark the tool as non-destructive
247    pub fn non_destructive(mut self) -> Self {
248        self.annotations
249            .get_or_insert_with(ToolAnnotations::default)
250            .destructive_hint = false;
251        self
252    }
253
254    /// Mark the tool as idempotent (same args = same effect)
255    pub fn idempotent(mut self) -> Self {
256        self.annotations
257            .get_or_insert_with(ToolAnnotations::default)
258            .idempotent_hint = true;
259        self
260    }
261
262    /// Set tool annotations directly
263    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
264        self.annotations = Some(annotations);
265        self
266    }
267
268    /// Specify input type and handler.
269    ///
270    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
271    /// The handler receives the deserialized input and returns a `CallToolResult`.
272    ///
273    /// # State Sharing
274    ///
275    /// To share state across tool calls (e.g., database connections, API clients),
276    /// wrap your state in an `Arc` and clone it into the async block:
277    ///
278    /// ```rust
279    /// use std::sync::Arc;
280    /// use tower_mcp::{ToolBuilder, CallToolResult};
281    /// use schemars::JsonSchema;
282    /// use serde::Deserialize;
283    ///
284    /// struct AppState {
285    ///     api_key: String,
286    /// }
287    ///
288    /// #[derive(Debug, Deserialize, JsonSchema)]
289    /// struct MyInput {
290    ///     query: String,
291    /// }
292    ///
293    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
294    ///
295    /// let tool = ToolBuilder::new("my_tool")
296    ///     .description("A tool that uses shared state")
297    ///     .handler(move |input: MyInput| {
298    ///         let state = state.clone(); // Clone Arc for the async block
299    ///         async move {
300    ///             // Use state.api_key here...
301    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
302    ///         }
303    ///     })
304    ///     .build()
305    ///     .expect("valid tool name");
306    /// ```
307    ///
308    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
309    /// cloning it inside the closure body allows each async invocation to
310    /// have its own reference to the shared state.
311    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
312    where
313        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
314        F: Fn(I) -> Fut + Send + Sync + 'static,
315        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
316    {
317        ToolBuilderWithHandler {
318            name: self.name,
319            title: self.title,
320            description: self.description,
321            output_schema: self.output_schema,
322            icons: self.icons,
323            annotations: self.annotations,
324            handler,
325            _phantom: std::marker::PhantomData,
326        }
327    }
328
329    /// Specify input type and context-aware handler
330    ///
331    /// The handler receives a `RequestContext` for progress reporting and
332    /// cancellation checking, along with the deserialized input.
333    ///
334    /// # Example
335    ///
336    /// ```rust
337    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
338    /// use schemars::JsonSchema;
339    /// use serde::Deserialize;
340    ///
341    /// #[derive(Debug, Deserialize, JsonSchema)]
342    /// struct ProcessInput {
343    ///     items: Vec<String>,
344    /// }
345    ///
346    /// let tool = ToolBuilder::new("process")
347    ///     .description("Process items with progress")
348    ///     .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
349    ///         for (i, item) in input.items.iter().enumerate() {
350    ///             if ctx.is_cancelled() {
351    ///                 return Ok(CallToolResult::error("Cancelled"));
352    ///             }
353    ///             ctx.report_progress(i as f64, Some(input.items.len() as f64), Some("Processing...")).await;
354    ///             // Process item...
355    ///         }
356    ///         Ok(CallToolResult::text("Done"))
357    ///     })
358    ///     .build()
359    ///     .expect("valid tool name");
360    /// ```
361    pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
362    where
363        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
364        F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
365        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
366    {
367        ToolBuilderWithContextHandler {
368            name: self.name,
369            title: self.title,
370            description: self.description,
371            output_schema: self.output_schema,
372            icons: self.icons,
373            annotations: self.annotations,
374            handler,
375            _phantom: std::marker::PhantomData,
376        }
377    }
378
379    /// Specify input type, shared state, and handler.
380    ///
381    /// The state is cloned for each invocation, so wrapping it in an `Arc`
382    /// is recommended for expensive-to-clone types. This eliminates the
383    /// boilerplate of cloning state inside a `move` closure.
384    ///
385    /// # Example
386    ///
387    /// ```rust
388    /// use std::sync::Arc;
389    /// use tower_mcp::{ToolBuilder, CallToolResult};
390    /// use schemars::JsonSchema;
391    /// use serde::Deserialize;
392    ///
393    /// #[derive(Debug, Deserialize, JsonSchema)]
394    /// struct QueryInput { query: String }
395    ///
396    /// struct Db { connection_string: String }
397    ///
398    /// let db = Arc::new(Db { connection_string: "postgres://...".to_string() });
399    ///
400    /// let tool = ToolBuilder::new("search")
401    ///     .description("Search the database")
402    ///     .handler_with_state(db, |db: Arc<Db>, input: QueryInput| async move {
403    ///         Ok(CallToolResult::text(format!("Queried: {}", input.query)))
404    ///     })
405    ///     .build()
406    ///     .expect("valid tool name");
407    /// ```
408    pub fn handler_with_state<S, I, F, Fut>(
409        self,
410        state: S,
411        handler: F,
412    ) -> ToolBuilderWithHandler<
413        I,
414        impl Fn(I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
415    >
416    where
417        S: Clone + Send + Sync + 'static,
418        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
419        F: Fn(S, I) -> Fut + Send + Sync + 'static,
420        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
421    {
422        let handler = Arc::new(handler);
423        self.handler(move |input: I| {
424            let state = state.clone();
425            let handler = handler.clone();
426            Box::pin(async move { handler(state, input).await })
427                as BoxFuture<'static, Result<CallToolResult>>
428        })
429    }
430
431    /// Specify input type, shared state, and context-aware handler.
432    ///
433    /// Combines state injection with `RequestContext` access for progress
434    /// reporting, cancellation, sampling, and logging.
435    ///
436    /// # Example
437    ///
438    /// ```rust
439    /// use std::sync::Arc;
440    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
441    /// use schemars::JsonSchema;
442    /// use serde::Deserialize;
443    ///
444    /// #[derive(Debug, Deserialize, JsonSchema)]
445    /// struct QueryInput { query: String }
446    ///
447    /// struct Db { connection_string: String }
448    ///
449    /// let db = Arc::new(Db { connection_string: "postgres://...".to_string() });
450    ///
451    /// let tool = ToolBuilder::new("search")
452    ///     .description("Search the database with progress")
453    ///     .handler_with_state_and_context(db, |db: Arc<Db>, ctx: RequestContext, input: QueryInput| async move {
454    ///         ctx.report_progress(0.0, Some(1.0), Some("Searching...")).await;
455    ///         Ok(CallToolResult::text(format!("Queried: {}", input.query)))
456    ///     })
457    ///     .build()
458    ///     .expect("valid tool name");
459    /// ```
460    pub fn handler_with_state_and_context<S, I, F, Fut>(
461        self,
462        state: S,
463        handler: F,
464    ) -> ToolBuilderWithContextHandler<
465        I,
466        impl Fn(RequestContext, I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
467    >
468    where
469        S: Clone + Send + Sync + 'static,
470        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
471        F: Fn(S, RequestContext, I) -> Fut + Send + Sync + 'static,
472        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
473    {
474        let handler = Arc::new(handler);
475        self.handler_with_context(move |ctx: RequestContext, input: I| {
476            let state = state.clone();
477            let handler = handler.clone();
478            Box::pin(async move { handler(state, ctx, input).await })
479                as BoxFuture<'static, Result<CallToolResult>>
480        })
481    }
482
483    /// Create a tool that takes no parameters.
484    ///
485    /// The handler receives no input arguments. An empty object input schema
486    /// is generated automatically. Returns `Result<Tool>` directly.
487    ///
488    /// # Example
489    ///
490    /// ```rust
491    /// use tower_mcp::{ToolBuilder, CallToolResult};
492    ///
493    /// let tool = ToolBuilder::new("server_time")
494    ///     .description("Get the current server time")
495    ///     .handler_no_params(|| async {
496    ///         Ok(CallToolResult::text("2025-01-01T00:00:00Z"))
497    ///     })
498    ///     .expect("valid tool name");
499    ///
500    /// assert_eq!(tool.name, "server_time");
501    /// ```
502    pub fn handler_no_params<F, Fut>(self, handler: F) -> Result<Tool>
503    where
504        F: Fn() -> Fut + Send + Sync + 'static,
505        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
506    {
507        validate_tool_name(&self.name)?;
508        Ok(Tool {
509            name: self.name,
510            title: self.title,
511            description: self.description,
512            output_schema: self.output_schema,
513            icons: self.icons,
514            annotations: self.annotations,
515            handler: Arc::new(NoParamsHandler { handler }),
516        })
517    }
518
519    /// Create a tool with raw JSON handling (no automatic deserialization)
520    ///
521    /// Returns an error if the tool name is invalid.
522    pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
523    where
524        F: Fn(Value) -> Fut + Send + Sync + 'static,
525        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
526    {
527        validate_tool_name(&self.name)?;
528        Ok(Tool {
529            name: self.name,
530            title: self.title,
531            description: self.description,
532            output_schema: self.output_schema,
533            icons: self.icons,
534            annotations: self.annotations,
535            handler: Arc::new(RawHandler { handler }),
536        })
537    }
538
539    /// Create a tool with raw JSON handling and request context
540    ///
541    /// The handler receives a `RequestContext` for progress reporting,
542    /// cancellation, sampling, and logging, along with raw JSON arguments.
543    ///
544    /// Returns an error if the tool name is invalid.
545    pub fn raw_handler_with_context<F, Fut>(self, handler: F) -> Result<Tool>
546    where
547        F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
548        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
549    {
550        validate_tool_name(&self.name)?;
551        Ok(Tool {
552            name: self.name,
553            title: self.title,
554            description: self.description,
555            output_schema: self.output_schema,
556            icons: self.icons,
557            annotations: self.annotations,
558            handler: Arc::new(RawContextHandler { handler }),
559        })
560    }
561}
562
563/// Builder state after handler is specified
564pub struct ToolBuilderWithHandler<I, F> {
565    name: String,
566    title: Option<String>,
567    description: Option<String>,
568    output_schema: Option<Value>,
569    icons: Option<Vec<ToolIcon>>,
570    annotations: Option<ToolAnnotations>,
571    handler: F,
572    _phantom: std::marker::PhantomData<I>,
573}
574
575impl<I, F, Fut> ToolBuilderWithHandler<I, F>
576where
577    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
578    F: Fn(I) -> Fut + Send + Sync + 'static,
579    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
580{
581    /// Build the tool
582    ///
583    /// Returns an error if the tool name is invalid.
584    pub fn build(self) -> Result<Tool> {
585        validate_tool_name(&self.name)?;
586        Ok(Tool {
587            name: self.name,
588            title: self.title,
589            description: self.description,
590            output_schema: self.output_schema,
591            icons: self.icons,
592            annotations: self.annotations,
593            handler: Arc::new(TypedHandler {
594                handler: self.handler,
595                _phantom: std::marker::PhantomData,
596            }),
597        })
598    }
599}
600
601/// Builder state after context-aware handler is specified
602pub struct ToolBuilderWithContextHandler<I, F> {
603    name: String,
604    title: Option<String>,
605    description: Option<String>,
606    output_schema: Option<Value>,
607    icons: Option<Vec<ToolIcon>>,
608    annotations: Option<ToolAnnotations>,
609    handler: F,
610    _phantom: std::marker::PhantomData<I>,
611}
612
613impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
614where
615    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
616    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
617    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
618{
619    /// Build the tool
620    ///
621    /// Returns an error if the tool name is invalid.
622    pub fn build(self) -> Result<Tool> {
623        validate_tool_name(&self.name)?;
624        Ok(Tool {
625            name: self.name,
626            title: self.title,
627            description: self.description,
628            output_schema: self.output_schema,
629            icons: self.icons,
630            annotations: self.annotations,
631            handler: Arc::new(ContextAwareHandler {
632                handler: self.handler,
633                _phantom: std::marker::PhantomData,
634            }),
635        })
636    }
637}
638
639// =============================================================================
640// Handler implementations
641// =============================================================================
642
643/// Handler that deserializes input to a specific type
644struct TypedHandler<I, F> {
645    handler: F,
646    _phantom: std::marker::PhantomData<I>,
647}
648
649impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
650where
651    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
652    F: Fn(I) -> Fut + Send + Sync + 'static,
653    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
654{
655    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
656        Box::pin(async move {
657            let input: I = serde_json::from_value(args)
658                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
659            (self.handler)(input).await
660        })
661    }
662
663    fn input_schema(&self) -> Value {
664        let schema = schemars::schema_for!(I);
665        serde_json::to_value(schema).unwrap_or_else(|_| {
666            serde_json::json!({
667                "type": "object"
668            })
669        })
670    }
671}
672
673/// Handler that works with raw JSON
674struct RawHandler<F> {
675    handler: F,
676}
677
678impl<F, Fut> ToolHandler for RawHandler<F>
679where
680    F: Fn(Value) -> Fut + Send + Sync + 'static,
681    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
682{
683    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
684        Box::pin((self.handler)(args))
685    }
686
687    fn input_schema(&self) -> Value {
688        // Raw handlers accept any JSON
689        serde_json::json!({
690            "type": "object",
691            "additionalProperties": true
692        })
693    }
694}
695
696/// Handler that works with raw JSON and request context
697struct RawContextHandler<F> {
698    handler: F,
699}
700
701impl<F, Fut> ToolHandler for RawContextHandler<F>
702where
703    F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
704    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
705{
706    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
707        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
708        self.call_with_context(ctx, args)
709    }
710
711    fn call_with_context(
712        &self,
713        ctx: RequestContext,
714        args: Value,
715    ) -> BoxFuture<'_, Result<CallToolResult>> {
716        Box::pin((self.handler)(ctx, args))
717    }
718
719    fn uses_context(&self) -> bool {
720        true
721    }
722
723    fn input_schema(&self) -> Value {
724        // Raw context handlers accept any JSON object
725        serde_json::json!({
726            "type": "object",
727            "additionalProperties": true
728        })
729    }
730}
731
732/// Handler that receives request context for progress/cancellation
733struct ContextAwareHandler<I, F> {
734    handler: F,
735    _phantom: std::marker::PhantomData<I>,
736}
737
738impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
739where
740    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
741    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
742    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
743{
744    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
745        // When called without context, create a dummy context
746        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
747        self.call_with_context(ctx, args)
748    }
749
750    fn call_with_context(
751        &self,
752        ctx: RequestContext,
753        args: Value,
754    ) -> BoxFuture<'_, Result<CallToolResult>> {
755        Box::pin(async move {
756            let input: I = serde_json::from_value(args)
757                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
758            (self.handler)(ctx, input).await
759        })
760    }
761
762    fn uses_context(&self) -> bool {
763        true
764    }
765
766    fn input_schema(&self) -> Value {
767        let schema = schemars::schema_for!(I);
768        serde_json::to_value(schema).unwrap_or_else(|_| {
769            serde_json::json!({
770                "type": "object"
771            })
772        })
773    }
774}
775
776/// Handler that takes no parameters
777struct NoParamsHandler<F> {
778    handler: F,
779}
780
781impl<F, Fut> ToolHandler for NoParamsHandler<F>
782where
783    F: Fn() -> Fut + Send + Sync + 'static,
784    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
785{
786    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
787        Box::pin((self.handler)())
788    }
789
790    fn input_schema(&self) -> Value {
791        serde_json::json!({
792            "type": "object",
793            "properties": {}
794        })
795    }
796}
797
798// =============================================================================
799// Trait-based tool definition
800// =============================================================================
801
802/// Trait for defining tools with full control
803///
804/// Implement this trait when you need more control than the builder provides,
805/// or when you want to define tools as standalone types.
806///
807/// # Example
808///
809/// ```rust
810/// use tower_mcp::tool::McpTool;
811/// use tower_mcp::error::Result;
812/// use schemars::JsonSchema;
813/// use serde::{Deserialize, Serialize};
814///
815/// #[derive(Debug, Deserialize, JsonSchema)]
816/// struct AddInput {
817///     a: i64,
818///     b: i64,
819/// }
820///
821/// struct AddTool;
822///
823/// impl McpTool for AddTool {
824///     const NAME: &'static str = "add";
825///     const DESCRIPTION: &'static str = "Add two numbers";
826///
827///     type Input = AddInput;
828///     type Output = i64;
829///
830///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
831///         Ok(input.a + input.b)
832///     }
833/// }
834///
835/// let tool = AddTool.into_tool().expect("valid tool name");
836/// assert_eq!(tool.name, "add");
837/// ```
838pub trait McpTool: Send + Sync + 'static {
839    const NAME: &'static str;
840    const DESCRIPTION: &'static str;
841
842    type Input: JsonSchema + DeserializeOwned + Send;
843    type Output: Serialize + Send;
844
845    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
846
847    /// Optional annotations for the tool
848    fn annotations(&self) -> Option<ToolAnnotations> {
849        None
850    }
851
852    /// Convert to a Tool instance
853    ///
854    /// Returns an error if the tool name is invalid.
855    fn into_tool(self) -> Result<Tool>
856    where
857        Self: Sized,
858    {
859        validate_tool_name(Self::NAME)?;
860        let annotations = self.annotations();
861        let tool = Arc::new(self);
862        Ok(Tool {
863            name: Self::NAME.to_string(),
864            title: None,
865            description: Some(Self::DESCRIPTION.to_string()),
866            output_schema: None,
867            icons: None,
868            annotations,
869            handler: Arc::new(McpToolHandler { tool }),
870        })
871    }
872}
873
874/// Wrapper to make McpTool implement ToolHandler
875struct McpToolHandler<T: McpTool> {
876    tool: Arc<T>,
877}
878
879impl<T: McpTool> ToolHandler for McpToolHandler<T> {
880    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
881        let tool = self.tool.clone();
882        Box::pin(async move {
883            let input: T::Input = serde_json::from_value(args)
884                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
885            let output = tool.call(input).await?;
886            let value = serde_json::to_value(output)
887                .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
888            Ok(CallToolResult::json(value))
889        })
890    }
891
892    fn input_schema(&self) -> Value {
893        let schema = schemars::schema_for!(T::Input);
894        serde_json::to_value(schema).unwrap_or_else(|_| {
895            serde_json::json!({
896                "type": "object"
897            })
898        })
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905    use schemars::JsonSchema;
906    use serde::Deserialize;
907
908    #[derive(Debug, Deserialize, JsonSchema)]
909    struct GreetInput {
910        name: String,
911    }
912
913    #[tokio::test]
914    async fn test_builder_tool() {
915        let tool = ToolBuilder::new("greet")
916            .description("Greet someone")
917            .handler(|input: GreetInput| async move {
918                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
919            })
920            .build()
921            .expect("valid tool name");
922
923        assert_eq!(tool.name, "greet");
924        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
925
926        let result = tool
927            .call(serde_json::json!({"name": "World"}))
928            .await
929            .unwrap();
930
931        assert!(!result.is_error);
932    }
933
934    #[tokio::test]
935    async fn test_raw_handler() {
936        let tool = ToolBuilder::new("echo")
937            .description("Echo input")
938            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
939            .expect("valid tool name");
940
941        let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
942
943        assert!(!result.is_error);
944    }
945
946    #[test]
947    fn test_invalid_tool_name_empty() {
948        let result = ToolBuilder::new("")
949            .description("Empty name")
950            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
951
952        assert!(result.is_err());
953        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
954    }
955
956    #[test]
957    fn test_invalid_tool_name_too_long() {
958        let long_name = "a".repeat(129);
959        let result = ToolBuilder::new(long_name)
960            .description("Too long")
961            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
962
963        assert!(result.is_err());
964        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
965    }
966
967    #[test]
968    fn test_invalid_tool_name_bad_chars() {
969        let result = ToolBuilder::new("my tool!")
970            .description("Bad chars")
971            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
972
973        assert!(result.is_err());
974        assert!(
975            result
976                .unwrap_err()
977                .to_string()
978                .contains("invalid character")
979        );
980    }
981
982    #[test]
983    fn test_valid_tool_names() {
984        // All valid characters
985        let names = [
986            "my_tool",
987            "my-tool",
988            "my.tool",
989            "MyTool123",
990            "a",
991            &"a".repeat(128),
992        ];
993        for name in names {
994            let result = ToolBuilder::new(name)
995                .description("Valid")
996                .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
997            assert!(result.is_ok(), "Expected '{}' to be valid", name);
998        }
999    }
1000
1001    #[tokio::test]
1002    async fn test_context_aware_handler() {
1003        use crate::context::{RequestContext, notification_channel};
1004        use crate::protocol::{ProgressToken, RequestId};
1005
1006        #[derive(Debug, Deserialize, JsonSchema)]
1007        struct ProcessInput {
1008            count: i32,
1009        }
1010
1011        let tool = ToolBuilder::new("process")
1012            .description("Process with context")
1013            .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
1014                // Simulate progress reporting
1015                for i in 0..input.count {
1016                    if ctx.is_cancelled() {
1017                        return Ok(CallToolResult::error("Cancelled"));
1018                    }
1019                    ctx.report_progress(i as f64, Some(input.count as f64), None)
1020                        .await;
1021                }
1022                Ok(CallToolResult::text(format!(
1023                    "Processed {} items",
1024                    input.count
1025                )))
1026            })
1027            .build()
1028            .expect("valid tool name");
1029
1030        assert_eq!(tool.name, "process");
1031        assert!(tool.uses_context());
1032
1033        // Test with a context that has progress token and notification sender
1034        let (tx, mut rx) = notification_channel(10);
1035        let ctx = RequestContext::new(RequestId::Number(1))
1036            .with_progress_token(ProgressToken::Number(42))
1037            .with_notification_sender(tx);
1038
1039        let result = tool
1040            .call_with_context(ctx, serde_json::json!({"count": 3}))
1041            .await
1042            .unwrap();
1043
1044        assert!(!result.is_error);
1045
1046        // Check that progress notifications were sent
1047        let mut progress_count = 0;
1048        while rx.try_recv().is_ok() {
1049            progress_count += 1;
1050        }
1051        assert_eq!(progress_count, 3);
1052    }
1053
1054    #[tokio::test]
1055    async fn test_context_aware_handler_cancellation() {
1056        use crate::context::RequestContext;
1057        use crate::protocol::RequestId;
1058        use std::sync::Arc;
1059        use std::sync::atomic::{AtomicI32, Ordering};
1060
1061        #[derive(Debug, Deserialize, JsonSchema)]
1062        struct LongRunningInput {
1063            iterations: i32,
1064        }
1065
1066        let iterations_completed = Arc::new(AtomicI32::new(0));
1067        let iterations_ref = iterations_completed.clone();
1068
1069        let tool = ToolBuilder::new("long_running")
1070            .description("Long running task")
1071            .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
1072                let completed = iterations_ref.clone();
1073                async move {
1074                    for i in 0..input.iterations {
1075                        if ctx.is_cancelled() {
1076                            return Ok(CallToolResult::error("Cancelled"));
1077                        }
1078                        completed.fetch_add(1, Ordering::SeqCst);
1079                        // Simulate work
1080                        tokio::task::yield_now().await;
1081                        // Cancel after iteration 2
1082                        if i == 2 {
1083                            ctx.cancellation_token().cancel();
1084                        }
1085                    }
1086                    Ok(CallToolResult::text("Done"))
1087                }
1088            })
1089            .build()
1090            .expect("valid tool name");
1091
1092        let ctx = RequestContext::new(RequestId::Number(1));
1093
1094        let result = tool
1095            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1096            .await
1097            .unwrap();
1098
1099        // Should have been cancelled after 3 iterations (0, 1, 2)
1100        // The next iteration (3) checks cancellation and returns
1101        assert!(result.is_error);
1102        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1103    }
1104
1105    #[tokio::test]
1106    async fn test_tool_builder_with_enhanced_fields() {
1107        let output_schema = serde_json::json!({
1108            "type": "object",
1109            "properties": {
1110                "greeting": {"type": "string"}
1111            }
1112        });
1113
1114        let tool = ToolBuilder::new("greet")
1115            .title("Greeting Tool")
1116            .description("Greet someone")
1117            .output_schema(output_schema.clone())
1118            .icon("https://example.com/icon.png")
1119            .icon_with_meta(
1120                "https://example.com/icon-large.png",
1121                Some("image/png".to_string()),
1122                Some(vec!["96x96".to_string()]),
1123            )
1124            .handler(|input: GreetInput| async move {
1125                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1126            })
1127            .build()
1128            .expect("valid tool name");
1129
1130        assert_eq!(tool.name, "greet");
1131        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1132        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1133        assert_eq!(tool.output_schema, Some(output_schema));
1134        assert!(tool.icons.is_some());
1135        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1136
1137        // Test definition includes new fields
1138        let def = tool.definition();
1139        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1140        assert!(def.output_schema.is_some());
1141        assert!(def.icons.is_some());
1142    }
1143
1144    #[tokio::test]
1145    async fn test_handler_with_state() {
1146        let shared = Arc::new("shared-state".to_string());
1147
1148        let tool = ToolBuilder::new("stateful")
1149            .description("Uses shared state")
1150            .handler_with_state(shared, |state: Arc<String>, input: GreetInput| async move {
1151                Ok(CallToolResult::text(format!(
1152                    "{}: Hello, {}!",
1153                    state, input.name
1154                )))
1155            })
1156            .build()
1157            .expect("valid tool name");
1158
1159        let result = tool
1160            .call(serde_json::json!({"name": "World"}))
1161            .await
1162            .unwrap();
1163        assert!(!result.is_error);
1164    }
1165
1166    #[tokio::test]
1167    async fn test_handler_with_state_and_context() {
1168        use crate::context::RequestContext;
1169        use crate::protocol::RequestId;
1170
1171        let shared = Arc::new(42_i32);
1172
1173        let tool = ToolBuilder::new("stateful_ctx")
1174            .description("Uses state and context")
1175            .handler_with_state_and_context(
1176                shared,
1177                |state: Arc<i32>, _ctx: RequestContext, input: GreetInput| async move {
1178                    Ok(CallToolResult::text(format!(
1179                        "{}: Hello, {}!",
1180                        state, input.name
1181                    )))
1182                },
1183            )
1184            .build()
1185            .expect("valid tool name");
1186
1187        assert!(tool.uses_context());
1188
1189        let ctx = RequestContext::new(RequestId::Number(1));
1190        let result = tool
1191            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1192            .await
1193            .unwrap();
1194        assert!(!result.is_error);
1195    }
1196
1197    #[tokio::test]
1198    async fn test_handler_no_params() {
1199        let tool = ToolBuilder::new("no_params")
1200            .description("Takes no parameters")
1201            .handler_no_params(|| async { Ok(CallToolResult::text("no params result")) })
1202            .expect("valid tool name");
1203
1204        assert_eq!(tool.name, "no_params");
1205
1206        // Should work with empty args
1207        let result = tool.call(serde_json::json!({})).await.unwrap();
1208        assert!(!result.is_error);
1209
1210        // Should also work with unexpected args (ignored)
1211        let result = tool
1212            .call(serde_json::json!({"unexpected": "value"}))
1213            .await
1214            .unwrap();
1215        assert!(!result.is_error);
1216
1217        // Check input schema is an empty-properties object
1218        let schema = tool.definition().input_schema;
1219        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1220        assert!(
1221            schema
1222                .get("properties")
1223                .unwrap()
1224                .as_object()
1225                .unwrap()
1226                .is_empty()
1227        );
1228    }
1229}