Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8//!
9//! ## Per-Tool Middleware
10//!
11//! Tools are implemented as Tower services internally, enabling middleware
12//! composition via the `.layer()` method:
13//!
14//! ```rust
15//! use std::time::Duration;
16//! use tower::timeout::TimeoutLayer;
17//! use tower_mcp::{ToolBuilder, CallToolResult};
18//! use schemars::JsonSchema;
19//! use serde::Deserialize;
20//!
21//! #[derive(Debug, Deserialize, JsonSchema)]
22//! struct SearchInput { query: String }
23//!
24//! let tool = ToolBuilder::new("slow_search")
25//!     .description("Search with extended timeout")
26//!     .handler(|input: SearchInput| async move {
27//!         Ok(CallToolResult::text("result"))
28//!     })
29//!     .layer(TimeoutLayer::new(Duration::from_secs(30)))
30//!     .build();
31//! ```
32
33use std::borrow::Cow;
34use std::convert::Infallible;
35use std::fmt;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use pin_project_lite::pin_project;
42
43use schemars::{JsonSchema, Schema, SchemaGenerator};
44use serde::Serialize;
45use serde::de::DeserializeOwned;
46use serde_json::Value;
47use tower::util::BoxCloneService;
48use tower_service::Service;
49
50use crate::context::RequestContext;
51use crate::error::{Error, Result, ResultExt};
52use crate::protocol::{
53    CallToolResult, TaskSupportMode, ToolAnnotations, ToolDefinition, ToolExecution, ToolIcon,
54};
55
56// =============================================================================
57// Service Types for Per-Tool Middleware
58// =============================================================================
59
60/// Request type for tool services.
61///
62/// Contains the request context (for progress reporting, cancellation, etc.)
63/// and the tool arguments as raw JSON.
64#[derive(Debug, Clone)]
65pub struct ToolRequest {
66    /// Request context for progress reporting, cancellation, and client requests
67    pub ctx: RequestContext,
68    /// Tool arguments as raw JSON
69    pub args: Value,
70}
71
72impl ToolRequest {
73    /// Create a new tool request
74    pub fn new(ctx: RequestContext, args: Value) -> Self {
75        Self { ctx, args }
76    }
77}
78
79/// A boxed, cloneable tool service with `Error = Infallible`.
80///
81/// This is the internal service type that tools use. Middleware errors are
82/// caught and converted to `CallToolResult::error()` responses, so the
83/// service never fails at the Tower level.
84pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
85
86/// Catches errors from the inner service and converts them to `CallToolResult::error()`.
87///
88/// This wrapper ensures that middleware errors (e.g., timeouts, rate limits)
89/// and handler errors are converted to tool-level error responses with
90/// `is_error: true`, rather than propagating as Tower service errors.
91#[doc(hidden)]
92pub struct ToolCatchError<S> {
93    inner: S,
94}
95
96impl<S> ToolCatchError<S> {
97    /// Create a new `ToolCatchError` wrapping the given service.
98    pub fn new(inner: S) -> Self {
99        Self { inner }
100    }
101}
102
103impl<S: Clone> Clone for ToolCatchError<S> {
104    fn clone(&self) -> Self {
105        Self {
106            inner: self.inner.clone(),
107        }
108    }
109}
110
111impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        f.debug_struct("ToolCatchError")
114            .field("inner", &self.inner)
115            .finish()
116    }
117}
118
119pin_project! {
120    /// Future for [`ToolCatchError`].
121    #[doc(hidden)]
122    pub struct ToolCatchErrorFuture<F> {
123        #[pin]
124        inner: F,
125    }
126}
127
128impl<F, E> Future for ToolCatchErrorFuture<F>
129where
130    F: Future<Output = std::result::Result<CallToolResult, E>>,
131    E: fmt::Display,
132{
133    type Output = std::result::Result<CallToolResult, Infallible>;
134
135    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136        match self.project().inner.poll(cx) {
137            Poll::Pending => Poll::Pending,
138            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
139            Poll::Ready(Err(err)) => Poll::Ready(Ok(CallToolResult::error(err.to_string()))),
140        }
141    }
142}
143
144impl<S> Service<ToolRequest> for ToolCatchError<S>
145where
146    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
147    S::Error: fmt::Display + Send,
148    S::Future: Send,
149{
150    type Response = CallToolResult;
151    type Error = Infallible;
152    type Future = ToolCatchErrorFuture<S::Future>;
153
154    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
155        // Map any readiness error to Infallible (we catch it on call)
156        match self.inner.poll_ready(cx) {
157            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
158            Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
159            Poll::Pending => Poll::Pending,
160        }
161    }
162
163    fn call(&mut self, req: ToolRequest) -> Self::Future {
164        ToolCatchErrorFuture {
165            inner: self.inner.call(req),
166        }
167    }
168}
169
170/// A tower [`Layer`](tower::Layer) that applies a guard function before the inner service.
171///
172/// Guards run before the tool handler and can short-circuit with an error message.
173/// Use via [`ToolBuilderWithHandler::guard`] or [`Tool::with_guard`] rather than
174/// constructing directly.
175///
176/// # Example
177///
178/// ```rust
179/// use tower_mcp::{ToolBuilder, ToolRequest, CallToolResult};
180/// use schemars::JsonSchema;
181/// use serde::Deserialize;
182///
183/// #[derive(Debug, Deserialize, JsonSchema)]
184/// struct DeleteInput { id: String, confirm: bool }
185///
186/// let tool = ToolBuilder::new("delete")
187///     .description("Delete a record")
188///     .handler(|input: DeleteInput| async move {
189///         Ok(CallToolResult::text(format!("deleted {}", input.id)))
190///     })
191///     .guard(|req: &ToolRequest| {
192///         let confirm = req.args.get("confirm").and_then(|v| v.as_bool()).unwrap_or(false);
193///         if !confirm {
194///             return Err("Must set confirm=true to delete".to_string());
195///         }
196///         Ok(())
197///     })
198///     .build();
199/// ```
200#[derive(Clone)]
201pub struct GuardLayer<G> {
202    guard: G,
203}
204
205impl<G> GuardLayer<G> {
206    /// Create a new guard layer from a closure.
207    ///
208    /// The closure receives a `&ToolRequest` and returns `Ok(())` to proceed
209    /// or `Err(String)` to reject with an error message.
210    pub fn new(guard: G) -> Self {
211        Self { guard }
212    }
213}
214
215impl<G, S> tower::Layer<S> for GuardLayer<G>
216where
217    G: Clone,
218{
219    type Service = GuardService<G, S>;
220
221    fn layer(&self, inner: S) -> Self::Service {
222        GuardService {
223            guard: self.guard.clone(),
224            inner,
225        }
226    }
227}
228
229/// Service wrapper that runs a guard check before calling the inner service.
230///
231/// Created by [`GuardLayer`]. See its documentation for usage.
232#[doc(hidden)]
233#[derive(Clone)]
234pub struct GuardService<G, S> {
235    guard: G,
236    inner: S,
237}
238
239impl<G, S> Service<ToolRequest> for GuardService<G, S>
240where
241    G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
242    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
243    S::Error: Into<Error> + Send,
244    S::Future: Send,
245{
246    type Response = CallToolResult;
247    type Error = Error;
248    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
249
250    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
251        self.inner.poll_ready(cx).map_err(Into::into)
252    }
253
254    fn call(&mut self, req: ToolRequest) -> Self::Future {
255        match (self.guard)(&req) {
256            Ok(()) => {
257                let fut = self.inner.call(req);
258                Box::pin(async move { fut.await.map_err(Into::into) })
259            }
260            Err(msg) => Box::pin(async move { Err(Error::tool(msg)) }),
261        }
262    }
263}
264
265/// A marker type for tools that take no parameters.
266///
267/// Use this instead of `()` when defining tools with no input parameters.
268/// The unit type `()` generates `"type": "null"` in JSON Schema, which many
269/// MCP clients reject. `NoParams` generates `"type": "object"` with no
270/// required properties, which is the correct schema for parameterless tools.
271///
272/// # Example
273///
274/// ```rust
275/// use tower_mcp::{ToolBuilder, CallToolResult, NoParams};
276///
277/// let tool = ToolBuilder::new("get_status")
278///     .description("Get current status")
279///     .handler(|_input: NoParams| async move {
280///         Ok(CallToolResult::text("OK"))
281///     })
282///     .build();
283/// ```
284#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
285pub struct NoParams;
286
287impl<'de> serde::Deserialize<'de> for NoParams {
288    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
289    where
290        D: serde::Deserializer<'de>,
291    {
292        // Accept null, empty object, or any object (ignoring all fields)
293        struct NoParamsVisitor;
294
295        impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
296            type Value = NoParams;
297
298            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
299                formatter.write_str("null or an object")
300            }
301
302            fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
303            where
304                E: serde::de::Error,
305            {
306                Ok(NoParams)
307            }
308
309            fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
310            where
311                E: serde::de::Error,
312            {
313                Ok(NoParams)
314            }
315
316            fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
317            where
318                D: serde::Deserializer<'de>,
319            {
320                serde::Deserialize::deserialize(deserializer)
321            }
322
323            fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
324            where
325                A: serde::de::MapAccess<'de>,
326            {
327                // Drain the map, ignoring all entries
328                while map
329                    .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
330                    .is_some()
331                {}
332                Ok(NoParams)
333            }
334        }
335
336        deserializer.deserialize_any(NoParamsVisitor)
337    }
338}
339
340impl JsonSchema for NoParams {
341    fn schema_name() -> Cow<'static, str> {
342        Cow::Borrowed("NoParams")
343    }
344
345    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
346        serde_json::json!({
347            "type": "object"
348        })
349        .try_into()
350        .expect("valid schema")
351    }
352}
353
354/// Validate a tool name according to MCP spec (SEP-986).
355///
356/// Tool names must be:
357/// - 1-64 characters long
358/// - Contain only ASCII alphanumeric characters, underscores, hyphens, dots,
359///   and forward slashes
360///
361/// Returns `Ok(())` if valid, `Err` with description if invalid.
362pub(crate) fn validate_tool_name(name: &str) -> Result<()> {
363    if name.is_empty() {
364        return Err(Error::tool("Tool name cannot be empty"));
365    }
366    if name.len() > 64 {
367        return Err(Error::tool(format!(
368            "Tool name '{}' exceeds maximum length of 64 characters (got {})",
369            name,
370            name.len()
371        )));
372    }
373    if let Some(invalid_char) = name
374        .chars()
375        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.' && *c != '/')
376    {
377        return Err(Error::tool(format!(
378            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, dot, and forward slash are allowed.",
379            name, invalid_char
380        )));
381    }
382    Ok(())
383}
384
385/// Ensure a JSON Schema value has `"type": "object"`.
386///
387/// The MCP spec requires tool input schemas to be JSON objects with a `"type"` field.
388/// Some types (e.g., `serde_json::Value`) generate schemas via schemars that lack
389/// the `"type"` field, which causes MCP clients to reject the tool.
390pub(crate) fn ensure_object_schema(mut schema: Value) -> Value {
391    if let Some(obj) = schema.as_object_mut()
392        && !obj.contains_key("type")
393    {
394        obj.insert("type".to_string(), serde_json::json!("object"));
395    }
396    schema
397}
398
399/// A boxed future for tool handlers
400pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
401
402/// Tool handler trait - the core abstraction for tool execution
403pub trait ToolHandler: Send + Sync {
404    /// Execute the tool with the given arguments
405    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
406
407    /// Execute the tool with request context for progress/cancellation support
408    ///
409    /// The default implementation ignores the context and calls `call`.
410    /// Override this to receive progress/cancellation context.
411    fn call_with_context(
412        &self,
413        _ctx: RequestContext,
414        args: Value,
415    ) -> BoxFuture<'_, Result<CallToolResult>> {
416        self.call(args)
417    }
418
419    /// Returns true if this handler uses context (for optimization)
420    fn uses_context(&self) -> bool {
421        false
422    }
423
424    /// Get the tool's input schema
425    fn input_schema(&self) -> Value;
426}
427
428/// Adapts a `ToolHandler` to a Tower `Service<ToolRequest>`.
429///
430/// This is an internal adapter that bridges the handler abstraction to the
431/// service abstraction, enabling middleware composition.
432pub(crate) struct ToolHandlerService<H> {
433    handler: Arc<H>,
434}
435
436impl<H> ToolHandlerService<H> {
437    pub(crate) fn new(handler: H) -> Self {
438        Self {
439            handler: Arc::new(handler),
440        }
441    }
442}
443
444impl<H> Clone for ToolHandlerService<H> {
445    fn clone(&self) -> Self {
446        Self {
447            handler: self.handler.clone(),
448        }
449    }
450}
451
452impl<H> Service<ToolRequest> for ToolHandlerService<H>
453where
454    H: ToolHandler + 'static,
455{
456    type Response = CallToolResult;
457    type Error = Error;
458    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
459
460    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
461        Poll::Ready(Ok(()))
462    }
463
464    fn call(&mut self, req: ToolRequest) -> Self::Future {
465        let handler = self.handler.clone();
466        Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
467    }
468}
469
470/// A complete tool definition with service-based execution.
471///
472/// Tools are implemented as Tower services internally, enabling middleware
473/// composition via the builder's `.layer()` method. The service is wrapped
474/// in [`ToolCatchError`] to convert any errors (from handlers or middleware)
475/// into `CallToolResult::error()` responses.
476pub struct Tool {
477    /// Tool name (must be 1-128 chars, alphanumeric/underscore/hyphen/dot only)
478    pub name: String,
479    /// Human-readable title for the tool
480    pub title: Option<String>,
481    /// Description of what the tool does
482    pub description: Option<String>,
483    /// JSON Schema for the tool's output (optional)
484    pub output_schema: Option<Value>,
485    /// Icons for the tool
486    pub icons: Option<Vec<ToolIcon>>,
487    /// Tool annotations (hints about behavior)
488    pub annotations: Option<ToolAnnotations>,
489    /// Task support mode for this tool
490    pub task_support: TaskSupportMode,
491    /// The boxed service that executes the tool
492    pub(crate) service: BoxToolService,
493    /// JSON Schema for the tool's input
494    pub(crate) input_schema: Value,
495}
496
497impl std::fmt::Debug for Tool {
498    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499        f.debug_struct("Tool")
500            .field("name", &self.name)
501            .field("title", &self.title)
502            .field("description", &self.description)
503            .field("output_schema", &self.output_schema)
504            .field("icons", &self.icons)
505            .field("annotations", &self.annotations)
506            .field("task_support", &self.task_support)
507            .finish_non_exhaustive()
508    }
509}
510
511// SAFETY: BoxCloneService is Send + Sync (tower provides unsafe impl Sync),
512// and all other fields in Tool are Send + Sync.
513unsafe impl Send for Tool {}
514unsafe impl Sync for Tool {}
515
516impl Clone for Tool {
517    fn clone(&self) -> Self {
518        Self {
519            name: self.name.clone(),
520            title: self.title.clone(),
521            description: self.description.clone(),
522            output_schema: self.output_schema.clone(),
523            icons: self.icons.clone(),
524            annotations: self.annotations.clone(),
525            task_support: self.task_support,
526            service: self.service.clone(),
527            input_schema: self.input_schema.clone(),
528        }
529    }
530}
531
532impl Tool {
533    /// Create a new tool builder
534    pub fn builder(name: impl Into<String>) -> ToolBuilder {
535        ToolBuilder::new(name)
536    }
537
538    /// Get the tool definition for tools/list
539    pub fn definition(&self) -> ToolDefinition {
540        let execution = match self.task_support {
541            TaskSupportMode::Forbidden => None,
542            mode => Some(ToolExecution {
543                task_support: Some(mode),
544            }),
545        };
546        ToolDefinition {
547            name: self.name.clone(),
548            title: self.title.clone(),
549            description: self.description.clone(),
550            input_schema: self.input_schema.clone(),
551            output_schema: self.output_schema.clone(),
552            icons: self.icons.clone(),
553            annotations: self.annotations.clone(),
554            execution,
555            meta: None,
556        }
557    }
558
559    /// Call the tool without context
560    ///
561    /// Creates a dummy request context. For full context support, use
562    /// [`call_with_context`](Self::call_with_context).
563    pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
564        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
565        self.call_with_context(ctx, args)
566    }
567
568    /// Call the tool with request context
569    ///
570    /// The context provides progress reporting, cancellation support, and
571    /// access to client requests (for sampling, etc.).
572    ///
573    /// # Note
574    ///
575    /// This method returns `CallToolResult` directly (not `Result<CallToolResult>`).
576    /// Any errors from the handler or middleware are converted to
577    /// `CallToolResult::error()` with `is_error: true`.
578    pub fn call_with_context(
579        &self,
580        ctx: RequestContext,
581        args: Value,
582    ) -> BoxFuture<'static, CallToolResult> {
583        use tower::ServiceExt;
584        let service = self.service.clone();
585        Box::pin(async move {
586            // ServiceExt::oneshot properly handles poll_ready before call
587            // Service is Infallible, so unwrap is safe
588            service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
589        })
590    }
591
592    /// Apply a guard to this built tool.
593    ///
594    /// The guard runs before the handler and can short-circuit with an error.
595    /// This is useful for applying the same guard to multiple tools (per-group
596    /// pattern):
597    ///
598    /// ```rust
599    /// use tower_mcp::{ToolBuilder, CallToolResult};
600    /// use tower_mcp::tool::ToolRequest;
601    /// use schemars::JsonSchema;
602    /// use serde::Deserialize;
603    ///
604    /// #[derive(Debug, Deserialize, JsonSchema)]
605    /// struct Input { value: String }
606    ///
607    /// fn build_tool(name: &str) -> tower_mcp::tool::Tool {
608    ///     ToolBuilder::new(name)
609    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
610    ///         .build()
611    /// }
612    ///
613    /// let guard = |_req: &ToolRequest| -> Result<(), String> { Ok(()) };
614    ///
615    /// let tools: Vec<_> = vec![build_tool("a"), build_tool("b")]
616    ///     .into_iter()
617    ///     .map(|t| t.with_guard(guard.clone()))
618    ///     .collect();
619    /// ```
620    pub fn with_guard<G>(self, guard: G) -> Self
621    where
622        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
623    {
624        let guarded = GuardService {
625            guard,
626            inner: self.service,
627        };
628        let caught = ToolCatchError::new(guarded);
629        Tool {
630            service: BoxCloneService::new(caught),
631            ..self
632        }
633    }
634
635    /// Create a new tool with a prefixed name.
636    ///
637    /// This creates a copy of the tool with its name prefixed by the given
638    /// string and a dot separator. For example, if the tool is named "query"
639    /// and the prefix is "db", the new tool will be named "db.query".
640    ///
641    /// This is used internally by `McpRouter::nest()` to namespace tools.
642    ///
643    /// # Example
644    ///
645    /// ```rust
646    /// use tower_mcp::{ToolBuilder, CallToolResult};
647    /// use schemars::JsonSchema;
648    /// use serde::Deserialize;
649    ///
650    /// #[derive(Debug, Deserialize, JsonSchema)]
651    /// struct Input { value: String }
652    ///
653    /// let tool = ToolBuilder::new("query")
654    ///     .description("Query the database")
655    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
656    ///     .build();
657    ///
658    /// let prefixed = tool.with_name_prefix("db");
659    /// assert_eq!(prefixed.name, "db.query");
660    /// ```
661    pub fn with_name_prefix(&self, prefix: &str) -> Self {
662        Self {
663            name: format!("{}.{}", prefix, self.name),
664            title: self.title.clone(),
665            description: self.description.clone(),
666            output_schema: self.output_schema.clone(),
667            icons: self.icons.clone(),
668            annotations: self.annotations.clone(),
669            task_support: self.task_support,
670            service: self.service.clone(),
671            input_schema: self.input_schema.clone(),
672        }
673    }
674
675    /// Create a tool from a handler (internal helper)
676    #[allow(clippy::too_many_arguments)]
677    fn from_handler<H: ToolHandler + 'static>(
678        name: String,
679        title: Option<String>,
680        description: Option<String>,
681        output_schema: Option<Value>,
682        icons: Option<Vec<ToolIcon>>,
683        annotations: Option<ToolAnnotations>,
684        task_support: TaskSupportMode,
685        handler: H,
686    ) -> Self {
687        let input_schema = ensure_object_schema(handler.input_schema());
688        let handler_service = ToolHandlerService::new(handler);
689        let catch_error = ToolCatchError::new(handler_service);
690        let service = BoxCloneService::new(catch_error);
691
692        Self {
693            name,
694            title,
695            description,
696            output_schema,
697            icons,
698            annotations,
699            task_support,
700            service,
701            input_schema,
702        }
703    }
704}
705
706// =============================================================================
707// Builder API
708// =============================================================================
709
710/// Builder for creating tools with a fluent API
711///
712/// # Example
713///
714/// ```rust
715/// use tower_mcp::{ToolBuilder, CallToolResult};
716/// use schemars::JsonSchema;
717/// use serde::Deserialize;
718///
719/// #[derive(Debug, Deserialize, JsonSchema)]
720/// struct GreetInput {
721///     name: String,
722/// }
723///
724/// let tool = ToolBuilder::new("greet")
725///     .description("Greet someone by name")
726///     .handler(|input: GreetInput| async move {
727///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
728///     })
729///     .build();
730///
731/// assert_eq!(tool.name, "greet");
732/// ```
733pub struct ToolBuilder {
734    name: String,
735    title: Option<String>,
736    description: Option<String>,
737    output_schema: Option<Value>,
738    icons: Option<Vec<ToolIcon>>,
739    annotations: Option<ToolAnnotations>,
740    task_support: TaskSupportMode,
741}
742
743impl ToolBuilder {
744    /// Create a new tool builder with the given name.
745    ///
746    /// Tool names must be 1-64 characters and contain only ASCII alphanumeric
747    /// characters, underscores, hyphens, dots, and forward slashes (per
748    /// [SEP-986](https://github.com/modelcontextprotocol/specification/issues/986)).
749    ///
750    /// Use [`try_new`](Self::try_new) if the name comes from runtime input.
751    ///
752    /// # Panics
753    ///
754    /// Panics if `name` is empty, exceeds 64 characters, or contains
755    /// characters other than ASCII alphanumerics, `_`, `-`, `.`, and `/`.
756    pub fn new(name: impl Into<String>) -> Self {
757        let name = name.into();
758        if let Err(e) = validate_tool_name(&name) {
759            panic!("{e}");
760        }
761        Self {
762            name,
763            title: None,
764            description: None,
765            output_schema: None,
766            icons: None,
767            annotations: None,
768            task_support: TaskSupportMode::default(),
769        }
770    }
771
772    /// Create a new tool builder, returning an error if the name is invalid.
773    ///
774    /// This is the fallible alternative to [`new`](Self::new) for cases where
775    /// the tool name comes from runtime input (e.g., user configuration or
776    /// database).
777    pub fn try_new(name: impl Into<String>) -> Result<Self> {
778        let name = name.into();
779        validate_tool_name(&name)?;
780        Ok(Self {
781            name,
782            title: None,
783            description: None,
784            output_schema: None,
785            icons: None,
786            annotations: None,
787            task_support: TaskSupportMode::default(),
788        })
789    }
790
791    /// Set a human-readable title for the tool.
792    ///
793    /// The title is displayed by MCP clients (e.g., Claude Code's `/mcp` tool list)
794    /// as a friendly label instead of the raw tool name. For example, a tool named
795    /// `search_crates` with title `"Search Crates"` will display the title in UIs
796    /// that support it.
797    ///
798    /// ```
799    /// # use tower_mcp::ToolBuilder;
800    /// let tool = ToolBuilder::new("search_crates")
801    ///     .title("Search Crates")
802    ///     .description("Search for Rust crates on crates.io")
803    ///     .handler(|()| async { Ok(tower_mcp::CallToolResult::text("results")) })
804    ///     .build();
805    /// ```
806    pub fn title(mut self, title: impl Into<String>) -> Self {
807        self.title = Some(title.into());
808        self
809    }
810
811    /// Set the output schema (JSON Schema for structured output)
812    pub fn output_schema(mut self, schema: Value) -> Self {
813        self.output_schema = Some(schema);
814        self
815    }
816
817    /// Add an icon for the tool
818    pub fn icon(mut self, src: impl Into<String>) -> Self {
819        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
820            src: src.into(),
821            mime_type: None,
822            sizes: None,
823            theme: None,
824        });
825        self
826    }
827
828    /// Add an icon with metadata
829    pub fn icon_with_meta(
830        mut self,
831        src: impl Into<String>,
832        mime_type: Option<String>,
833        sizes: Option<Vec<String>>,
834    ) -> Self {
835        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
836            src: src.into(),
837            mime_type,
838            sizes,
839            theme: None,
840        });
841        self
842    }
843
844    /// Set the tool description
845    pub fn description(mut self, description: impl Into<String>) -> Self {
846        self.description = Some(description.into());
847        self
848    }
849
850    /// Mark the tool as read-only (does not modify state)
851    pub fn read_only(mut self) -> Self {
852        self.annotations
853            .get_or_insert_with(ToolAnnotations::default)
854            .read_only_hint = true;
855        self
856    }
857
858    /// Mark the tool as non-destructive
859    pub fn non_destructive(mut self) -> Self {
860        self.annotations
861            .get_or_insert_with(ToolAnnotations::default)
862            .destructive_hint = false;
863        self
864    }
865
866    /// Mark the tool as destructive (may perform irreversible operations)
867    pub fn destructive(mut self) -> Self {
868        self.annotations
869            .get_or_insert_with(ToolAnnotations::default)
870            .destructive_hint = true;
871        self
872    }
873
874    /// Mark the tool as idempotent (same args = same effect)
875    pub fn idempotent(mut self) -> Self {
876        self.annotations
877            .get_or_insert_with(ToolAnnotations::default)
878            .idempotent_hint = true;
879        self
880    }
881
882    /// Mark the tool as read-only, idempotent, and non-destructive.
883    ///
884    /// This is a convenience method for safe, side-effect-free tools.
885    /// For finer control, use `.read_only()`, `.idempotent()`, and
886    /// `.non_destructive()` individually.
887    pub fn read_only_safe(mut self) -> Self {
888        let ann = self
889            .annotations
890            .get_or_insert_with(ToolAnnotations::default);
891        ann.read_only_hint = true;
892        ann.idempotent_hint = true;
893        ann.destructive_hint = false;
894        self
895    }
896
897    /// Set tool annotations directly
898    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
899        self.annotations = Some(annotations);
900        self
901    }
902
903    /// Set the task support mode for this tool
904    pub fn task_support(mut self, mode: TaskSupportMode) -> Self {
905        self.task_support = mode;
906        self
907    }
908
909    /// Create a tool that takes no parameters.
910    ///
911    /// This is a convenience method for tools that don't require any input.
912    /// It generates the correct `{"type": "object"}` schema that MCP clients expect.
913    ///
914    /// # Example
915    ///
916    /// ```rust
917    /// use tower_mcp::{ToolBuilder, CallToolResult};
918    ///
919    /// let tool = ToolBuilder::new("get_status")
920    ///     .description("Get current status")
921    ///     .no_params_handler(|| async {
922    ///         Ok(CallToolResult::text("OK"))
923    ///     })
924    ///     .build();
925    /// ```
926    pub fn no_params_handler<F, Fut>(self, handler: F) -> ToolBuilderWithNoParamsHandler<F>
927    where
928        F: Fn() -> Fut + Send + Sync + 'static,
929        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
930    {
931        ToolBuilderWithNoParamsHandler {
932            name: self.name,
933            title: self.title,
934            description: self.description,
935            output_schema: self.output_schema,
936            icons: self.icons,
937            annotations: self.annotations,
938            task_support: self.task_support,
939            handler,
940        }
941    }
942
943    /// Specify input type and handler.
944    ///
945    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
946    /// The handler receives the deserialized input and returns a `CallToolResult`.
947    ///
948    /// # State Sharing
949    ///
950    /// To share state across tool calls (e.g., database connections, API clients),
951    /// wrap your state in an `Arc` and clone it into the async block:
952    ///
953    /// ```rust
954    /// use std::sync::Arc;
955    /// use tower_mcp::{ToolBuilder, CallToolResult};
956    /// use schemars::JsonSchema;
957    /// use serde::Deserialize;
958    ///
959    /// struct AppState {
960    ///     api_key: String,
961    /// }
962    ///
963    /// #[derive(Debug, Deserialize, JsonSchema)]
964    /// struct MyInput {
965    ///     query: String,
966    /// }
967    ///
968    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
969    ///
970    /// let tool = ToolBuilder::new("my_tool")
971    ///     .description("A tool that uses shared state")
972    ///     .handler(move |input: MyInput| {
973    ///         let state = state.clone(); // Clone Arc for the async block
974    ///         async move {
975    ///             // Use state.api_key here...
976    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
977    ///         }
978    ///     })
979    ///     .build();
980    /// ```
981    ///
982    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
983    /// cloning it inside the closure body allows each async invocation to
984    /// have its own reference to the shared state.
985    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
986    where
987        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
988        F: Fn(I) -> Fut + Send + Sync + 'static,
989        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
990    {
991        ToolBuilderWithHandler {
992            name: self.name,
993            title: self.title,
994            description: self.description,
995            output_schema: self.output_schema,
996            icons: self.icons,
997            annotations: self.annotations,
998            task_support: self.task_support,
999            handler,
1000            _phantom: std::marker::PhantomData,
1001        }
1002    }
1003
1004    /// Create a tool using the extractor pattern.
1005    ///
1006    /// This method provides an axum-inspired way to define handlers where state,
1007    /// context, and input are extracted declaratively from function parameters.
1008    /// This reduces the combinatorial explosion of handler variants like
1009    /// `handler_with_state`, `handler_with_context`, etc.
1010    ///
1011    /// # Schema Auto-Detection
1012    ///
1013    /// When a [`Json<T>`](crate::extract::Json) extractor is used, the proper JSON
1014    /// schema is automatically generated from `T`'s `JsonSchema` implementation.
1015    /// No turbofish is needed -- the schema type is inferred from the closure
1016    /// parameters.
1017    ///
1018    /// # Extractors
1019    ///
1020    /// Built-in extractors available in [`crate::extract`]:
1021    /// - [`Json<T>`](crate::extract::Json) - Deserialize JSON arguments to type `T`
1022    /// - [`State<T>`](crate::extract::State) - Extract cloned state
1023    /// - [`Extension<T>`](crate::extract::Extension) - Extract router-level state
1024    /// - [`Context`](crate::extract::Context) - Extract request context
1025    /// - [`RawArgs`](crate::extract::RawArgs) - Extract raw JSON arguments
1026    ///
1027    /// # Per-Tool Middleware
1028    ///
1029    /// The returned builder supports `.layer()` to apply Tower middleware:
1030    ///
1031    /// ```rust
1032    /// use std::sync::Arc;
1033    /// use std::time::Duration;
1034    /// use tower::timeout::TimeoutLayer;
1035    /// use tower_mcp::{ToolBuilder, CallToolResult};
1036    /// use tower_mcp::extract::{Json, State};
1037    /// use schemars::JsonSchema;
1038    /// use serde::Deserialize;
1039    ///
1040    /// #[derive(Clone)]
1041    /// struct Database { url: String }
1042    ///
1043    /// #[derive(Debug, Deserialize, JsonSchema)]
1044    /// struct QueryInput { query: String }
1045    ///
1046    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
1047    ///
1048    /// let tool = ToolBuilder::new("search")
1049    ///     .description("Search the database")
1050    ///     .extractor_handler(db, |
1051    ///         State(db): State<Arc<Database>>,
1052    ///         Json(input): Json<QueryInput>,
1053    ///     | async move {
1054    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
1055    ///     })
1056    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1057    ///     .build();
1058    /// ```
1059    ///
1060    /// # Example
1061    ///
1062    /// ```rust
1063    /// use std::sync::Arc;
1064    /// use tower_mcp::{ToolBuilder, CallToolResult};
1065    /// use tower_mcp::extract::{Json, State, Context};
1066    /// use schemars::JsonSchema;
1067    /// use serde::Deserialize;
1068    ///
1069    /// #[derive(Clone)]
1070    /// struct Database { url: String }
1071    ///
1072    /// #[derive(Debug, Deserialize, JsonSchema)]
1073    /// struct QueryInput { query: String }
1074    ///
1075    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
1076    ///
1077    /// let tool = ToolBuilder::new("search")
1078    ///     .description("Search the database")
1079    ///     .extractor_handler(db, |
1080    ///         State(db): State<Arc<Database>>,
1081    ///         ctx: Context,
1082    ///         Json(input): Json<QueryInput>,
1083    ///     | async move {
1084    ///         if ctx.is_cancelled() {
1085    ///             return Ok(CallToolResult::error("Cancelled"));
1086    ///         }
1087    ///         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
1088    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
1089    ///     })
1090    ///     .build();
1091    /// ```
1092    ///
1093    /// # Type Inference
1094    ///
1095    /// The compiler infers extractor types from the function signature. Make sure
1096    /// to annotate the extractor types explicitly in the closure parameters.
1097    pub fn extractor_handler<S, F, T>(
1098        self,
1099        state: S,
1100        handler: F,
1101    ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
1102    where
1103        S: Clone + Send + Sync + 'static,
1104        F: crate::extract::ExtractorHandler<S, T> + Clone,
1105        T: Send + Sync + 'static,
1106    {
1107        crate::extract::ToolBuilderWithExtractor {
1108            name: self.name,
1109            title: self.title,
1110            description: self.description,
1111            output_schema: self.output_schema,
1112            icons: self.icons,
1113            annotations: self.annotations,
1114            task_support: self.task_support,
1115            state,
1116            handler,
1117            input_schema: F::input_schema(),
1118            _phantom: std::marker::PhantomData,
1119        }
1120    }
1121
1122    /// Create a tool using the extractor pattern with typed JSON input.
1123    ///
1124    /// # Deprecated
1125    ///
1126    /// Use [`extractor_handler`](Self::extractor_handler) instead. It auto-detects
1127    /// the JSON schema from `Json<T>` extractors, producing identical results
1128    /// without requiring a turbofish.
1129    ///
1130    /// ```rust
1131    /// # use std::sync::Arc;
1132    /// # use tower_mcp::{ToolBuilder, CallToolResult};
1133    /// # use tower_mcp::extract::{Json, State};
1134    /// # use schemars::JsonSchema;
1135    /// # use serde::Deserialize;
1136    /// # #[derive(Clone)]
1137    /// # struct AppState { prefix: String }
1138    /// # #[derive(Debug, Deserialize, JsonSchema)]
1139    /// # struct GreetInput { name: String }
1140    /// # let state = Arc::new(AppState { prefix: "Hello".to_string() });
1141    /// // Before (deprecated):
1142    /// // .extractor_handler_typed::<_, _, _, GreetInput>(state, handler)
1143    ///
1144    /// // After:
1145    /// let tool = ToolBuilder::new("greet")
1146    ///     .description("Greet someone")
1147    ///     .extractor_handler(state, |
1148    ///         State(app): State<Arc<AppState>>,
1149    ///         Json(input): Json<GreetInput>,
1150    ///     | async move {
1151    ///         Ok(CallToolResult::text(format!("{}, {}!", app.prefix, input.name)))
1152    ///     })
1153    ///     .build();
1154    /// ```
1155    #[deprecated(
1156        since = "0.8.0",
1157        note = "Use `extractor_handler` instead -- it auto-detects JSON schema from `Json<T>` extractors without requiring a turbofish"
1158    )]
1159    #[allow(deprecated)]
1160    pub fn extractor_handler_typed<S, F, T, I>(
1161        self,
1162        state: S,
1163        handler: F,
1164    ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
1165    where
1166        S: Clone + Send + Sync + 'static,
1167        F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
1168        T: Send + Sync + 'static,
1169        I: schemars::JsonSchema + Send + Sync + 'static,
1170    {
1171        crate::extract::ToolBuilderWithTypedExtractor {
1172            name: self.name,
1173            title: self.title,
1174            description: self.description,
1175            output_schema: self.output_schema,
1176            icons: self.icons,
1177            annotations: self.annotations,
1178            task_support: self.task_support,
1179            state,
1180            handler,
1181            _phantom: std::marker::PhantomData,
1182        }
1183    }
1184}
1185
1186/// Handler for tools with no parameters.
1187///
1188/// Used internally by [`ToolBuilder::no_params_handler`].
1189struct NoParamsTypedHandler<F> {
1190    handler: F,
1191}
1192
1193impl<F, Fut> ToolHandler for NoParamsTypedHandler<F>
1194where
1195    F: Fn() -> Fut + Send + Sync + 'static,
1196    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1197{
1198    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1199        Box::pin(async move { (self.handler)().await })
1200    }
1201
1202    fn input_schema(&self) -> Value {
1203        serde_json::json!({ "type": "object" })
1204    }
1205}
1206
1207/// Builder state after handler is specified
1208#[doc(hidden)]
1209pub struct ToolBuilderWithHandler<I, F> {
1210    name: String,
1211    title: Option<String>,
1212    description: Option<String>,
1213    output_schema: Option<Value>,
1214    icons: Option<Vec<ToolIcon>>,
1215    annotations: Option<ToolAnnotations>,
1216    task_support: TaskSupportMode,
1217    handler: F,
1218    _phantom: std::marker::PhantomData<I>,
1219}
1220
1221/// Builder state for tools with no parameters.
1222///
1223/// Created by [`ToolBuilder::no_params_handler`].
1224#[doc(hidden)]
1225pub struct ToolBuilderWithNoParamsHandler<F> {
1226    name: String,
1227    title: Option<String>,
1228    description: Option<String>,
1229    output_schema: Option<Value>,
1230    icons: Option<Vec<ToolIcon>>,
1231    annotations: Option<ToolAnnotations>,
1232    task_support: TaskSupportMode,
1233    handler: F,
1234}
1235
1236impl<F, Fut> ToolBuilderWithNoParamsHandler<F>
1237where
1238    F: Fn() -> Fut + Send + Sync + 'static,
1239    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1240{
1241    /// Build the tool.
1242    pub fn build(self) -> Tool {
1243        Tool::from_handler(
1244            self.name,
1245            self.title,
1246            self.description,
1247            self.output_schema,
1248            self.icons,
1249            self.annotations,
1250            self.task_support,
1251            NoParamsTypedHandler {
1252                handler: self.handler,
1253            },
1254        )
1255    }
1256
1257    /// Apply a Tower layer (middleware) to this tool.
1258    ///
1259    /// See [`ToolBuilderWithHandler::layer`] for details.
1260    pub fn layer<L>(self, layer: L) -> ToolBuilderWithNoParamsHandlerLayer<F, L> {
1261        ToolBuilderWithNoParamsHandlerLayer {
1262            name: self.name,
1263            title: self.title,
1264            description: self.description,
1265            output_schema: self.output_schema,
1266            icons: self.icons,
1267            annotations: self.annotations,
1268            task_support: self.task_support,
1269            handler: self.handler,
1270            layer,
1271        }
1272    }
1273
1274    /// Apply a guard to this tool.
1275    ///
1276    /// See [`ToolBuilderWithHandler::guard`] for details.
1277    pub fn guard<G>(self, guard: G) -> ToolBuilderWithNoParamsHandlerLayer<F, GuardLayer<G>>
1278    where
1279        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1280    {
1281        self.layer(GuardLayer::new(guard))
1282    }
1283}
1284
1285/// Builder state after a layer has been applied to a no-params handler.
1286#[doc(hidden)]
1287pub struct ToolBuilderWithNoParamsHandlerLayer<F, L> {
1288    name: String,
1289    title: Option<String>,
1290    description: Option<String>,
1291    output_schema: Option<Value>,
1292    icons: Option<Vec<ToolIcon>>,
1293    annotations: Option<ToolAnnotations>,
1294    task_support: TaskSupportMode,
1295    handler: F,
1296    layer: L,
1297}
1298
1299#[allow(private_bounds)]
1300impl<F, Fut, L> ToolBuilderWithNoParamsHandlerLayer<F, L>
1301where
1302    F: Fn() -> Fut + Send + Sync + 'static,
1303    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1304    L: tower::Layer<ToolHandlerService<NoParamsTypedHandler<F>>> + Clone + Send + Sync + 'static,
1305    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1306    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1307    <L::Service as Service<ToolRequest>>::Future: Send,
1308{
1309    /// Build the tool with the applied layer(s).
1310    pub fn build(self) -> Tool {
1311        let input_schema = serde_json::json!({ "type": "object" });
1312
1313        let handler_service = ToolHandlerService::new(NoParamsTypedHandler {
1314            handler: self.handler,
1315        });
1316        let layered = self.layer.layer(handler_service);
1317        let catch_error = ToolCatchError::new(layered);
1318        let service = BoxCloneService::new(catch_error);
1319
1320        Tool {
1321            name: self.name,
1322            title: self.title,
1323            description: self.description,
1324            output_schema: self.output_schema,
1325            icons: self.icons,
1326            annotations: self.annotations,
1327            task_support: self.task_support,
1328            service,
1329            input_schema,
1330        }
1331    }
1332
1333    /// Apply an additional Tower layer (middleware).
1334    pub fn layer<L2>(
1335        self,
1336        layer: L2,
1337    ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<L2, L>> {
1338        ToolBuilderWithNoParamsHandlerLayer {
1339            name: self.name,
1340            title: self.title,
1341            description: self.description,
1342            output_schema: self.output_schema,
1343            icons: self.icons,
1344            annotations: self.annotations,
1345            task_support: self.task_support,
1346            handler: self.handler,
1347            layer: tower::layer::util::Stack::new(layer, self.layer),
1348        }
1349    }
1350
1351    /// Apply a guard to this tool.
1352    ///
1353    /// See [`ToolBuilderWithHandler::guard`] for details.
1354    pub fn guard<G>(
1355        self,
1356        guard: G,
1357    ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<GuardLayer<G>, L>>
1358    where
1359        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1360    {
1361        self.layer(GuardLayer::new(guard))
1362    }
1363}
1364
1365impl<I, F, Fut> ToolBuilderWithHandler<I, F>
1366where
1367    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1368    F: Fn(I) -> Fut + Send + Sync + 'static,
1369    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1370{
1371    /// Build the tool.
1372    pub fn build(self) -> Tool {
1373        Tool::from_handler(
1374            self.name,
1375            self.title,
1376            self.description,
1377            self.output_schema,
1378            self.icons,
1379            self.annotations,
1380            self.task_support,
1381            TypedHandler {
1382                handler: self.handler,
1383                _phantom: std::marker::PhantomData,
1384            },
1385        )
1386    }
1387
1388    /// Apply a Tower layer (middleware) to this tool.
1389    ///
1390    /// The layer wraps the tool's handler service, enabling functionality like
1391    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1392    ///
1393    /// # Example
1394    ///
1395    /// ```rust
1396    /// use std::time::Duration;
1397    /// use tower::timeout::TimeoutLayer;
1398    /// use tower_mcp::{ToolBuilder, CallToolResult};
1399    /// use schemars::JsonSchema;
1400    /// use serde::Deserialize;
1401    ///
1402    /// #[derive(Debug, Deserialize, JsonSchema)]
1403    /// struct Input { query: String }
1404    ///
1405    /// let tool = ToolBuilder::new("search")
1406    ///     .description("Search with timeout")
1407    ///     .handler(|input: Input| async move {
1408    ///         Ok(CallToolResult::text("result"))
1409    ///     })
1410    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1411    ///     .build();
1412    /// ```
1413    pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
1414        ToolBuilderWithLayer {
1415            name: self.name,
1416            title: self.title,
1417            description: self.description,
1418            output_schema: self.output_schema,
1419            icons: self.icons,
1420            annotations: self.annotations,
1421            task_support: self.task_support,
1422            handler: self.handler,
1423            layer,
1424            _phantom: std::marker::PhantomData,
1425        }
1426    }
1427
1428    /// Apply a guard to this tool.
1429    ///
1430    /// The guard runs before the handler and can short-circuit with an error
1431    /// message. This is syntactic sugar for `.layer(GuardLayer::new(f))`.
1432    ///
1433    /// See [`GuardLayer`] for a full example.
1434    pub fn guard<G>(self, guard: G) -> ToolBuilderWithLayer<I, F, GuardLayer<G>>
1435    where
1436        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1437    {
1438        self.layer(GuardLayer::new(guard))
1439    }
1440}
1441
1442/// Builder state after a layer has been applied to the handler.
1443///
1444/// This builder allows chaining additional layers and building the final tool.
1445#[doc(hidden)]
1446pub struct ToolBuilderWithLayer<I, F, L> {
1447    name: String,
1448    title: Option<String>,
1449    description: Option<String>,
1450    output_schema: Option<Value>,
1451    icons: Option<Vec<ToolIcon>>,
1452    annotations: Option<ToolAnnotations>,
1453    task_support: TaskSupportMode,
1454    handler: F,
1455    layer: L,
1456    _phantom: std::marker::PhantomData<I>,
1457}
1458
1459// Allow private_bounds because these internal types (ToolHandlerService, TypedHandler, etc.)
1460// are implementation details that users don't interact with directly.
1461#[allow(private_bounds)]
1462impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
1463where
1464    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1465    F: Fn(I) -> Fut + Send + Sync + 'static,
1466    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1467    L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
1468    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1469    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1470    <L::Service as Service<ToolRequest>>::Future: Send,
1471{
1472    /// Build the tool with the applied layer(s).
1473    pub fn build(self) -> Tool {
1474        let input_schema = schemars::schema_for!(I);
1475        let input_schema = serde_json::to_value(input_schema)
1476            .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
1477        let input_schema = ensure_object_schema(input_schema);
1478
1479        let handler_service = ToolHandlerService::new(TypedHandler {
1480            handler: self.handler,
1481            _phantom: std::marker::PhantomData,
1482        });
1483        let layered = self.layer.layer(handler_service);
1484        let catch_error = ToolCatchError::new(layered);
1485        let service = BoxCloneService::new(catch_error);
1486
1487        Tool {
1488            name: self.name,
1489            title: self.title,
1490            description: self.description,
1491            output_schema: self.output_schema,
1492            icons: self.icons,
1493            annotations: self.annotations,
1494            task_support: self.task_support,
1495            service,
1496            input_schema,
1497        }
1498    }
1499
1500    /// Apply an additional Tower layer (middleware).
1501    ///
1502    /// Layers are applied in order, with earlier layers wrapping later ones.
1503    /// This means the first layer added is the outermost middleware.
1504    pub fn layer<L2>(
1505        self,
1506        layer: L2,
1507    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
1508        ToolBuilderWithLayer {
1509            name: self.name,
1510            title: self.title,
1511            description: self.description,
1512            output_schema: self.output_schema,
1513            icons: self.icons,
1514            annotations: self.annotations,
1515            task_support: self.task_support,
1516            handler: self.handler,
1517            layer: tower::layer::util::Stack::new(layer, self.layer),
1518            _phantom: std::marker::PhantomData,
1519        }
1520    }
1521
1522    /// Apply a guard to this tool.
1523    ///
1524    /// See [`ToolBuilderWithHandler::guard`] for details.
1525    pub fn guard<G>(
1526        self,
1527        guard: G,
1528    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<GuardLayer<G>, L>>
1529    where
1530        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1531    {
1532        self.layer(GuardLayer::new(guard))
1533    }
1534}
1535
1536// =============================================================================
1537// Handler implementations
1538// =============================================================================
1539
1540/// Handler that deserializes input to a specific type
1541struct TypedHandler<I, F> {
1542    handler: F,
1543    _phantom: std::marker::PhantomData<I>,
1544}
1545
1546impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1547where
1548    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1549    F: Fn(I) -> Fut + Send + Sync + 'static,
1550    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1551{
1552    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1553        Box::pin(async move {
1554            let input: I = match serde_json::from_value(args) {
1555                Ok(input) => input,
1556                Err(e) => return Ok(CallToolResult::error(format!("Invalid input: {e}"))),
1557            };
1558            (self.handler)(input).await
1559        })
1560    }
1561
1562    fn input_schema(&self) -> Value {
1563        let schema = schemars::schema_for!(I);
1564        let schema = serde_json::to_value(schema).unwrap_or_else(|_| {
1565            serde_json::json!({
1566                "type": "object"
1567            })
1568        });
1569        ensure_object_schema(schema)
1570    }
1571}
1572
1573// =============================================================================
1574// Trait-based tool definition
1575// =============================================================================
1576
1577/// Trait for defining tools with full control
1578///
1579/// Implement this trait when you need more control than the builder provides,
1580/// or when you want to define tools as standalone types.
1581///
1582/// # Example
1583///
1584/// ```rust
1585/// use tower_mcp::tool::McpTool;
1586/// use tower_mcp::error::Result;
1587/// use schemars::JsonSchema;
1588/// use serde::{Deserialize, Serialize};
1589///
1590/// #[derive(Debug, Deserialize, JsonSchema)]
1591/// struct AddInput {
1592///     a: i64,
1593///     b: i64,
1594/// }
1595///
1596/// struct AddTool;
1597///
1598/// impl McpTool for AddTool {
1599///     const NAME: &'static str = "add";
1600///     const DESCRIPTION: &'static str = "Add two numbers";
1601///
1602///     type Input = AddInput;
1603///     type Output = i64;
1604///
1605///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
1606///         Ok(input.a + input.b)
1607///     }
1608/// }
1609///
1610/// let tool = AddTool.into_tool();
1611/// assert_eq!(tool.name, "add");
1612/// ```
1613pub trait McpTool: Send + Sync + 'static {
1614    /// The tool name (must be unique within the router).
1615    const NAME: &'static str;
1616    /// A human-readable description of the tool.
1617    const DESCRIPTION: &'static str;
1618
1619    /// The input type, deserialized from tool call arguments.
1620    type Input: JsonSchema + DeserializeOwned + Send;
1621    /// The output type, serialized into the tool call result.
1622    type Output: Serialize + Send;
1623
1624    /// Execute the tool with the given input.
1625    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1626
1627    /// Optional annotations for the tool
1628    fn annotations(&self) -> Option<ToolAnnotations> {
1629        None
1630    }
1631
1632    /// Convert to a [`Tool`] instance.
1633    ///
1634    /// # Panics
1635    ///
1636    /// Panics if [`NAME`](Self::NAME) is not a valid tool name. Since `NAME`
1637    /// is a `&'static str`, invalid names are caught immediately during
1638    /// development.
1639    fn into_tool(self) -> Tool
1640    where
1641        Self: Sized,
1642    {
1643        if let Err(e) = validate_tool_name(Self::NAME) {
1644            panic!("{e}");
1645        }
1646        let annotations = self.annotations();
1647        let tool = Arc::new(self);
1648        Tool::from_handler(
1649            Self::NAME.to_string(),
1650            None,
1651            Some(Self::DESCRIPTION.to_string()),
1652            None,
1653            None,
1654            annotations,
1655            TaskSupportMode::default(),
1656            McpToolHandler { tool },
1657        )
1658    }
1659}
1660
1661/// Wrapper to make McpTool implement ToolHandler
1662struct McpToolHandler<T: McpTool> {
1663    tool: Arc<T>,
1664}
1665
1666impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1667    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1668        let tool = self.tool.clone();
1669        Box::pin(async move {
1670            let input: T::Input = match serde_json::from_value(args) {
1671                Ok(input) => input,
1672                Err(e) => return Ok(CallToolResult::error(format!("Invalid input: {e}"))),
1673            };
1674            let output = tool.call(input).await?;
1675            let value = serde_json::to_value(output).tool_context("Failed to serialize output")?;
1676            Ok(CallToolResult::json(value))
1677        })
1678    }
1679
1680    fn input_schema(&self) -> Value {
1681        let schema = schemars::schema_for!(T::Input);
1682        let schema = serde_json::to_value(schema).unwrap_or_else(|_| {
1683            serde_json::json!({
1684                "type": "object"
1685            })
1686        });
1687        ensure_object_schema(schema)
1688    }
1689}
1690
1691#[cfg(test)]
1692mod tests {
1693    use super::*;
1694    use crate::extract::{Context, Json, RawArgs, State};
1695    use crate::protocol::Content;
1696    use schemars::JsonSchema;
1697    use serde::Deserialize;
1698
1699    #[derive(Debug, Deserialize, JsonSchema)]
1700    struct GreetInput {
1701        name: String,
1702    }
1703
1704    #[tokio::test]
1705    async fn test_builder_tool() {
1706        let tool = ToolBuilder::new("greet")
1707            .description("Greet someone")
1708            .handler(|input: GreetInput| async move {
1709                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1710            })
1711            .build();
1712
1713        assert_eq!(tool.name, "greet");
1714        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1715
1716        let result = tool.call(serde_json::json!({"name": "World"})).await;
1717
1718        assert!(!result.is_error);
1719    }
1720
1721    #[tokio::test]
1722    async fn test_raw_handler() {
1723        let tool = ToolBuilder::new("echo")
1724            .description("Echo input")
1725            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1726                Ok(CallToolResult::json(args))
1727            })
1728            .build();
1729
1730        let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1731
1732        assert!(!result.is_error);
1733    }
1734
1735    #[test]
1736    fn test_invalid_tool_name_empty() {
1737        let err = ToolBuilder::try_new("").err().expect("should fail");
1738        assert!(err.to_string().contains("cannot be empty"));
1739    }
1740
1741    #[test]
1742    fn test_invalid_tool_name_too_long() {
1743        let long_name = "a".repeat(65);
1744        let err = ToolBuilder::try_new(long_name).err().expect("should fail");
1745        assert!(err.to_string().contains("exceeds maximum"));
1746    }
1747
1748    #[test]
1749    fn test_invalid_tool_name_bad_chars() {
1750        let err = ToolBuilder::try_new("my tool!").err().expect("should fail");
1751        assert!(err.to_string().contains("invalid character"));
1752    }
1753
1754    #[test]
1755    #[should_panic(expected = "cannot be empty")]
1756    fn test_new_panics_on_empty_name() {
1757        ToolBuilder::new("");
1758    }
1759
1760    #[test]
1761    #[should_panic(expected = "exceeds maximum")]
1762    fn test_new_panics_on_too_long_name() {
1763        ToolBuilder::new("a".repeat(65));
1764    }
1765
1766    #[test]
1767    #[should_panic(expected = "invalid character")]
1768    fn test_new_panics_on_invalid_chars() {
1769        ToolBuilder::new("my tool!");
1770    }
1771
1772    #[test]
1773    fn test_valid_tool_names() {
1774        // All valid characters per SEP-986
1775        let names = [
1776            "my_tool",
1777            "my-tool",
1778            "my.tool",
1779            "my/tool",
1780            "user-profile/update",
1781            "MyTool123",
1782            "a",
1783            &"a".repeat(64),
1784        ];
1785        for name in names {
1786            assert!(
1787                ToolBuilder::try_new(name).is_ok(),
1788                "Expected '{}' to be valid",
1789                name
1790            );
1791        }
1792    }
1793
1794    #[tokio::test]
1795    async fn test_context_aware_handler() {
1796        use crate::context::notification_channel;
1797        use crate::protocol::{ProgressToken, RequestId};
1798
1799        #[derive(Debug, Deserialize, JsonSchema)]
1800        struct ProcessInput {
1801            count: i32,
1802        }
1803
1804        let tool = ToolBuilder::new("process")
1805            .description("Process with context")
1806            .extractor_handler(
1807                (),
1808                |ctx: Context, Json(input): Json<ProcessInput>| async move {
1809                    // Simulate progress reporting
1810                    for i in 0..input.count {
1811                        if ctx.is_cancelled() {
1812                            return Ok(CallToolResult::error("Cancelled"));
1813                        }
1814                        ctx.report_progress(i as f64, Some(input.count as f64), None)
1815                            .await;
1816                    }
1817                    Ok(CallToolResult::text(format!(
1818                        "Processed {} items",
1819                        input.count
1820                    )))
1821                },
1822            )
1823            .build();
1824
1825        assert_eq!(tool.name, "process");
1826
1827        // Test with a context that has progress token and notification sender
1828        let (tx, mut rx) = notification_channel(10);
1829        let ctx = RequestContext::new(RequestId::Number(1))
1830            .with_progress_token(ProgressToken::Number(42))
1831            .with_notification_sender(tx);
1832
1833        let result = tool
1834            .call_with_context(ctx, serde_json::json!({"count": 3}))
1835            .await;
1836
1837        assert!(!result.is_error);
1838
1839        // Check that progress notifications were sent
1840        let mut progress_count = 0;
1841        while rx.try_recv().is_ok() {
1842            progress_count += 1;
1843        }
1844        assert_eq!(progress_count, 3);
1845    }
1846
1847    #[tokio::test]
1848    async fn test_context_aware_handler_cancellation() {
1849        use crate::protocol::RequestId;
1850        use std::sync::atomic::{AtomicI32, Ordering};
1851
1852        #[derive(Debug, Deserialize, JsonSchema)]
1853        struct LongRunningInput {
1854            iterations: i32,
1855        }
1856
1857        let iterations_completed = Arc::new(AtomicI32::new(0));
1858        let iterations_ref = iterations_completed.clone();
1859
1860        let tool = ToolBuilder::new("long_running")
1861            .description("Long running task")
1862            .extractor_handler(
1863                (),
1864                move |ctx: Context, Json(input): Json<LongRunningInput>| {
1865                    let completed = iterations_ref.clone();
1866                    async move {
1867                        for i in 0..input.iterations {
1868                            if ctx.is_cancelled() {
1869                                return Ok(CallToolResult::error("Cancelled"));
1870                            }
1871                            completed.fetch_add(1, Ordering::SeqCst);
1872                            // Simulate work
1873                            tokio::task::yield_now().await;
1874                            // Cancel after iteration 2
1875                            if i == 2 {
1876                                ctx.cancellation_token().cancel();
1877                            }
1878                        }
1879                        Ok(CallToolResult::text("Done"))
1880                    }
1881                },
1882            )
1883            .build();
1884
1885        let ctx = RequestContext::new(RequestId::Number(1));
1886
1887        let result = tool
1888            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1889            .await;
1890
1891        // Should have been cancelled after 3 iterations (0, 1, 2)
1892        // The next iteration (3) checks cancellation and returns
1893        assert!(result.is_error);
1894        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1895    }
1896
1897    #[tokio::test]
1898    async fn test_tool_builder_with_enhanced_fields() {
1899        let output_schema = serde_json::json!({
1900            "type": "object",
1901            "properties": {
1902                "greeting": {"type": "string"}
1903            }
1904        });
1905
1906        let tool = ToolBuilder::new("greet")
1907            .title("Greeting Tool")
1908            .description("Greet someone")
1909            .output_schema(output_schema.clone())
1910            .icon("https://example.com/icon.png")
1911            .icon_with_meta(
1912                "https://example.com/icon-large.png",
1913                Some("image/png".to_string()),
1914                Some(vec!["96x96".to_string()]),
1915            )
1916            .handler(|input: GreetInput| async move {
1917                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1918            })
1919            .build();
1920
1921        assert_eq!(tool.name, "greet");
1922        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1923        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1924        assert_eq!(tool.output_schema, Some(output_schema));
1925        assert!(tool.icons.is_some());
1926        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1927
1928        // Test definition includes new fields
1929        let def = tool.definition();
1930        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1931        assert!(def.output_schema.is_some());
1932        assert!(def.icons.is_some());
1933    }
1934
1935    #[tokio::test]
1936    async fn test_handler_with_state() {
1937        let shared = Arc::new("shared-state".to_string());
1938
1939        let tool = ToolBuilder::new("stateful")
1940            .description("Uses shared state")
1941            .extractor_handler(
1942                shared,
1943                |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1944                    Ok(CallToolResult::text(format!(
1945                        "{}: Hello, {}!",
1946                        state, input.name
1947                    )))
1948                },
1949            )
1950            .build();
1951
1952        let result = tool.call(serde_json::json!({"name": "World"})).await;
1953        assert!(!result.is_error);
1954    }
1955
1956    #[tokio::test]
1957    async fn test_handler_with_state_and_context() {
1958        use crate::protocol::RequestId;
1959
1960        let shared = Arc::new(42_i32);
1961
1962        let tool =
1963            ToolBuilder::new("stateful_ctx")
1964                .description("Uses state and context")
1965                .extractor_handler(
1966                    shared,
1967                    |State(state): State<Arc<i32>>,
1968                     _ctx: Context,
1969                     Json(input): Json<GreetInput>| async move {
1970                        Ok(CallToolResult::text(format!(
1971                            "{}: Hello, {}!",
1972                            state, input.name
1973                        )))
1974                    },
1975                )
1976                .build();
1977
1978        let ctx = RequestContext::new(RequestId::Number(1));
1979        let result = tool
1980            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1981            .await;
1982        assert!(!result.is_error);
1983    }
1984
1985    #[tokio::test]
1986    async fn test_handler_no_params() {
1987        let tool = ToolBuilder::new("no_params")
1988            .description("Takes no parameters")
1989            .extractor_handler((), |Json(_): Json<NoParams>| async {
1990                Ok(CallToolResult::text("no params result"))
1991            })
1992            .build();
1993
1994        assert_eq!(tool.name, "no_params");
1995
1996        // Should work with empty args
1997        let result = tool.call(serde_json::json!({})).await;
1998        assert!(!result.is_error);
1999
2000        // Should also work with unexpected args (ignored)
2001        let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
2002        assert!(!result.is_error);
2003
2004        // Check input schema includes type: object
2005        let schema = tool.definition().input_schema;
2006        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
2007    }
2008
2009    #[tokio::test]
2010    async fn test_handler_with_state_no_params() {
2011        let shared = Arc::new("shared_value".to_string());
2012
2013        let tool = ToolBuilder::new("with_state_no_params")
2014            .description("Takes no parameters but has state")
2015            .extractor_handler(
2016                shared,
2017                |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
2018                    Ok(CallToolResult::text(format!("state: {}", state)))
2019                },
2020            )
2021            .build();
2022
2023        assert_eq!(tool.name, "with_state_no_params");
2024
2025        // Should work with empty args
2026        let result = tool.call(serde_json::json!({})).await;
2027        assert!(!result.is_error);
2028        assert_eq!(result.first_text().unwrap(), "state: shared_value");
2029
2030        // Check input schema includes type: object
2031        let schema = tool.definition().input_schema;
2032        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
2033    }
2034
2035    #[tokio::test]
2036    async fn test_handler_no_params_with_context() {
2037        let tool = ToolBuilder::new("no_params_with_context")
2038            .description("Takes no parameters but has context")
2039            .extractor_handler((), |_ctx: Context, Json(_): Json<NoParams>| async move {
2040                Ok(CallToolResult::text("context available"))
2041            })
2042            .build();
2043
2044        assert_eq!(tool.name, "no_params_with_context");
2045
2046        let result = tool.call(serde_json::json!({})).await;
2047        assert!(!result.is_error);
2048        assert_eq!(result.first_text().unwrap(), "context available");
2049    }
2050
2051    #[tokio::test]
2052    async fn test_handler_with_state_and_context_no_params() {
2053        let shared = Arc::new("shared".to_string());
2054
2055        let tool = ToolBuilder::new("state_context_no_params")
2056            .description("Has state and context, no params")
2057            .extractor_handler(
2058                shared,
2059                |State(state): State<Arc<String>>,
2060                 _ctx: Context,
2061                 Json(_): Json<NoParams>| async move {
2062                    Ok(CallToolResult::text(format!("state: {}", state)))
2063                },
2064            )
2065            .build();
2066
2067        assert_eq!(tool.name, "state_context_no_params");
2068
2069        let result = tool.call(serde_json::json!({})).await;
2070        assert!(!result.is_error);
2071        assert_eq!(result.first_text().unwrap(), "state: shared");
2072    }
2073
2074    #[tokio::test]
2075    async fn test_raw_handler_with_state() {
2076        let prefix = Arc::new("prefix:".to_string());
2077
2078        let tool = ToolBuilder::new("raw_with_state")
2079            .description("Raw handler with state")
2080            .extractor_handler(
2081                prefix,
2082                |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
2083                    Ok(CallToolResult::text(format!("{} {}", state, args)))
2084                },
2085            )
2086            .build();
2087
2088        assert_eq!(tool.name, "raw_with_state");
2089
2090        let result = tool.call(serde_json::json!({"key": "value"})).await;
2091        assert!(!result.is_error);
2092        assert!(result.first_text().unwrap().starts_with("prefix:"));
2093    }
2094
2095    #[tokio::test]
2096    async fn test_raw_handler_with_state_and_context() {
2097        let prefix = Arc::new("prefix:".to_string());
2098
2099        let tool = ToolBuilder::new("raw_state_context")
2100            .description("Raw handler with state and context")
2101            .extractor_handler(
2102                prefix,
2103                |State(state): State<Arc<String>>,
2104                 _ctx: Context,
2105                 RawArgs(args): RawArgs| async move {
2106                    Ok(CallToolResult::text(format!("{} {}", state, args)))
2107                },
2108            )
2109            .build();
2110
2111        assert_eq!(tool.name, "raw_state_context");
2112
2113        let result = tool.call(serde_json::json!({"key": "value"})).await;
2114        assert!(!result.is_error);
2115        assert!(result.first_text().unwrap().starts_with("prefix:"));
2116    }
2117
2118    #[tokio::test]
2119    async fn test_tool_with_timeout_layer() {
2120        use std::time::Duration;
2121        use tower::timeout::TimeoutLayer;
2122
2123        #[derive(Debug, Deserialize, JsonSchema)]
2124        struct SlowInput {
2125            delay_ms: u64,
2126        }
2127
2128        // Create a tool with a short timeout
2129        let tool = ToolBuilder::new("slow_tool")
2130            .description("A slow tool")
2131            .handler(|input: SlowInput| async move {
2132                tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
2133                Ok(CallToolResult::text("completed"))
2134            })
2135            .layer(TimeoutLayer::new(Duration::from_millis(50)))
2136            .build();
2137
2138        // Fast call should succeed
2139        let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
2140        assert!(!result.is_error);
2141        assert_eq!(result.first_text().unwrap(), "completed");
2142
2143        // Slow call should timeout and return an error result
2144        let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
2145        assert!(result.is_error);
2146        // Tower's timeout error message is "request timed out"
2147        let msg = result.first_text().unwrap().to_lowercase();
2148        assert!(
2149            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2150            "Expected timeout error, got: {}",
2151            msg
2152        );
2153    }
2154
2155    #[tokio::test]
2156    async fn test_tool_with_concurrency_limit_layer() {
2157        use std::sync::atomic::{AtomicU32, Ordering};
2158        use std::time::Duration;
2159        use tower::limit::ConcurrencyLimitLayer;
2160
2161        #[derive(Debug, Deserialize, JsonSchema)]
2162        struct WorkInput {
2163            id: u32,
2164        }
2165
2166        let max_concurrent = Arc::new(AtomicU32::new(0));
2167        let current_concurrent = Arc::new(AtomicU32::new(0));
2168        let max_ref = max_concurrent.clone();
2169        let current_ref = current_concurrent.clone();
2170
2171        // Create a tool with concurrency limit of 2
2172        let tool = ToolBuilder::new("concurrent_tool")
2173            .description("A concurrent tool")
2174            .handler(move |input: WorkInput| {
2175                let max = max_ref.clone();
2176                let current = current_ref.clone();
2177                async move {
2178                    // Track concurrency
2179                    let prev = current.fetch_add(1, Ordering::SeqCst);
2180                    max.fetch_max(prev + 1, Ordering::SeqCst);
2181
2182                    // Simulate work
2183                    tokio::time::sleep(Duration::from_millis(50)).await;
2184
2185                    current.fetch_sub(1, Ordering::SeqCst);
2186                    Ok(CallToolResult::text(format!("completed {}", input.id)))
2187                }
2188            })
2189            .layer(ConcurrencyLimitLayer::new(2))
2190            .build();
2191
2192        // Launch 4 concurrent calls
2193        let handles: Vec<_> = (0..4)
2194            .map(|i| {
2195                let t = tool.call(serde_json::json!({"id": i}));
2196                tokio::spawn(t)
2197            })
2198            .collect();
2199
2200        for handle in handles {
2201            let result = handle.await.unwrap();
2202            assert!(!result.is_error);
2203        }
2204
2205        // Max concurrent should not exceed 2
2206        assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
2207    }
2208
2209    #[tokio::test]
2210    async fn test_tool_with_multiple_layers() {
2211        use std::time::Duration;
2212        use tower::limit::ConcurrencyLimitLayer;
2213        use tower::timeout::TimeoutLayer;
2214
2215        #[derive(Debug, Deserialize, JsonSchema)]
2216        struct Input {
2217            value: String,
2218        }
2219
2220        // Create a tool with multiple layers stacked
2221        let tool = ToolBuilder::new("multi_layer_tool")
2222            .description("Tool with multiple layers")
2223            .handler(|input: Input| async move {
2224                Ok(CallToolResult::text(format!("processed: {}", input.value)))
2225            })
2226            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2227            .layer(ConcurrencyLimitLayer::new(10))
2228            .build();
2229
2230        let result = tool.call(serde_json::json!({"value": "test"})).await;
2231        assert!(!result.is_error);
2232        assert_eq!(result.first_text().unwrap(), "processed: test");
2233    }
2234
2235    #[test]
2236    fn test_tool_catch_error_clone() {
2237        // ToolCatchError should be Clone when inner is Clone
2238        // Use a simple tool that we can clone
2239        let tool = ToolBuilder::new("test")
2240            .description("test")
2241            .extractor_handler((), |RawArgs(_args): RawArgs| async {
2242                Ok(CallToolResult::text("ok"))
2243            })
2244            .build();
2245        // The tool contains a BoxToolService which is cloneable
2246        let _clone = tool.call(serde_json::json!({}));
2247    }
2248
2249    #[test]
2250    fn test_tool_catch_error_debug() {
2251        // ToolCatchError implements Debug when inner implements Debug
2252        // Since our internal services don't require Debug, just verify
2253        // that ToolCatchError has a Debug impl for appropriate types
2254        #[derive(Debug, Clone)]
2255        struct DebugService;
2256
2257        impl Service<ToolRequest> for DebugService {
2258            type Response = CallToolResult;
2259            type Error = crate::error::Error;
2260            type Future = Pin<
2261                Box<
2262                    dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
2263                        + Send,
2264                >,
2265            >;
2266
2267            fn poll_ready(
2268                &mut self,
2269                _cx: &mut std::task::Context<'_>,
2270            ) -> Poll<std::result::Result<(), Self::Error>> {
2271                Poll::Ready(Ok(()))
2272            }
2273
2274            fn call(&mut self, _req: ToolRequest) -> Self::Future {
2275                Box::pin(async { Ok(CallToolResult::text("ok")) })
2276            }
2277        }
2278
2279        let catch_error = ToolCatchError::new(DebugService);
2280        let debug = format!("{:?}", catch_error);
2281        assert!(debug.contains("ToolCatchError"));
2282    }
2283
2284    #[test]
2285    fn test_tool_request_new() {
2286        use crate::protocol::RequestId;
2287
2288        let ctx = RequestContext::new(RequestId::Number(42));
2289        let args = serde_json::json!({"key": "value"});
2290        let req = ToolRequest::new(ctx.clone(), args.clone());
2291
2292        assert_eq!(req.args, args);
2293    }
2294
2295    #[test]
2296    fn test_no_params_schema() {
2297        // NoParams should produce a schema with type: "object"
2298        let schema = schemars::schema_for!(NoParams);
2299        let schema_value = serde_json::to_value(&schema).unwrap();
2300        assert_eq!(
2301            schema_value.get("type").and_then(|v| v.as_str()),
2302            Some("object"),
2303            "NoParams should generate type: object schema"
2304        );
2305    }
2306
2307    #[test]
2308    fn test_no_params_deserialize() {
2309        // NoParams should deserialize from various inputs
2310        let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
2311        assert_eq!(from_empty_object, NoParams);
2312
2313        let from_null: NoParams = serde_json::from_str("null").unwrap();
2314        assert_eq!(from_null, NoParams);
2315
2316        // Should also accept objects with unexpected fields (ignored)
2317        let from_object_with_fields: NoParams =
2318            serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
2319        assert_eq!(from_object_with_fields, NoParams);
2320    }
2321
2322    #[tokio::test]
2323    async fn test_no_params_type_in_handler() {
2324        // NoParams can be used as a handler input type
2325        let tool = ToolBuilder::new("status")
2326            .description("Get status")
2327            .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
2328            .build();
2329
2330        // Check schema has type: object (not type: null like () would produce)
2331        let schema = tool.definition().input_schema;
2332        assert_eq!(
2333            schema.get("type").and_then(|v| v.as_str()),
2334            Some("object"),
2335            "NoParams handler should produce type: object schema"
2336        );
2337
2338        // Should work with empty input
2339        let result = tool.call(serde_json::json!({})).await;
2340        assert!(!result.is_error);
2341    }
2342
2343    #[tokio::test]
2344    async fn test_serde_json_value_handler_has_type_object() {
2345        // serde_json::Value generates a schema without "type" via schemars.
2346        // We must ensure "type": "object" is added for MCP compliance.
2347        let tool = ToolBuilder::new("any_input")
2348            .description("Accepts any input")
2349            .handler(|_input: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
2350            .build();
2351
2352        let schema = tool.definition().input_schema;
2353        assert_eq!(
2354            schema.get("type").and_then(|v| v.as_str()),
2355            Some("object"),
2356            "serde_json::Value handler should produce schema with type: object"
2357        );
2358    }
2359
2360    #[tokio::test]
2361    async fn test_tool_with_name_prefix() {
2362        #[derive(Debug, Deserialize, JsonSchema)]
2363        struct Input {
2364            value: String,
2365        }
2366
2367        let tool = ToolBuilder::new("query")
2368            .description("Query something")
2369            .title("Query Tool")
2370            .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
2371            .build();
2372
2373        // Create prefixed version
2374        let prefixed = tool.with_name_prefix("db");
2375
2376        // Check name is prefixed
2377        assert_eq!(prefixed.name, "db.query");
2378
2379        // Check other fields are preserved
2380        assert_eq!(prefixed.description.as_deref(), Some("Query something"));
2381        assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
2382
2383        // Check the tool still works
2384        let result = prefixed
2385            .call(serde_json::json!({"value": "test input"}))
2386            .await;
2387        assert!(!result.is_error);
2388        match &result.content[0] {
2389            Content::Text { text, .. } => assert_eq!(text, "test input"),
2390            _ => panic!("Expected text content"),
2391        }
2392    }
2393
2394    #[tokio::test]
2395    async fn test_tool_with_name_prefix_multiple_levels() {
2396        let tool = ToolBuilder::new("action")
2397            .description("Do something")
2398            .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
2399            .build();
2400
2401        // Apply multiple prefixes
2402        let prefixed = tool.with_name_prefix("level1");
2403        assert_eq!(prefixed.name, "level1.action");
2404
2405        let double_prefixed = prefixed.with_name_prefix("level0");
2406        assert_eq!(double_prefixed.name, "level0.level1.action");
2407    }
2408
2409    // =============================================================================
2410    // no_params_handler tests
2411    // =============================================================================
2412
2413    #[tokio::test]
2414    async fn test_no_params_handler_basic() {
2415        let tool = ToolBuilder::new("get_status")
2416            .description("Get current status")
2417            .no_params_handler(|| async { Ok(CallToolResult::text("OK")) })
2418            .build();
2419
2420        assert_eq!(tool.name, "get_status");
2421        assert_eq!(tool.description.as_deref(), Some("Get current status"));
2422
2423        // Should work with empty args
2424        let result = tool.call(serde_json::json!({})).await;
2425        assert!(!result.is_error);
2426        assert_eq!(result.first_text().unwrap(), "OK");
2427
2428        // Should also work with null args
2429        let result = tool.call(serde_json::json!(null)).await;
2430        assert!(!result.is_error);
2431
2432        // Check input schema has type: object
2433        let schema = tool.definition().input_schema;
2434        assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
2435    }
2436
2437    #[tokio::test]
2438    async fn test_no_params_handler_with_captured_state() {
2439        let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
2440        let counter_ref = counter.clone();
2441
2442        let tool = ToolBuilder::new("increment")
2443            .description("Increment counter")
2444            .no_params_handler(move || {
2445                let c = counter_ref.clone();
2446                async move {
2447                    let prev = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2448                    Ok(CallToolResult::text(format!("Incremented from {}", prev)))
2449                }
2450            })
2451            .build();
2452
2453        // Call multiple times
2454        let _ = tool.call(serde_json::json!({})).await;
2455        let _ = tool.call(serde_json::json!({})).await;
2456        let result = tool.call(serde_json::json!({})).await;
2457
2458        assert!(!result.is_error);
2459        assert_eq!(result.first_text().unwrap(), "Incremented from 2");
2460        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
2461    }
2462
2463    #[tokio::test]
2464    async fn test_no_params_handler_with_layer() {
2465        use std::time::Duration;
2466        use tower::timeout::TimeoutLayer;
2467
2468        let tool = ToolBuilder::new("slow_status")
2469            .description("Slow status check")
2470            .no_params_handler(|| async {
2471                tokio::time::sleep(Duration::from_millis(10)).await;
2472                Ok(CallToolResult::text("done"))
2473            })
2474            .layer(TimeoutLayer::new(Duration::from_secs(1)))
2475            .build();
2476
2477        let result = tool.call(serde_json::json!({})).await;
2478        assert!(!result.is_error);
2479        assert_eq!(result.first_text().unwrap(), "done");
2480    }
2481
2482    #[tokio::test]
2483    async fn test_no_params_handler_timeout() {
2484        use std::time::Duration;
2485        use tower::timeout::TimeoutLayer;
2486
2487        let tool = ToolBuilder::new("very_slow_status")
2488            .description("Very slow status check")
2489            .no_params_handler(|| async {
2490                tokio::time::sleep(Duration::from_millis(200)).await;
2491                Ok(CallToolResult::text("done"))
2492            })
2493            .layer(TimeoutLayer::new(Duration::from_millis(50)))
2494            .build();
2495
2496        let result = tool.call(serde_json::json!({})).await;
2497        assert!(result.is_error);
2498        let msg = result.first_text().unwrap().to_lowercase();
2499        assert!(
2500            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2501            "Expected timeout error, got: {}",
2502            msg
2503        );
2504    }
2505
2506    #[tokio::test]
2507    async fn test_no_params_handler_with_multiple_layers() {
2508        use std::time::Duration;
2509        use tower::limit::ConcurrencyLimitLayer;
2510        use tower::timeout::TimeoutLayer;
2511
2512        let tool = ToolBuilder::new("multi_layer_status")
2513            .description("Status with multiple layers")
2514            .no_params_handler(|| async { Ok(CallToolResult::text("status ok")) })
2515            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2516            .layer(ConcurrencyLimitLayer::new(10))
2517            .build();
2518
2519        let result = tool.call(serde_json::json!({})).await;
2520        assert!(!result.is_error);
2521        assert_eq!(result.first_text().unwrap(), "status ok");
2522    }
2523
2524    // =========================================================================
2525    // Guard tests
2526    // =========================================================================
2527
2528    #[tokio::test]
2529    async fn test_guard_allows_request() {
2530        #[derive(Debug, Deserialize, JsonSchema)]
2531        #[allow(dead_code)]
2532        struct DeleteInput {
2533            id: String,
2534            confirm: bool,
2535        }
2536
2537        let tool = ToolBuilder::new("delete")
2538            .description("Delete a record")
2539            .handler(|input: DeleteInput| async move {
2540                Ok(CallToolResult::text(format!("deleted {}", input.id)))
2541            })
2542            .guard(|req: &ToolRequest| {
2543                let confirm = req
2544                    .args
2545                    .get("confirm")
2546                    .and_then(|v| v.as_bool())
2547                    .unwrap_or(false);
2548                if !confirm {
2549                    return Err("Must set confirm=true to delete".to_string());
2550                }
2551                Ok(())
2552            })
2553            .build();
2554
2555        let result = tool
2556            .call(serde_json::json!({"id": "abc", "confirm": true}))
2557            .await;
2558        assert!(!result.is_error);
2559        assert_eq!(result.first_text().unwrap(), "deleted abc");
2560    }
2561
2562    #[tokio::test]
2563    async fn test_guard_rejects_request() {
2564        #[derive(Debug, Deserialize, JsonSchema)]
2565        #[allow(dead_code)]
2566        struct DeleteInput2 {
2567            id: String,
2568            confirm: bool,
2569        }
2570
2571        let tool = ToolBuilder::new("delete2")
2572            .description("Delete a record")
2573            .handler(|input: DeleteInput2| async move {
2574                Ok(CallToolResult::text(format!("deleted {}", input.id)))
2575            })
2576            .guard(|req: &ToolRequest| {
2577                let confirm = req
2578                    .args
2579                    .get("confirm")
2580                    .and_then(|v| v.as_bool())
2581                    .unwrap_or(false);
2582                if !confirm {
2583                    return Err("Must set confirm=true to delete".to_string());
2584                }
2585                Ok(())
2586            })
2587            .build();
2588
2589        let result = tool
2590            .call(serde_json::json!({"id": "abc", "confirm": false}))
2591            .await;
2592        assert!(result.is_error);
2593        assert!(
2594            result
2595                .first_text()
2596                .unwrap()
2597                .contains("Must set confirm=true")
2598        );
2599    }
2600
2601    #[tokio::test]
2602    async fn test_guard_with_layer() {
2603        use std::time::Duration;
2604        use tower::timeout::TimeoutLayer;
2605
2606        let tool = ToolBuilder::new("guarded_timeout")
2607            .description("Guarded with timeout")
2608            .handler(|input: GreetInput| async move {
2609                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
2610            })
2611            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2612            .guard(|_req: &ToolRequest| Ok(()))
2613            .build();
2614
2615        let result = tool.call(serde_json::json!({"name": "World"})).await;
2616        assert!(!result.is_error);
2617        assert_eq!(result.first_text().unwrap(), "Hello, World!");
2618    }
2619
2620    #[tokio::test]
2621    async fn test_guard_on_no_params_handler() {
2622        let allowed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true));
2623        let allowed_clone = allowed.clone();
2624
2625        let tool = ToolBuilder::new("status")
2626            .description("Get status")
2627            .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2628            .guard(move |_req: &ToolRequest| {
2629                if allowed_clone.load(std::sync::atomic::Ordering::Relaxed) {
2630                    Ok(())
2631                } else {
2632                    Err("Access denied".to_string())
2633                }
2634            })
2635            .build();
2636
2637        // Allowed
2638        let result = tool.call(serde_json::json!({})).await;
2639        assert!(!result.is_error);
2640        assert_eq!(result.first_text().unwrap(), "ok");
2641
2642        // Denied
2643        allowed.store(false, std::sync::atomic::Ordering::Relaxed);
2644        let result = tool.call(serde_json::json!({})).await;
2645        assert!(result.is_error);
2646        assert!(result.first_text().unwrap().contains("Access denied"));
2647    }
2648
2649    #[tokio::test]
2650    async fn test_guard_on_no_params_handler_with_layer() {
2651        use std::time::Duration;
2652        use tower::timeout::TimeoutLayer;
2653
2654        let tool = ToolBuilder::new("status_layered")
2655            .description("Get status with layers")
2656            .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2657            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2658            .guard(|_req: &ToolRequest| Ok(()))
2659            .build();
2660
2661        let result = tool.call(serde_json::json!({})).await;
2662        assert!(!result.is_error);
2663        assert_eq!(result.first_text().unwrap(), "ok");
2664    }
2665
2666    #[tokio::test]
2667    async fn test_guard_on_extractor_handler() {
2668        use std::sync::Arc;
2669
2670        #[derive(Clone)]
2671        struct AppState {
2672            prefix: String,
2673        }
2674
2675        #[derive(Debug, Deserialize, JsonSchema)]
2676        struct QueryInput {
2677            query: String,
2678        }
2679
2680        let state = Arc::new(AppState {
2681            prefix: "db".to_string(),
2682        });
2683
2684        let tool = ToolBuilder::new("search")
2685            .description("Search")
2686            .extractor_handler(
2687                state,
2688                |State(app): State<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
2689                    Ok(CallToolResult::text(format!(
2690                        "{}: {}",
2691                        app.prefix, input.query
2692                    )))
2693                },
2694            )
2695            .guard(|req: &ToolRequest| {
2696                let query = req.args.get("query").and_then(|v| v.as_str()).unwrap_or("");
2697                if query.is_empty() {
2698                    return Err("Query cannot be empty".to_string());
2699                }
2700                Ok(())
2701            })
2702            .build();
2703
2704        // Valid query
2705        let result = tool.call(serde_json::json!({"query": "hello"})).await;
2706        assert!(!result.is_error);
2707        assert_eq!(result.first_text().unwrap(), "db: hello");
2708
2709        // Empty query rejected by guard
2710        let result = tool.call(serde_json::json!({"query": ""})).await;
2711        assert!(result.is_error);
2712        assert!(
2713            result
2714                .first_text()
2715                .unwrap()
2716                .contains("Query cannot be empty")
2717        );
2718    }
2719
2720    #[tokio::test]
2721    async fn test_guard_on_extractor_handler_with_layer() {
2722        use std::sync::Arc;
2723        use std::time::Duration;
2724        use tower::timeout::TimeoutLayer;
2725
2726        #[derive(Clone)]
2727        struct AppState2 {
2728            prefix: String,
2729        }
2730
2731        #[derive(Debug, Deserialize, JsonSchema)]
2732        struct QueryInput2 {
2733            query: String,
2734        }
2735
2736        let state = Arc::new(AppState2 {
2737            prefix: "db".to_string(),
2738        });
2739
2740        let tool = ToolBuilder::new("search2")
2741            .description("Search with layer and guard")
2742            .extractor_handler(
2743                state,
2744                |State(app): State<Arc<AppState2>>, Json(input): Json<QueryInput2>| async move {
2745                    Ok(CallToolResult::text(format!(
2746                        "{}: {}",
2747                        app.prefix, input.query
2748                    )))
2749                },
2750            )
2751            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2752            .guard(|_req: &ToolRequest| Ok(()))
2753            .build();
2754
2755        let result = tool.call(serde_json::json!({"query": "hello"})).await;
2756        assert!(!result.is_error);
2757        assert_eq!(result.first_text().unwrap(), "db: hello");
2758    }
2759
2760    #[tokio::test]
2761    async fn test_tool_with_guard_post_build() {
2762        let tool = ToolBuilder::new("admin_action")
2763            .description("Admin action")
2764            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2765            .build();
2766
2767        // Apply guard after building
2768        let guarded = tool.with_guard(|req: &ToolRequest| {
2769            let name = req.args.get("name").and_then(|v| v.as_str()).unwrap_or("");
2770            if name == "admin" {
2771                Ok(())
2772            } else {
2773                Err("Only admin allowed".to_string())
2774            }
2775        });
2776
2777        // Admin passes
2778        let result = guarded.call(serde_json::json!({"name": "admin"})).await;
2779        assert!(!result.is_error);
2780
2781        // Non-admin blocked
2782        let result = guarded.call(serde_json::json!({"name": "user"})).await;
2783        assert!(result.is_error);
2784        assert!(result.first_text().unwrap().contains("Only admin allowed"));
2785    }
2786
2787    #[tokio::test]
2788    async fn test_with_guard_preserves_tool_metadata() {
2789        let tool = ToolBuilder::new("my_tool")
2790            .description("A tool")
2791            .title("My Tool")
2792            .read_only()
2793            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2794            .build();
2795
2796        let guarded = tool.with_guard(|_req: &ToolRequest| Ok(()));
2797
2798        assert_eq!(guarded.name, "my_tool");
2799        assert_eq!(guarded.description.as_deref(), Some("A tool"));
2800        assert_eq!(guarded.title.as_deref(), Some("My Tool"));
2801        assert!(guarded.annotations.is_some());
2802    }
2803
2804    #[tokio::test]
2805    async fn test_guard_group_pattern() {
2806        // Demonstrate applying the same guard to multiple tools (per-group pattern)
2807        let require_auth = |req: &ToolRequest| {
2808            let token = req
2809                .args
2810                .get("_token")
2811                .and_then(|v| v.as_str())
2812                .unwrap_or("");
2813            if token == "valid" {
2814                Ok(())
2815            } else {
2816                Err("Authentication required".to_string())
2817            }
2818        };
2819
2820        let tool1 = ToolBuilder::new("action1")
2821            .description("Action 1")
2822            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action1")) })
2823            .build();
2824        let tool2 = ToolBuilder::new("action2")
2825            .description("Action 2")
2826            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action2")) })
2827            .build();
2828
2829        // Apply same guard to both
2830        let guarded1 = tool1.with_guard(require_auth);
2831        let guarded2 = tool2.with_guard(require_auth);
2832
2833        // Without auth
2834        let r1 = guarded1
2835            .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2836            .await;
2837        let r2 = guarded2
2838            .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2839            .await;
2840        assert!(r1.is_error);
2841        assert!(r2.is_error);
2842
2843        // With auth
2844        let r1 = guarded1
2845            .call(serde_json::json!({"name": "test", "_token": "valid"}))
2846            .await;
2847        let r2 = guarded2
2848            .call(serde_json::json!({"name": "test", "_token": "valid"}))
2849            .await;
2850        assert!(!r1.is_error);
2851        assert!(!r2.is_error);
2852    }
2853
2854    #[tokio::test]
2855    async fn test_input_validation_returns_tool_error() {
2856        // Per SEP-1303: input validation errors should be returned as
2857        // CallToolResult with isError=true, not as protocol errors.
2858        #[derive(Debug, Deserialize, JsonSchema)]
2859        struct StrictInput {
2860            name: String,
2861            count: u32,
2862        }
2863
2864        let tool = ToolBuilder::new("strict_tool")
2865            .description("requires specific input")
2866            .handler(|input: StrictInput| async move {
2867                Ok(CallToolResult::text(format!(
2868                    "{}: {}",
2869                    input.name, input.count
2870                )))
2871            })
2872            .build();
2873
2874        // Valid input works
2875        let result = tool
2876            .call(serde_json::json!({"name": "test", "count": 5}))
2877            .await;
2878        assert!(!result.is_error);
2879
2880        // Missing required field returns isError, not protocol error
2881        let result = tool.call(serde_json::json!({"name": "test"})).await;
2882        assert!(result.is_error);
2883        let text = result.first_text().unwrap();
2884        assert!(text.contains("Invalid input"), "got: {text}");
2885
2886        // Wrong type returns isError, not protocol error
2887        let result = tool
2888            .call(serde_json::json!({"name": "test", "count": "not_a_number"}))
2889            .await;
2890        assert!(result.is_error);
2891        let text = result.first_text().unwrap();
2892        assert!(text.contains("Invalid input"), "got: {text}");
2893    }
2894}