Skip to main content

tower_mcp/
extract.rs

1//! Extractor pattern for tool handlers
2//!
3//! This module provides an axum-inspired extractor pattern that makes state and context
4//! injection more declarative, reducing the combinatorial explosion of handler variants.
5//!
6//! # Overview
7//!
8//! Extractors implement [`FromToolRequest`], which extracts data from the tool request
9//! (context, state, and arguments). Multiple extractors can be combined in handler
10//! function parameters.
11//!
12//! # Built-in Extractors
13//!
14//! - [`Json<T>`] - Extract typed input from args (deserializes JSON)
15//! - [`State<T>`] - Extract shared state from per-tool state (cloned for each request)
16//! - [`Extension<T>`] - Extract data from router extensions (via `router.with_state()`)
17//! - [`Context`] - Extract the [`RequestContext`] for progress, cancellation, etc.
18//! - [`RawArgs`] - Extract raw `serde_json::Value` arguments
19//!
20//! ## State vs Extension
21//!
22//! - Use **`State<T>`** when state is passed directly to `extractor_handler()` (per-tool state)
23//! - Use **`Extension<T>`** when state is set via `McpRouter::with_state()` (router-level state)
24//!
25//! # Example
26//!
27//! ```rust
28//! use std::sync::Arc;
29//! use tower_mcp::{ToolBuilder, CallToolResult};
30//! use tower_mcp::extract::{Json, State, Context};
31//! use schemars::JsonSchema;
32//! use serde::Deserialize;
33//!
34//! #[derive(Clone)]
35//! struct AppState {
36//!     db_url: String,
37//! }
38//!
39//! #[derive(Debug, Deserialize, JsonSchema)]
40//! struct QueryInput {
41//!     query: String,
42//! }
43//!
44//! let state = Arc::new(AppState { db_url: "postgres://...".to_string() });
45//!
46//! let tool = ToolBuilder::new("search")
47//!     .description("Search the database")
48//!     .extractor_handler(state, |
49//!         State(db): State<Arc<AppState>>,
50//!         ctx: Context,
51//!         Json(input): Json<QueryInput>,
52//!     | async move {
53//!         // Check cancellation
54//!         if ctx.is_cancelled() {
55//!             return Ok(CallToolResult::error("Cancelled"));
56//!         }
57//!         // Report progress
58//!         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
59//!         // Use state
60//!         Ok(CallToolResult::text(format!("Searched {} with query: {}", db.db_url, input.query)))
61//!     })
62//!     .build();
63//! ```
64//!
65//! # Extractor Order
66//!
67//! The order of extractors in the function signature doesn't matter. Each extractor
68//! independently extracts its data from the request.
69//!
70//! # Error Handling
71//!
72//! If an extractor fails (e.g., JSON deserialization fails), the handler returns
73//! a `CallToolResult::error()` with the rejection message.
74
75use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88// =============================================================================
89// Rejection Types
90// =============================================================================
91
92/// A simple rejection with a message string.
93///
94/// This is a general-purpose rejection type for custom extractors.
95/// For more specific error information, use the typed rejection types
96/// like [`JsonRejection`] or [`ExtensionRejection`].
97#[derive(Debug, Clone)]
98pub struct Rejection {
99    message: String,
100}
101
102impl Rejection {
103    /// Create a new rejection with the given message.
104    pub fn new(message: impl Into<String>) -> Self {
105        Self {
106            message: message.into(),
107        }
108    }
109
110    /// Get the rejection message.
111    pub fn message(&self) -> &str {
112        &self.message
113    }
114}
115
116impl std::fmt::Display for Rejection {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}", self.message)
119    }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125    fn from(rejection: Rejection) -> Self {
126        Error::tool(rejection.message)
127    }
128}
129
130/// Rejection returned when JSON deserialization fails.
131///
132/// This rejection provides structured information about the deserialization
133/// error, including the path to the failing field when available.
134///
135/// # Example
136///
137/// ```rust
138/// use tower_mcp::extract::JsonRejection;
139///
140/// let rejection = JsonRejection::new("missing field `name`");
141/// assert!(rejection.message().contains("name"));
142/// ```
143#[derive(Debug, Clone)]
144pub struct JsonRejection {
145    message: String,
146    /// The serde error path, if available (e.g., "users[0].name")
147    path: Option<String>,
148}
149
150impl JsonRejection {
151    /// Create a new JSON rejection from a serde error.
152    pub fn new(message: impl Into<String>) -> Self {
153        Self {
154            message: message.into(),
155            path: None,
156        }
157    }
158
159    /// Create a JSON rejection with a path to the failing field.
160    pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161        Self {
162            message: message.into(),
163            path: Some(path.into()),
164        }
165    }
166
167    /// Get the error message.
168    pub fn message(&self) -> &str {
169        &self.message
170    }
171
172    /// Get the path to the failing field, if available.
173    pub fn path(&self) -> Option<&str> {
174        self.path.as_deref()
175    }
176}
177
178impl std::fmt::Display for JsonRejection {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        if let Some(path) = &self.path {
181            write!(f, "Invalid input at `{}`: {}", path, self.message)
182        } else {
183            write!(f, "Invalid input: {}", self.message)
184        }
185    }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191    fn from(rejection: JsonRejection) -> Self {
192        Error::tool(rejection.to_string())
193    }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197    fn from(err: serde_json::Error) -> Self {
198        // Try to extract path information from serde error
199        let path = if err.is_data() {
200            // serde_json provides line/column but not field path in the error itself
201            // The path is embedded in the message for some error types
202            None
203        } else {
204            None
205        };
206
207        Self {
208            message: err.to_string(),
209            path,
210        }
211    }
212}
213
214/// Rejection returned when an extension is not found.
215///
216/// This rejection is returned by the [`Extension`] extractor when the
217/// requested type is not present in the router's extensions.
218///
219/// # Example
220///
221/// ```rust
222/// use tower_mcp::extract::ExtensionRejection;
223///
224/// let rejection = ExtensionRejection::not_found::<String>();
225/// assert!(rejection.type_name().contains("String"));
226/// ```
227#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229    type_name: &'static str,
230}
231
232impl ExtensionRejection {
233    /// Create a rejection for a missing extension type.
234    pub fn not_found<T>() -> Self {
235        Self {
236            type_name: std::any::type_name::<T>(),
237        }
238    }
239
240    /// Get the type name of the missing extension.
241    pub fn type_name(&self) -> &'static str {
242        self.type_name
243    }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(
249            f,
250            "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251            self.type_name
252        )
253    }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259    fn from(rejection: ExtensionRejection) -> Self {
260        Error::tool(rejection.to_string())
261    }
262}
263
264/// Trait for extracting data from a tool request.
265///
266/// Implement this trait to create custom extractors that can be used
267/// in `extractor_handler` functions.
268///
269/// # Type Parameters
270///
271/// - `S` - The state type. Defaults to `()` for extractors that don't need state.
272///
273/// # Example
274///
275/// ```rust
276/// use tower_mcp::extract::{FromToolRequest, Rejection};
277/// use tower_mcp::RequestContext;
278/// use serde_json::Value;
279///
280/// struct RequestId(String);
281///
282/// impl<S> FromToolRequest<S> for RequestId {
283///     type Rejection = Rejection;
284///
285///     fn from_tool_request(
286///         ctx: &RequestContext,
287///         _state: &S,
288///         _args: &Value,
289///     ) -> Result<Self, Self::Rejection> {
290///         Ok(RequestId(format!("{:?}", ctx.request_id())))
291///     }
292/// }
293/// ```
294pub trait FromToolRequest<S = ()>: Sized {
295    /// The rejection type returned when extraction fails.
296    type Rejection: Into<Error>;
297
298    /// Extract this type from the tool request.
299    ///
300    /// # Arguments
301    ///
302    /// * `ctx` - The request context with progress, cancellation, etc.
303    /// * `state` - The shared state passed to the handler
304    /// * `args` - The raw JSON arguments to the tool
305    fn from_tool_request(
306        ctx: &RequestContext,
307        state: &S,
308        args: &Value,
309    ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312// =============================================================================
313// Built-in Extractors
314// =============================================================================
315
316/// Extract and deserialize JSON arguments into a typed struct.
317///
318/// This extractor deserializes the tool's JSON arguments into type `T`.
319/// The type must implement [`serde::de::DeserializeOwned`] and [`schemars::JsonSchema`].
320///
321/// # Example
322///
323/// ```rust
324/// use tower_mcp::extract::Json;
325/// use schemars::JsonSchema;
326/// use serde::Deserialize;
327///
328/// #[derive(Debug, Deserialize, JsonSchema)]
329/// struct MyInput {
330///     name: String,
331///     count: i32,
332/// }
333///
334/// // In an extractor handler:
335/// // |Json(input): Json<MyInput>| async move { ... }
336/// ```
337///
338/// # Rejection
339///
340/// Returns a [`JsonRejection`] if deserialization fails. The rejection contains
341/// the error message and potentially the path to the failing field.
342#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346    type Target = T;
347
348    fn deref(&self) -> &Self::Target {
349        &self.0
350    }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355    T: DeserializeOwned,
356{
357    type Rejection = JsonRejection;
358
359    fn from_tool_request(
360        _ctx: &RequestContext,
361        _state: &S,
362        args: &Value,
363    ) -> std::result::Result<Self, Self::Rejection> {
364        serde_json::from_value(args.clone())
365            .map(Json)
366            .map_err(JsonRejection::from)
367    }
368}
369
370/// Extract shared state.
371///
372/// This extractor clones the state passed to `extractor_handler` and provides
373/// it to the handler. The state type must match the type passed to the builder.
374///
375/// # Example
376///
377/// ```rust
378/// use std::sync::Arc;
379/// use tower_mcp::extract::State;
380///
381/// #[derive(Clone)]
382/// struct AppState {
383///     db_url: String,
384/// }
385///
386/// // In an extractor handler:
387/// // |State(state): State<Arc<AppState>>| async move { ... }
388/// ```
389///
390/// # Note
391///
392/// For expensive-to-clone types, wrap them in `Arc` before passing to
393/// `extractor_handler`.
394#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398    type Target = T;
399
400    fn deref(&self) -> &Self::Target {
401        &self.0
402    }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406    type Rejection = Rejection;
407
408    fn from_tool_request(
409        _ctx: &RequestContext,
410        state: &S,
411        _args: &Value,
412    ) -> std::result::Result<Self, Self::Rejection> {
413        Ok(State(state.clone()))
414    }
415}
416
417/// Extract the request context.
418///
419/// This extractor provides access to the [`RequestContext`], which contains:
420/// - Progress reporting via `report_progress()`
421/// - Cancellation checking via `is_cancelled()`
422/// - Sampling capabilities via `sample()`
423/// - Elicitation capabilities via `elicit_form()` and `elicit_url()`
424/// - Log sending via `send_log()`
425///
426/// # Example
427///
428/// ```rust
429/// use tower_mcp::extract::Context;
430///
431/// // In an extractor handler:
432/// // |ctx: Context| async move {
433/// //     ctx.report_progress(0.5, Some(1.0), Some("Working...")).await;
434/// //     // ...
435/// // }
436/// ```
437#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441    /// Get the inner RequestContext
442    pub fn into_inner(self) -> RequestContext {
443        self.0
444    }
445}
446
447impl Deref for Context {
448    type Target = RequestContext;
449
450    fn deref(&self) -> &Self::Target {
451        &self.0
452    }
453}
454
455impl<S> FromToolRequest<S> for Context {
456    type Rejection = Rejection;
457
458    fn from_tool_request(
459        ctx: &RequestContext,
460        _state: &S,
461        _args: &Value,
462    ) -> std::result::Result<Self, Self::Rejection> {
463        Ok(Context(ctx.clone()))
464    }
465}
466
467/// Extract raw JSON arguments.
468///
469/// This extractor provides the raw `serde_json::Value` arguments without
470/// any deserialization. Useful when you need full control over argument
471/// parsing or when the schema is dynamic.
472///
473/// # Example
474///
475/// ```rust
476/// use tower_mcp::extract::RawArgs;
477///
478/// // In an extractor handler:
479/// // |RawArgs(args): RawArgs| async move {
480/// //     // args is serde_json::Value
481/// //     if let Some(name) = args.get("name") { ... }
482/// // }
483/// ```
484#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488    type Target = Value;
489
490    fn deref(&self) -> &Self::Target {
491        &self.0
492    }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496    type Rejection = Rejection;
497
498    fn from_tool_request(
499        _ctx: &RequestContext,
500        _state: &S,
501        args: &Value,
502    ) -> std::result::Result<Self, Self::Rejection> {
503        Ok(RawArgs(args.clone()))
504    }
505}
506
507/// Extract typed data from router extensions.
508///
509/// This extractor retrieves data that was added to the router via
510/// [`crate::McpRouter::with_state()`] or [`crate::McpRouter::with_extension()`], or
511/// inserted by middleware into the request context's extensions.
512///
513/// # Example
514///
515/// ```rust
516/// use std::sync::Arc;
517/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
518/// use tower_mcp::extract::{Extension, Json};
519/// use schemars::JsonSchema;
520/// use serde::Deserialize;
521///
522/// #[derive(Clone)]
523/// struct DatabasePool {
524///     url: String,
525/// }
526///
527/// #[derive(Deserialize, JsonSchema)]
528/// struct QueryInput {
529///     sql: String,
530/// }
531///
532/// let pool = Arc::new(DatabasePool { url: "postgres://...".into() });
533///
534/// let tool = ToolBuilder::new("query")
535///     .description("Run a query")
536///     .extractor_handler(
537///         (),
538///         |Extension(db): Extension<Arc<DatabasePool>>, Json(input): Json<QueryInput>| async move {
539///             Ok(CallToolResult::text(format!("Query on {}: {}", db.url, input.sql)))
540///         },
541///     )
542///     .build();
543///
544/// let router = McpRouter::new()
545///     .with_state(pool)
546///     .tool(tool);
547/// ```
548///
549/// # Rejection
550///
551/// Returns an [`ExtensionRejection`] if the requested type is not found in the extensions.
552/// The rejection contains the type name of the missing extension.
553#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557    type Target = T;
558
559    fn deref(&self) -> &Self::Target {
560        &self.0
561    }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566    T: Clone + Send + Sync + 'static,
567{
568    type Rejection = ExtensionRejection;
569
570    fn from_tool_request(
571        ctx: &RequestContext,
572        _state: &S,
573        _args: &Value,
574    ) -> std::result::Result<Self, Self::Rejection> {
575        ctx.extension::<T>()
576            .cloned()
577            .map(Extension)
578            .ok_or_else(ExtensionRejection::not_found::<T>)
579    }
580}
581
582// =============================================================================
583// Handler Trait
584// =============================================================================
585
586/// A handler that uses extractors.
587///
588/// This trait is implemented for functions that take extractors as arguments.
589/// You don't need to implement this trait directly; it's automatically
590/// implemented for compatible async functions.
591pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592    /// The future returned by the handler.
593    type Future: Future<Output = Result<CallToolResult>> + Send;
594
595    /// Call the handler with extracted values.
596    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598    /// Get the input schema for this handler.
599    ///
600    /// Returns `None` if no `Json<T>` extractor is used.
601    fn input_schema() -> Value;
602}
603
604// Implementation for single extractor
605impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607    S: Clone + Send + Sync + 'static,
608    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609    Fut: Future<Output = Result<CallToolResult>> + Send,
610    T1: FromToolRequest<S> + HasSchema + Send,
611{
612    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615        Box::pin(async move {
616            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617            self(t1).await
618        })
619    }
620
621    fn input_schema() -> Value {
622        if let Some(schema) = T1::schema() {
623            return schema;
624        }
625        serde_json::json!({
626            "type": "object",
627            "additionalProperties": true
628        })
629    }
630}
631
632// Implementation for two extractors
633impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635    S: Clone + Send + Sync + 'static,
636    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637    Fut: Future<Output = Result<CallToolResult>> + Send,
638    T1: FromToolRequest<S> + HasSchema + Send,
639    T2: FromToolRequest<S> + HasSchema + Send,
640{
641    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644        Box::pin(async move {
645            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647            self(t1, t2).await
648        })
649    }
650
651    fn input_schema() -> Value {
652        if let Some(schema) = T2::schema() {
653            return schema;
654        }
655        if let Some(schema) = T1::schema() {
656            return schema;
657        }
658        serde_json::json!({
659            "type": "object",
660            "additionalProperties": true
661        })
662    }
663}
664
665// Implementation for three extractors
666impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668    S: Clone + Send + Sync + 'static,
669    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670    Fut: Future<Output = Result<CallToolResult>> + Send,
671    T1: FromToolRequest<S> + HasSchema + Send,
672    T2: FromToolRequest<S> + HasSchema + Send,
673    T3: FromToolRequest<S> + HasSchema + Send,
674{
675    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678        Box::pin(async move {
679            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682            self(t1, t2, t3).await
683        })
684    }
685
686    fn input_schema() -> Value {
687        if let Some(schema) = T3::schema() {
688            return schema;
689        }
690        if let Some(schema) = T2::schema() {
691            return schema;
692        }
693        if let Some(schema) = T1::schema() {
694            return schema;
695        }
696        serde_json::json!({
697            "type": "object",
698            "additionalProperties": true
699        })
700    }
701}
702
703// Implementation for four extractors
704impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706    S: Clone + Send + Sync + 'static,
707    F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708    Fut: Future<Output = Result<CallToolResult>> + Send,
709    T1: FromToolRequest<S> + HasSchema + Send,
710    T2: FromToolRequest<S> + HasSchema + Send,
711    T3: FromToolRequest<S> + HasSchema + Send,
712    T4: FromToolRequest<S> + HasSchema + Send,
713{
714    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717        Box::pin(async move {
718            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722            self(t1, t2, t3, t4).await
723        })
724    }
725
726    fn input_schema() -> Value {
727        if let Some(schema) = T4::schema() {
728            return schema;
729        }
730        if let Some(schema) = T3::schema() {
731            return schema;
732        }
733        if let Some(schema) = T2::schema() {
734            return schema;
735        }
736        if let Some(schema) = T1::schema() {
737            return schema;
738        }
739        serde_json::json!({
740            "type": "object",
741            "additionalProperties": true
742        })
743    }
744}
745
746// Implementation for five extractors
747impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749    S: Clone + Send + Sync + 'static,
750    F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751    Fut: Future<Output = Result<CallToolResult>> + Send,
752    T1: FromToolRequest<S> + HasSchema + Send,
753    T2: FromToolRequest<S> + HasSchema + Send,
754    T3: FromToolRequest<S> + HasSchema + Send,
755    T4: FromToolRequest<S> + HasSchema + Send,
756    T5: FromToolRequest<S> + HasSchema + Send,
757{
758    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761        Box::pin(async move {
762            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766            let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767            self(t1, t2, t3, t4, t5).await
768        })
769    }
770
771    fn input_schema() -> Value {
772        if let Some(schema) = T5::schema() {
773            return schema;
774        }
775        if let Some(schema) = T4::schema() {
776            return schema;
777        }
778        if let Some(schema) = T3::schema() {
779            return schema;
780        }
781        if let Some(schema) = T2::schema() {
782            return schema;
783        }
784        if let Some(schema) = T1::schema() {
785            return schema;
786        }
787        serde_json::json!({
788            "type": "object",
789            "additionalProperties": true
790        })
791    }
792}
793
794// =============================================================================
795// Schema Extraction Helper
796// =============================================================================
797
798/// Helper trait to get schema from `Json<T>` extractor
799pub trait HasSchema {
800    fn schema() -> Option<Value>;
801}
802
803impl<T: JsonSchema> HasSchema for Json<T> {
804    fn schema() -> Option<Value> {
805        let schema = schemars::schema_for!(T);
806        serde_json::to_value(schema).ok()
807    }
808}
809
810// Default impl for non-Json extractors
811impl HasSchema for Context {
812    fn schema() -> Option<Value> {
813        None
814    }
815}
816
817impl HasSchema for RawArgs {
818    fn schema() -> Option<Value> {
819        None
820    }
821}
822
823impl<T> HasSchema for State<T> {
824    fn schema() -> Option<Value> {
825        None
826    }
827}
828
829impl<T> HasSchema for Extension<T> {
830    fn schema() -> Option<Value> {
831        None
832    }
833}
834
835// =============================================================================
836// Typed Extractor Handler
837// =============================================================================
838
839/// A handler that uses extractors with typed JSON input.
840///
841/// This trait is similar to [`ExtractorHandler`] but provides proper JSON
842/// schema generation for the input type when `Json<T>` is used.
843pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
844where
845    I: JsonSchema,
846{
847    /// The future returned by the handler.
848    type Future: Future<Output = Result<CallToolResult>> + Send;
849
850    /// Call the handler with extracted values.
851    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
852}
853
854// Single extractor with Json<T>
855impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
856where
857    S: Clone + Send + Sync + 'static,
858    F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
859    Fut: Future<Output = Result<CallToolResult>> + Send,
860    T: DeserializeOwned + JsonSchema + Send,
861{
862    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
863
864    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
865        Box::pin(async move {
866            let t1 =
867                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
868            self(t1).await
869        })
870    }
871}
872
873// Two extractors ending with Json<T>
874impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
875where
876    S: Clone + Send + Sync + 'static,
877    F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
878    Fut: Future<Output = Result<CallToolResult>> + Send,
879    T1: FromToolRequest<S> + Send,
880    T: DeserializeOwned + JsonSchema + Send,
881{
882    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
883
884    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
885        Box::pin(async move {
886            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
887            let t2 =
888                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
889            self(t1, t2).await
890        })
891    }
892}
893
894// Three extractors ending with Json<T>
895impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
896where
897    S: Clone + Send + Sync + 'static,
898    F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
899    Fut: Future<Output = Result<CallToolResult>> + Send,
900    T1: FromToolRequest<S> + Send,
901    T2: FromToolRequest<S> + Send,
902    T: DeserializeOwned + JsonSchema + Send,
903{
904    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
905
906    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
907        Box::pin(async move {
908            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
909            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
910            let t3 =
911                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
912            self(t1, t2, t3).await
913        })
914    }
915}
916
917// Four extractors ending with Json<T>
918impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
919where
920    S: Clone + Send + Sync + 'static,
921    F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
922    Fut: Future<Output = Result<CallToolResult>> + Send,
923    T1: FromToolRequest<S> + Send,
924    T2: FromToolRequest<S> + Send,
925    T3: FromToolRequest<S> + Send,
926    T: DeserializeOwned + JsonSchema + Send,
927{
928    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
929
930    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
931        Box::pin(async move {
932            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
933            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
934            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
935            let t4 =
936                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
937            self(t1, t2, t3, t4).await
938        })
939    }
940}
941
942// =============================================================================
943// ToolBuilder Extensions
944// =============================================================================
945
946use crate::tool::{
947    BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
948};
949use tower::util::BoxCloneService;
950use tower_service::Service;
951
952/// Internal handler wrapper for extractor-based handlers
953pub(crate) struct ExtractorToolHandler<S, F, T> {
954    state: S,
955    handler: F,
956    input_schema: Value,
957    _phantom: PhantomData<T>,
958}
959
960impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
961where
962    S: Clone + Send + Sync + 'static,
963    F: ExtractorHandler<S, T> + Clone,
964    T: Send + Sync + 'static,
965{
966    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
967        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
968        self.call_with_context(ctx, args)
969    }
970
971    fn call_with_context(
972        &self,
973        ctx: RequestContext,
974        args: Value,
975    ) -> BoxFuture<'_, Result<CallToolResult>> {
976        let state = self.state.clone();
977        let handler = self.handler.clone();
978        Box::pin(async move { handler.call(ctx, state, args).await })
979    }
980
981    fn uses_context(&self) -> bool {
982        true
983    }
984
985    fn input_schema(&self) -> Value {
986        self.input_schema.clone()
987    }
988}
989
990/// Builder state for extractor-based handlers
991pub struct ToolBuilderWithExtractor<S, F, T> {
992    pub(crate) name: String,
993    pub(crate) title: Option<String>,
994    pub(crate) description: Option<String>,
995    pub(crate) output_schema: Option<Value>,
996    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
997    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
998    pub(crate) task_support: crate::protocol::TaskSupportMode,
999    pub(crate) state: S,
1000    pub(crate) handler: F,
1001    pub(crate) input_schema: Value,
1002    pub(crate) _phantom: PhantomData<T>,
1003}
1004
1005impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1006where
1007    S: Clone + Send + Sync + 'static,
1008    F: ExtractorHandler<S, T> + Clone,
1009    T: Send + Sync + 'static,
1010{
1011    /// Build the tool.
1012    pub fn build(self) -> Tool {
1013        let handler = ExtractorToolHandler {
1014            state: self.state,
1015            handler: self.handler,
1016            input_schema: self.input_schema.clone(),
1017            _phantom: PhantomData,
1018        };
1019
1020        let handler_service = ToolHandlerService::new(handler);
1021        let catch_error = ToolCatchError::new(handler_service);
1022        let service = BoxCloneService::new(catch_error);
1023
1024        Tool {
1025            name: self.name,
1026            title: self.title,
1027            description: self.description,
1028            output_schema: self.output_schema,
1029            icons: self.icons,
1030            annotations: self.annotations,
1031            task_support: self.task_support,
1032            service,
1033            input_schema: self.input_schema,
1034        }
1035    }
1036
1037    /// Apply a Tower layer (middleware) to this tool.
1038    ///
1039    /// The layer wraps the tool's handler service, enabling functionality like
1040    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1041    ///
1042    /// # Example
1043    ///
1044    /// ```rust
1045    /// use std::sync::Arc;
1046    /// use std::time::Duration;
1047    /// use tower::timeout::TimeoutLayer;
1048    /// use tower_mcp::{ToolBuilder, CallToolResult};
1049    /// use tower_mcp::extract::{Json, State};
1050    /// use schemars::JsonSchema;
1051    /// use serde::Deserialize;
1052    ///
1053    /// #[derive(Clone)]
1054    /// struct AppState { prefix: String }
1055    ///
1056    /// #[derive(Debug, Deserialize, JsonSchema)]
1057    /// struct QueryInput { query: String }
1058    ///
1059    /// let state = Arc::new(AppState { prefix: "db".to_string() });
1060    ///
1061    /// let tool = ToolBuilder::new("search")
1062    ///     .description("Search with timeout")
1063    ///     .extractor_handler(state, |
1064    ///         State(app): State<Arc<AppState>>,
1065    ///         Json(input): Json<QueryInput>,
1066    ///     | async move {
1067    ///         Ok(CallToolResult::text(format!("{}: {}", app.prefix, input.query)))
1068    ///     })
1069    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1070    ///     .build();
1071    /// ```
1072    pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1073        ToolBuilderWithExtractorLayer {
1074            name: self.name,
1075            title: self.title,
1076            description: self.description,
1077            output_schema: self.output_schema,
1078            icons: self.icons,
1079            annotations: self.annotations,
1080            task_support: self.task_support,
1081            state: self.state,
1082            handler: self.handler,
1083            input_schema: self.input_schema,
1084            layer,
1085            _phantom: PhantomData,
1086        }
1087    }
1088
1089    /// Apply a guard to this tool.
1090    ///
1091    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1092    pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1093    where
1094        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1095    {
1096        self.layer(GuardLayer::new(guard))
1097    }
1098}
1099
1100/// Builder state after a layer has been applied to an extractor handler.
1101///
1102/// This builder allows chaining additional layers and building the final tool.
1103pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1104    name: String,
1105    title: Option<String>,
1106    description: Option<String>,
1107    output_schema: Option<Value>,
1108    icons: Option<Vec<crate::protocol::ToolIcon>>,
1109    annotations: Option<crate::protocol::ToolAnnotations>,
1110    task_support: crate::protocol::TaskSupportMode,
1111    state: S,
1112    handler: F,
1113    input_schema: Value,
1114    layer: L,
1115    _phantom: PhantomData<T>,
1116}
1117
1118#[allow(private_bounds)]
1119impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1120where
1121    S: Clone + Send + Sync + 'static,
1122    F: ExtractorHandler<S, T> + Clone,
1123    T: Send + Sync + 'static,
1124    L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1125        + Clone
1126        + Send
1127        + Sync
1128        + 'static,
1129    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1130    <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1131    <L::Service as Service<ToolRequest>>::Future: Send,
1132{
1133    /// Build the tool with the applied layer(s).
1134    pub fn build(self) -> Tool {
1135        let handler = ExtractorToolHandler {
1136            state: self.state,
1137            handler: self.handler,
1138            input_schema: self.input_schema.clone(),
1139            _phantom: PhantomData,
1140        };
1141
1142        let handler_service = ToolHandlerService::new(handler);
1143        let layered = self.layer.layer(handler_service);
1144        let catch_error = ToolCatchError::new(layered);
1145        let service = BoxCloneService::new(catch_error);
1146
1147        Tool {
1148            name: self.name,
1149            title: self.title,
1150            description: self.description,
1151            output_schema: self.output_schema,
1152            icons: self.icons,
1153            annotations: self.annotations,
1154            task_support: self.task_support,
1155            service,
1156            input_schema: self.input_schema,
1157        }
1158    }
1159
1160    /// Apply an additional Tower layer (middleware).
1161    ///
1162    /// Layers are applied in order, with earlier layers wrapping later ones.
1163    /// This means the first layer added is the outermost middleware.
1164    pub fn layer<L2>(
1165        self,
1166        layer: L2,
1167    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1168        ToolBuilderWithExtractorLayer {
1169            name: self.name,
1170            title: self.title,
1171            description: self.description,
1172            output_schema: self.output_schema,
1173            icons: self.icons,
1174            annotations: self.annotations,
1175            task_support: self.task_support,
1176            state: self.state,
1177            handler: self.handler,
1178            input_schema: self.input_schema,
1179            layer: tower::layer::util::Stack::new(layer, self.layer),
1180            _phantom: PhantomData,
1181        }
1182    }
1183
1184    /// Apply a guard to this tool.
1185    ///
1186    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1187    pub fn guard<G>(
1188        self,
1189        guard: G,
1190    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1191    where
1192        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1193    {
1194        self.layer(GuardLayer::new(guard))
1195    }
1196}
1197
1198/// Builder state for extractor-based handlers with typed JSON input
1199pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1200    pub(crate) name: String,
1201    pub(crate) title: Option<String>,
1202    pub(crate) description: Option<String>,
1203    pub(crate) output_schema: Option<Value>,
1204    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1205    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1206    pub(crate) task_support: crate::protocol::TaskSupportMode,
1207    pub(crate) state: S,
1208    pub(crate) handler: F,
1209    pub(crate) _phantom: PhantomData<(T, I)>,
1210}
1211
1212impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1213where
1214    S: Clone + Send + Sync + 'static,
1215    F: TypedExtractorHandler<S, T, I> + Clone,
1216    T: Send + Sync + 'static,
1217    I: JsonSchema + Send + Sync + 'static,
1218{
1219    /// Build the tool.
1220    pub fn build(self) -> Tool {
1221        let input_schema = {
1222            let schema = schemars::schema_for!(I);
1223            serde_json::to_value(schema).unwrap_or_else(|_| {
1224                serde_json::json!({
1225                    "type": "object"
1226                })
1227            })
1228        };
1229
1230        let handler = TypedExtractorToolHandler {
1231            state: self.state,
1232            handler: self.handler,
1233            input_schema: input_schema.clone(),
1234            _phantom: PhantomData,
1235        };
1236
1237        let handler_service = crate::tool::ToolHandlerService::new(handler);
1238        let catch_error = ToolCatchError::new(handler_service);
1239        let service = BoxCloneService::new(catch_error);
1240
1241        Tool {
1242            name: self.name,
1243            title: self.title,
1244            description: self.description,
1245            output_schema: self.output_schema,
1246            icons: self.icons,
1247            annotations: self.annotations,
1248            task_support: self.task_support,
1249            service,
1250            input_schema,
1251        }
1252    }
1253}
1254
1255/// Internal handler wrapper for typed extractor-based handlers
1256struct TypedExtractorToolHandler<S, F, T, I> {
1257    state: S,
1258    handler: F,
1259    input_schema: Value,
1260    _phantom: PhantomData<(T, I)>,
1261}
1262
1263impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1264where
1265    S: Clone + Send + Sync + 'static,
1266    F: TypedExtractorHandler<S, T, I> + Clone,
1267    T: Send + Sync + 'static,
1268    I: JsonSchema + Send + Sync + 'static,
1269{
1270    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1271        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1272        self.call_with_context(ctx, args)
1273    }
1274
1275    fn call_with_context(
1276        &self,
1277        ctx: RequestContext,
1278        args: Value,
1279    ) -> BoxFuture<'_, Result<CallToolResult>> {
1280        let state = self.state.clone();
1281        let handler = self.handler.clone();
1282        Box::pin(async move { handler.call(ctx, state, args).await })
1283    }
1284
1285    fn uses_context(&self) -> bool {
1286        true
1287    }
1288
1289    fn input_schema(&self) -> Value {
1290        self.input_schema.clone()
1291    }
1292}
1293
1294#[cfg(test)]
1295mod tests {
1296    use super::*;
1297    use crate::protocol::RequestId;
1298    use schemars::JsonSchema;
1299    use serde::Deserialize;
1300    use std::sync::Arc;
1301
1302    #[derive(Debug, Deserialize, JsonSchema)]
1303    struct TestInput {
1304        name: String,
1305        count: i32,
1306    }
1307
1308    #[test]
1309    fn test_json_extraction() {
1310        let args = serde_json::json!({"name": "test", "count": 42});
1311        let ctx = RequestContext::new(RequestId::Number(1));
1312
1313        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1314        assert!(result.is_ok());
1315        let Json(input) = result.unwrap();
1316        assert_eq!(input.name, "test");
1317        assert_eq!(input.count, 42);
1318    }
1319
1320    #[test]
1321    fn test_json_extraction_error() {
1322        let args = serde_json::json!({"name": "test"}); // missing count
1323        let ctx = RequestContext::new(RequestId::Number(1));
1324
1325        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1326        assert!(result.is_err());
1327        let rejection = result.unwrap_err();
1328        // JsonRejection contains the serde error message
1329        assert!(rejection.message().contains("count"));
1330    }
1331
1332    #[test]
1333    fn test_state_extraction() {
1334        let args = serde_json::json!({});
1335        let ctx = RequestContext::new(RequestId::Number(1));
1336        let state = Arc::new("my-state".to_string());
1337
1338        let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1339        assert!(result.is_ok());
1340        let State(extracted) = result.unwrap();
1341        assert_eq!(*extracted, "my-state");
1342    }
1343
1344    #[test]
1345    fn test_context_extraction() {
1346        let args = serde_json::json!({});
1347        let ctx = RequestContext::new(RequestId::Number(42));
1348
1349        let result = Context::from_tool_request(&ctx, &(), &args);
1350        assert!(result.is_ok());
1351        let extracted = result.unwrap();
1352        assert_eq!(*extracted.request_id(), RequestId::Number(42));
1353    }
1354
1355    #[test]
1356    fn test_raw_args_extraction() {
1357        let args = serde_json::json!({"foo": "bar", "baz": 123});
1358        let ctx = RequestContext::new(RequestId::Number(1));
1359
1360        let result = RawArgs::from_tool_request(&ctx, &(), &args);
1361        assert!(result.is_ok());
1362        let RawArgs(extracted) = result.unwrap();
1363        assert_eq!(extracted["foo"], "bar");
1364        assert_eq!(extracted["baz"], 123);
1365    }
1366
1367    #[test]
1368    fn test_extension_extraction() {
1369        use crate::context::Extensions;
1370
1371        #[derive(Clone, Debug, PartialEq)]
1372        struct DatabasePool {
1373            url: String,
1374        }
1375
1376        let args = serde_json::json!({});
1377
1378        // Create extensions with a value
1379        let mut extensions = Extensions::new();
1380        extensions.insert(Arc::new(DatabasePool {
1381            url: "postgres://localhost".to_string(),
1382        }));
1383
1384        // Create context with extensions
1385        let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1386
1387        // Extract the extension
1388        let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1389        assert!(result.is_ok());
1390        let Extension(pool) = result.unwrap();
1391        assert_eq!(pool.url, "postgres://localhost");
1392    }
1393
1394    #[test]
1395    fn test_extension_extraction_missing() {
1396        #[derive(Clone, Debug)]
1397        struct NotPresent;
1398
1399        let args = serde_json::json!({});
1400        let ctx = RequestContext::new(RequestId::Number(1));
1401
1402        // Try to extract something that's not in extensions
1403        let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1404        assert!(result.is_err());
1405        let rejection = result.unwrap_err();
1406        // ExtensionRejection contains the type name
1407        assert!(rejection.type_name().contains("NotPresent"));
1408    }
1409
1410    #[tokio::test]
1411    async fn test_single_extractor_handler() {
1412        let handler = |Json(input): Json<TestInput>| async move {
1413            Ok(CallToolResult::text(format!(
1414                "{}: {}",
1415                input.name, input.count
1416            )))
1417        };
1418
1419        let ctx = RequestContext::new(RequestId::Number(1));
1420        let args = serde_json::json!({"name": "test", "count": 5});
1421
1422        // Use explicit trait to avoid ambiguity
1423        let result: Result<CallToolResult> =
1424            ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1425        assert!(result.is_ok());
1426    }
1427
1428    #[tokio::test]
1429    async fn test_two_extractor_handler() {
1430        let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1431            Ok(CallToolResult::text(format!(
1432                "{}: {} - {}",
1433                state, input.name, input.count
1434            )))
1435        };
1436
1437        let ctx = RequestContext::new(RequestId::Number(1));
1438        let state = Arc::new("prefix".to_string());
1439        let args = serde_json::json!({"name": "test", "count": 5});
1440
1441        // Use explicit trait to avoid ambiguity
1442        let result: Result<CallToolResult> = ExtractorHandler::<
1443            Arc<String>,
1444            (State<Arc<String>>, Json<TestInput>),
1445        >::call(handler, ctx, state, args)
1446        .await;
1447        assert!(result.is_ok());
1448    }
1449
1450    #[tokio::test]
1451    async fn test_three_extractor_handler() {
1452        let handler = |State(state): State<Arc<String>>,
1453                       ctx: Context,
1454                       Json(input): Json<TestInput>| async move {
1455            // Verify we can access all extractors
1456            assert!(!ctx.is_cancelled());
1457            Ok(CallToolResult::text(format!(
1458                "{}: {} - {}",
1459                state, input.name, input.count
1460            )))
1461        };
1462
1463        let ctx = RequestContext::new(RequestId::Number(1));
1464        let state = Arc::new("prefix".to_string());
1465        let args = serde_json::json!({"name": "test", "count": 5});
1466
1467        // Use explicit trait to avoid ambiguity
1468        let result: Result<CallToolResult> = ExtractorHandler::<
1469            Arc<String>,
1470            (State<Arc<String>>, Context, Json<TestInput>),
1471        >::call(handler, ctx, state, args)
1472        .await;
1473        assert!(result.is_ok());
1474    }
1475
1476    #[test]
1477    fn test_json_schema_generation() {
1478        let schema = Json::<TestInput>::schema();
1479        assert!(schema.is_some());
1480        let schema = schema.unwrap();
1481        assert!(schema.get("properties").is_some());
1482    }
1483
1484    #[test]
1485    fn test_rejection_into_error() {
1486        let rejection = Rejection::new("test error");
1487        let error: Error = rejection.into();
1488        assert!(error.to_string().contains("test error"));
1489    }
1490
1491    #[test]
1492    fn test_json_rejection() {
1493        // Test basic JsonRejection
1494        let rejection = JsonRejection::new("missing field `name`");
1495        assert_eq!(rejection.message(), "missing field `name`");
1496        assert!(rejection.path().is_none());
1497        assert!(rejection.to_string().contains("Invalid input"));
1498
1499        // Test JsonRejection with path
1500        let rejection = JsonRejection::with_path("expected string", "users[0].name");
1501        assert_eq!(rejection.message(), "expected string");
1502        assert_eq!(rejection.path(), Some("users[0].name"));
1503        assert!(rejection.to_string().contains("users[0].name"));
1504
1505        // Test conversion to Error
1506        let error: Error = rejection.into();
1507        assert!(error.to_string().contains("users[0].name"));
1508    }
1509
1510    #[test]
1511    fn test_json_rejection_from_serde_error() {
1512        // Create a real serde error by deserializing invalid JSON
1513        #[derive(Debug, serde::Deserialize)]
1514        struct TestStruct {
1515            #[allow(dead_code)]
1516            name: String,
1517        }
1518
1519        let result: std::result::Result<TestStruct, _> =
1520            serde_json::from_value(serde_json::json!({"count": 42}));
1521        assert!(result.is_err());
1522
1523        let rejection: JsonRejection = result.unwrap_err().into();
1524        assert!(rejection.message().contains("name"));
1525    }
1526
1527    #[test]
1528    fn test_extension_rejection() {
1529        // Test ExtensionRejection
1530        let rejection = ExtensionRejection::not_found::<String>();
1531        assert!(rejection.type_name().contains("String"));
1532        assert!(rejection.to_string().contains("not found"));
1533        assert!(rejection.to_string().contains("with_state"));
1534
1535        // Test conversion to Error
1536        let error: Error = rejection.into();
1537        assert!(error.to_string().contains("not found"));
1538    }
1539
1540    #[tokio::test]
1541    async fn test_tool_builder_extractor_handler() {
1542        use crate::ToolBuilder;
1543
1544        let state = Arc::new("shared-state".to_string());
1545
1546        let tool =
1547            ToolBuilder::new("test_extractor")
1548                .description("Test extractor handler")
1549                .extractor_handler(
1550                    state,
1551                    |State(state): State<Arc<String>>,
1552                     ctx: Context,
1553                     Json(input): Json<TestInput>| async move {
1554                        assert!(!ctx.is_cancelled());
1555                        Ok(CallToolResult::text(format!(
1556                            "{}: {} - {}",
1557                            state, input.name, input.count
1558                        )))
1559                    },
1560                )
1561                .build();
1562
1563        assert_eq!(tool.name, "test_extractor");
1564        assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1565
1566        // Test calling the tool
1567        let result = tool
1568            .call(serde_json::json!({"name": "test", "count": 42}))
1569            .await;
1570        assert!(!result.is_error);
1571    }
1572
1573    #[tokio::test]
1574    async fn test_tool_builder_extractor_handler_typed() {
1575        use crate::ToolBuilder;
1576
1577        let state = Arc::new("typed-state".to_string());
1578
1579        let tool = ToolBuilder::new("test_typed")
1580            .description("Test typed extractor handler")
1581            .extractor_handler_typed::<_, _, _, TestInput>(
1582                state,
1583                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1584                    Ok(CallToolResult::text(format!(
1585                        "{}: {} - {}",
1586                        state, input.name, input.count
1587                    )))
1588                },
1589            )
1590            .build();
1591
1592        assert_eq!(tool.name, "test_typed");
1593
1594        // Verify schema is properly generated from TestInput
1595        let def = tool.definition();
1596        let schema = def.input_schema;
1597        assert!(schema.get("properties").is_some());
1598
1599        // Test calling the tool
1600        let result = tool
1601            .call(serde_json::json!({"name": "world", "count": 99}))
1602            .await;
1603        assert!(!result.is_error);
1604    }
1605
1606    #[tokio::test]
1607    async fn test_extractor_handler_auto_schema() {
1608        use crate::ToolBuilder;
1609
1610        let state = Arc::new("auto-schema".to_string());
1611
1612        // extractor_handler (not _typed) should auto-detect Json<TestInput> schema
1613        let tool = ToolBuilder::new("test_auto_schema")
1614            .description("Test auto schema detection")
1615            .extractor_handler(
1616                state,
1617                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1618                    Ok(CallToolResult::text(format!(
1619                        "{}: {} - {}",
1620                        state, input.name, input.count
1621                    )))
1622                },
1623            )
1624            .build();
1625
1626        // Verify schema is properly generated from TestInput (not generic object)
1627        let def = tool.definition();
1628        let schema = def.input_schema;
1629        assert!(
1630            schema.get("properties").is_some(),
1631            "Schema should have properties from TestInput, got: {}",
1632            schema
1633        );
1634        let props = schema.get("properties").unwrap();
1635        assert!(
1636            props.get("name").is_some(),
1637            "Schema should have 'name' property"
1638        );
1639        assert!(
1640            props.get("count").is_some(),
1641            "Schema should have 'count' property"
1642        );
1643
1644        // Test calling the tool
1645        let result = tool
1646            .call(serde_json::json!({"name": "world", "count": 99}))
1647            .await;
1648        assert!(!result.is_error);
1649    }
1650
1651    #[test]
1652    fn test_extractor_handler_no_json_fallback() {
1653        use crate::ToolBuilder;
1654
1655        // extractor_handler without Json<T> should fall back to generic schema
1656        let tool = ToolBuilder::new("test_no_json")
1657            .description("Test no json fallback")
1658            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1659                Ok(CallToolResult::json(args))
1660            })
1661            .build();
1662
1663        let def = tool.definition();
1664        let schema = def.input_schema;
1665        assert_eq!(
1666            schema.get("type").and_then(|v| v.as_str()),
1667            Some("object"),
1668            "Schema should be generic object"
1669        );
1670        assert_eq!(
1671            schema.get("additionalProperties").and_then(|v| v.as_bool()),
1672            Some(true),
1673            "Schema should allow additional properties"
1674        );
1675        // Should NOT have specific properties
1676        assert!(
1677            schema.get("properties").is_none(),
1678            "Generic schema should not have specific properties"
1679        );
1680    }
1681
1682    #[tokio::test]
1683    async fn test_extractor_handler_with_layer() {
1684        use crate::ToolBuilder;
1685        use std::time::Duration;
1686        use tower::timeout::TimeoutLayer;
1687
1688        let state = Arc::new("layered".to_string());
1689
1690        let tool = ToolBuilder::new("test_extractor_layer")
1691            .description("Test extractor handler with layer")
1692            .extractor_handler(
1693                state,
1694                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1695                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1696                },
1697            )
1698            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1699            .build();
1700
1701        // Verify the tool works
1702        let result = tool
1703            .call(serde_json::json!({"name": "test", "count": 1}))
1704            .await;
1705        assert!(!result.is_error);
1706        assert_eq!(result.first_text().unwrap(), "layered: test");
1707
1708        // Verify schema is still properly generated
1709        let def = tool.definition();
1710        let schema = def.input_schema;
1711        assert!(
1712            schema.get("properties").is_some(),
1713            "Schema should have properties even with layer"
1714        );
1715    }
1716
1717    #[tokio::test]
1718    async fn test_extractor_handler_with_timeout_layer() {
1719        use crate::ToolBuilder;
1720        use std::time::Duration;
1721        use tower::timeout::TimeoutLayer;
1722
1723        let tool = ToolBuilder::new("test_extractor_timeout")
1724            .description("Test extractor handler timeout")
1725            .extractor_handler((), |Json(input): Json<TestInput>| async move {
1726                tokio::time::sleep(Duration::from_millis(200)).await;
1727                Ok(CallToolResult::text(input.name.to_string()))
1728            })
1729            .layer(TimeoutLayer::new(Duration::from_millis(50)))
1730            .build();
1731
1732        // Should timeout
1733        let result = tool
1734            .call(serde_json::json!({"name": "slow", "count": 1}))
1735            .await;
1736        assert!(result.is_error);
1737        let msg = result.first_text().unwrap().to_lowercase();
1738        assert!(
1739            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1740            "Expected timeout error, got: {}",
1741            msg
1742        );
1743    }
1744
1745    #[tokio::test]
1746    async fn test_extractor_handler_with_multiple_layers() {
1747        use crate::ToolBuilder;
1748        use std::time::Duration;
1749        use tower::limit::ConcurrencyLimitLayer;
1750        use tower::timeout::TimeoutLayer;
1751
1752        let state = Arc::new("multi".to_string());
1753
1754        let tool = ToolBuilder::new("test_multi_layer")
1755            .description("Test multiple layers")
1756            .extractor_handler(
1757                state,
1758                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1759                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1760                },
1761            )
1762            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1763            .layer(ConcurrencyLimitLayer::new(10))
1764            .build();
1765
1766        let result = tool
1767            .call(serde_json::json!({"name": "test", "count": 1}))
1768            .await;
1769        assert!(!result.is_error);
1770        assert_eq!(result.first_text().unwrap(), "multi: test");
1771    }
1772}