Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8
9use std::borrow::Cow;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use schemars::{JsonSchema, Schema, SchemaGenerator};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18
19use crate::context::RequestContext;
20use crate::error::{Error, Result};
21use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
22
23/// A marker type for tools that take no parameters.
24///
25/// Use this instead of `()` when defining tools with no input parameters.
26/// The unit type `()` generates `"type": "null"` in JSON Schema, which many
27/// MCP clients reject. `NoParams` generates `"type": "object"` with no
28/// required properties, which is the correct schema for parameterless tools.
29///
30/// # Example
31///
32/// ```rust
33/// use tower_mcp::{ToolBuilder, CallToolResult, NoParams};
34///
35/// let tool = ToolBuilder::new("get_status")
36///     .description("Get current status")
37///     .handler(|_input: NoParams| async move {
38///         Ok(CallToolResult::text("OK"))
39///     })
40///     .build()
41///     .unwrap();
42/// ```
43#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
44pub struct NoParams;
45
46impl<'de> serde::Deserialize<'de> for NoParams {
47    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
48    where
49        D: serde::Deserializer<'de>,
50    {
51        // Accept null, empty object, or any object (ignoring all fields)
52        struct NoParamsVisitor;
53
54        impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
55            type Value = NoParams;
56
57            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
58                formatter.write_str("null or an object")
59            }
60
61            fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
62            where
63                E: serde::de::Error,
64            {
65                Ok(NoParams)
66            }
67
68            fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
69            where
70                E: serde::de::Error,
71            {
72                Ok(NoParams)
73            }
74
75            fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
76            where
77                D: serde::Deserializer<'de>,
78            {
79                serde::Deserialize::deserialize(deserializer)
80            }
81
82            fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
83            where
84                A: serde::de::MapAccess<'de>,
85            {
86                // Drain the map, ignoring all entries
87                while map
88                    .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
89                    .is_some()
90                {}
91                Ok(NoParams)
92            }
93        }
94
95        deserializer.deserialize_any(NoParamsVisitor)
96    }
97}
98
99impl JsonSchema for NoParams {
100    fn schema_name() -> Cow<'static, str> {
101        Cow::Borrowed("NoParams")
102    }
103
104    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
105        serde_json::json!({
106            "type": "object"
107        })
108        .try_into()
109        .expect("valid schema")
110    }
111}
112
113/// Validate a tool name according to MCP spec.
114///
115/// Tool names must be:
116/// - 1-128 characters long
117/// - Contain only alphanumeric characters, underscores, hyphens, and dots
118///
119/// Returns `Ok(())` if valid, `Err` with description if invalid.
120pub fn validate_tool_name(name: &str) -> Result<()> {
121    if name.is_empty() {
122        return Err(Error::tool("Tool name cannot be empty"));
123    }
124    if name.len() > 128 {
125        return Err(Error::tool(format!(
126            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
127            name,
128            name.len()
129        )));
130    }
131    if let Some(invalid_char) = name
132        .chars()
133        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
134    {
135        return Err(Error::tool(format!(
136            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
137            name, invalid_char
138        )));
139    }
140    Ok(())
141}
142
143/// A boxed future for tool handlers
144pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
145
146/// Tool handler trait - the core abstraction for tool execution
147pub trait ToolHandler: Send + Sync {
148    /// Execute the tool with the given arguments
149    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
150
151    /// Execute the tool with request context for progress/cancellation support
152    ///
153    /// The default implementation ignores the context and calls `call`.
154    /// Override this to receive progress/cancellation context.
155    fn call_with_context(
156        &self,
157        _ctx: RequestContext,
158        args: Value,
159    ) -> BoxFuture<'_, Result<CallToolResult>> {
160        self.call(args)
161    }
162
163    /// Returns true if this handler uses context (for optimization)
164    fn uses_context(&self) -> bool {
165        false
166    }
167
168    /// Get the tool's input schema
169    fn input_schema(&self) -> Value;
170}
171
172/// A complete tool definition with handler
173pub struct Tool {
174    pub name: String,
175    pub title: Option<String>,
176    pub description: Option<String>,
177    pub output_schema: Option<Value>,
178    pub icons: Option<Vec<ToolIcon>>,
179    pub annotations: Option<ToolAnnotations>,
180    handler: Arc<dyn ToolHandler>,
181}
182
183impl std::fmt::Debug for Tool {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("Tool")
186            .field("name", &self.name)
187            .field("title", &self.title)
188            .field("description", &self.description)
189            .field("output_schema", &self.output_schema)
190            .field("icons", &self.icons)
191            .field("annotations", &self.annotations)
192            .finish_non_exhaustive()
193    }
194}
195
196impl Tool {
197    /// Create a new tool builder
198    pub fn builder(name: impl Into<String>) -> ToolBuilder {
199        ToolBuilder::new(name)
200    }
201
202    /// Get the tool definition for tools/list
203    pub fn definition(&self) -> ToolDefinition {
204        ToolDefinition {
205            name: self.name.clone(),
206            title: self.title.clone(),
207            description: self.description.clone(),
208            input_schema: self.handler.input_schema(),
209            output_schema: self.output_schema.clone(),
210            icons: self.icons.clone(),
211            annotations: self.annotations.clone(),
212        }
213    }
214
215    /// Call the tool without context
216    pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
217        self.handler.call(args)
218    }
219
220    /// Call the tool with request context
221    ///
222    /// Use this when you have a RequestContext available for progress/cancellation.
223    pub fn call_with_context(
224        &self,
225        ctx: RequestContext,
226        args: Value,
227    ) -> BoxFuture<'_, Result<CallToolResult>> {
228        self.handler.call_with_context(ctx, args)
229    }
230
231    /// Returns true if this tool uses context
232    pub fn uses_context(&self) -> bool {
233        self.handler.uses_context()
234    }
235}
236
237// =============================================================================
238// Builder API
239// =============================================================================
240
241/// Builder for creating tools with a fluent API
242///
243/// # Example
244///
245/// ```rust
246/// use tower_mcp::{ToolBuilder, CallToolResult};
247/// use schemars::JsonSchema;
248/// use serde::Deserialize;
249///
250/// #[derive(Debug, Deserialize, JsonSchema)]
251/// struct GreetInput {
252///     name: String,
253/// }
254///
255/// let tool = ToolBuilder::new("greet")
256///     .description("Greet someone by name")
257///     .handler(|input: GreetInput| async move {
258///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
259///     })
260///     .build()
261///     .expect("valid tool name");
262///
263/// assert_eq!(tool.name, "greet");
264/// ```
265pub struct ToolBuilder {
266    name: String,
267    title: Option<String>,
268    description: Option<String>,
269    output_schema: Option<Value>,
270    icons: Option<Vec<ToolIcon>>,
271    annotations: Option<ToolAnnotations>,
272}
273
274impl ToolBuilder {
275    pub fn new(name: impl Into<String>) -> Self {
276        Self {
277            name: name.into(),
278            title: None,
279            description: None,
280            output_schema: None,
281            icons: None,
282            annotations: None,
283        }
284    }
285
286    /// Set a human-readable title for the tool
287    pub fn title(mut self, title: impl Into<String>) -> Self {
288        self.title = Some(title.into());
289        self
290    }
291
292    /// Set the output schema (JSON Schema for structured output)
293    pub fn output_schema(mut self, schema: Value) -> Self {
294        self.output_schema = Some(schema);
295        self
296    }
297
298    /// Add an icon for the tool
299    pub fn icon(mut self, src: impl Into<String>) -> Self {
300        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
301            src: src.into(),
302            mime_type: None,
303            sizes: None,
304        });
305        self
306    }
307
308    /// Add an icon with metadata
309    pub fn icon_with_meta(
310        mut self,
311        src: impl Into<String>,
312        mime_type: Option<String>,
313        sizes: Option<Vec<String>>,
314    ) -> Self {
315        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
316            src: src.into(),
317            mime_type,
318            sizes,
319        });
320        self
321    }
322
323    /// Set the tool description
324    pub fn description(mut self, description: impl Into<String>) -> Self {
325        self.description = Some(description.into());
326        self
327    }
328
329    /// Mark the tool as read-only (does not modify state)
330    pub fn read_only(mut self) -> Self {
331        self.annotations
332            .get_or_insert_with(ToolAnnotations::default)
333            .read_only_hint = true;
334        self
335    }
336
337    /// Mark the tool as non-destructive
338    pub fn non_destructive(mut self) -> Self {
339        self.annotations
340            .get_or_insert_with(ToolAnnotations::default)
341            .destructive_hint = false;
342        self
343    }
344
345    /// Mark the tool as idempotent (same args = same effect)
346    pub fn idempotent(mut self) -> Self {
347        self.annotations
348            .get_or_insert_with(ToolAnnotations::default)
349            .idempotent_hint = true;
350        self
351    }
352
353    /// Set tool annotations directly
354    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
355        self.annotations = Some(annotations);
356        self
357    }
358
359    /// Specify input type and handler.
360    ///
361    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
362    /// The handler receives the deserialized input and returns a `CallToolResult`.
363    ///
364    /// # State Sharing
365    ///
366    /// To share state across tool calls (e.g., database connections, API clients),
367    /// wrap your state in an `Arc` and clone it into the async block:
368    ///
369    /// ```rust
370    /// use std::sync::Arc;
371    /// use tower_mcp::{ToolBuilder, CallToolResult};
372    /// use schemars::JsonSchema;
373    /// use serde::Deserialize;
374    ///
375    /// struct AppState {
376    ///     api_key: String,
377    /// }
378    ///
379    /// #[derive(Debug, Deserialize, JsonSchema)]
380    /// struct MyInput {
381    ///     query: String,
382    /// }
383    ///
384    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
385    ///
386    /// let tool = ToolBuilder::new("my_tool")
387    ///     .description("A tool that uses shared state")
388    ///     .handler(move |input: MyInput| {
389    ///         let state = state.clone(); // Clone Arc for the async block
390    ///         async move {
391    ///             // Use state.api_key here...
392    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
393    ///         }
394    ///     })
395    ///     .build()
396    ///     .expect("valid tool name");
397    /// ```
398    ///
399    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
400    /// cloning it inside the closure body allows each async invocation to
401    /// have its own reference to the shared state.
402    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
403    where
404        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
405        F: Fn(I) -> Fut + Send + Sync + 'static,
406        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
407    {
408        ToolBuilderWithHandler {
409            name: self.name,
410            title: self.title,
411            description: self.description,
412            output_schema: self.output_schema,
413            icons: self.icons,
414            annotations: self.annotations,
415            handler,
416            _phantom: std::marker::PhantomData,
417        }
418    }
419
420    /// Specify input type and context-aware handler
421    ///
422    /// The handler receives a `RequestContext` for progress reporting and
423    /// cancellation checking, along with the deserialized input.
424    ///
425    /// # Example
426    ///
427    /// ```rust
428    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
429    /// use schemars::JsonSchema;
430    /// use serde::Deserialize;
431    ///
432    /// #[derive(Debug, Deserialize, JsonSchema)]
433    /// struct ProcessInput {
434    ///     items: Vec<String>,
435    /// }
436    ///
437    /// let tool = ToolBuilder::new("process")
438    ///     .description("Process items with progress")
439    ///     .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
440    ///         for (i, item) in input.items.iter().enumerate() {
441    ///             if ctx.is_cancelled() {
442    ///                 return Ok(CallToolResult::error("Cancelled"));
443    ///             }
444    ///             ctx.report_progress(i as f64, Some(input.items.len() as f64), Some("Processing...")).await;
445    ///             // Process item...
446    ///         }
447    ///         Ok(CallToolResult::text("Done"))
448    ///     })
449    ///     .build()
450    ///     .expect("valid tool name");
451    /// ```
452    pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
453    where
454        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
455        F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
456        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
457    {
458        ToolBuilderWithContextHandler {
459            name: self.name,
460            title: self.title,
461            description: self.description,
462            output_schema: self.output_schema,
463            icons: self.icons,
464            annotations: self.annotations,
465            handler,
466            _phantom: std::marker::PhantomData,
467        }
468    }
469
470    /// Specify input type, shared state, and handler.
471    ///
472    /// The state is cloned for each invocation, so wrapping it in an `Arc`
473    /// is recommended for expensive-to-clone types. This eliminates the
474    /// boilerplate of cloning state inside a `move` closure.
475    ///
476    /// # Example
477    ///
478    /// ```rust
479    /// use std::sync::Arc;
480    /// use tower_mcp::{ToolBuilder, CallToolResult};
481    /// use schemars::JsonSchema;
482    /// use serde::Deserialize;
483    ///
484    /// #[derive(Debug, Deserialize, JsonSchema)]
485    /// struct QueryInput { query: String }
486    ///
487    /// struct Db { connection_string: String }
488    ///
489    /// let db = Arc::new(Db { connection_string: "postgres://...".to_string() });
490    ///
491    /// let tool = ToolBuilder::new("search")
492    ///     .description("Search the database")
493    ///     .handler_with_state(db, |db: Arc<Db>, input: QueryInput| async move {
494    ///         Ok(CallToolResult::text(format!("Queried: {}", input.query)))
495    ///     })
496    ///     .build()
497    ///     .expect("valid tool name");
498    /// ```
499    pub fn handler_with_state<S, I, F, Fut>(
500        self,
501        state: S,
502        handler: F,
503    ) -> ToolBuilderWithHandler<
504        I,
505        impl Fn(I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
506    >
507    where
508        S: Clone + Send + Sync + 'static,
509        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
510        F: Fn(S, I) -> Fut + Send + Sync + 'static,
511        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
512    {
513        let handler = Arc::new(handler);
514        self.handler(move |input: I| {
515            let state = state.clone();
516            let handler = handler.clone();
517            Box::pin(async move { handler(state, input).await })
518                as BoxFuture<'static, Result<CallToolResult>>
519        })
520    }
521
522    /// Specify input type, shared state, and context-aware handler.
523    ///
524    /// Combines state injection with `RequestContext` access for progress
525    /// reporting, cancellation, sampling, and logging.
526    ///
527    /// # Example
528    ///
529    /// ```rust
530    /// use std::sync::Arc;
531    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
532    /// use schemars::JsonSchema;
533    /// use serde::Deserialize;
534    ///
535    /// #[derive(Debug, Deserialize, JsonSchema)]
536    /// struct QueryInput { query: String }
537    ///
538    /// struct Db { connection_string: String }
539    ///
540    /// let db = Arc::new(Db { connection_string: "postgres://...".to_string() });
541    ///
542    /// let tool = ToolBuilder::new("search")
543    ///     .description("Search the database with progress")
544    ///     .handler_with_state_and_context(db, |db: Arc<Db>, ctx: RequestContext, input: QueryInput| async move {
545    ///         ctx.report_progress(0.0, Some(1.0), Some("Searching...")).await;
546    ///         Ok(CallToolResult::text(format!("Queried: {}", input.query)))
547    ///     })
548    ///     .build()
549    ///     .expect("valid tool name");
550    /// ```
551    pub fn handler_with_state_and_context<S, I, F, Fut>(
552        self,
553        state: S,
554        handler: F,
555    ) -> ToolBuilderWithContextHandler<
556        I,
557        impl Fn(RequestContext, I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
558    >
559    where
560        S: Clone + Send + Sync + 'static,
561        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
562        F: Fn(S, RequestContext, I) -> Fut + Send + Sync + 'static,
563        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
564    {
565        let handler = Arc::new(handler);
566        self.handler_with_context(move |ctx: RequestContext, input: I| {
567            let state = state.clone();
568            let handler = handler.clone();
569            Box::pin(async move { handler(state, ctx, input).await })
570                as BoxFuture<'static, Result<CallToolResult>>
571        })
572    }
573
574    /// Create a tool that takes no parameters.
575    ///
576    /// The handler receives no input arguments. An empty object input schema
577    /// is generated automatically. Returns `Result<Tool>` directly.
578    ///
579    /// # Example
580    ///
581    /// ```rust
582    /// use tower_mcp::{ToolBuilder, CallToolResult};
583    ///
584    /// let tool = ToolBuilder::new("server_time")
585    ///     .description("Get the current server time")
586    ///     .handler_no_params(|| async {
587    ///         Ok(CallToolResult::text("2025-01-01T00:00:00Z"))
588    ///     })
589    ///     .expect("valid tool name");
590    ///
591    /// assert_eq!(tool.name, "server_time");
592    /// ```
593    pub fn handler_no_params<F, Fut>(self, handler: F) -> Result<Tool>
594    where
595        F: Fn() -> Fut + Send + Sync + 'static,
596        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
597    {
598        validate_tool_name(&self.name)?;
599        Ok(Tool {
600            name: self.name,
601            title: self.title,
602            description: self.description,
603            output_schema: self.output_schema,
604            icons: self.icons,
605            annotations: self.annotations,
606            handler: Arc::new(NoParamsHandler { handler }),
607        })
608    }
609
610    /// Create a tool with no parameters but with shared state
611    ///
612    /// Use this for tools that need access to shared state (e.g., a connection pool,
613    /// configuration, or shared registry) but don't take any input parameters.
614    ///
615    /// Returns an error if the tool name is invalid.
616    ///
617    /// # Example
618    ///
619    /// ```rust
620    /// use std::sync::Arc;
621    /// use tower_mcp::{ToolBuilder, CallToolResult};
622    ///
623    /// struct Config { version: String }
624    ///
625    /// let config = Arc::new(Config { version: "1.0.0".to_string() });
626    ///
627    /// let tool = ToolBuilder::new("get_version")
628    ///     .description("Get the server version")
629    ///     .handler_no_params_with_state(config, |config: Arc<Config>| async move {
630    ///         Ok(CallToolResult::text(&config.version))
631    ///     })
632    ///     .expect("valid tool name");
633    ///
634    /// assert_eq!(tool.name, "get_version");
635    /// ```
636    pub fn handler_no_params_with_state<S, F, Fut>(self, state: S, handler: F) -> Result<Tool>
637    where
638        S: Clone + Send + Sync + 'static,
639        F: Fn(S) -> Fut + Send + Sync + 'static,
640        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
641    {
642        validate_tool_name(&self.name)?;
643        Ok(Tool {
644            name: self.name,
645            title: self.title,
646            description: self.description,
647            output_schema: self.output_schema,
648            icons: self.icons,
649            annotations: self.annotations,
650            handler: Arc::new(NoParamsWithStateHandler { state, handler }),
651        })
652    }
653
654    /// Create a tool with raw JSON handling (no automatic deserialization)
655    ///
656    /// Returns an error if the tool name is invalid.
657    pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
658    where
659        F: Fn(Value) -> Fut + Send + Sync + 'static,
660        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
661    {
662        validate_tool_name(&self.name)?;
663        Ok(Tool {
664            name: self.name,
665            title: self.title,
666            description: self.description,
667            output_schema: self.output_schema,
668            icons: self.icons,
669            annotations: self.annotations,
670            handler: Arc::new(RawHandler { handler }),
671        })
672    }
673
674    /// Create a tool with raw JSON handling and request context
675    ///
676    /// The handler receives a `RequestContext` for progress reporting,
677    /// cancellation, sampling, and logging, along with raw JSON arguments.
678    ///
679    /// Returns an error if the tool name is invalid.
680    pub fn raw_handler_with_context<F, Fut>(self, handler: F) -> Result<Tool>
681    where
682        F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
683        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
684    {
685        validate_tool_name(&self.name)?;
686        Ok(Tool {
687            name: self.name,
688            title: self.title,
689            description: self.description,
690            output_schema: self.output_schema,
691            icons: self.icons,
692            annotations: self.annotations,
693            handler: Arc::new(RawContextHandler { handler }),
694        })
695    }
696}
697
698/// Builder state after handler is specified
699pub struct ToolBuilderWithHandler<I, F> {
700    name: String,
701    title: Option<String>,
702    description: Option<String>,
703    output_schema: Option<Value>,
704    icons: Option<Vec<ToolIcon>>,
705    annotations: Option<ToolAnnotations>,
706    handler: F,
707    _phantom: std::marker::PhantomData<I>,
708}
709
710impl<I, F, Fut> ToolBuilderWithHandler<I, F>
711where
712    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
713    F: Fn(I) -> Fut + Send + Sync + 'static,
714    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
715{
716    /// Build the tool
717    ///
718    /// Returns an error if the tool name is invalid.
719    pub fn build(self) -> Result<Tool> {
720        validate_tool_name(&self.name)?;
721        Ok(Tool {
722            name: self.name,
723            title: self.title,
724            description: self.description,
725            output_schema: self.output_schema,
726            icons: self.icons,
727            annotations: self.annotations,
728            handler: Arc::new(TypedHandler {
729                handler: self.handler,
730                _phantom: std::marker::PhantomData,
731            }),
732        })
733    }
734}
735
736/// Builder state after context-aware handler is specified
737pub struct ToolBuilderWithContextHandler<I, F> {
738    name: String,
739    title: Option<String>,
740    description: Option<String>,
741    output_schema: Option<Value>,
742    icons: Option<Vec<ToolIcon>>,
743    annotations: Option<ToolAnnotations>,
744    handler: F,
745    _phantom: std::marker::PhantomData<I>,
746}
747
748impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
749where
750    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
751    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
752    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
753{
754    /// Build the tool
755    ///
756    /// Returns an error if the tool name is invalid.
757    pub fn build(self) -> Result<Tool> {
758        validate_tool_name(&self.name)?;
759        Ok(Tool {
760            name: self.name,
761            title: self.title,
762            description: self.description,
763            output_schema: self.output_schema,
764            icons: self.icons,
765            annotations: self.annotations,
766            handler: Arc::new(ContextAwareHandler {
767                handler: self.handler,
768                _phantom: std::marker::PhantomData,
769            }),
770        })
771    }
772}
773
774// =============================================================================
775// Handler implementations
776// =============================================================================
777
778/// Handler that deserializes input to a specific type
779struct TypedHandler<I, F> {
780    handler: F,
781    _phantom: std::marker::PhantomData<I>,
782}
783
784impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
785where
786    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
787    F: Fn(I) -> Fut + Send + Sync + 'static,
788    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
789{
790    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
791        Box::pin(async move {
792            let input: I = serde_json::from_value(args)
793                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
794            (self.handler)(input).await
795        })
796    }
797
798    fn input_schema(&self) -> Value {
799        let schema = schemars::schema_for!(I);
800        serde_json::to_value(schema).unwrap_or_else(|_| {
801            serde_json::json!({
802                "type": "object"
803            })
804        })
805    }
806}
807
808/// Handler that works with raw JSON
809struct RawHandler<F> {
810    handler: F,
811}
812
813impl<F, Fut> ToolHandler for RawHandler<F>
814where
815    F: Fn(Value) -> Fut + Send + Sync + 'static,
816    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
817{
818    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
819        Box::pin((self.handler)(args))
820    }
821
822    fn input_schema(&self) -> Value {
823        // Raw handlers accept any JSON
824        serde_json::json!({
825            "type": "object",
826            "additionalProperties": true
827        })
828    }
829}
830
831/// Handler that works with raw JSON and request context
832struct RawContextHandler<F> {
833    handler: F,
834}
835
836impl<F, Fut> ToolHandler for RawContextHandler<F>
837where
838    F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
839    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
840{
841    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
842        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
843        self.call_with_context(ctx, args)
844    }
845
846    fn call_with_context(
847        &self,
848        ctx: RequestContext,
849        args: Value,
850    ) -> BoxFuture<'_, Result<CallToolResult>> {
851        Box::pin((self.handler)(ctx, args))
852    }
853
854    fn uses_context(&self) -> bool {
855        true
856    }
857
858    fn input_schema(&self) -> Value {
859        // Raw context handlers accept any JSON object
860        serde_json::json!({
861            "type": "object",
862            "additionalProperties": true
863        })
864    }
865}
866
867/// Handler that receives request context for progress/cancellation
868struct ContextAwareHandler<I, F> {
869    handler: F,
870    _phantom: std::marker::PhantomData<I>,
871}
872
873impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
874where
875    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
876    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
877    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
878{
879    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
880        // When called without context, create a dummy context
881        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
882        self.call_with_context(ctx, args)
883    }
884
885    fn call_with_context(
886        &self,
887        ctx: RequestContext,
888        args: Value,
889    ) -> BoxFuture<'_, Result<CallToolResult>> {
890        Box::pin(async move {
891            let input: I = serde_json::from_value(args)
892                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
893            (self.handler)(ctx, input).await
894        })
895    }
896
897    fn uses_context(&self) -> bool {
898        true
899    }
900
901    fn input_schema(&self) -> Value {
902        let schema = schemars::schema_for!(I);
903        serde_json::to_value(schema).unwrap_or_else(|_| {
904            serde_json::json!({
905                "type": "object"
906            })
907        })
908    }
909}
910
911/// Handler that takes no parameters
912struct NoParamsHandler<F> {
913    handler: F,
914}
915
916impl<F, Fut> ToolHandler for NoParamsHandler<F>
917where
918    F: Fn() -> Fut + Send + Sync + 'static,
919    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
920{
921    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
922        Box::pin((self.handler)())
923    }
924
925    fn input_schema(&self) -> Value {
926        serde_json::json!({
927            "type": "object",
928            "properties": {}
929        })
930    }
931}
932
933/// Handler that takes no parameters but has shared state
934struct NoParamsWithStateHandler<S, F> {
935    state: S,
936    handler: F,
937}
938
939impl<S, F, Fut> ToolHandler for NoParamsWithStateHandler<S, F>
940where
941    S: Clone + Send + Sync + 'static,
942    F: Fn(S) -> Fut + Send + Sync + 'static,
943    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
944{
945    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
946        let state = self.state.clone();
947        let fut = (self.handler)(state);
948        Box::pin(fut)
949    }
950
951    fn input_schema(&self) -> Value {
952        serde_json::json!({
953            "type": "object",
954            "properties": {}
955        })
956    }
957}
958
959// =============================================================================
960// Trait-based tool definition
961// =============================================================================
962
963/// Trait for defining tools with full control
964///
965/// Implement this trait when you need more control than the builder provides,
966/// or when you want to define tools as standalone types.
967///
968/// # Example
969///
970/// ```rust
971/// use tower_mcp::tool::McpTool;
972/// use tower_mcp::error::Result;
973/// use schemars::JsonSchema;
974/// use serde::{Deserialize, Serialize};
975///
976/// #[derive(Debug, Deserialize, JsonSchema)]
977/// struct AddInput {
978///     a: i64,
979///     b: i64,
980/// }
981///
982/// struct AddTool;
983///
984/// impl McpTool for AddTool {
985///     const NAME: &'static str = "add";
986///     const DESCRIPTION: &'static str = "Add two numbers";
987///
988///     type Input = AddInput;
989///     type Output = i64;
990///
991///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
992///         Ok(input.a + input.b)
993///     }
994/// }
995///
996/// let tool = AddTool.into_tool().expect("valid tool name");
997/// assert_eq!(tool.name, "add");
998/// ```
999pub trait McpTool: Send + Sync + 'static {
1000    const NAME: &'static str;
1001    const DESCRIPTION: &'static str;
1002
1003    type Input: JsonSchema + DeserializeOwned + Send;
1004    type Output: Serialize + Send;
1005
1006    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1007
1008    /// Optional annotations for the tool
1009    fn annotations(&self) -> Option<ToolAnnotations> {
1010        None
1011    }
1012
1013    /// Convert to a Tool instance
1014    ///
1015    /// Returns an error if the tool name is invalid.
1016    fn into_tool(self) -> Result<Tool>
1017    where
1018        Self: Sized,
1019    {
1020        validate_tool_name(Self::NAME)?;
1021        let annotations = self.annotations();
1022        let tool = Arc::new(self);
1023        Ok(Tool {
1024            name: Self::NAME.to_string(),
1025            title: None,
1026            description: Some(Self::DESCRIPTION.to_string()),
1027            output_schema: None,
1028            icons: None,
1029            annotations,
1030            handler: Arc::new(McpToolHandler { tool }),
1031        })
1032    }
1033}
1034
1035/// Wrapper to make McpTool implement ToolHandler
1036struct McpToolHandler<T: McpTool> {
1037    tool: Arc<T>,
1038}
1039
1040impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1041    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1042        let tool = self.tool.clone();
1043        Box::pin(async move {
1044            let input: T::Input = serde_json::from_value(args)
1045                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1046            let output = tool.call(input).await?;
1047            let value = serde_json::to_value(output)
1048                .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
1049            Ok(CallToolResult::json(value))
1050        })
1051    }
1052
1053    fn input_schema(&self) -> Value {
1054        let schema = schemars::schema_for!(T::Input);
1055        serde_json::to_value(schema).unwrap_or_else(|_| {
1056            serde_json::json!({
1057                "type": "object"
1058            })
1059        })
1060    }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use super::*;
1066    use schemars::JsonSchema;
1067    use serde::Deserialize;
1068
1069    #[derive(Debug, Deserialize, JsonSchema)]
1070    struct GreetInput {
1071        name: String,
1072    }
1073
1074    #[tokio::test]
1075    async fn test_builder_tool() {
1076        let tool = ToolBuilder::new("greet")
1077            .description("Greet someone")
1078            .handler(|input: GreetInput| async move {
1079                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1080            })
1081            .build()
1082            .expect("valid tool name");
1083
1084        assert_eq!(tool.name, "greet");
1085        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1086
1087        let result = tool
1088            .call(serde_json::json!({"name": "World"}))
1089            .await
1090            .unwrap();
1091
1092        assert!(!result.is_error);
1093    }
1094
1095    #[tokio::test]
1096    async fn test_raw_handler() {
1097        let tool = ToolBuilder::new("echo")
1098            .description("Echo input")
1099            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
1100            .expect("valid tool name");
1101
1102        let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
1103
1104        assert!(!result.is_error);
1105    }
1106
1107    #[test]
1108    fn test_invalid_tool_name_empty() {
1109        let result = ToolBuilder::new("")
1110            .description("Empty name")
1111            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1112
1113        assert!(result.is_err());
1114        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1115    }
1116
1117    #[test]
1118    fn test_invalid_tool_name_too_long() {
1119        let long_name = "a".repeat(129);
1120        let result = ToolBuilder::new(long_name)
1121            .description("Too long")
1122            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1123
1124        assert!(result.is_err());
1125        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1126    }
1127
1128    #[test]
1129    fn test_invalid_tool_name_bad_chars() {
1130        let result = ToolBuilder::new("my tool!")
1131            .description("Bad chars")
1132            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1133
1134        assert!(result.is_err());
1135        assert!(
1136            result
1137                .unwrap_err()
1138                .to_string()
1139                .contains("invalid character")
1140        );
1141    }
1142
1143    #[test]
1144    fn test_valid_tool_names() {
1145        // All valid characters
1146        let names = [
1147            "my_tool",
1148            "my-tool",
1149            "my.tool",
1150            "MyTool123",
1151            "a",
1152            &"a".repeat(128),
1153        ];
1154        for name in names {
1155            let result = ToolBuilder::new(name)
1156                .description("Valid")
1157                .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1158            assert!(result.is_ok(), "Expected '{}' to be valid", name);
1159        }
1160    }
1161
1162    #[tokio::test]
1163    async fn test_context_aware_handler() {
1164        use crate::context::{RequestContext, notification_channel};
1165        use crate::protocol::{ProgressToken, RequestId};
1166
1167        #[derive(Debug, Deserialize, JsonSchema)]
1168        struct ProcessInput {
1169            count: i32,
1170        }
1171
1172        let tool = ToolBuilder::new("process")
1173            .description("Process with context")
1174            .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
1175                // Simulate progress reporting
1176                for i in 0..input.count {
1177                    if ctx.is_cancelled() {
1178                        return Ok(CallToolResult::error("Cancelled"));
1179                    }
1180                    ctx.report_progress(i as f64, Some(input.count as f64), None)
1181                        .await;
1182                }
1183                Ok(CallToolResult::text(format!(
1184                    "Processed {} items",
1185                    input.count
1186                )))
1187            })
1188            .build()
1189            .expect("valid tool name");
1190
1191        assert_eq!(tool.name, "process");
1192        assert!(tool.uses_context());
1193
1194        // Test with a context that has progress token and notification sender
1195        let (tx, mut rx) = notification_channel(10);
1196        let ctx = RequestContext::new(RequestId::Number(1))
1197            .with_progress_token(ProgressToken::Number(42))
1198            .with_notification_sender(tx);
1199
1200        let result = tool
1201            .call_with_context(ctx, serde_json::json!({"count": 3}))
1202            .await
1203            .unwrap();
1204
1205        assert!(!result.is_error);
1206
1207        // Check that progress notifications were sent
1208        let mut progress_count = 0;
1209        while rx.try_recv().is_ok() {
1210            progress_count += 1;
1211        }
1212        assert_eq!(progress_count, 3);
1213    }
1214
1215    #[tokio::test]
1216    async fn test_context_aware_handler_cancellation() {
1217        use crate::context::RequestContext;
1218        use crate::protocol::RequestId;
1219        use std::sync::Arc;
1220        use std::sync::atomic::{AtomicI32, Ordering};
1221
1222        #[derive(Debug, Deserialize, JsonSchema)]
1223        struct LongRunningInput {
1224            iterations: i32,
1225        }
1226
1227        let iterations_completed = Arc::new(AtomicI32::new(0));
1228        let iterations_ref = iterations_completed.clone();
1229
1230        let tool = ToolBuilder::new("long_running")
1231            .description("Long running task")
1232            .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
1233                let completed = iterations_ref.clone();
1234                async move {
1235                    for i in 0..input.iterations {
1236                        if ctx.is_cancelled() {
1237                            return Ok(CallToolResult::error("Cancelled"));
1238                        }
1239                        completed.fetch_add(1, Ordering::SeqCst);
1240                        // Simulate work
1241                        tokio::task::yield_now().await;
1242                        // Cancel after iteration 2
1243                        if i == 2 {
1244                            ctx.cancellation_token().cancel();
1245                        }
1246                    }
1247                    Ok(CallToolResult::text("Done"))
1248                }
1249            })
1250            .build()
1251            .expect("valid tool name");
1252
1253        let ctx = RequestContext::new(RequestId::Number(1));
1254
1255        let result = tool
1256            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1257            .await
1258            .unwrap();
1259
1260        // Should have been cancelled after 3 iterations (0, 1, 2)
1261        // The next iteration (3) checks cancellation and returns
1262        assert!(result.is_error);
1263        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1264    }
1265
1266    #[tokio::test]
1267    async fn test_tool_builder_with_enhanced_fields() {
1268        let output_schema = serde_json::json!({
1269            "type": "object",
1270            "properties": {
1271                "greeting": {"type": "string"}
1272            }
1273        });
1274
1275        let tool = ToolBuilder::new("greet")
1276            .title("Greeting Tool")
1277            .description("Greet someone")
1278            .output_schema(output_schema.clone())
1279            .icon("https://example.com/icon.png")
1280            .icon_with_meta(
1281                "https://example.com/icon-large.png",
1282                Some("image/png".to_string()),
1283                Some(vec!["96x96".to_string()]),
1284            )
1285            .handler(|input: GreetInput| async move {
1286                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1287            })
1288            .build()
1289            .expect("valid tool name");
1290
1291        assert_eq!(tool.name, "greet");
1292        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1293        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1294        assert_eq!(tool.output_schema, Some(output_schema));
1295        assert!(tool.icons.is_some());
1296        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1297
1298        // Test definition includes new fields
1299        let def = tool.definition();
1300        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1301        assert!(def.output_schema.is_some());
1302        assert!(def.icons.is_some());
1303    }
1304
1305    #[tokio::test]
1306    async fn test_handler_with_state() {
1307        let shared = Arc::new("shared-state".to_string());
1308
1309        let tool = ToolBuilder::new("stateful")
1310            .description("Uses shared state")
1311            .handler_with_state(shared, |state: Arc<String>, input: GreetInput| async move {
1312                Ok(CallToolResult::text(format!(
1313                    "{}: Hello, {}!",
1314                    state, input.name
1315                )))
1316            })
1317            .build()
1318            .expect("valid tool name");
1319
1320        let result = tool
1321            .call(serde_json::json!({"name": "World"}))
1322            .await
1323            .unwrap();
1324        assert!(!result.is_error);
1325    }
1326
1327    #[tokio::test]
1328    async fn test_handler_with_state_and_context() {
1329        use crate::context::RequestContext;
1330        use crate::protocol::RequestId;
1331
1332        let shared = Arc::new(42_i32);
1333
1334        let tool = ToolBuilder::new("stateful_ctx")
1335            .description("Uses state and context")
1336            .handler_with_state_and_context(
1337                shared,
1338                |state: Arc<i32>, _ctx: RequestContext, input: GreetInput| async move {
1339                    Ok(CallToolResult::text(format!(
1340                        "{}: Hello, {}!",
1341                        state, input.name
1342                    )))
1343                },
1344            )
1345            .build()
1346            .expect("valid tool name");
1347
1348        assert!(tool.uses_context());
1349
1350        let ctx = RequestContext::new(RequestId::Number(1));
1351        let result = tool
1352            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1353            .await
1354            .unwrap();
1355        assert!(!result.is_error);
1356    }
1357
1358    #[tokio::test]
1359    async fn test_handler_no_params() {
1360        let tool = ToolBuilder::new("no_params")
1361            .description("Takes no parameters")
1362            .handler_no_params(|| async { Ok(CallToolResult::text("no params result")) })
1363            .expect("valid tool name");
1364
1365        assert_eq!(tool.name, "no_params");
1366
1367        // Should work with empty args
1368        let result = tool.call(serde_json::json!({})).await.unwrap();
1369        assert!(!result.is_error);
1370
1371        // Should also work with unexpected args (ignored)
1372        let result = tool
1373            .call(serde_json::json!({"unexpected": "value"}))
1374            .await
1375            .unwrap();
1376        assert!(!result.is_error);
1377
1378        // Check input schema is an empty-properties object
1379        let schema = tool.definition().input_schema;
1380        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1381        assert!(
1382            schema
1383                .get("properties")
1384                .unwrap()
1385                .as_object()
1386                .unwrap()
1387                .is_empty()
1388        );
1389    }
1390
1391    #[tokio::test]
1392    async fn test_handler_no_params_with_state() {
1393        let shared = Arc::new("shared_value".to_string());
1394
1395        let tool = ToolBuilder::new("no_params_with_state")
1396            .description("Takes no parameters but has state")
1397            .handler_no_params_with_state(shared, |state: Arc<String>| async move {
1398                Ok(CallToolResult::text(format!("state: {}", state)))
1399            })
1400            .expect("valid tool name");
1401
1402        assert_eq!(tool.name, "no_params_with_state");
1403
1404        // Should work with empty args
1405        let result = tool.call(serde_json::json!({})).await.unwrap();
1406        assert!(!result.is_error);
1407        assert_eq!(result.first_text().unwrap(), "state: shared_value");
1408
1409        // Check input schema is an empty-properties object
1410        let schema = tool.definition().input_schema;
1411        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1412        assert!(
1413            schema
1414                .get("properties")
1415                .unwrap()
1416                .as_object()
1417                .unwrap()
1418                .is_empty()
1419        );
1420    }
1421
1422    #[test]
1423    fn test_no_params_schema() {
1424        // NoParams should produce a schema with type: "object"
1425        let schema = schemars::schema_for!(NoParams);
1426        let schema_value = serde_json::to_value(&schema).unwrap();
1427        assert_eq!(
1428            schema_value.get("type").and_then(|v| v.as_str()),
1429            Some("object"),
1430            "NoParams should generate type: object schema"
1431        );
1432    }
1433
1434    #[test]
1435    fn test_no_params_deserialize() {
1436        // NoParams should deserialize from various inputs
1437        let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1438        assert_eq!(from_empty_object, NoParams);
1439
1440        let from_null: NoParams = serde_json::from_str("null").unwrap();
1441        assert_eq!(from_null, NoParams);
1442
1443        // Should also accept objects with unexpected fields (ignored)
1444        let from_object_with_fields: NoParams =
1445            serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1446        assert_eq!(from_object_with_fields, NoParams);
1447    }
1448
1449    #[tokio::test]
1450    async fn test_no_params_type_in_handler() {
1451        // NoParams can be used as a handler input type
1452        let tool = ToolBuilder::new("status")
1453            .description("Get status")
1454            .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1455            .build()
1456            .expect("valid tool name");
1457
1458        // Check schema has type: object (not type: null like () would produce)
1459        let schema = tool.definition().input_schema;
1460        assert_eq!(
1461            schema.get("type").and_then(|v| v.as_str()),
1462            Some("object"),
1463            "NoParams handler should produce type: object schema"
1464        );
1465
1466        // Should work with empty input
1467        let result = tool.call(serde_json::json!({})).await.unwrap();
1468        assert!(!result.is_error);
1469    }
1470}