Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8//!
9//! ## Per-Tool Middleware
10//!
11//! Tools are implemented as Tower services internally, enabling middleware
12//! composition via the `.layer()` method:
13//!
14//! ```rust
15//! use std::time::Duration;
16//! use tower::timeout::TimeoutLayer;
17//! use tower_mcp::{ToolBuilder, CallToolResult};
18//! use schemars::JsonSchema;
19//! use serde::Deserialize;
20//!
21//! #[derive(Debug, Deserialize, JsonSchema)]
22//! struct SearchInput { query: String }
23//!
24//! let tool = ToolBuilder::new("slow_search")
25//!     .description("Search with extended timeout")
26//!     .handler(|input: SearchInput| async move {
27//!         Ok(CallToolResult::text("result"))
28//!     })
29//!     .layer(TimeoutLayer::new(Duration::from_secs(30)))
30//!     .build()
31//!     .unwrap();
32//! ```
33
34use std::borrow::Cow;
35use std::convert::Infallible;
36use std::fmt;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41
42use schemars::{JsonSchema, Schema, SchemaGenerator};
43use serde::Serialize;
44use serde::de::DeserializeOwned;
45use serde_json::Value;
46use tower::util::BoxCloneService;
47use tower_service::Service;
48
49use crate::context::RequestContext;
50use crate::error::{Error, Result};
51use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
52
53// =============================================================================
54// Service Types for Per-Tool Middleware
55// =============================================================================
56
57/// Request type for tool services.
58///
59/// Contains the request context (for progress reporting, cancellation, etc.)
60/// and the tool arguments as raw JSON.
61#[derive(Debug, Clone)]
62pub struct ToolRequest {
63    /// Request context for progress reporting, cancellation, and client requests
64    pub ctx: RequestContext,
65    /// Tool arguments as raw JSON
66    pub args: Value,
67}
68
69impl ToolRequest {
70    /// Create a new tool request
71    pub fn new(ctx: RequestContext, args: Value) -> Self {
72        Self { ctx, args }
73    }
74}
75
76/// A boxed, cloneable tool service with `Error = Infallible`.
77///
78/// This is the internal service type that tools use. Middleware errors are
79/// caught and converted to `CallToolResult::error()` responses, so the
80/// service never fails at the Tower level.
81pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
82
83/// Catches errors from the inner service and converts them to `CallToolResult::error()`.
84///
85/// This wrapper ensures that middleware errors (e.g., timeouts, rate limits)
86/// and handler errors are converted to tool-level error responses with
87/// `is_error: true`, rather than propagating as Tower service errors.
88pub struct ToolCatchError<S> {
89    inner: S,
90}
91
92impl<S> ToolCatchError<S> {
93    /// Create a new `ToolCatchError` wrapping the given service.
94    pub fn new(inner: S) -> Self {
95        Self { inner }
96    }
97}
98
99impl<S: Clone> Clone for ToolCatchError<S> {
100    fn clone(&self) -> Self {
101        Self {
102            inner: self.inner.clone(),
103        }
104    }
105}
106
107impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        f.debug_struct("ToolCatchError")
110            .field("inner", &self.inner)
111            .finish()
112    }
113}
114
115impl<S> Service<ToolRequest> for ToolCatchError<S>
116where
117    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
118    S::Error: fmt::Display + Send,
119    S::Future: Send,
120{
121    type Response = CallToolResult;
122    type Error = Infallible;
123    type Future =
124        Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
125
126    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
127        // Map any readiness error to Infallible (we catch it on call)
128        match self.inner.poll_ready(cx) {
129            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
130            Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
131            Poll::Pending => Poll::Pending,
132        }
133    }
134
135    fn call(&mut self, req: ToolRequest) -> Self::Future {
136        let fut = self.inner.call(req);
137
138        Box::pin(async move {
139            match fut.await {
140                Ok(result) => Ok(result),
141                Err(err) => Ok(CallToolResult::error(err.to_string())),
142            }
143        })
144    }
145}
146
147/// A marker type for tools that take no parameters.
148///
149/// Use this instead of `()` when defining tools with no input parameters.
150/// The unit type `()` generates `"type": "null"` in JSON Schema, which many
151/// MCP clients reject. `NoParams` generates `"type": "object"` with no
152/// required properties, which is the correct schema for parameterless tools.
153///
154/// # Example
155///
156/// ```rust
157/// use tower_mcp::{ToolBuilder, CallToolResult, NoParams};
158///
159/// let tool = ToolBuilder::new("get_status")
160///     .description("Get current status")
161///     .handler(|_input: NoParams| async move {
162///         Ok(CallToolResult::text("OK"))
163///     })
164///     .build()
165///     .unwrap();
166/// ```
167#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
168pub struct NoParams;
169
170impl<'de> serde::Deserialize<'de> for NoParams {
171    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
172    where
173        D: serde::Deserializer<'de>,
174    {
175        // Accept null, empty object, or any object (ignoring all fields)
176        struct NoParamsVisitor;
177
178        impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
179            type Value = NoParams;
180
181            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
182                formatter.write_str("null or an object")
183            }
184
185            fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
186            where
187                E: serde::de::Error,
188            {
189                Ok(NoParams)
190            }
191
192            fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
193            where
194                E: serde::de::Error,
195            {
196                Ok(NoParams)
197            }
198
199            fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
200            where
201                D: serde::Deserializer<'de>,
202            {
203                serde::Deserialize::deserialize(deserializer)
204            }
205
206            fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
207            where
208                A: serde::de::MapAccess<'de>,
209            {
210                // Drain the map, ignoring all entries
211                while map
212                    .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
213                    .is_some()
214                {}
215                Ok(NoParams)
216            }
217        }
218
219        deserializer.deserialize_any(NoParamsVisitor)
220    }
221}
222
223impl JsonSchema for NoParams {
224    fn schema_name() -> Cow<'static, str> {
225        Cow::Borrowed("NoParams")
226    }
227
228    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
229        serde_json::json!({
230            "type": "object"
231        })
232        .try_into()
233        .expect("valid schema")
234    }
235}
236
237/// Validate a tool name according to MCP spec.
238///
239/// Tool names must be:
240/// - 1-128 characters long
241/// - Contain only alphanumeric characters, underscores, hyphens, and dots
242///
243/// Returns `Ok(())` if valid, `Err` with description if invalid.
244pub fn validate_tool_name(name: &str) -> Result<()> {
245    if name.is_empty() {
246        return Err(Error::tool("Tool name cannot be empty"));
247    }
248    if name.len() > 128 {
249        return Err(Error::tool(format!(
250            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
251            name,
252            name.len()
253        )));
254    }
255    if let Some(invalid_char) = name
256        .chars()
257        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
258    {
259        return Err(Error::tool(format!(
260            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
261            name, invalid_char
262        )));
263    }
264    Ok(())
265}
266
267/// A boxed future for tool handlers
268pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
269
270/// Tool handler trait - the core abstraction for tool execution
271pub trait ToolHandler: Send + Sync {
272    /// Execute the tool with the given arguments
273    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
274
275    /// Execute the tool with request context for progress/cancellation support
276    ///
277    /// The default implementation ignores the context and calls `call`.
278    /// Override this to receive progress/cancellation context.
279    fn call_with_context(
280        &self,
281        _ctx: RequestContext,
282        args: Value,
283    ) -> BoxFuture<'_, Result<CallToolResult>> {
284        self.call(args)
285    }
286
287    /// Returns true if this handler uses context (for optimization)
288    fn uses_context(&self) -> bool {
289        false
290    }
291
292    /// Get the tool's input schema
293    fn input_schema(&self) -> Value;
294}
295
296/// Adapts a `ToolHandler` to a Tower `Service<ToolRequest>`.
297///
298/// This is an internal adapter that bridges the handler abstraction to the
299/// service abstraction, enabling middleware composition.
300pub(crate) struct ToolHandlerService<H> {
301    handler: Arc<H>,
302}
303
304impl<H> ToolHandlerService<H> {
305    pub(crate) fn new(handler: H) -> Self {
306        Self {
307            handler: Arc::new(handler),
308        }
309    }
310}
311
312impl<H> Clone for ToolHandlerService<H> {
313    fn clone(&self) -> Self {
314        Self {
315            handler: self.handler.clone(),
316        }
317    }
318}
319
320impl<H> Service<ToolRequest> for ToolHandlerService<H>
321where
322    H: ToolHandler + 'static,
323{
324    type Response = CallToolResult;
325    type Error = Error;
326    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
327
328    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
329        Poll::Ready(Ok(()))
330    }
331
332    fn call(&mut self, req: ToolRequest) -> Self::Future {
333        let handler = self.handler.clone();
334        Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
335    }
336}
337
338/// A complete tool definition with service-based execution.
339///
340/// Tools are implemented as Tower services internally, enabling middleware
341/// composition via the builder's `.layer()` method. The service is wrapped
342/// in [`ToolCatchError`] to convert any errors (from handlers or middleware)
343/// into `CallToolResult::error()` responses.
344pub struct Tool {
345    /// Tool name (must be 1-128 chars, alphanumeric/underscore/hyphen/dot only)
346    pub name: String,
347    /// Human-readable title for the tool
348    pub title: Option<String>,
349    /// Description of what the tool does
350    pub description: Option<String>,
351    /// JSON Schema for the tool's output (optional)
352    pub output_schema: Option<Value>,
353    /// Icons for the tool
354    pub icons: Option<Vec<ToolIcon>>,
355    /// Tool annotations (hints about behavior)
356    pub annotations: Option<ToolAnnotations>,
357    /// The boxed service that executes the tool
358    pub(crate) service: BoxToolService,
359    /// JSON Schema for the tool's input
360    pub(crate) input_schema: Value,
361}
362
363impl std::fmt::Debug for Tool {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        f.debug_struct("Tool")
366            .field("name", &self.name)
367            .field("title", &self.title)
368            .field("description", &self.description)
369            .field("output_schema", &self.output_schema)
370            .field("icons", &self.icons)
371            .field("annotations", &self.annotations)
372            .finish_non_exhaustive()
373    }
374}
375
376// SAFETY: BoxCloneService is Send + Sync (tower provides unsafe impl Sync),
377// and all other fields in Tool are Send + Sync.
378unsafe impl Send for Tool {}
379unsafe impl Sync for Tool {}
380
381impl Tool {
382    /// Create a new tool builder
383    pub fn builder(name: impl Into<String>) -> ToolBuilder {
384        ToolBuilder::new(name)
385    }
386
387    /// Get the tool definition for tools/list
388    pub fn definition(&self) -> ToolDefinition {
389        ToolDefinition {
390            name: self.name.clone(),
391            title: self.title.clone(),
392            description: self.description.clone(),
393            input_schema: self.input_schema.clone(),
394            output_schema: self.output_schema.clone(),
395            icons: self.icons.clone(),
396            annotations: self.annotations.clone(),
397        }
398    }
399
400    /// Call the tool without context
401    ///
402    /// Creates a dummy request context. For full context support, use
403    /// [`call_with_context`](Self::call_with_context).
404    pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
405        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
406        self.call_with_context(ctx, args)
407    }
408
409    /// Call the tool with request context
410    ///
411    /// The context provides progress reporting, cancellation support, and
412    /// access to client requests (for sampling, etc.).
413    ///
414    /// # Note
415    ///
416    /// This method returns `CallToolResult` directly (not `Result<CallToolResult>`).
417    /// Any errors from the handler or middleware are converted to
418    /// `CallToolResult::error()` with `is_error: true`.
419    pub fn call_with_context(
420        &self,
421        ctx: RequestContext,
422        args: Value,
423    ) -> BoxFuture<'static, CallToolResult> {
424        use tower::ServiceExt;
425        let service = self.service.clone();
426        Box::pin(async move {
427            // ServiceExt::oneshot properly handles poll_ready before call
428            // Service is Infallible, so unwrap is safe
429            service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
430        })
431    }
432
433    /// Create a new tool with a prefixed name.
434    ///
435    /// This creates a copy of the tool with its name prefixed by the given
436    /// string and a dot separator. For example, if the tool is named "query"
437    /// and the prefix is "db", the new tool will be named "db.query".
438    ///
439    /// This is used internally by `McpRouter::nest()` to namespace tools.
440    ///
441    /// # Example
442    ///
443    /// ```rust
444    /// use tower_mcp::{ToolBuilder, CallToolResult};
445    /// use schemars::JsonSchema;
446    /// use serde::Deserialize;
447    ///
448    /// #[derive(Debug, Deserialize, JsonSchema)]
449    /// struct Input { value: String }
450    ///
451    /// let tool = ToolBuilder::new("query")
452    ///     .description("Query the database")
453    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
454    ///     .build()
455    ///     .unwrap();
456    ///
457    /// let prefixed = tool.with_name_prefix("db");
458    /// assert_eq!(prefixed.name, "db.query");
459    /// ```
460    pub fn with_name_prefix(&self, prefix: &str) -> Self {
461        Self {
462            name: format!("{}.{}", prefix, self.name),
463            title: self.title.clone(),
464            description: self.description.clone(),
465            output_schema: self.output_schema.clone(),
466            icons: self.icons.clone(),
467            annotations: self.annotations.clone(),
468            service: self.service.clone(),
469            input_schema: self.input_schema.clone(),
470        }
471    }
472
473    /// Create a tool from a handler (internal helper)
474    fn from_handler<H: ToolHandler + 'static>(
475        name: String,
476        title: Option<String>,
477        description: Option<String>,
478        output_schema: Option<Value>,
479        icons: Option<Vec<ToolIcon>>,
480        annotations: Option<ToolAnnotations>,
481        handler: H,
482    ) -> Self {
483        let input_schema = handler.input_schema();
484        let handler_service = ToolHandlerService::new(handler);
485        let catch_error = ToolCatchError::new(handler_service);
486        let service = BoxCloneService::new(catch_error);
487
488        Self {
489            name,
490            title,
491            description,
492            output_schema,
493            icons,
494            annotations,
495            service,
496            input_schema,
497        }
498    }
499}
500
501// =============================================================================
502// Builder API
503// =============================================================================
504
505/// Builder for creating tools with a fluent API
506///
507/// # Example
508///
509/// ```rust
510/// use tower_mcp::{ToolBuilder, CallToolResult};
511/// use schemars::JsonSchema;
512/// use serde::Deserialize;
513///
514/// #[derive(Debug, Deserialize, JsonSchema)]
515/// struct GreetInput {
516///     name: String,
517/// }
518///
519/// let tool = ToolBuilder::new("greet")
520///     .description("Greet someone by name")
521///     .handler(|input: GreetInput| async move {
522///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
523///     })
524///     .build()
525///     .expect("valid tool name");
526///
527/// assert_eq!(tool.name, "greet");
528/// ```
529pub struct ToolBuilder {
530    name: String,
531    title: Option<String>,
532    description: Option<String>,
533    output_schema: Option<Value>,
534    icons: Option<Vec<ToolIcon>>,
535    annotations: Option<ToolAnnotations>,
536}
537
538impl ToolBuilder {
539    pub fn new(name: impl Into<String>) -> Self {
540        Self {
541            name: name.into(),
542            title: None,
543            description: None,
544            output_schema: None,
545            icons: None,
546            annotations: None,
547        }
548    }
549
550    /// Set a human-readable title for the tool
551    pub fn title(mut self, title: impl Into<String>) -> Self {
552        self.title = Some(title.into());
553        self
554    }
555
556    /// Set the output schema (JSON Schema for structured output)
557    pub fn output_schema(mut self, schema: Value) -> Self {
558        self.output_schema = Some(schema);
559        self
560    }
561
562    /// Add an icon for the tool
563    pub fn icon(mut self, src: impl Into<String>) -> Self {
564        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
565            src: src.into(),
566            mime_type: None,
567            sizes: None,
568        });
569        self
570    }
571
572    /// Add an icon with metadata
573    pub fn icon_with_meta(
574        mut self,
575        src: impl Into<String>,
576        mime_type: Option<String>,
577        sizes: Option<Vec<String>>,
578    ) -> Self {
579        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
580            src: src.into(),
581            mime_type,
582            sizes,
583        });
584        self
585    }
586
587    /// Set the tool description
588    pub fn description(mut self, description: impl Into<String>) -> Self {
589        self.description = Some(description.into());
590        self
591    }
592
593    /// Mark the tool as read-only (does not modify state)
594    pub fn read_only(mut self) -> Self {
595        self.annotations
596            .get_or_insert_with(ToolAnnotations::default)
597            .read_only_hint = true;
598        self
599    }
600
601    /// Mark the tool as non-destructive
602    pub fn non_destructive(mut self) -> Self {
603        self.annotations
604            .get_or_insert_with(ToolAnnotations::default)
605            .destructive_hint = false;
606        self
607    }
608
609    /// Mark the tool as idempotent (same args = same effect)
610    pub fn idempotent(mut self) -> Self {
611        self.annotations
612            .get_or_insert_with(ToolAnnotations::default)
613            .idempotent_hint = true;
614        self
615    }
616
617    /// Set tool annotations directly
618    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
619        self.annotations = Some(annotations);
620        self
621    }
622
623    /// Specify input type and handler.
624    ///
625    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
626    /// The handler receives the deserialized input and returns a `CallToolResult`.
627    ///
628    /// # State Sharing
629    ///
630    /// To share state across tool calls (e.g., database connections, API clients),
631    /// wrap your state in an `Arc` and clone it into the async block:
632    ///
633    /// ```rust
634    /// use std::sync::Arc;
635    /// use tower_mcp::{ToolBuilder, CallToolResult};
636    /// use schemars::JsonSchema;
637    /// use serde::Deserialize;
638    ///
639    /// struct AppState {
640    ///     api_key: String,
641    /// }
642    ///
643    /// #[derive(Debug, Deserialize, JsonSchema)]
644    /// struct MyInput {
645    ///     query: String,
646    /// }
647    ///
648    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
649    ///
650    /// let tool = ToolBuilder::new("my_tool")
651    ///     .description("A tool that uses shared state")
652    ///     .handler(move |input: MyInput| {
653    ///         let state = state.clone(); // Clone Arc for the async block
654    ///         async move {
655    ///             // Use state.api_key here...
656    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
657    ///         }
658    ///     })
659    ///     .build()
660    ///     .expect("valid tool name");
661    /// ```
662    ///
663    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
664    /// cloning it inside the closure body allows each async invocation to
665    /// have its own reference to the shared state.
666    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
667    where
668        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
669        F: Fn(I) -> Fut + Send + Sync + 'static,
670        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
671    {
672        ToolBuilderWithHandler {
673            name: self.name,
674            title: self.title,
675            description: self.description,
676            output_schema: self.output_schema,
677            icons: self.icons,
678            annotations: self.annotations,
679            handler,
680            _phantom: std::marker::PhantomData,
681        }
682    }
683
684    /// Create a tool using the extractor pattern.
685    ///
686    /// This method provides an axum-inspired way to define handlers where state,
687    /// context, and input are extracted declaratively from function parameters.
688    /// This reduces the combinatorial explosion of handler variants like
689    /// `handler_with_state`, `handler_with_context`, etc.
690    ///
691    /// # Extractors
692    ///
693    /// Built-in extractors available in [`crate::extract`]:
694    /// - [`Json<T>`](crate::extract::Json) - Deserialize JSON arguments to type `T`
695    /// - [`State<T>`](crate::extract::State) - Extract cloned state
696    /// - [`Context`](crate::extract::Context) - Extract request context
697    /// - [`RawArgs`](crate::extract::RawArgs) - Extract raw JSON arguments
698    ///
699    /// # Example
700    ///
701    /// ```rust
702    /// use std::sync::Arc;
703    /// use tower_mcp::{ToolBuilder, CallToolResult};
704    /// use tower_mcp::extract::{Json, State, Context};
705    /// use schemars::JsonSchema;
706    /// use serde::Deserialize;
707    ///
708    /// #[derive(Clone)]
709    /// struct Database { url: String }
710    ///
711    /// #[derive(Debug, Deserialize, JsonSchema)]
712    /// struct QueryInput { query: String }
713    ///
714    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
715    ///
716    /// let tool = ToolBuilder::new("search")
717    ///     .description("Search the database")
718    ///     .extractor_handler(db, |
719    ///         State(db): State<Arc<Database>>,
720    ///         ctx: Context,
721    ///         Json(input): Json<QueryInput>,
722    ///     | async move {
723    ///         if ctx.is_cancelled() {
724    ///             return Ok(CallToolResult::error("Cancelled"));
725    ///         }
726    ///         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
727    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
728    ///     })
729    ///     .build()
730    ///     .unwrap();
731    /// ```
732    ///
733    /// # Type Inference
734    ///
735    /// The compiler infers extractor types from the function signature. Make sure
736    /// to annotate the extractor types explicitly in the closure parameters.
737    pub fn extractor_handler<S, F, T>(
738        self,
739        state: S,
740        handler: F,
741    ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
742    where
743        S: Clone + Send + Sync + 'static,
744        F: crate::extract::ExtractorHandler<S, T> + Clone,
745        T: Send + Sync + 'static,
746    {
747        crate::extract::ToolBuilderWithExtractor {
748            name: self.name,
749            title: self.title,
750            description: self.description,
751            output_schema: self.output_schema,
752            icons: self.icons,
753            annotations: self.annotations,
754            state,
755            handler,
756            input_schema: F::input_schema(),
757            _phantom: std::marker::PhantomData,
758        }
759    }
760
761    /// Create a tool using the extractor pattern with typed JSON input.
762    ///
763    /// This is similar to [`extractor_handler`](Self::extractor_handler) but provides
764    /// proper JSON schema generation when using `Json<T>` as an extractor.
765    ///
766    /// # Example
767    ///
768    /// ```rust
769    /// use std::sync::Arc;
770    /// use tower_mcp::{ToolBuilder, CallToolResult};
771    /// use tower_mcp::extract::{Json, State};
772    /// use schemars::JsonSchema;
773    /// use serde::Deserialize;
774    ///
775    /// #[derive(Clone)]
776    /// struct AppState { prefix: String }
777    ///
778    /// #[derive(Debug, Deserialize, JsonSchema)]
779    /// struct GreetInput { name: String }
780    ///
781    /// let state = Arc::new(AppState { prefix: "Hello".to_string() });
782    ///
783    /// let tool = ToolBuilder::new("greet")
784    ///     .description("Greet someone")
785    ///     .extractor_handler_typed::<_, _, _, GreetInput>(state, |
786    ///         State(app): State<Arc<AppState>>,
787    ///         Json(input): Json<GreetInput>,
788    ///     | async move {
789    ///         Ok(CallToolResult::text(format!("{}, {}!", app.prefix, input.name)))
790    ///     })
791    ///     .build()
792    ///     .unwrap();
793    /// ```
794    pub fn extractor_handler_typed<S, F, T, I>(
795        self,
796        state: S,
797        handler: F,
798    ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
799    where
800        S: Clone + Send + Sync + 'static,
801        F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
802        T: Send + Sync + 'static,
803        I: schemars::JsonSchema + Send + Sync + 'static,
804    {
805        crate::extract::ToolBuilderWithTypedExtractor {
806            name: self.name,
807            title: self.title,
808            description: self.description,
809            output_schema: self.output_schema,
810            icons: self.icons,
811            annotations: self.annotations,
812            state,
813            handler,
814            _phantom: std::marker::PhantomData,
815        }
816    }
817}
818
819/// Builder state after handler is specified
820pub struct ToolBuilderWithHandler<I, F> {
821    name: String,
822    title: Option<String>,
823    description: Option<String>,
824    output_schema: Option<Value>,
825    icons: Option<Vec<ToolIcon>>,
826    annotations: Option<ToolAnnotations>,
827    handler: F,
828    _phantom: std::marker::PhantomData<I>,
829}
830
831impl<I, F, Fut> ToolBuilderWithHandler<I, F>
832where
833    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
834    F: Fn(I) -> Fut + Send + Sync + 'static,
835    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
836{
837    /// Build the tool
838    ///
839    /// Returns an error if the tool name is invalid.
840    pub fn build(self) -> Result<Tool> {
841        validate_tool_name(&self.name)?;
842        Ok(Tool::from_handler(
843            self.name,
844            self.title,
845            self.description,
846            self.output_schema,
847            self.icons,
848            self.annotations,
849            TypedHandler {
850                handler: self.handler,
851                _phantom: std::marker::PhantomData,
852            },
853        ))
854    }
855
856    /// Apply a Tower layer (middleware) to this tool.
857    ///
858    /// The layer wraps the tool's handler service, enabling functionality like
859    /// timeouts, rate limiting, and metrics collection at the per-tool level.
860    ///
861    /// # Example
862    ///
863    /// ```rust
864    /// use std::time::Duration;
865    /// use tower::timeout::TimeoutLayer;
866    /// use tower_mcp::{ToolBuilder, CallToolResult};
867    /// use schemars::JsonSchema;
868    /// use serde::Deserialize;
869    ///
870    /// #[derive(Debug, Deserialize, JsonSchema)]
871    /// struct Input { query: String }
872    ///
873    /// let tool = ToolBuilder::new("search")
874    ///     .description("Search with timeout")
875    ///     .handler(|input: Input| async move {
876    ///         Ok(CallToolResult::text("result"))
877    ///     })
878    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
879    ///     .build()
880    ///     .unwrap();
881    /// ```
882    pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
883        ToolBuilderWithLayer {
884            name: self.name,
885            title: self.title,
886            description: self.description,
887            output_schema: self.output_schema,
888            icons: self.icons,
889            annotations: self.annotations,
890            handler: self.handler,
891            layer,
892            _phantom: std::marker::PhantomData,
893        }
894    }
895}
896
897/// Builder state after a layer has been applied to the handler.
898///
899/// This builder allows chaining additional layers and building the final tool.
900pub struct ToolBuilderWithLayer<I, F, L> {
901    name: String,
902    title: Option<String>,
903    description: Option<String>,
904    output_schema: Option<Value>,
905    icons: Option<Vec<ToolIcon>>,
906    annotations: Option<ToolAnnotations>,
907    handler: F,
908    layer: L,
909    _phantom: std::marker::PhantomData<I>,
910}
911
912// Allow private_bounds because these internal types (ToolHandlerService, TypedHandler, etc.)
913// are implementation details that users don't interact with directly.
914#[allow(private_bounds)]
915impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
916where
917    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
918    F: Fn(I) -> Fut + Send + Sync + 'static,
919    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
920    L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
921    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
922    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
923    <L::Service as Service<ToolRequest>>::Future: Send,
924{
925    /// Build the tool with the applied layer(s).
926    ///
927    /// Returns an error if the tool name is invalid.
928    pub fn build(self) -> Result<Tool> {
929        validate_tool_name(&self.name)?;
930
931        let input_schema = schemars::schema_for!(I);
932        let input_schema = serde_json::to_value(input_schema)
933            .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
934
935        let handler_service = ToolHandlerService::new(TypedHandler {
936            handler: self.handler,
937            _phantom: std::marker::PhantomData,
938        });
939        let layered = self.layer.layer(handler_service);
940        let catch_error = ToolCatchError::new(layered);
941        let service = BoxCloneService::new(catch_error);
942
943        Ok(Tool {
944            name: self.name,
945            title: self.title,
946            description: self.description,
947            output_schema: self.output_schema,
948            icons: self.icons,
949            annotations: self.annotations,
950            service,
951            input_schema,
952        })
953    }
954
955    /// Apply an additional Tower layer (middleware).
956    ///
957    /// Layers are applied in order, with earlier layers wrapping later ones.
958    /// This means the first layer added is the outermost middleware.
959    pub fn layer<L2>(
960        self,
961        layer: L2,
962    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
963        ToolBuilderWithLayer {
964            name: self.name,
965            title: self.title,
966            description: self.description,
967            output_schema: self.output_schema,
968            icons: self.icons,
969            annotations: self.annotations,
970            handler: self.handler,
971            layer: tower::layer::util::Stack::new(layer, self.layer),
972            _phantom: std::marker::PhantomData,
973        }
974    }
975}
976
977// =============================================================================
978// Handler implementations
979// =============================================================================
980
981/// Handler that deserializes input to a specific type
982struct TypedHandler<I, F> {
983    handler: F,
984    _phantom: std::marker::PhantomData<I>,
985}
986
987impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
988where
989    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
990    F: Fn(I) -> Fut + Send + Sync + 'static,
991    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
992{
993    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
994        Box::pin(async move {
995            let input: I = serde_json::from_value(args)
996                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
997            (self.handler)(input).await
998        })
999    }
1000
1001    fn input_schema(&self) -> Value {
1002        let schema = schemars::schema_for!(I);
1003        serde_json::to_value(schema).unwrap_or_else(|_| {
1004            serde_json::json!({
1005                "type": "object"
1006            })
1007        })
1008    }
1009}
1010
1011// =============================================================================
1012// Trait-based tool definition
1013// =============================================================================
1014
1015/// Trait for defining tools with full control
1016///
1017/// Implement this trait when you need more control than the builder provides,
1018/// or when you want to define tools as standalone types.
1019///
1020/// # Example
1021///
1022/// ```rust
1023/// use tower_mcp::tool::McpTool;
1024/// use tower_mcp::error::Result;
1025/// use schemars::JsonSchema;
1026/// use serde::{Deserialize, Serialize};
1027///
1028/// #[derive(Debug, Deserialize, JsonSchema)]
1029/// struct AddInput {
1030///     a: i64,
1031///     b: i64,
1032/// }
1033///
1034/// struct AddTool;
1035///
1036/// impl McpTool for AddTool {
1037///     const NAME: &'static str = "add";
1038///     const DESCRIPTION: &'static str = "Add two numbers";
1039///
1040///     type Input = AddInput;
1041///     type Output = i64;
1042///
1043///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
1044///         Ok(input.a + input.b)
1045///     }
1046/// }
1047///
1048/// let tool = AddTool.into_tool().expect("valid tool name");
1049/// assert_eq!(tool.name, "add");
1050/// ```
1051pub trait McpTool: Send + Sync + 'static {
1052    const NAME: &'static str;
1053    const DESCRIPTION: &'static str;
1054
1055    type Input: JsonSchema + DeserializeOwned + Send;
1056    type Output: Serialize + Send;
1057
1058    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1059
1060    /// Optional annotations for the tool
1061    fn annotations(&self) -> Option<ToolAnnotations> {
1062        None
1063    }
1064
1065    /// Convert to a Tool instance
1066    ///
1067    /// Returns an error if the tool name is invalid.
1068    fn into_tool(self) -> Result<Tool>
1069    where
1070        Self: Sized,
1071    {
1072        validate_tool_name(Self::NAME)?;
1073        let annotations = self.annotations();
1074        let tool = Arc::new(self);
1075        Ok(Tool::from_handler(
1076            Self::NAME.to_string(),
1077            None,
1078            Some(Self::DESCRIPTION.to_string()),
1079            None,
1080            None,
1081            annotations,
1082            McpToolHandler { tool },
1083        ))
1084    }
1085}
1086
1087/// Wrapper to make McpTool implement ToolHandler
1088struct McpToolHandler<T: McpTool> {
1089    tool: Arc<T>,
1090}
1091
1092impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1093    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1094        let tool = self.tool.clone();
1095        Box::pin(async move {
1096            let input: T::Input = serde_json::from_value(args)
1097                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1098            let output = tool.call(input).await?;
1099            let value = serde_json::to_value(output)
1100                .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
1101            Ok(CallToolResult::json(value))
1102        })
1103    }
1104
1105    fn input_schema(&self) -> Value {
1106        let schema = schemars::schema_for!(T::Input);
1107        serde_json::to_value(schema).unwrap_or_else(|_| {
1108            serde_json::json!({
1109                "type": "object"
1110            })
1111        })
1112    }
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118    use crate::extract::{Context, Json, RawArgs, State};
1119    use crate::protocol::Content;
1120    use schemars::JsonSchema;
1121    use serde::Deserialize;
1122
1123    #[derive(Debug, Deserialize, JsonSchema)]
1124    struct GreetInput {
1125        name: String,
1126    }
1127
1128    #[tokio::test]
1129    async fn test_builder_tool() {
1130        let tool = ToolBuilder::new("greet")
1131            .description("Greet someone")
1132            .handler(|input: GreetInput| async move {
1133                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1134            })
1135            .build()
1136            .expect("valid tool name");
1137
1138        assert_eq!(tool.name, "greet");
1139        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1140
1141        let result = tool.call(serde_json::json!({"name": "World"})).await;
1142
1143        assert!(!result.is_error);
1144    }
1145
1146    #[tokio::test]
1147    async fn test_raw_handler() {
1148        let tool = ToolBuilder::new("echo")
1149            .description("Echo input")
1150            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1151                Ok(CallToolResult::json(args))
1152            })
1153            .build()
1154            .expect("valid tool name");
1155
1156        let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1157
1158        assert!(!result.is_error);
1159    }
1160
1161    #[test]
1162    fn test_invalid_tool_name_empty() {
1163        let result = ToolBuilder::new("")
1164            .description("Empty name")
1165            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1166                Ok(CallToolResult::json(args))
1167            })
1168            .build();
1169
1170        assert!(result.is_err());
1171        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1172    }
1173
1174    #[test]
1175    fn test_invalid_tool_name_too_long() {
1176        let long_name = "a".repeat(129);
1177        let result = ToolBuilder::new(long_name)
1178            .description("Too long")
1179            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1180                Ok(CallToolResult::json(args))
1181            })
1182            .build();
1183
1184        assert!(result.is_err());
1185        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1186    }
1187
1188    #[test]
1189    fn test_invalid_tool_name_bad_chars() {
1190        let result = ToolBuilder::new("my tool!")
1191            .description("Bad chars")
1192            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1193                Ok(CallToolResult::json(args))
1194            })
1195            .build();
1196
1197        assert!(result.is_err());
1198        assert!(
1199            result
1200                .unwrap_err()
1201                .to_string()
1202                .contains("invalid character")
1203        );
1204    }
1205
1206    #[test]
1207    fn test_valid_tool_names() {
1208        // All valid characters
1209        let names = [
1210            "my_tool",
1211            "my-tool",
1212            "my.tool",
1213            "MyTool123",
1214            "a",
1215            &"a".repeat(128),
1216        ];
1217        for name in names {
1218            let result = ToolBuilder::new(name)
1219                .description("Valid")
1220                .extractor_handler((), |RawArgs(args): RawArgs| async move {
1221                    Ok(CallToolResult::json(args))
1222                })
1223                .build();
1224            assert!(result.is_ok(), "Expected '{}' to be valid", name);
1225        }
1226    }
1227
1228    #[tokio::test]
1229    async fn test_context_aware_handler() {
1230        use crate::context::notification_channel;
1231        use crate::protocol::{ProgressToken, RequestId};
1232
1233        #[derive(Debug, Deserialize, JsonSchema)]
1234        struct ProcessInput {
1235            count: i32,
1236        }
1237
1238        let tool = ToolBuilder::new("process")
1239            .description("Process with context")
1240            .extractor_handler_typed::<_, _, _, ProcessInput>(
1241                (),
1242                |ctx: Context, Json(input): Json<ProcessInput>| async move {
1243                    // Simulate progress reporting
1244                    for i in 0..input.count {
1245                        if ctx.is_cancelled() {
1246                            return Ok(CallToolResult::error("Cancelled"));
1247                        }
1248                        ctx.report_progress(i as f64, Some(input.count as f64), None)
1249                            .await;
1250                    }
1251                    Ok(CallToolResult::text(format!(
1252                        "Processed {} items",
1253                        input.count
1254                    )))
1255                },
1256            )
1257            .build()
1258            .expect("valid tool name");
1259
1260        assert_eq!(tool.name, "process");
1261
1262        // Test with a context that has progress token and notification sender
1263        let (tx, mut rx) = notification_channel(10);
1264        let ctx = RequestContext::new(RequestId::Number(1))
1265            .with_progress_token(ProgressToken::Number(42))
1266            .with_notification_sender(tx);
1267
1268        let result = tool
1269            .call_with_context(ctx, serde_json::json!({"count": 3}))
1270            .await;
1271
1272        assert!(!result.is_error);
1273
1274        // Check that progress notifications were sent
1275        let mut progress_count = 0;
1276        while rx.try_recv().is_ok() {
1277            progress_count += 1;
1278        }
1279        assert_eq!(progress_count, 3);
1280    }
1281
1282    #[tokio::test]
1283    async fn test_context_aware_handler_cancellation() {
1284        use crate::protocol::RequestId;
1285        use std::sync::atomic::{AtomicI32, Ordering};
1286
1287        #[derive(Debug, Deserialize, JsonSchema)]
1288        struct LongRunningInput {
1289            iterations: i32,
1290        }
1291
1292        let iterations_completed = Arc::new(AtomicI32::new(0));
1293        let iterations_ref = iterations_completed.clone();
1294
1295        let tool = ToolBuilder::new("long_running")
1296            .description("Long running task")
1297            .extractor_handler_typed::<_, _, _, LongRunningInput>(
1298                (),
1299                move |ctx: Context, Json(input): Json<LongRunningInput>| {
1300                    let completed = iterations_ref.clone();
1301                    async move {
1302                        for i in 0..input.iterations {
1303                            if ctx.is_cancelled() {
1304                                return Ok(CallToolResult::error("Cancelled"));
1305                            }
1306                            completed.fetch_add(1, Ordering::SeqCst);
1307                            // Simulate work
1308                            tokio::task::yield_now().await;
1309                            // Cancel after iteration 2
1310                            if i == 2 {
1311                                ctx.cancellation_token().cancel();
1312                            }
1313                        }
1314                        Ok(CallToolResult::text("Done"))
1315                    }
1316                },
1317            )
1318            .build()
1319            .expect("valid tool name");
1320
1321        let ctx = RequestContext::new(RequestId::Number(1));
1322
1323        let result = tool
1324            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1325            .await;
1326
1327        // Should have been cancelled after 3 iterations (0, 1, 2)
1328        // The next iteration (3) checks cancellation and returns
1329        assert!(result.is_error);
1330        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1331    }
1332
1333    #[tokio::test]
1334    async fn test_tool_builder_with_enhanced_fields() {
1335        let output_schema = serde_json::json!({
1336            "type": "object",
1337            "properties": {
1338                "greeting": {"type": "string"}
1339            }
1340        });
1341
1342        let tool = ToolBuilder::new("greet")
1343            .title("Greeting Tool")
1344            .description("Greet someone")
1345            .output_schema(output_schema.clone())
1346            .icon("https://example.com/icon.png")
1347            .icon_with_meta(
1348                "https://example.com/icon-large.png",
1349                Some("image/png".to_string()),
1350                Some(vec!["96x96".to_string()]),
1351            )
1352            .handler(|input: GreetInput| async move {
1353                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1354            })
1355            .build()
1356            .expect("valid tool name");
1357
1358        assert_eq!(tool.name, "greet");
1359        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1360        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1361        assert_eq!(tool.output_schema, Some(output_schema));
1362        assert!(tool.icons.is_some());
1363        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1364
1365        // Test definition includes new fields
1366        let def = tool.definition();
1367        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1368        assert!(def.output_schema.is_some());
1369        assert!(def.icons.is_some());
1370    }
1371
1372    #[tokio::test]
1373    async fn test_handler_with_state() {
1374        let shared = Arc::new("shared-state".to_string());
1375
1376        let tool = ToolBuilder::new("stateful")
1377            .description("Uses shared state")
1378            .extractor_handler_typed::<_, _, _, GreetInput>(
1379                shared,
1380                |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1381                    Ok(CallToolResult::text(format!(
1382                        "{}: Hello, {}!",
1383                        state, input.name
1384                    )))
1385                },
1386            )
1387            .build()
1388            .expect("valid tool name");
1389
1390        let result = tool.call(serde_json::json!({"name": "World"})).await;
1391        assert!(!result.is_error);
1392    }
1393
1394    #[tokio::test]
1395    async fn test_handler_with_state_and_context() {
1396        use crate::protocol::RequestId;
1397
1398        let shared = Arc::new(42_i32);
1399
1400        let tool =
1401            ToolBuilder::new("stateful_ctx")
1402                .description("Uses state and context")
1403                .extractor_handler_typed::<_, _, _, GreetInput>(
1404                    shared,
1405                    |State(state): State<Arc<i32>>,
1406                     _ctx: Context,
1407                     Json(input): Json<GreetInput>| async move {
1408                        Ok(CallToolResult::text(format!(
1409                            "{}: Hello, {}!",
1410                            state, input.name
1411                        )))
1412                    },
1413                )
1414                .build()
1415                .expect("valid tool name");
1416
1417        let ctx = RequestContext::new(RequestId::Number(1));
1418        let result = tool
1419            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1420            .await;
1421        assert!(!result.is_error);
1422    }
1423
1424    #[tokio::test]
1425    async fn test_handler_no_params() {
1426        let tool = ToolBuilder::new("no_params")
1427            .description("Takes no parameters")
1428            .extractor_handler_typed::<_, _, _, NoParams>((), |Json(_): Json<NoParams>| async {
1429                Ok(CallToolResult::text("no params result"))
1430            })
1431            .build()
1432            .expect("valid tool name");
1433
1434        assert_eq!(tool.name, "no_params");
1435
1436        // Should work with empty args
1437        let result = tool.call(serde_json::json!({})).await;
1438        assert!(!result.is_error);
1439
1440        // Should also work with unexpected args (ignored)
1441        let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1442        assert!(!result.is_error);
1443
1444        // Check input schema includes type: object
1445        let schema = tool.definition().input_schema;
1446        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1447    }
1448
1449    #[tokio::test]
1450    async fn test_handler_with_state_no_params() {
1451        let shared = Arc::new("shared_value".to_string());
1452
1453        let tool = ToolBuilder::new("with_state_no_params")
1454            .description("Takes no parameters but has state")
1455            .extractor_handler_typed::<_, _, _, NoParams>(
1456                shared,
1457                |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1458                    Ok(CallToolResult::text(format!("state: {}", state)))
1459                },
1460            )
1461            .build()
1462            .expect("valid tool name");
1463
1464        assert_eq!(tool.name, "with_state_no_params");
1465
1466        // Should work with empty args
1467        let result = tool.call(serde_json::json!({})).await;
1468        assert!(!result.is_error);
1469        assert_eq!(result.first_text().unwrap(), "state: shared_value");
1470
1471        // Check input schema includes type: object
1472        let schema = tool.definition().input_schema;
1473        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1474    }
1475
1476    #[tokio::test]
1477    async fn test_handler_no_params_with_context() {
1478        let tool = ToolBuilder::new("no_params_with_context")
1479            .description("Takes no parameters but has context")
1480            .extractor_handler_typed::<_, _, _, NoParams>(
1481                (),
1482                |_ctx: Context, Json(_): Json<NoParams>| async move {
1483                    Ok(CallToolResult::text("context available"))
1484                },
1485            )
1486            .build()
1487            .expect("valid tool name");
1488
1489        assert_eq!(tool.name, "no_params_with_context");
1490
1491        let result = tool.call(serde_json::json!({})).await;
1492        assert!(!result.is_error);
1493        assert_eq!(result.first_text().unwrap(), "context available");
1494    }
1495
1496    #[tokio::test]
1497    async fn test_handler_with_state_and_context_no_params() {
1498        let shared = Arc::new("shared".to_string());
1499
1500        let tool = ToolBuilder::new("state_context_no_params")
1501            .description("Has state and context, no params")
1502            .extractor_handler_typed::<_, _, _, NoParams>(
1503                shared,
1504                |State(state): State<Arc<String>>,
1505                 _ctx: Context,
1506                 Json(_): Json<NoParams>| async move {
1507                    Ok(CallToolResult::text(format!("state: {}", state)))
1508                },
1509            )
1510            .build()
1511            .expect("valid tool name");
1512
1513        assert_eq!(tool.name, "state_context_no_params");
1514
1515        let result = tool.call(serde_json::json!({})).await;
1516        assert!(!result.is_error);
1517        assert_eq!(result.first_text().unwrap(), "state: shared");
1518    }
1519
1520    #[tokio::test]
1521    async fn test_raw_handler_with_state() {
1522        let prefix = Arc::new("prefix:".to_string());
1523
1524        let tool = ToolBuilder::new("raw_with_state")
1525            .description("Raw handler with state")
1526            .extractor_handler(
1527                prefix,
1528                |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
1529                    Ok(CallToolResult::text(format!("{} {}", state, args)))
1530                },
1531            )
1532            .build()
1533            .expect("valid tool name");
1534
1535        assert_eq!(tool.name, "raw_with_state");
1536
1537        let result = tool.call(serde_json::json!({"key": "value"})).await;
1538        assert!(!result.is_error);
1539        assert!(result.first_text().unwrap().starts_with("prefix:"));
1540    }
1541
1542    #[tokio::test]
1543    async fn test_raw_handler_with_state_and_context() {
1544        let prefix = Arc::new("prefix:".to_string());
1545
1546        let tool = ToolBuilder::new("raw_state_context")
1547            .description("Raw handler with state and context")
1548            .extractor_handler(
1549                prefix,
1550                |State(state): State<Arc<String>>,
1551                 _ctx: Context,
1552                 RawArgs(args): RawArgs| async move {
1553                    Ok(CallToolResult::text(format!("{} {}", state, args)))
1554                },
1555            )
1556            .build()
1557            .expect("valid tool name");
1558
1559        assert_eq!(tool.name, "raw_state_context");
1560
1561        let result = tool.call(serde_json::json!({"key": "value"})).await;
1562        assert!(!result.is_error);
1563        assert!(result.first_text().unwrap().starts_with("prefix:"));
1564    }
1565
1566    #[tokio::test]
1567    async fn test_tool_with_timeout_layer() {
1568        use std::time::Duration;
1569        use tower::timeout::TimeoutLayer;
1570
1571        #[derive(Debug, Deserialize, JsonSchema)]
1572        struct SlowInput {
1573            delay_ms: u64,
1574        }
1575
1576        // Create a tool with a short timeout
1577        let tool = ToolBuilder::new("slow_tool")
1578            .description("A slow tool")
1579            .handler(|input: SlowInput| async move {
1580                tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
1581                Ok(CallToolResult::text("completed"))
1582            })
1583            .layer(TimeoutLayer::new(Duration::from_millis(50)))
1584            .build()
1585            .expect("valid tool name");
1586
1587        // Fast call should succeed
1588        let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
1589        assert!(!result.is_error);
1590        assert_eq!(result.first_text().unwrap(), "completed");
1591
1592        // Slow call should timeout and return an error result
1593        let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
1594        assert!(result.is_error);
1595        // Tower's timeout error message is "request timed out"
1596        let msg = result.first_text().unwrap().to_lowercase();
1597        assert!(
1598            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1599            "Expected timeout error, got: {}",
1600            msg
1601        );
1602    }
1603
1604    #[tokio::test]
1605    async fn test_tool_with_concurrency_limit_layer() {
1606        use std::sync::atomic::{AtomicU32, Ordering};
1607        use std::time::Duration;
1608        use tower::limit::ConcurrencyLimitLayer;
1609
1610        #[derive(Debug, Deserialize, JsonSchema)]
1611        struct WorkInput {
1612            id: u32,
1613        }
1614
1615        let max_concurrent = Arc::new(AtomicU32::new(0));
1616        let current_concurrent = Arc::new(AtomicU32::new(0));
1617        let max_ref = max_concurrent.clone();
1618        let current_ref = current_concurrent.clone();
1619
1620        // Create a tool with concurrency limit of 2
1621        let tool = ToolBuilder::new("concurrent_tool")
1622            .description("A concurrent tool")
1623            .handler(move |input: WorkInput| {
1624                let max = max_ref.clone();
1625                let current = current_ref.clone();
1626                async move {
1627                    // Track concurrency
1628                    let prev = current.fetch_add(1, Ordering::SeqCst);
1629                    max.fetch_max(prev + 1, Ordering::SeqCst);
1630
1631                    // Simulate work
1632                    tokio::time::sleep(Duration::from_millis(50)).await;
1633
1634                    current.fetch_sub(1, Ordering::SeqCst);
1635                    Ok(CallToolResult::text(format!("completed {}", input.id)))
1636                }
1637            })
1638            .layer(ConcurrencyLimitLayer::new(2))
1639            .build()
1640            .expect("valid tool name");
1641
1642        // Launch 4 concurrent calls
1643        let handles: Vec<_> = (0..4)
1644            .map(|i| {
1645                let t = tool.call(serde_json::json!({"id": i}));
1646                tokio::spawn(t)
1647            })
1648            .collect();
1649
1650        for handle in handles {
1651            let result = handle.await.unwrap();
1652            assert!(!result.is_error);
1653        }
1654
1655        // Max concurrent should not exceed 2
1656        assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
1657    }
1658
1659    #[tokio::test]
1660    async fn test_tool_with_multiple_layers() {
1661        use std::time::Duration;
1662        use tower::limit::ConcurrencyLimitLayer;
1663        use tower::timeout::TimeoutLayer;
1664
1665        #[derive(Debug, Deserialize, JsonSchema)]
1666        struct Input {
1667            value: String,
1668        }
1669
1670        // Create a tool with multiple layers stacked
1671        let tool = ToolBuilder::new("multi_layer_tool")
1672            .description("Tool with multiple layers")
1673            .handler(|input: Input| async move {
1674                Ok(CallToolResult::text(format!("processed: {}", input.value)))
1675            })
1676            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1677            .layer(ConcurrencyLimitLayer::new(10))
1678            .build()
1679            .expect("valid tool name");
1680
1681        let result = tool.call(serde_json::json!({"value": "test"})).await;
1682        assert!(!result.is_error);
1683        assert_eq!(result.first_text().unwrap(), "processed: test");
1684    }
1685
1686    #[test]
1687    fn test_tool_catch_error_clone() {
1688        // ToolCatchError should be Clone when inner is Clone
1689        // Use a simple tool that we can clone
1690        let tool = ToolBuilder::new("test")
1691            .description("test")
1692            .extractor_handler((), |RawArgs(_args): RawArgs| async {
1693                Ok(CallToolResult::text("ok"))
1694            })
1695            .build()
1696            .unwrap();
1697        // The tool contains a BoxToolService which is cloneable
1698        let _clone = tool.call(serde_json::json!({}));
1699    }
1700
1701    #[test]
1702    fn test_tool_catch_error_debug() {
1703        // ToolCatchError implements Debug when inner implements Debug
1704        // Since our internal services don't require Debug, just verify
1705        // that ToolCatchError has a Debug impl for appropriate types
1706        #[derive(Debug, Clone)]
1707        struct DebugService;
1708
1709        impl Service<ToolRequest> for DebugService {
1710            type Response = CallToolResult;
1711            type Error = crate::error::Error;
1712            type Future = Pin<
1713                Box<
1714                    dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
1715                        + Send,
1716                >,
1717            >;
1718
1719            fn poll_ready(
1720                &mut self,
1721                _cx: &mut std::task::Context<'_>,
1722            ) -> Poll<std::result::Result<(), Self::Error>> {
1723                Poll::Ready(Ok(()))
1724            }
1725
1726            fn call(&mut self, _req: ToolRequest) -> Self::Future {
1727                Box::pin(async { Ok(CallToolResult::text("ok")) })
1728            }
1729        }
1730
1731        let catch_error = ToolCatchError::new(DebugService);
1732        let debug = format!("{:?}", catch_error);
1733        assert!(debug.contains("ToolCatchError"));
1734    }
1735
1736    #[test]
1737    fn test_tool_request_new() {
1738        use crate::protocol::RequestId;
1739
1740        let ctx = RequestContext::new(RequestId::Number(42));
1741        let args = serde_json::json!({"key": "value"});
1742        let req = ToolRequest::new(ctx.clone(), args.clone());
1743
1744        assert_eq!(req.args, args);
1745    }
1746
1747    #[test]
1748    fn test_no_params_schema() {
1749        // NoParams should produce a schema with type: "object"
1750        let schema = schemars::schema_for!(NoParams);
1751        let schema_value = serde_json::to_value(&schema).unwrap();
1752        assert_eq!(
1753            schema_value.get("type").and_then(|v| v.as_str()),
1754            Some("object"),
1755            "NoParams should generate type: object schema"
1756        );
1757    }
1758
1759    #[test]
1760    fn test_no_params_deserialize() {
1761        // NoParams should deserialize from various inputs
1762        let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1763        assert_eq!(from_empty_object, NoParams);
1764
1765        let from_null: NoParams = serde_json::from_str("null").unwrap();
1766        assert_eq!(from_null, NoParams);
1767
1768        // Should also accept objects with unexpected fields (ignored)
1769        let from_object_with_fields: NoParams =
1770            serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1771        assert_eq!(from_object_with_fields, NoParams);
1772    }
1773
1774    #[tokio::test]
1775    async fn test_no_params_type_in_handler() {
1776        // NoParams can be used as a handler input type
1777        let tool = ToolBuilder::new("status")
1778            .description("Get status")
1779            .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1780            .build()
1781            .expect("valid tool name");
1782
1783        // Check schema has type: object (not type: null like () would produce)
1784        let schema = tool.definition().input_schema;
1785        assert_eq!(
1786            schema.get("type").and_then(|v| v.as_str()),
1787            Some("object"),
1788            "NoParams handler should produce type: object schema"
1789        );
1790
1791        // Should work with empty input
1792        let result = tool.call(serde_json::json!({})).await;
1793        assert!(!result.is_error);
1794    }
1795
1796    #[tokio::test]
1797    async fn test_tool_with_name_prefix() {
1798        #[derive(Debug, Deserialize, JsonSchema)]
1799        struct Input {
1800            value: String,
1801        }
1802
1803        let tool = ToolBuilder::new("query")
1804            .description("Query something")
1805            .title("Query Tool")
1806            .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
1807            .build()
1808            .expect("valid tool name");
1809
1810        // Create prefixed version
1811        let prefixed = tool.with_name_prefix("db");
1812
1813        // Check name is prefixed
1814        assert_eq!(prefixed.name, "db.query");
1815
1816        // Check other fields are preserved
1817        assert_eq!(prefixed.description.as_deref(), Some("Query something"));
1818        assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
1819
1820        // Check the tool still works
1821        let result = prefixed
1822            .call(serde_json::json!({"value": "test input"}))
1823            .await;
1824        assert!(!result.is_error);
1825        match &result.content[0] {
1826            Content::Text { text, .. } => assert_eq!(text, "test input"),
1827            _ => panic!("Expected text content"),
1828        }
1829    }
1830
1831    #[tokio::test]
1832    async fn test_tool_with_name_prefix_multiple_levels() {
1833        let tool = ToolBuilder::new("action")
1834            .description("Do something")
1835            .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
1836            .build()
1837            .expect("valid tool name");
1838
1839        // Apply multiple prefixes
1840        let prefixed = tool.with_name_prefix("level1");
1841        assert_eq!(prefixed.name, "level1.action");
1842
1843        let double_prefixed = prefixed.with_name_prefix("level0");
1844        assert_eq!(double_prefixed.name, "level0.level1.action");
1845    }
1846}