Skip to main content

tower_mcp/
extract.rs

1//! Extractor pattern for tool handlers
2//!
3//! This module provides an axum-inspired extractor pattern that makes state and context
4//! injection more declarative, reducing the combinatorial explosion of handler variants.
5//!
6//! # Overview
7//!
8//! Extractors implement [`FromToolRequest`], which extracts data from the tool request
9//! (context, state, and arguments). Multiple extractors can be combined in handler
10//! function parameters.
11//!
12//! # Built-in Extractors
13//!
14//! - [`Json<T>`] - Extract typed input from args (deserializes JSON)
15//! - [`State<T>`] - Extract shared state from per-tool state (cloned for each request)
16//! - [`Extension<T>`] - Extract data from router extensions (via `router.with_state()`)
17//! - [`Context`] - Extract the [`RequestContext`] for progress, cancellation, etc.
18//! - [`RawArgs`] - Extract raw `serde_json::Value` arguments
19//!
20//! ## State vs Extension
21//!
22//! - Use **`State<T>`** when state is passed directly to `extractor_handler()` (per-tool state)
23//! - Use **`Extension<T>`** when state is set via `McpRouter::with_state()` (router-level state)
24//!
25//! # Example
26//!
27//! ```rust
28//! use std::sync::Arc;
29//! use tower_mcp::{ToolBuilder, CallToolResult};
30//! use tower_mcp::extract::{Json, State, Context};
31//! use schemars::JsonSchema;
32//! use serde::Deserialize;
33//!
34//! #[derive(Clone)]
35//! struct AppState {
36//!     db_url: String,
37//! }
38//!
39//! #[derive(Debug, Deserialize, JsonSchema)]
40//! struct QueryInput {
41//!     query: String,
42//! }
43//!
44//! let state = Arc::new(AppState { db_url: "postgres://...".to_string() });
45//!
46//! let tool = ToolBuilder::new("search")
47//!     .description("Search the database")
48//!     .extractor_handler(state, |
49//!         State(db): State<Arc<AppState>>,
50//!         ctx: Context,
51//!         Json(input): Json<QueryInput>,
52//!     | async move {
53//!         // Check cancellation
54//!         if ctx.is_cancelled() {
55//!             return Ok(CallToolResult::error("Cancelled"));
56//!         }
57//!         // Report progress
58//!         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
59//!         // Use state
60//!         Ok(CallToolResult::text(format!("Searched {} with query: {}", db.db_url, input.query)))
61//!     })
62//!     .build();
63//! ```
64//!
65//! # Extractor Order
66//!
67//! The order of extractors in the function signature doesn't matter. Each extractor
68//! independently extracts its data from the request.
69//!
70//! # Error Handling
71//!
72//! If an extractor fails (e.g., JSON deserialization fails), the handler returns
73//! a `CallToolResult::error()` with the rejection message.
74
75use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88// =============================================================================
89// Rejection Types
90// =============================================================================
91
92/// A simple rejection with a message string.
93///
94/// This is a general-purpose rejection type for custom extractors.
95/// For more specific error information, use the typed rejection types
96/// like [`JsonRejection`] or [`ExtensionRejection`].
97#[derive(Debug, Clone)]
98pub struct Rejection {
99    message: String,
100}
101
102impl Rejection {
103    /// Create a new rejection with the given message.
104    pub fn new(message: impl Into<String>) -> Self {
105        Self {
106            message: message.into(),
107        }
108    }
109
110    /// Get the rejection message.
111    pub fn message(&self) -> &str {
112        &self.message
113    }
114}
115
116impl std::fmt::Display for Rejection {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}", self.message)
119    }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125    fn from(rejection: Rejection) -> Self {
126        Error::tool(rejection.message)
127    }
128}
129
130/// Rejection returned when JSON deserialization fails.
131///
132/// This rejection provides structured information about the deserialization
133/// error, including the path to the failing field when available.
134///
135/// # Example
136///
137/// ```rust
138/// use tower_mcp::extract::JsonRejection;
139///
140/// let rejection = JsonRejection::new("missing field `name`");
141/// assert!(rejection.message().contains("name"));
142/// ```
143#[derive(Debug, Clone)]
144pub struct JsonRejection {
145    message: String,
146    /// The serde error path, if available (e.g., "users[0].name")
147    path: Option<String>,
148}
149
150impl JsonRejection {
151    /// Create a new JSON rejection from a serde error.
152    pub fn new(message: impl Into<String>) -> Self {
153        Self {
154            message: message.into(),
155            path: None,
156        }
157    }
158
159    /// Create a JSON rejection with a path to the failing field.
160    pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161        Self {
162            message: message.into(),
163            path: Some(path.into()),
164        }
165    }
166
167    /// Get the error message.
168    pub fn message(&self) -> &str {
169        &self.message
170    }
171
172    /// Get the path to the failing field, if available.
173    pub fn path(&self) -> Option<&str> {
174        self.path.as_deref()
175    }
176}
177
178impl std::fmt::Display for JsonRejection {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        if let Some(path) = &self.path {
181            write!(f, "Invalid input at `{}`: {}", path, self.message)
182        } else {
183            write!(f, "Invalid input: {}", self.message)
184        }
185    }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191    fn from(rejection: JsonRejection) -> Self {
192        Error::tool(rejection.to_string())
193    }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197    fn from(err: serde_json::Error) -> Self {
198        // Try to extract path information from serde error
199        let path = if err.is_data() {
200            // serde_json provides line/column but not field path in the error itself
201            // The path is embedded in the message for some error types
202            None
203        } else {
204            None
205        };
206
207        Self {
208            message: err.to_string(),
209            path,
210        }
211    }
212}
213
214/// Rejection returned when an extension is not found.
215///
216/// This rejection is returned by the [`Extension`] extractor when the
217/// requested type is not present in the router's extensions.
218///
219/// # Example
220///
221/// ```rust
222/// use tower_mcp::extract::ExtensionRejection;
223///
224/// let rejection = ExtensionRejection::not_found::<String>();
225/// assert!(rejection.type_name().contains("String"));
226/// ```
227#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229    type_name: &'static str,
230}
231
232impl ExtensionRejection {
233    /// Create a rejection for a missing extension type.
234    pub fn not_found<T>() -> Self {
235        Self {
236            type_name: std::any::type_name::<T>(),
237        }
238    }
239
240    /// Get the type name of the missing extension.
241    pub fn type_name(&self) -> &'static str {
242        self.type_name
243    }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(
249            f,
250            "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251            self.type_name
252        )
253    }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259    fn from(rejection: ExtensionRejection) -> Self {
260        Error::tool(rejection.to_string())
261    }
262}
263
264/// Trait for extracting data from a tool request.
265///
266/// Implement this trait to create custom extractors that can be used
267/// in `extractor_handler` functions.
268///
269/// # Type Parameters
270///
271/// - `S` - The state type. Defaults to `()` for extractors that don't need state.
272///
273/// # Example
274///
275/// ```rust
276/// use tower_mcp::extract::{FromToolRequest, Rejection};
277/// use tower_mcp::RequestContext;
278/// use serde_json::Value;
279///
280/// struct RequestId(String);
281///
282/// impl<S> FromToolRequest<S> for RequestId {
283///     type Rejection = Rejection;
284///
285///     fn from_tool_request(
286///         ctx: &RequestContext,
287///         _state: &S,
288///         _args: &Value,
289///     ) -> Result<Self, Self::Rejection> {
290///         Ok(RequestId(format!("{:?}", ctx.request_id())))
291///     }
292/// }
293/// ```
294pub trait FromToolRequest<S = ()>: Sized {
295    /// The rejection type returned when extraction fails.
296    type Rejection: Into<Error>;
297
298    /// Extract this type from the tool request.
299    ///
300    /// # Arguments
301    ///
302    /// * `ctx` - The request context with progress, cancellation, etc.
303    /// * `state` - The shared state passed to the handler
304    /// * `args` - The raw JSON arguments to the tool
305    fn from_tool_request(
306        ctx: &RequestContext,
307        state: &S,
308        args: &Value,
309    ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312// =============================================================================
313// Built-in Extractors
314// =============================================================================
315
316/// Extract and deserialize JSON arguments into a typed struct.
317///
318/// This extractor deserializes the tool's JSON arguments into type `T`.
319/// The type must implement [`serde::de::DeserializeOwned`] and [`schemars::JsonSchema`].
320///
321/// # Example
322///
323/// ```rust
324/// use tower_mcp::extract::Json;
325/// use schemars::JsonSchema;
326/// use serde::Deserialize;
327///
328/// #[derive(Debug, Deserialize, JsonSchema)]
329/// struct MyInput {
330///     name: String,
331///     count: i32,
332/// }
333///
334/// // In an extractor handler:
335/// // |Json(input): Json<MyInput>| async move { ... }
336/// ```
337///
338/// # Rejection
339///
340/// Returns a [`JsonRejection`] if deserialization fails. The rejection contains
341/// the error message and potentially the path to the failing field.
342#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346    type Target = T;
347
348    fn deref(&self) -> &Self::Target {
349        &self.0
350    }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355    T: DeserializeOwned,
356{
357    type Rejection = JsonRejection;
358
359    fn from_tool_request(
360        _ctx: &RequestContext,
361        _state: &S,
362        args: &Value,
363    ) -> std::result::Result<Self, Self::Rejection> {
364        serde_json::from_value(args.clone())
365            .map(Json)
366            .map_err(JsonRejection::from)
367    }
368}
369
370/// Extract shared state.
371///
372/// This extractor clones the state passed to `extractor_handler` and provides
373/// it to the handler. The state type must match the type passed to the builder.
374///
375/// # Example
376///
377/// ```rust
378/// use std::sync::Arc;
379/// use tower_mcp::extract::State;
380///
381/// #[derive(Clone)]
382/// struct AppState {
383///     db_url: String,
384/// }
385///
386/// // In an extractor handler:
387/// // |State(state): State<Arc<AppState>>| async move { ... }
388/// ```
389///
390/// # Note
391///
392/// For expensive-to-clone types, wrap them in `Arc` before passing to
393/// `extractor_handler`.
394#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398    type Target = T;
399
400    fn deref(&self) -> &Self::Target {
401        &self.0
402    }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406    type Rejection = Rejection;
407
408    fn from_tool_request(
409        _ctx: &RequestContext,
410        state: &S,
411        _args: &Value,
412    ) -> std::result::Result<Self, Self::Rejection> {
413        Ok(State(state.clone()))
414    }
415}
416
417/// Extract the request context.
418///
419/// This extractor provides access to the [`RequestContext`], which contains:
420/// - Progress reporting via `report_progress()`
421/// - Cancellation checking via `is_cancelled()`
422/// - Sampling capabilities via `sample()`
423/// - Elicitation capabilities via `elicit_form()` and `elicit_url()`
424/// - Log sending via `send_log()`
425///
426/// # Example
427///
428/// ```rust
429/// use tower_mcp::extract::Context;
430///
431/// // In an extractor handler:
432/// // |ctx: Context| async move {
433/// //     ctx.report_progress(0.5, Some(1.0), Some("Working...")).await;
434/// //     // ...
435/// // }
436/// ```
437#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441    /// Get the inner RequestContext
442    pub fn into_inner(self) -> RequestContext {
443        self.0
444    }
445}
446
447impl Deref for Context {
448    type Target = RequestContext;
449
450    fn deref(&self) -> &Self::Target {
451        &self.0
452    }
453}
454
455impl<S> FromToolRequest<S> for Context {
456    type Rejection = Rejection;
457
458    fn from_tool_request(
459        ctx: &RequestContext,
460        _state: &S,
461        _args: &Value,
462    ) -> std::result::Result<Self, Self::Rejection> {
463        Ok(Context(ctx.clone()))
464    }
465}
466
467/// Extract raw JSON arguments.
468///
469/// This extractor provides the raw `serde_json::Value` arguments without
470/// any deserialization. Useful when you need full control over argument
471/// parsing or when the schema is dynamic.
472///
473/// # Example
474///
475/// ```rust
476/// use tower_mcp::extract::RawArgs;
477///
478/// // In an extractor handler:
479/// // |RawArgs(args): RawArgs| async move {
480/// //     // args is serde_json::Value
481/// //     if let Some(name) = args.get("name") { ... }
482/// // }
483/// ```
484#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488    type Target = Value;
489
490    fn deref(&self) -> &Self::Target {
491        &self.0
492    }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496    type Rejection = Rejection;
497
498    fn from_tool_request(
499        _ctx: &RequestContext,
500        _state: &S,
501        args: &Value,
502    ) -> std::result::Result<Self, Self::Rejection> {
503        Ok(RawArgs(args.clone()))
504    }
505}
506
507/// Extract typed data from router extensions.
508///
509/// This extractor retrieves data that was added to the router via
510/// [`crate::McpRouter::with_state()`] or [`crate::McpRouter::with_extension()`], or
511/// inserted by middleware into the request context's extensions.
512///
513/// # Example
514///
515/// ```rust
516/// use std::sync::Arc;
517/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
518/// use tower_mcp::extract::{Extension, Json};
519/// use schemars::JsonSchema;
520/// use serde::Deserialize;
521///
522/// #[derive(Clone)]
523/// struct DatabasePool {
524///     url: String,
525/// }
526///
527/// #[derive(Deserialize, JsonSchema)]
528/// struct QueryInput {
529///     sql: String,
530/// }
531///
532/// let pool = Arc::new(DatabasePool { url: "postgres://...".into() });
533///
534/// let tool = ToolBuilder::new("query")
535///     .description("Run a query")
536///     .extractor_handler(
537///         (),
538///         |Extension(db): Extension<Arc<DatabasePool>>, Json(input): Json<QueryInput>| async move {
539///             Ok(CallToolResult::text(format!("Query on {}: {}", db.url, input.sql)))
540///         },
541///     )
542///     .build();
543///
544/// let router = McpRouter::new()
545///     .with_state(pool)
546///     .tool(tool);
547/// ```
548///
549/// # Rejection
550///
551/// Returns an [`ExtensionRejection`] if the requested type is not found in the extensions.
552/// The rejection contains the type name of the missing extension.
553#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557    type Target = T;
558
559    fn deref(&self) -> &Self::Target {
560        &self.0
561    }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566    T: Clone + Send + Sync + 'static,
567{
568    type Rejection = ExtensionRejection;
569
570    fn from_tool_request(
571        ctx: &RequestContext,
572        _state: &S,
573        _args: &Value,
574    ) -> std::result::Result<Self, Self::Rejection> {
575        ctx.extension::<T>()
576            .cloned()
577            .map(Extension)
578            .ok_or_else(ExtensionRejection::not_found::<T>)
579    }
580}
581
582// =============================================================================
583// Handler Trait
584// =============================================================================
585
586/// A handler that uses extractors.
587///
588/// This trait is implemented for functions that take extractors as arguments.
589/// You don't need to implement this trait directly; it's automatically
590/// implemented for compatible async functions.
591pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592    /// The future returned by the handler.
593    type Future: Future<Output = Result<CallToolResult>> + Send;
594
595    /// Call the handler with extracted values.
596    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598    /// Get the input schema for this handler.
599    ///
600    /// Returns `None` if no `Json<T>` extractor is used.
601    fn input_schema() -> Value;
602}
603
604// Implementation for single extractor
605impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607    S: Clone + Send + Sync + 'static,
608    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609    Fut: Future<Output = Result<CallToolResult>> + Send,
610    T1: FromToolRequest<S> + HasSchema + Send,
611{
612    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615        Box::pin(async move {
616            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617            self(t1).await
618        })
619    }
620
621    fn input_schema() -> Value {
622        if let Some(schema) = T1::schema() {
623            return schema;
624        }
625        serde_json::json!({
626            "type": "object",
627            "additionalProperties": true
628        })
629    }
630}
631
632// Implementation for two extractors
633impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635    S: Clone + Send + Sync + 'static,
636    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637    Fut: Future<Output = Result<CallToolResult>> + Send,
638    T1: FromToolRequest<S> + HasSchema + Send,
639    T2: FromToolRequest<S> + HasSchema + Send,
640{
641    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644        Box::pin(async move {
645            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647            self(t1, t2).await
648        })
649    }
650
651    fn input_schema() -> Value {
652        if let Some(schema) = T2::schema() {
653            return schema;
654        }
655        if let Some(schema) = T1::schema() {
656            return schema;
657        }
658        serde_json::json!({
659            "type": "object",
660            "additionalProperties": true
661        })
662    }
663}
664
665// Implementation for three extractors
666impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668    S: Clone + Send + Sync + 'static,
669    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670    Fut: Future<Output = Result<CallToolResult>> + Send,
671    T1: FromToolRequest<S> + HasSchema + Send,
672    T2: FromToolRequest<S> + HasSchema + Send,
673    T3: FromToolRequest<S> + HasSchema + Send,
674{
675    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678        Box::pin(async move {
679            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682            self(t1, t2, t3).await
683        })
684    }
685
686    fn input_schema() -> Value {
687        if let Some(schema) = T3::schema() {
688            return schema;
689        }
690        if let Some(schema) = T2::schema() {
691            return schema;
692        }
693        if let Some(schema) = T1::schema() {
694            return schema;
695        }
696        serde_json::json!({
697            "type": "object",
698            "additionalProperties": true
699        })
700    }
701}
702
703// Implementation for four extractors
704impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706    S: Clone + Send + Sync + 'static,
707    F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708    Fut: Future<Output = Result<CallToolResult>> + Send,
709    T1: FromToolRequest<S> + HasSchema + Send,
710    T2: FromToolRequest<S> + HasSchema + Send,
711    T3: FromToolRequest<S> + HasSchema + Send,
712    T4: FromToolRequest<S> + HasSchema + Send,
713{
714    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717        Box::pin(async move {
718            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722            self(t1, t2, t3, t4).await
723        })
724    }
725
726    fn input_schema() -> Value {
727        if let Some(schema) = T4::schema() {
728            return schema;
729        }
730        if let Some(schema) = T3::schema() {
731            return schema;
732        }
733        if let Some(schema) = T2::schema() {
734            return schema;
735        }
736        if let Some(schema) = T1::schema() {
737            return schema;
738        }
739        serde_json::json!({
740            "type": "object",
741            "additionalProperties": true
742        })
743    }
744}
745
746// Implementation for five extractors
747impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749    S: Clone + Send + Sync + 'static,
750    F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751    Fut: Future<Output = Result<CallToolResult>> + Send,
752    T1: FromToolRequest<S> + HasSchema + Send,
753    T2: FromToolRequest<S> + HasSchema + Send,
754    T3: FromToolRequest<S> + HasSchema + Send,
755    T4: FromToolRequest<S> + HasSchema + Send,
756    T5: FromToolRequest<S> + HasSchema + Send,
757{
758    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761        Box::pin(async move {
762            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766            let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767            self(t1, t2, t3, t4, t5).await
768        })
769    }
770
771    fn input_schema() -> Value {
772        if let Some(schema) = T5::schema() {
773            return schema;
774        }
775        if let Some(schema) = T4::schema() {
776            return schema;
777        }
778        if let Some(schema) = T3::schema() {
779            return schema;
780        }
781        if let Some(schema) = T2::schema() {
782            return schema;
783        }
784        if let Some(schema) = T1::schema() {
785            return schema;
786        }
787        serde_json::json!({
788            "type": "object",
789            "additionalProperties": true
790        })
791    }
792}
793
794// =============================================================================
795// Schema Extraction Helper
796// =============================================================================
797
798/// Helper trait to get schema from `Json<T>` extractor
799pub trait HasSchema {
800    /// Returns the JSON Schema for this type, if available.
801    fn schema() -> Option<Value>;
802}
803
804impl<T: JsonSchema> HasSchema for Json<T> {
805    fn schema() -> Option<Value> {
806        let schema = schemars::schema_for!(T);
807        serde_json::to_value(schema)
808            .ok()
809            .map(crate::tool::ensure_object_schema)
810    }
811}
812
813// Default impl for non-Json extractors
814impl HasSchema for Context {
815    fn schema() -> Option<Value> {
816        None
817    }
818}
819
820impl HasSchema for RawArgs {
821    fn schema() -> Option<Value> {
822        None
823    }
824}
825
826impl<T> HasSchema for State<T> {
827    fn schema() -> Option<Value> {
828        None
829    }
830}
831
832impl<T> HasSchema for Extension<T> {
833    fn schema() -> Option<Value> {
834        None
835    }
836}
837
838// =============================================================================
839// Typed Extractor Handler
840// =============================================================================
841
842/// A handler that uses extractors with typed JSON input.
843///
844/// This trait is similar to [`ExtractorHandler`] but provides proper JSON
845/// schema generation for the input type when `Json<T>` is used.
846#[deprecated(
847    since = "0.8.0",
848    note = "Use `ExtractorHandler` instead -- `extractor_handler` auto-detects JSON schema from `Json<T>` extractors"
849)]
850pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
851where
852    I: JsonSchema,
853{
854    /// The future returned by the handler.
855    type Future: Future<Output = Result<CallToolResult>> + Send;
856
857    /// Call the handler with extracted values.
858    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
859}
860
861// Single extractor with Json<T>
862#[allow(deprecated)]
863impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
864where
865    S: Clone + Send + Sync + 'static,
866    F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
867    Fut: Future<Output = Result<CallToolResult>> + Send,
868    T: DeserializeOwned + JsonSchema + Send,
869{
870    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
871
872    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
873        Box::pin(async move {
874            let t1 =
875                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
876            self(t1).await
877        })
878    }
879}
880
881// Two extractors ending with Json<T>
882#[allow(deprecated)]
883impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
884where
885    S: Clone + Send + Sync + 'static,
886    F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
887    Fut: Future<Output = Result<CallToolResult>> + Send,
888    T1: FromToolRequest<S> + Send,
889    T: DeserializeOwned + JsonSchema + Send,
890{
891    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
892
893    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
894        Box::pin(async move {
895            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
896            let t2 =
897                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
898            self(t1, t2).await
899        })
900    }
901}
902
903// Three extractors ending with Json<T>
904#[allow(deprecated)]
905impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
906where
907    S: Clone + Send + Sync + 'static,
908    F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
909    Fut: Future<Output = Result<CallToolResult>> + Send,
910    T1: FromToolRequest<S> + Send,
911    T2: FromToolRequest<S> + Send,
912    T: DeserializeOwned + JsonSchema + Send,
913{
914    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
915
916    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
917        Box::pin(async move {
918            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
919            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
920            let t3 =
921                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
922            self(t1, t2, t3).await
923        })
924    }
925}
926
927// Four extractors ending with Json<T>
928#[allow(deprecated)]
929impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
930where
931    S: Clone + Send + Sync + 'static,
932    F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
933    Fut: Future<Output = Result<CallToolResult>> + Send,
934    T1: FromToolRequest<S> + Send,
935    T2: FromToolRequest<S> + Send,
936    T3: FromToolRequest<S> + Send,
937    T: DeserializeOwned + JsonSchema + Send,
938{
939    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
940
941    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
942        Box::pin(async move {
943            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
944            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
945            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
946            let t4 =
947                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
948            self(t1, t2, t3, t4).await
949        })
950    }
951}
952
953// =============================================================================
954// ToolBuilder Extensions
955// =============================================================================
956
957use crate::tool::{
958    BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
959};
960use tower::util::BoxCloneService;
961use tower_service::Service;
962
963/// Internal handler wrapper for extractor-based handlers
964pub(crate) struct ExtractorToolHandler<S, F, T> {
965    state: S,
966    handler: F,
967    input_schema: Value,
968    _phantom: PhantomData<T>,
969}
970
971impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
972where
973    S: Clone + Send + Sync + 'static,
974    F: ExtractorHandler<S, T> + Clone,
975    T: Send + Sync + 'static,
976{
977    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
978        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
979        self.call_with_context(ctx, args)
980    }
981
982    fn call_with_context(
983        &self,
984        ctx: RequestContext,
985        args: Value,
986    ) -> BoxFuture<'_, Result<CallToolResult>> {
987        let state = self.state.clone();
988        let handler = self.handler.clone();
989        Box::pin(async move { handler.call(ctx, state, args).await })
990    }
991
992    fn uses_context(&self) -> bool {
993        true
994    }
995
996    fn input_schema(&self) -> Value {
997        self.input_schema.clone()
998    }
999}
1000
1001/// Builder state for extractor-based handlers
1002#[doc(hidden)]
1003pub struct ToolBuilderWithExtractor<S, F, T> {
1004    pub(crate) name: String,
1005    pub(crate) title: Option<String>,
1006    pub(crate) description: Option<String>,
1007    pub(crate) output_schema: Option<Value>,
1008    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1009    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1010    pub(crate) task_support: crate::protocol::TaskSupportMode,
1011    pub(crate) state: S,
1012    pub(crate) handler: F,
1013    pub(crate) input_schema: Value,
1014    pub(crate) _phantom: PhantomData<T>,
1015}
1016
1017impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1018where
1019    S: Clone + Send + Sync + 'static,
1020    F: ExtractorHandler<S, T> + Clone,
1021    T: Send + Sync + 'static,
1022{
1023    /// Build the tool.
1024    pub fn build(self) -> Tool {
1025        let handler = ExtractorToolHandler {
1026            state: self.state,
1027            handler: self.handler,
1028            input_schema: self.input_schema.clone(),
1029            _phantom: PhantomData,
1030        };
1031
1032        let handler_service = ToolHandlerService::new(handler);
1033        let catch_error = ToolCatchError::new(handler_service);
1034        let service = BoxCloneService::new(catch_error);
1035
1036        Tool {
1037            name: self.name,
1038            title: self.title,
1039            description: self.description,
1040            output_schema: self.output_schema,
1041            icons: self.icons,
1042            annotations: self.annotations,
1043            task_support: self.task_support,
1044            service,
1045            input_schema: self.input_schema,
1046        }
1047    }
1048
1049    /// Apply a Tower layer (middleware) to this tool.
1050    ///
1051    /// The layer wraps the tool's handler service, enabling functionality like
1052    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1053    ///
1054    /// # Example
1055    ///
1056    /// ```rust
1057    /// use std::sync::Arc;
1058    /// use std::time::Duration;
1059    /// use tower::timeout::TimeoutLayer;
1060    /// use tower_mcp::{ToolBuilder, CallToolResult};
1061    /// use tower_mcp::extract::{Json, State};
1062    /// use schemars::JsonSchema;
1063    /// use serde::Deserialize;
1064    ///
1065    /// #[derive(Clone)]
1066    /// struct AppState { prefix: String }
1067    ///
1068    /// #[derive(Debug, Deserialize, JsonSchema)]
1069    /// struct QueryInput { query: String }
1070    ///
1071    /// let state = Arc::new(AppState { prefix: "db".to_string() });
1072    ///
1073    /// let tool = ToolBuilder::new("search")
1074    ///     .description("Search with timeout")
1075    ///     .extractor_handler(state, |
1076    ///         State(app): State<Arc<AppState>>,
1077    ///         Json(input): Json<QueryInput>,
1078    ///     | async move {
1079    ///         Ok(CallToolResult::text(format!("{}: {}", app.prefix, input.query)))
1080    ///     })
1081    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1082    ///     .build();
1083    /// ```
1084    pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1085        ToolBuilderWithExtractorLayer {
1086            name: self.name,
1087            title: self.title,
1088            description: self.description,
1089            output_schema: self.output_schema,
1090            icons: self.icons,
1091            annotations: self.annotations,
1092            task_support: self.task_support,
1093            state: self.state,
1094            handler: self.handler,
1095            input_schema: self.input_schema,
1096            layer,
1097            _phantom: PhantomData,
1098        }
1099    }
1100
1101    /// Apply a guard to this tool.
1102    ///
1103    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1104    pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1105    where
1106        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1107    {
1108        self.layer(GuardLayer::new(guard))
1109    }
1110}
1111
1112/// Builder state after a layer has been applied to an extractor handler.
1113///
1114/// This builder allows chaining additional layers and building the final tool.
1115#[doc(hidden)]
1116pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1117    name: String,
1118    title: Option<String>,
1119    description: Option<String>,
1120    output_schema: Option<Value>,
1121    icons: Option<Vec<crate::protocol::ToolIcon>>,
1122    annotations: Option<crate::protocol::ToolAnnotations>,
1123    task_support: crate::protocol::TaskSupportMode,
1124    state: S,
1125    handler: F,
1126    input_schema: Value,
1127    layer: L,
1128    _phantom: PhantomData<T>,
1129}
1130
1131#[allow(private_bounds)]
1132impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1133where
1134    S: Clone + Send + Sync + 'static,
1135    F: ExtractorHandler<S, T> + Clone,
1136    T: Send + Sync + 'static,
1137    L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1138        + Clone
1139        + Send
1140        + Sync
1141        + 'static,
1142    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1143    <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1144    <L::Service as Service<ToolRequest>>::Future: Send,
1145{
1146    /// Build the tool with the applied layer(s).
1147    pub fn build(self) -> Tool {
1148        let handler = ExtractorToolHandler {
1149            state: self.state,
1150            handler: self.handler,
1151            input_schema: self.input_schema.clone(),
1152            _phantom: PhantomData,
1153        };
1154
1155        let handler_service = ToolHandlerService::new(handler);
1156        let layered = self.layer.layer(handler_service);
1157        let catch_error = ToolCatchError::new(layered);
1158        let service = BoxCloneService::new(catch_error);
1159
1160        Tool {
1161            name: self.name,
1162            title: self.title,
1163            description: self.description,
1164            output_schema: self.output_schema,
1165            icons: self.icons,
1166            annotations: self.annotations,
1167            task_support: self.task_support,
1168            service,
1169            input_schema: self.input_schema,
1170        }
1171    }
1172
1173    /// Apply an additional Tower layer (middleware).
1174    ///
1175    /// Layers are applied in order, with earlier layers wrapping later ones.
1176    /// This means the first layer added is the outermost middleware.
1177    pub fn layer<L2>(
1178        self,
1179        layer: L2,
1180    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1181        ToolBuilderWithExtractorLayer {
1182            name: self.name,
1183            title: self.title,
1184            description: self.description,
1185            output_schema: self.output_schema,
1186            icons: self.icons,
1187            annotations: self.annotations,
1188            task_support: self.task_support,
1189            state: self.state,
1190            handler: self.handler,
1191            input_schema: self.input_schema,
1192            layer: tower::layer::util::Stack::new(layer, self.layer),
1193            _phantom: PhantomData,
1194        }
1195    }
1196
1197    /// Apply a guard to this tool.
1198    ///
1199    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1200    pub fn guard<G>(
1201        self,
1202        guard: G,
1203    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1204    where
1205        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1206    {
1207        self.layer(GuardLayer::new(guard))
1208    }
1209}
1210
1211/// Builder state for extractor-based handlers with typed JSON input
1212#[doc(hidden)]
1213#[deprecated(
1214    since = "0.8.0",
1215    note = "Use `ToolBuilderWithExtractor` via `extractor_handler` instead"
1216)]
1217pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1218    pub(crate) name: String,
1219    pub(crate) title: Option<String>,
1220    pub(crate) description: Option<String>,
1221    pub(crate) output_schema: Option<Value>,
1222    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1223    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1224    pub(crate) task_support: crate::protocol::TaskSupportMode,
1225    pub(crate) state: S,
1226    pub(crate) handler: F,
1227    pub(crate) _phantom: PhantomData<(T, I)>,
1228}
1229
1230#[allow(deprecated)]
1231impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1232where
1233    S: Clone + Send + Sync + 'static,
1234    F: TypedExtractorHandler<S, T, I> + Clone,
1235    T: Send + Sync + 'static,
1236    I: JsonSchema + Send + Sync + 'static,
1237{
1238    /// Build the tool.
1239    pub fn build(self) -> Tool {
1240        let input_schema = {
1241            let schema = schemars::schema_for!(I);
1242            let schema = serde_json::to_value(schema).unwrap_or_else(|_| {
1243                serde_json::json!({
1244                    "type": "object"
1245                })
1246            });
1247            crate::tool::ensure_object_schema(schema)
1248        };
1249
1250        let handler = TypedExtractorToolHandler {
1251            state: self.state,
1252            handler: self.handler,
1253            input_schema: input_schema.clone(),
1254            _phantom: PhantomData,
1255        };
1256
1257        let handler_service = crate::tool::ToolHandlerService::new(handler);
1258        let catch_error = ToolCatchError::new(handler_service);
1259        let service = BoxCloneService::new(catch_error);
1260
1261        Tool {
1262            name: self.name,
1263            title: self.title,
1264            description: self.description,
1265            output_schema: self.output_schema,
1266            icons: self.icons,
1267            annotations: self.annotations,
1268            task_support: self.task_support,
1269            service,
1270            input_schema,
1271        }
1272    }
1273}
1274
1275/// Internal handler wrapper for typed extractor-based handlers
1276struct TypedExtractorToolHandler<S, F, T, I> {
1277    state: S,
1278    handler: F,
1279    input_schema: Value,
1280    _phantom: PhantomData<(T, I)>,
1281}
1282
1283#[allow(deprecated)]
1284impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1285where
1286    S: Clone + Send + Sync + 'static,
1287    F: TypedExtractorHandler<S, T, I> + Clone,
1288    T: Send + Sync + 'static,
1289    I: JsonSchema + Send + Sync + 'static,
1290{
1291    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1292        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1293        self.call_with_context(ctx, args)
1294    }
1295
1296    fn call_with_context(
1297        &self,
1298        ctx: RequestContext,
1299        args: Value,
1300    ) -> BoxFuture<'_, Result<CallToolResult>> {
1301        let state = self.state.clone();
1302        let handler = self.handler.clone();
1303        Box::pin(async move { handler.call(ctx, state, args).await })
1304    }
1305
1306    fn uses_context(&self) -> bool {
1307        true
1308    }
1309
1310    fn input_schema(&self) -> Value {
1311        self.input_schema.clone()
1312    }
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317    use super::*;
1318    use crate::protocol::RequestId;
1319    use schemars::JsonSchema;
1320    use serde::Deserialize;
1321    use std::sync::Arc;
1322
1323    #[derive(Debug, Deserialize, JsonSchema)]
1324    struct TestInput {
1325        name: String,
1326        count: i32,
1327    }
1328
1329    #[test]
1330    fn test_json_extraction() {
1331        let args = serde_json::json!({"name": "test", "count": 42});
1332        let ctx = RequestContext::new(RequestId::Number(1));
1333
1334        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1335        assert!(result.is_ok());
1336        let Json(input) = result.unwrap();
1337        assert_eq!(input.name, "test");
1338        assert_eq!(input.count, 42);
1339    }
1340
1341    #[test]
1342    fn test_json_extraction_error() {
1343        let args = serde_json::json!({"name": "test"}); // missing count
1344        let ctx = RequestContext::new(RequestId::Number(1));
1345
1346        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1347        assert!(result.is_err());
1348        let rejection = result.unwrap_err();
1349        // JsonRejection contains the serde error message
1350        assert!(rejection.message().contains("count"));
1351    }
1352
1353    #[test]
1354    fn test_state_extraction() {
1355        let args = serde_json::json!({});
1356        let ctx = RequestContext::new(RequestId::Number(1));
1357        let state = Arc::new("my-state".to_string());
1358
1359        let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1360        assert!(result.is_ok());
1361        let State(extracted) = result.unwrap();
1362        assert_eq!(*extracted, "my-state");
1363    }
1364
1365    #[test]
1366    fn test_context_extraction() {
1367        let args = serde_json::json!({});
1368        let ctx = RequestContext::new(RequestId::Number(42));
1369
1370        let result = Context::from_tool_request(&ctx, &(), &args);
1371        assert!(result.is_ok());
1372        let extracted = result.unwrap();
1373        assert_eq!(*extracted.request_id(), RequestId::Number(42));
1374    }
1375
1376    #[test]
1377    fn test_raw_args_extraction() {
1378        let args = serde_json::json!({"foo": "bar", "baz": 123});
1379        let ctx = RequestContext::new(RequestId::Number(1));
1380
1381        let result = RawArgs::from_tool_request(&ctx, &(), &args);
1382        assert!(result.is_ok());
1383        let RawArgs(extracted) = result.unwrap();
1384        assert_eq!(extracted["foo"], "bar");
1385        assert_eq!(extracted["baz"], 123);
1386    }
1387
1388    #[test]
1389    fn test_extension_extraction() {
1390        use crate::context::Extensions;
1391
1392        #[derive(Clone, Debug, PartialEq)]
1393        struct DatabasePool {
1394            url: String,
1395        }
1396
1397        let args = serde_json::json!({});
1398
1399        // Create extensions with a value
1400        let mut extensions = Extensions::new();
1401        extensions.insert(Arc::new(DatabasePool {
1402            url: "postgres://localhost".to_string(),
1403        }));
1404
1405        // Create context with extensions
1406        let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1407
1408        // Extract the extension
1409        let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1410        assert!(result.is_ok());
1411        let Extension(pool) = result.unwrap();
1412        assert_eq!(pool.url, "postgres://localhost");
1413    }
1414
1415    #[test]
1416    fn test_extension_extraction_missing() {
1417        #[derive(Clone, Debug)]
1418        struct NotPresent;
1419
1420        let args = serde_json::json!({});
1421        let ctx = RequestContext::new(RequestId::Number(1));
1422
1423        // Try to extract something that's not in extensions
1424        let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1425        assert!(result.is_err());
1426        let rejection = result.unwrap_err();
1427        // ExtensionRejection contains the type name
1428        assert!(rejection.type_name().contains("NotPresent"));
1429    }
1430
1431    #[tokio::test]
1432    async fn test_single_extractor_handler() {
1433        let handler = |Json(input): Json<TestInput>| async move {
1434            Ok(CallToolResult::text(format!(
1435                "{}: {}",
1436                input.name, input.count
1437            )))
1438        };
1439
1440        let ctx = RequestContext::new(RequestId::Number(1));
1441        let args = serde_json::json!({"name": "test", "count": 5});
1442
1443        // Use explicit trait to avoid ambiguity
1444        let result: Result<CallToolResult> =
1445            ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1446        assert!(result.is_ok());
1447    }
1448
1449    #[tokio::test]
1450    async fn test_two_extractor_handler() {
1451        let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1452            Ok(CallToolResult::text(format!(
1453                "{}: {} - {}",
1454                state, input.name, input.count
1455            )))
1456        };
1457
1458        let ctx = RequestContext::new(RequestId::Number(1));
1459        let state = Arc::new("prefix".to_string());
1460        let args = serde_json::json!({"name": "test", "count": 5});
1461
1462        // Use explicit trait to avoid ambiguity
1463        let result: Result<CallToolResult> = ExtractorHandler::<
1464            Arc<String>,
1465            (State<Arc<String>>, Json<TestInput>),
1466        >::call(handler, ctx, state, args)
1467        .await;
1468        assert!(result.is_ok());
1469    }
1470
1471    #[tokio::test]
1472    async fn test_three_extractor_handler() {
1473        let handler = |State(state): State<Arc<String>>,
1474                       ctx: Context,
1475                       Json(input): Json<TestInput>| async move {
1476            // Verify we can access all extractors
1477            assert!(!ctx.is_cancelled());
1478            Ok(CallToolResult::text(format!(
1479                "{}: {} - {}",
1480                state, input.name, input.count
1481            )))
1482        };
1483
1484        let ctx = RequestContext::new(RequestId::Number(1));
1485        let state = Arc::new("prefix".to_string());
1486        let args = serde_json::json!({"name": "test", "count": 5});
1487
1488        // Use explicit trait to avoid ambiguity
1489        let result: Result<CallToolResult> = ExtractorHandler::<
1490            Arc<String>,
1491            (State<Arc<String>>, Context, Json<TestInput>),
1492        >::call(handler, ctx, state, args)
1493        .await;
1494        assert!(result.is_ok());
1495    }
1496
1497    #[test]
1498    fn test_json_schema_generation() {
1499        let schema = Json::<TestInput>::schema();
1500        assert!(schema.is_some());
1501        let schema = schema.unwrap();
1502        assert!(schema.get("properties").is_some());
1503    }
1504
1505    #[test]
1506    fn test_rejection_into_error() {
1507        let rejection = Rejection::new("test error");
1508        let error: Error = rejection.into();
1509        assert!(error.to_string().contains("test error"));
1510    }
1511
1512    #[test]
1513    fn test_json_rejection() {
1514        // Test basic JsonRejection
1515        let rejection = JsonRejection::new("missing field `name`");
1516        assert_eq!(rejection.message(), "missing field `name`");
1517        assert!(rejection.path().is_none());
1518        assert!(rejection.to_string().contains("Invalid input"));
1519
1520        // Test JsonRejection with path
1521        let rejection = JsonRejection::with_path("expected string", "users[0].name");
1522        assert_eq!(rejection.message(), "expected string");
1523        assert_eq!(rejection.path(), Some("users[0].name"));
1524        assert!(rejection.to_string().contains("users[0].name"));
1525
1526        // Test conversion to Error
1527        let error: Error = rejection.into();
1528        assert!(error.to_string().contains("users[0].name"));
1529    }
1530
1531    #[test]
1532    fn test_json_rejection_from_serde_error() {
1533        // Create a real serde error by deserializing invalid JSON
1534        #[derive(Debug, serde::Deserialize)]
1535        struct TestStruct {
1536            #[allow(dead_code)]
1537            name: String,
1538        }
1539
1540        let result: std::result::Result<TestStruct, _> =
1541            serde_json::from_value(serde_json::json!({"count": 42}));
1542        assert!(result.is_err());
1543
1544        let rejection: JsonRejection = result.unwrap_err().into();
1545        assert!(rejection.message().contains("name"));
1546    }
1547
1548    #[test]
1549    fn test_extension_rejection() {
1550        // Test ExtensionRejection
1551        let rejection = ExtensionRejection::not_found::<String>();
1552        assert!(rejection.type_name().contains("String"));
1553        assert!(rejection.to_string().contains("not found"));
1554        assert!(rejection.to_string().contains("with_state"));
1555
1556        // Test conversion to Error
1557        let error: Error = rejection.into();
1558        assert!(error.to_string().contains("not found"));
1559    }
1560
1561    #[tokio::test]
1562    async fn test_tool_builder_extractor_handler() {
1563        use crate::ToolBuilder;
1564
1565        let state = Arc::new("shared-state".to_string());
1566
1567        let tool =
1568            ToolBuilder::new("test_extractor")
1569                .description("Test extractor handler")
1570                .extractor_handler(
1571                    state,
1572                    |State(state): State<Arc<String>>,
1573                     ctx: Context,
1574                     Json(input): Json<TestInput>| async move {
1575                        assert!(!ctx.is_cancelled());
1576                        Ok(CallToolResult::text(format!(
1577                            "{}: {} - {}",
1578                            state, input.name, input.count
1579                        )))
1580                    },
1581                )
1582                .build();
1583
1584        assert_eq!(tool.name, "test_extractor");
1585        assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1586
1587        // Test calling the tool
1588        let result = tool
1589            .call(serde_json::json!({"name": "test", "count": 42}))
1590            .await;
1591        assert!(!result.is_error);
1592    }
1593
1594    #[tokio::test]
1595    #[allow(deprecated)]
1596    async fn test_tool_builder_extractor_handler_typed() {
1597        use crate::ToolBuilder;
1598
1599        let state = Arc::new("typed-state".to_string());
1600
1601        let tool = ToolBuilder::new("test_typed")
1602            .description("Test typed extractor handler")
1603            .extractor_handler_typed::<_, _, _, TestInput>(
1604                state,
1605                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1606                    Ok(CallToolResult::text(format!(
1607                        "{}: {} - {}",
1608                        state, input.name, input.count
1609                    )))
1610                },
1611            )
1612            .build();
1613
1614        assert_eq!(tool.name, "test_typed");
1615
1616        // Verify schema is properly generated from TestInput
1617        let def = tool.definition();
1618        let schema = def.input_schema;
1619        assert!(schema.get("properties").is_some());
1620
1621        // Test calling the tool
1622        let result = tool
1623            .call(serde_json::json!({"name": "world", "count": 99}))
1624            .await;
1625        assert!(!result.is_error);
1626    }
1627
1628    #[tokio::test]
1629    async fn test_extractor_handler_auto_schema() {
1630        use crate::ToolBuilder;
1631
1632        let state = Arc::new("auto-schema".to_string());
1633
1634        // extractor_handler (not _typed) should auto-detect Json<TestInput> schema
1635        let tool = ToolBuilder::new("test_auto_schema")
1636            .description("Test auto schema detection")
1637            .extractor_handler(
1638                state,
1639                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1640                    Ok(CallToolResult::text(format!(
1641                        "{}: {} - {}",
1642                        state, input.name, input.count
1643                    )))
1644                },
1645            )
1646            .build();
1647
1648        // Verify schema is properly generated from TestInput (not generic object)
1649        let def = tool.definition();
1650        let schema = def.input_schema;
1651        assert!(
1652            schema.get("properties").is_some(),
1653            "Schema should have properties from TestInput, got: {}",
1654            schema
1655        );
1656        let props = schema.get("properties").unwrap();
1657        assert!(
1658            props.get("name").is_some(),
1659            "Schema should have 'name' property"
1660        );
1661        assert!(
1662            props.get("count").is_some(),
1663            "Schema should have 'count' property"
1664        );
1665
1666        // Test calling the tool
1667        let result = tool
1668            .call(serde_json::json!({"name": "world", "count": 99}))
1669            .await;
1670        assert!(!result.is_error);
1671    }
1672
1673    #[test]
1674    fn test_extractor_handler_no_json_fallback() {
1675        use crate::ToolBuilder;
1676
1677        // extractor_handler without Json<T> should fall back to generic schema
1678        let tool = ToolBuilder::new("test_no_json")
1679            .description("Test no json fallback")
1680            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1681                Ok(CallToolResult::json(args))
1682            })
1683            .build();
1684
1685        let def = tool.definition();
1686        let schema = def.input_schema;
1687        assert_eq!(
1688            schema.get("type").and_then(|v| v.as_str()),
1689            Some("object"),
1690            "Schema should be generic object"
1691        );
1692        assert_eq!(
1693            schema.get("additionalProperties").and_then(|v| v.as_bool()),
1694            Some(true),
1695            "Schema should allow additional properties"
1696        );
1697        // Should NOT have specific properties
1698        assert!(
1699            schema.get("properties").is_none(),
1700            "Generic schema should not have specific properties"
1701        );
1702    }
1703
1704    #[tokio::test]
1705    async fn test_extractor_handler_with_layer() {
1706        use crate::ToolBuilder;
1707        use std::time::Duration;
1708        use tower::timeout::TimeoutLayer;
1709
1710        let state = Arc::new("layered".to_string());
1711
1712        let tool = ToolBuilder::new("test_extractor_layer")
1713            .description("Test extractor handler with layer")
1714            .extractor_handler(
1715                state,
1716                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1717                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1718                },
1719            )
1720            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1721            .build();
1722
1723        // Verify the tool works
1724        let result = tool
1725            .call(serde_json::json!({"name": "test", "count": 1}))
1726            .await;
1727        assert!(!result.is_error);
1728        assert_eq!(result.first_text().unwrap(), "layered: test");
1729
1730        // Verify schema is still properly generated
1731        let def = tool.definition();
1732        let schema = def.input_schema;
1733        assert!(
1734            schema.get("properties").is_some(),
1735            "Schema should have properties even with layer"
1736        );
1737    }
1738
1739    #[tokio::test]
1740    async fn test_extractor_handler_with_timeout_layer() {
1741        use crate::ToolBuilder;
1742        use std::time::Duration;
1743        use tower::timeout::TimeoutLayer;
1744
1745        let tool = ToolBuilder::new("test_extractor_timeout")
1746            .description("Test extractor handler timeout")
1747            .extractor_handler((), |Json(input): Json<TestInput>| async move {
1748                tokio::time::sleep(Duration::from_millis(200)).await;
1749                Ok(CallToolResult::text(input.name.to_string()))
1750            })
1751            .layer(TimeoutLayer::new(Duration::from_millis(50)))
1752            .build();
1753
1754        // Should timeout
1755        let result = tool
1756            .call(serde_json::json!({"name": "slow", "count": 1}))
1757            .await;
1758        assert!(result.is_error);
1759        let msg = result.first_text().unwrap().to_lowercase();
1760        assert!(
1761            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1762            "Expected timeout error, got: {}",
1763            msg
1764        );
1765    }
1766
1767    #[tokio::test]
1768    async fn test_extractor_handler_with_multiple_layers() {
1769        use crate::ToolBuilder;
1770        use std::time::Duration;
1771        use tower::limit::ConcurrencyLimitLayer;
1772        use tower::timeout::TimeoutLayer;
1773
1774        let state = Arc::new("multi".to_string());
1775
1776        let tool = ToolBuilder::new("test_multi_layer")
1777            .description("Test multiple layers")
1778            .extractor_handler(
1779                state,
1780                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1781                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1782                },
1783            )
1784            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1785            .layer(ConcurrencyLimitLayer::new(10))
1786            .build();
1787
1788        let result = tool
1789            .call(serde_json::json!({"name": "test", "count": 1}))
1790            .await;
1791        assert!(!result.is_error);
1792        assert_eq!(result.first_text().unwrap(), "multi: test");
1793    }
1794}