Skip to main content

tower_mcp/
extract.rs

1//! Extractor pattern for tool handlers
2//!
3//! This module provides an axum-inspired extractor pattern that makes state and context
4//! injection more declarative, reducing the combinatorial explosion of handler variants.
5//!
6//! # Overview
7//!
8//! Extractors implement [`FromToolRequest`], which extracts data from the tool request
9//! (context, state, and arguments). Multiple extractors can be combined in handler
10//! function parameters.
11//!
12//! # Built-in Extractors
13//!
14//! - [`Json<T>`] - Extract typed input from args (deserializes JSON)
15//! - [`State<T>`] - Extract shared state from per-tool state (cloned for each request)
16//! - [`Extension<T>`] - Extract data from router extensions (via `router.with_state()`)
17//! - [`Context`] - Extract the [`RequestContext`] for progress, cancellation, etc.
18//! - [`RawArgs`] - Extract raw `serde_json::Value` arguments
19//!
20//! ## State vs Extension
21//!
22//! - Use **`State<T>`** when state is passed directly to `extractor_handler()` (per-tool state)
23//! - Use **`Extension<T>`** when state is set via `McpRouter::with_state()` (router-level state)
24//!
25//! # Example
26//!
27//! ```rust
28//! use std::sync::Arc;
29//! use tower_mcp::{ToolBuilder, CallToolResult};
30//! use tower_mcp::extract::{Json, State, Context};
31//! use schemars::JsonSchema;
32//! use serde::Deserialize;
33//!
34//! #[derive(Clone)]
35//! struct AppState {
36//!     db_url: String,
37//! }
38//!
39//! #[derive(Debug, Deserialize, JsonSchema)]
40//! struct QueryInput {
41//!     query: String,
42//! }
43//!
44//! let state = Arc::new(AppState { db_url: "postgres://...".to_string() });
45//!
46//! let tool = ToolBuilder::new("search")
47//!     .description("Search the database")
48//!     .extractor_handler(state, |
49//!         State(db): State<Arc<AppState>>,
50//!         ctx: Context,
51//!         Json(input): Json<QueryInput>,
52//!     | async move {
53//!         // Check cancellation
54//!         if ctx.is_cancelled() {
55//!             return Ok(CallToolResult::error("Cancelled"));
56//!         }
57//!         // Report progress
58//!         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
59//!         // Use state
60//!         Ok(CallToolResult::text(format!("Searched {} with query: {}", db.db_url, input.query)))
61//!     })
62//!     .build()
63//!     .unwrap();
64//! ```
65//!
66//! # Extractor Order
67//!
68//! The order of extractors in the function signature doesn't matter. Each extractor
69//! independently extracts its data from the request.
70//!
71//! # Error Handling
72//!
73//! If an extractor fails (e.g., JSON deserialization fails), the handler returns
74//! a `CallToolResult::error()` with the rejection message.
75
76use std::future::Future;
77use std::marker::PhantomData;
78use std::ops::Deref;
79use std::pin::Pin;
80
81use schemars::JsonSchema;
82use serde::de::DeserializeOwned;
83use serde_json::Value;
84
85use crate::context::RequestContext;
86use crate::error::{Error, Result};
87use crate::protocol::CallToolResult;
88
89// =============================================================================
90// Rejection Types
91// =============================================================================
92
93/// A simple rejection with a message string.
94///
95/// This is a general-purpose rejection type for custom extractors.
96/// For more specific error information, use the typed rejection types
97/// like [`JsonRejection`] or [`ExtensionRejection`].
98#[derive(Debug, Clone)]
99pub struct Rejection {
100    message: String,
101}
102
103impl Rejection {
104    /// Create a new rejection with the given message.
105    pub fn new(message: impl Into<String>) -> Self {
106        Self {
107            message: message.into(),
108        }
109    }
110
111    /// Get the rejection message.
112    pub fn message(&self) -> &str {
113        &self.message
114    }
115}
116
117impl std::fmt::Display for Rejection {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "{}", self.message)
120    }
121}
122
123impl std::error::Error for Rejection {}
124
125impl From<Rejection> for Error {
126    fn from(rejection: Rejection) -> Self {
127        Error::tool(rejection.message)
128    }
129}
130
131/// Rejection returned when JSON deserialization fails.
132///
133/// This rejection provides structured information about the deserialization
134/// error, including the path to the failing field when available.
135///
136/// # Example
137///
138/// ```rust
139/// use tower_mcp::extract::JsonRejection;
140///
141/// let rejection = JsonRejection::new("missing field `name`");
142/// assert!(rejection.message().contains("name"));
143/// ```
144#[derive(Debug, Clone)]
145pub struct JsonRejection {
146    message: String,
147    /// The serde error path, if available (e.g., "users[0].name")
148    path: Option<String>,
149}
150
151impl JsonRejection {
152    /// Create a new JSON rejection from a serde error.
153    pub fn new(message: impl Into<String>) -> Self {
154        Self {
155            message: message.into(),
156            path: None,
157        }
158    }
159
160    /// Create a JSON rejection with a path to the failing field.
161    pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
162        Self {
163            message: message.into(),
164            path: Some(path.into()),
165        }
166    }
167
168    /// Get the error message.
169    pub fn message(&self) -> &str {
170        &self.message
171    }
172
173    /// Get the path to the failing field, if available.
174    pub fn path(&self) -> Option<&str> {
175        self.path.as_deref()
176    }
177}
178
179impl std::fmt::Display for JsonRejection {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        if let Some(path) = &self.path {
182            write!(f, "Invalid input at `{}`: {}", path, self.message)
183        } else {
184            write!(f, "Invalid input: {}", self.message)
185        }
186    }
187}
188
189impl std::error::Error for JsonRejection {}
190
191impl From<JsonRejection> for Error {
192    fn from(rejection: JsonRejection) -> Self {
193        Error::tool(rejection.to_string())
194    }
195}
196
197impl From<serde_json::Error> for JsonRejection {
198    fn from(err: serde_json::Error) -> Self {
199        // Try to extract path information from serde error
200        let path = if err.is_data() {
201            // serde_json provides line/column but not field path in the error itself
202            // The path is embedded in the message for some error types
203            None
204        } else {
205            None
206        };
207
208        Self {
209            message: err.to_string(),
210            path,
211        }
212    }
213}
214
215/// Rejection returned when an extension is not found.
216///
217/// This rejection is returned by the [`Extension`] extractor when the
218/// requested type is not present in the router's extensions.
219///
220/// # Example
221///
222/// ```rust
223/// use tower_mcp::extract::ExtensionRejection;
224///
225/// let rejection = ExtensionRejection::not_found::<String>();
226/// assert!(rejection.type_name().contains("String"));
227/// ```
228#[derive(Debug, Clone)]
229pub struct ExtensionRejection {
230    type_name: &'static str,
231}
232
233impl ExtensionRejection {
234    /// Create a rejection for a missing extension type.
235    pub fn not_found<T>() -> Self {
236        Self {
237            type_name: std::any::type_name::<T>(),
238        }
239    }
240
241    /// Get the type name of the missing extension.
242    pub fn type_name(&self) -> &'static str {
243        self.type_name
244    }
245}
246
247impl std::fmt::Display for ExtensionRejection {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        write!(
250            f,
251            "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
252            self.type_name
253        )
254    }
255}
256
257impl std::error::Error for ExtensionRejection {}
258
259impl From<ExtensionRejection> for Error {
260    fn from(rejection: ExtensionRejection) -> Self {
261        Error::tool(rejection.to_string())
262    }
263}
264
265/// Trait for extracting data from a tool request.
266///
267/// Implement this trait to create custom extractors that can be used
268/// in `extractor_handler` functions.
269///
270/// # Type Parameters
271///
272/// - `S` - The state type. Defaults to `()` for extractors that don't need state.
273///
274/// # Example
275///
276/// ```rust
277/// use tower_mcp::extract::{FromToolRequest, Rejection};
278/// use tower_mcp::RequestContext;
279/// use serde_json::Value;
280///
281/// struct RequestId(String);
282///
283/// impl<S> FromToolRequest<S> for RequestId {
284///     type Rejection = Rejection;
285///
286///     fn from_tool_request(
287///         ctx: &RequestContext,
288///         _state: &S,
289///         _args: &Value,
290///     ) -> Result<Self, Self::Rejection> {
291///         Ok(RequestId(format!("{:?}", ctx.request_id())))
292///     }
293/// }
294/// ```
295pub trait FromToolRequest<S = ()>: Sized {
296    /// The rejection type returned when extraction fails.
297    type Rejection: Into<Error>;
298
299    /// Extract this type from the tool request.
300    ///
301    /// # Arguments
302    ///
303    /// * `ctx` - The request context with progress, cancellation, etc.
304    /// * `state` - The shared state passed to the handler
305    /// * `args` - The raw JSON arguments to the tool
306    fn from_tool_request(
307        ctx: &RequestContext,
308        state: &S,
309        args: &Value,
310    ) -> std::result::Result<Self, Self::Rejection>;
311}
312
313// =============================================================================
314// Built-in Extractors
315// =============================================================================
316
317/// Extract and deserialize JSON arguments into a typed struct.
318///
319/// This extractor deserializes the tool's JSON arguments into type `T`.
320/// The type must implement [`serde::de::DeserializeOwned`] and [`schemars::JsonSchema`].
321///
322/// # Example
323///
324/// ```rust
325/// use tower_mcp::extract::Json;
326/// use schemars::JsonSchema;
327/// use serde::Deserialize;
328///
329/// #[derive(Debug, Deserialize, JsonSchema)]
330/// struct MyInput {
331///     name: String,
332///     count: i32,
333/// }
334///
335/// // In an extractor handler:
336/// // |Json(input): Json<MyInput>| async move { ... }
337/// ```
338///
339/// # Rejection
340///
341/// Returns a [`JsonRejection`] if deserialization fails. The rejection contains
342/// the error message and potentially the path to the failing field.
343#[derive(Debug, Clone, Copy)]
344pub struct Json<T>(pub T);
345
346impl<T> Deref for Json<T> {
347    type Target = T;
348
349    fn deref(&self) -> &Self::Target {
350        &self.0
351    }
352}
353
354impl<S, T> FromToolRequest<S> for Json<T>
355where
356    T: DeserializeOwned,
357{
358    type Rejection = JsonRejection;
359
360    fn from_tool_request(
361        _ctx: &RequestContext,
362        _state: &S,
363        args: &Value,
364    ) -> std::result::Result<Self, Self::Rejection> {
365        serde_json::from_value(args.clone())
366            .map(Json)
367            .map_err(JsonRejection::from)
368    }
369}
370
371/// Extract shared state.
372///
373/// This extractor clones the state passed to `extractor_handler` and provides
374/// it to the handler. The state type must match the type passed to the builder.
375///
376/// # Example
377///
378/// ```rust
379/// use std::sync::Arc;
380/// use tower_mcp::extract::State;
381///
382/// #[derive(Clone)]
383/// struct AppState {
384///     db_url: String,
385/// }
386///
387/// // In an extractor handler:
388/// // |State(state): State<Arc<AppState>>| async move { ... }
389/// ```
390///
391/// # Note
392///
393/// For expensive-to-clone types, wrap them in `Arc` before passing to
394/// `extractor_handler`.
395#[derive(Debug, Clone, Copy)]
396pub struct State<T>(pub T);
397
398impl<T> Deref for State<T> {
399    type Target = T;
400
401    fn deref(&self) -> &Self::Target {
402        &self.0
403    }
404}
405
406impl<S: Clone> FromToolRequest<S> for State<S> {
407    type Rejection = Rejection;
408
409    fn from_tool_request(
410        _ctx: &RequestContext,
411        state: &S,
412        _args: &Value,
413    ) -> std::result::Result<Self, Self::Rejection> {
414        Ok(State(state.clone()))
415    }
416}
417
418/// Extract the request context.
419///
420/// This extractor provides access to the [`RequestContext`], which contains:
421/// - Progress reporting via `report_progress()`
422/// - Cancellation checking via `is_cancelled()`
423/// - Sampling capabilities via `sample()`
424/// - Elicitation capabilities via `elicit_form()` and `elicit_url()`
425/// - Log sending via `send_log()`
426///
427/// # Example
428///
429/// ```rust
430/// use tower_mcp::extract::Context;
431///
432/// // In an extractor handler:
433/// // |ctx: Context| async move {
434/// //     ctx.report_progress(0.5, Some(1.0), Some("Working...")).await;
435/// //     // ...
436/// // }
437/// ```
438#[derive(Debug, Clone)]
439pub struct Context(RequestContext);
440
441impl Context {
442    /// Get the inner RequestContext
443    pub fn into_inner(self) -> RequestContext {
444        self.0
445    }
446}
447
448impl Deref for Context {
449    type Target = RequestContext;
450
451    fn deref(&self) -> &Self::Target {
452        &self.0
453    }
454}
455
456impl<S> FromToolRequest<S> for Context {
457    type Rejection = Rejection;
458
459    fn from_tool_request(
460        ctx: &RequestContext,
461        _state: &S,
462        _args: &Value,
463    ) -> std::result::Result<Self, Self::Rejection> {
464        Ok(Context(ctx.clone()))
465    }
466}
467
468/// Extract raw JSON arguments.
469///
470/// This extractor provides the raw `serde_json::Value` arguments without
471/// any deserialization. Useful when you need full control over argument
472/// parsing or when the schema is dynamic.
473///
474/// # Example
475///
476/// ```rust
477/// use tower_mcp::extract::RawArgs;
478///
479/// // In an extractor handler:
480/// // |RawArgs(args): RawArgs| async move {
481/// //     // args is serde_json::Value
482/// //     if let Some(name) = args.get("name") { ... }
483/// // }
484/// ```
485#[derive(Debug, Clone)]
486pub struct RawArgs(pub Value);
487
488impl Deref for RawArgs {
489    type Target = Value;
490
491    fn deref(&self) -> &Self::Target {
492        &self.0
493    }
494}
495
496impl<S> FromToolRequest<S> for RawArgs {
497    type Rejection = Rejection;
498
499    fn from_tool_request(
500        _ctx: &RequestContext,
501        _state: &S,
502        args: &Value,
503    ) -> std::result::Result<Self, Self::Rejection> {
504        Ok(RawArgs(args.clone()))
505    }
506}
507
508/// Extract typed data from router extensions.
509///
510/// This extractor retrieves data that was added to the router via
511/// [`crate::McpRouter::with_state()`] or [`crate::McpRouter::with_extension()`], or
512/// inserted by middleware into the request context's extensions.
513///
514/// # Example
515///
516/// ```rust
517/// use std::sync::Arc;
518/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
519/// use tower_mcp::extract::{Extension, Json};
520/// use schemars::JsonSchema;
521/// use serde::Deserialize;
522///
523/// #[derive(Clone)]
524/// struct DatabasePool {
525///     url: String,
526/// }
527///
528/// #[derive(Deserialize, JsonSchema)]
529/// struct QueryInput {
530///     sql: String,
531/// }
532///
533/// let pool = Arc::new(DatabasePool { url: "postgres://...".into() });
534///
535/// let tool = ToolBuilder::new("query")
536///     .description("Run a query")
537///     .extractor_handler_typed::<_, _, _, QueryInput>(
538///         (),
539///         |Extension(db): Extension<Arc<DatabasePool>>, Json(input): Json<QueryInput>| async move {
540///             Ok(CallToolResult::text(format!("Query on {}: {}", db.url, input.sql)))
541///         },
542///     )
543///     .build()
544///     .unwrap();
545///
546/// let router = McpRouter::new()
547///     .with_state(pool)
548///     .tool(tool);
549/// ```
550///
551/// # Rejection
552///
553/// Returns an [`ExtensionRejection`] if the requested type is not found in the extensions.
554/// The rejection contains the type name of the missing extension.
555#[derive(Debug, Clone)]
556pub struct Extension<T>(pub T);
557
558impl<T> Deref for Extension<T> {
559    type Target = T;
560
561    fn deref(&self) -> &Self::Target {
562        &self.0
563    }
564}
565
566impl<S, T> FromToolRequest<S> for Extension<T>
567where
568    T: Clone + Send + Sync + 'static,
569{
570    type Rejection = ExtensionRejection;
571
572    fn from_tool_request(
573        ctx: &RequestContext,
574        _state: &S,
575        _args: &Value,
576    ) -> std::result::Result<Self, Self::Rejection> {
577        ctx.extension::<T>()
578            .cloned()
579            .map(Extension)
580            .ok_or_else(ExtensionRejection::not_found::<T>)
581    }
582}
583
584// =============================================================================
585// Handler Trait
586// =============================================================================
587
588/// A handler that uses extractors.
589///
590/// This trait is implemented for functions that take extractors as arguments.
591/// You don't need to implement this trait directly; it's automatically
592/// implemented for compatible async functions.
593pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
594    /// The future returned by the handler.
595    type Future: Future<Output = Result<CallToolResult>> + Send;
596
597    /// Call the handler with extracted values.
598    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
599
600    /// Get the input schema for this handler.
601    ///
602    /// Returns `None` if no `Json<T>` extractor is used.
603    fn input_schema() -> Value;
604}
605
606// Implementation for single extractor
607impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
608where
609    S: Clone + Send + Sync + 'static,
610    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
611    Fut: Future<Output = Result<CallToolResult>> + Send,
612    T1: FromToolRequest<S> + Send,
613{
614    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
615
616    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
617        Box::pin(async move {
618            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
619            self(t1).await
620        })
621    }
622
623    fn input_schema() -> Value {
624        // For single extractors, check if it's Json<T>
625        serde_json::json!({
626            "type": "object",
627            "additionalProperties": true
628        })
629    }
630}
631
632// Implementation for two extractors
633impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635    S: Clone + Send + Sync + 'static,
636    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637    Fut: Future<Output = Result<CallToolResult>> + Send,
638    T1: FromToolRequest<S> + Send,
639    T2: FromToolRequest<S> + Send,
640{
641    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644        Box::pin(async move {
645            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647            self(t1, t2).await
648        })
649    }
650
651    fn input_schema() -> Value {
652        serde_json::json!({
653            "type": "object",
654            "additionalProperties": true
655        })
656    }
657}
658
659// Implementation for three extractors
660impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
661where
662    S: Clone + Send + Sync + 'static,
663    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
664    Fut: Future<Output = Result<CallToolResult>> + Send,
665    T1: FromToolRequest<S> + Send,
666    T2: FromToolRequest<S> + Send,
667    T3: FromToolRequest<S> + Send,
668{
669    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
670
671    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
672        Box::pin(async move {
673            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
674            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
675            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
676            self(t1, t2, t3).await
677        })
678    }
679
680    fn input_schema() -> Value {
681        serde_json::json!({
682            "type": "object",
683            "additionalProperties": true
684        })
685    }
686}
687
688// Implementation for four extractors
689impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
690where
691    S: Clone + Send + Sync + 'static,
692    F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
693    Fut: Future<Output = Result<CallToolResult>> + Send,
694    T1: FromToolRequest<S> + Send,
695    T2: FromToolRequest<S> + Send,
696    T3: FromToolRequest<S> + Send,
697    T4: FromToolRequest<S> + Send,
698{
699    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
700
701    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
702        Box::pin(async move {
703            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
704            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
705            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
706            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
707            self(t1, t2, t3, t4).await
708        })
709    }
710
711    fn input_schema() -> Value {
712        serde_json::json!({
713            "type": "object",
714            "additionalProperties": true
715        })
716    }
717}
718
719// Implementation for five extractors
720impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
721where
722    S: Clone + Send + Sync + 'static,
723    F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
724    Fut: Future<Output = Result<CallToolResult>> + Send,
725    T1: FromToolRequest<S> + Send,
726    T2: FromToolRequest<S> + Send,
727    T3: FromToolRequest<S> + Send,
728    T4: FromToolRequest<S> + Send,
729    T5: FromToolRequest<S> + Send,
730{
731    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
732
733    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
734        Box::pin(async move {
735            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
736            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
737            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
738            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
739            let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
740            self(t1, t2, t3, t4, t5).await
741        })
742    }
743
744    fn input_schema() -> Value {
745        serde_json::json!({
746            "type": "object",
747            "additionalProperties": true
748        })
749    }
750}
751
752// =============================================================================
753// Schema Extraction Helper
754// =============================================================================
755
756/// Helper trait to get schema from `Json<T>` extractor
757pub trait HasSchema {
758    fn schema() -> Option<Value>;
759}
760
761impl<T: JsonSchema> HasSchema for Json<T> {
762    fn schema() -> Option<Value> {
763        let schema = schemars::schema_for!(T);
764        serde_json::to_value(schema).ok()
765    }
766}
767
768// Default impl for non-Json extractors
769impl HasSchema for Context {
770    fn schema() -> Option<Value> {
771        None
772    }
773}
774
775impl HasSchema for RawArgs {
776    fn schema() -> Option<Value> {
777        None
778    }
779}
780
781impl<T> HasSchema for State<T> {
782    fn schema() -> Option<Value> {
783        None
784    }
785}
786
787// =============================================================================
788// Typed Extractor Handler
789// =============================================================================
790
791/// A handler that uses extractors with typed JSON input.
792///
793/// This trait is similar to [`ExtractorHandler`] but provides proper JSON
794/// schema generation for the input type when `Json<T>` is used.
795pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
796where
797    I: JsonSchema,
798{
799    /// The future returned by the handler.
800    type Future: Future<Output = Result<CallToolResult>> + Send;
801
802    /// Call the handler with extracted values.
803    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
804}
805
806// Single extractor with Json<T>
807impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
808where
809    S: Clone + Send + Sync + 'static,
810    F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
811    Fut: Future<Output = Result<CallToolResult>> + Send,
812    T: DeserializeOwned + JsonSchema + Send,
813{
814    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
815
816    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
817        Box::pin(async move {
818            let t1 =
819                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
820            self(t1).await
821        })
822    }
823}
824
825// Two extractors ending with Json<T>
826impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
827where
828    S: Clone + Send + Sync + 'static,
829    F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
830    Fut: Future<Output = Result<CallToolResult>> + Send,
831    T1: FromToolRequest<S> + Send,
832    T: DeserializeOwned + JsonSchema + Send,
833{
834    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
835
836    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
837        Box::pin(async move {
838            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
839            let t2 =
840                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
841            self(t1, t2).await
842        })
843    }
844}
845
846// Three extractors ending with Json<T>
847impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
848where
849    S: Clone + Send + Sync + 'static,
850    F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
851    Fut: Future<Output = Result<CallToolResult>> + Send,
852    T1: FromToolRequest<S> + Send,
853    T2: FromToolRequest<S> + Send,
854    T: DeserializeOwned + JsonSchema + Send,
855{
856    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
857
858    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
859        Box::pin(async move {
860            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
861            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
862            let t3 =
863                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
864            self(t1, t2, t3).await
865        })
866    }
867}
868
869// Four extractors ending with Json<T>
870impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
871where
872    S: Clone + Send + Sync + 'static,
873    F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
874    Fut: Future<Output = Result<CallToolResult>> + Send,
875    T1: FromToolRequest<S> + Send,
876    T2: FromToolRequest<S> + Send,
877    T3: FromToolRequest<S> + Send,
878    T: DeserializeOwned + JsonSchema + Send,
879{
880    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
881
882    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
883        Box::pin(async move {
884            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
885            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
886            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
887            let t4 =
888                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
889            self(t1, t2, t3, t4).await
890        })
891    }
892}
893
894// =============================================================================
895// ToolBuilder Extensions
896// =============================================================================
897
898use crate::tool::{BoxFuture, Tool, ToolCatchError, ToolHandler, validate_tool_name};
899use tower::util::BoxCloneService;
900
901/// Internal handler wrapper for extractor-based handlers
902pub(crate) struct ExtractorToolHandler<S, F, T> {
903    state: S,
904    handler: F,
905    input_schema: Value,
906    _phantom: PhantomData<T>,
907}
908
909impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
910where
911    S: Clone + Send + Sync + 'static,
912    F: ExtractorHandler<S, T> + Clone,
913    T: Send + Sync + 'static,
914{
915    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
916        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
917        self.call_with_context(ctx, args)
918    }
919
920    fn call_with_context(
921        &self,
922        ctx: RequestContext,
923        args: Value,
924    ) -> BoxFuture<'_, Result<CallToolResult>> {
925        let state = self.state.clone();
926        let handler = self.handler.clone();
927        Box::pin(async move { handler.call(ctx, state, args).await })
928    }
929
930    fn uses_context(&self) -> bool {
931        true
932    }
933
934    fn input_schema(&self) -> Value {
935        self.input_schema.clone()
936    }
937}
938
939/// Builder state for extractor-based handlers
940pub struct ToolBuilderWithExtractor<S, F, T> {
941    pub(crate) name: String,
942    pub(crate) title: Option<String>,
943    pub(crate) description: Option<String>,
944    pub(crate) output_schema: Option<Value>,
945    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
946    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
947    pub(crate) state: S,
948    pub(crate) handler: F,
949    pub(crate) input_schema: Value,
950    pub(crate) _phantom: PhantomData<T>,
951}
952
953impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
954where
955    S: Clone + Send + Sync + 'static,
956    F: ExtractorHandler<S, T> + Clone,
957    T: Send + Sync + 'static,
958{
959    /// Build the tool.
960    ///
961    /// Returns an error if the tool name is invalid.
962    pub fn build(self) -> Result<Tool> {
963        validate_tool_name(&self.name)?;
964
965        let handler = ExtractorToolHandler {
966            state: self.state,
967            handler: self.handler,
968            input_schema: self.input_schema.clone(),
969            _phantom: PhantomData,
970        };
971
972        let handler_service = crate::tool::ToolHandlerService::new(handler);
973        let catch_error = ToolCatchError::new(handler_service);
974        let service = BoxCloneService::new(catch_error);
975
976        Ok(Tool {
977            name: self.name,
978            title: self.title,
979            description: self.description,
980            output_schema: self.output_schema,
981            icons: self.icons,
982            annotations: self.annotations,
983            service,
984            input_schema: self.input_schema,
985        })
986    }
987}
988
989/// Builder state for extractor-based handlers with typed JSON input
990pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
991    pub(crate) name: String,
992    pub(crate) title: Option<String>,
993    pub(crate) description: Option<String>,
994    pub(crate) output_schema: Option<Value>,
995    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
996    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
997    pub(crate) state: S,
998    pub(crate) handler: F,
999    pub(crate) _phantom: PhantomData<(T, I)>,
1000}
1001
1002impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1003where
1004    S: Clone + Send + Sync + 'static,
1005    F: TypedExtractorHandler<S, T, I> + Clone,
1006    T: Send + Sync + 'static,
1007    I: JsonSchema + Send + Sync + 'static,
1008{
1009    /// Build the tool.
1010    ///
1011    /// Returns an error if the tool name is invalid.
1012    pub fn build(self) -> Result<Tool> {
1013        validate_tool_name(&self.name)?;
1014
1015        let input_schema = {
1016            let schema = schemars::schema_for!(I);
1017            serde_json::to_value(schema).unwrap_or_else(|_| {
1018                serde_json::json!({
1019                    "type": "object"
1020                })
1021            })
1022        };
1023
1024        let handler = TypedExtractorToolHandler {
1025            state: self.state,
1026            handler: self.handler,
1027            input_schema: input_schema.clone(),
1028            _phantom: PhantomData,
1029        };
1030
1031        let handler_service = crate::tool::ToolHandlerService::new(handler);
1032        let catch_error = ToolCatchError::new(handler_service);
1033        let service = BoxCloneService::new(catch_error);
1034
1035        Ok(Tool {
1036            name: self.name,
1037            title: self.title,
1038            description: self.description,
1039            output_schema: self.output_schema,
1040            icons: self.icons,
1041            annotations: self.annotations,
1042            service,
1043            input_schema,
1044        })
1045    }
1046}
1047
1048/// Internal handler wrapper for typed extractor-based handlers
1049struct TypedExtractorToolHandler<S, F, T, I> {
1050    state: S,
1051    handler: F,
1052    input_schema: Value,
1053    _phantom: PhantomData<(T, I)>,
1054}
1055
1056impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1057where
1058    S: Clone + Send + Sync + 'static,
1059    F: TypedExtractorHandler<S, T, I> + Clone,
1060    T: Send + Sync + 'static,
1061    I: JsonSchema + Send + Sync + 'static,
1062{
1063    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1064        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1065        self.call_with_context(ctx, args)
1066    }
1067
1068    fn call_with_context(
1069        &self,
1070        ctx: RequestContext,
1071        args: Value,
1072    ) -> BoxFuture<'_, Result<CallToolResult>> {
1073        let state = self.state.clone();
1074        let handler = self.handler.clone();
1075        Box::pin(async move { handler.call(ctx, state, args).await })
1076    }
1077
1078    fn uses_context(&self) -> bool {
1079        true
1080    }
1081
1082    fn input_schema(&self) -> Value {
1083        self.input_schema.clone()
1084    }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089    use super::*;
1090    use crate::protocol::RequestId;
1091    use schemars::JsonSchema;
1092    use serde::Deserialize;
1093    use std::sync::Arc;
1094
1095    #[derive(Debug, Deserialize, JsonSchema)]
1096    struct TestInput {
1097        name: String,
1098        count: i32,
1099    }
1100
1101    #[test]
1102    fn test_json_extraction() {
1103        let args = serde_json::json!({"name": "test", "count": 42});
1104        let ctx = RequestContext::new(RequestId::Number(1));
1105
1106        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1107        assert!(result.is_ok());
1108        let Json(input) = result.unwrap();
1109        assert_eq!(input.name, "test");
1110        assert_eq!(input.count, 42);
1111    }
1112
1113    #[test]
1114    fn test_json_extraction_error() {
1115        let args = serde_json::json!({"name": "test"}); // missing count
1116        let ctx = RequestContext::new(RequestId::Number(1));
1117
1118        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1119        assert!(result.is_err());
1120        let rejection = result.unwrap_err();
1121        // JsonRejection contains the serde error message
1122        assert!(rejection.message().contains("count"));
1123    }
1124
1125    #[test]
1126    fn test_state_extraction() {
1127        let args = serde_json::json!({});
1128        let ctx = RequestContext::new(RequestId::Number(1));
1129        let state = Arc::new("my-state".to_string());
1130
1131        let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1132        assert!(result.is_ok());
1133        let State(extracted) = result.unwrap();
1134        assert_eq!(*extracted, "my-state");
1135    }
1136
1137    #[test]
1138    fn test_context_extraction() {
1139        let args = serde_json::json!({});
1140        let ctx = RequestContext::new(RequestId::Number(42));
1141
1142        let result = Context::from_tool_request(&ctx, &(), &args);
1143        assert!(result.is_ok());
1144        let extracted = result.unwrap();
1145        assert_eq!(*extracted.request_id(), RequestId::Number(42));
1146    }
1147
1148    #[test]
1149    fn test_raw_args_extraction() {
1150        let args = serde_json::json!({"foo": "bar", "baz": 123});
1151        let ctx = RequestContext::new(RequestId::Number(1));
1152
1153        let result = RawArgs::from_tool_request(&ctx, &(), &args);
1154        assert!(result.is_ok());
1155        let RawArgs(extracted) = result.unwrap();
1156        assert_eq!(extracted["foo"], "bar");
1157        assert_eq!(extracted["baz"], 123);
1158    }
1159
1160    #[test]
1161    fn test_extension_extraction() {
1162        use crate::context::Extensions;
1163
1164        #[derive(Clone, Debug, PartialEq)]
1165        struct DatabasePool {
1166            url: String,
1167        }
1168
1169        let args = serde_json::json!({});
1170
1171        // Create extensions with a value
1172        let mut extensions = Extensions::new();
1173        extensions.insert(Arc::new(DatabasePool {
1174            url: "postgres://localhost".to_string(),
1175        }));
1176
1177        // Create context with extensions
1178        let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1179
1180        // Extract the extension
1181        let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1182        assert!(result.is_ok());
1183        let Extension(pool) = result.unwrap();
1184        assert_eq!(pool.url, "postgres://localhost");
1185    }
1186
1187    #[test]
1188    fn test_extension_extraction_missing() {
1189        #[derive(Clone, Debug)]
1190        struct NotPresent;
1191
1192        let args = serde_json::json!({});
1193        let ctx = RequestContext::new(RequestId::Number(1));
1194
1195        // Try to extract something that's not in extensions
1196        let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1197        assert!(result.is_err());
1198        let rejection = result.unwrap_err();
1199        // ExtensionRejection contains the type name
1200        assert!(rejection.type_name().contains("NotPresent"));
1201    }
1202
1203    #[tokio::test]
1204    async fn test_single_extractor_handler() {
1205        let handler = |Json(input): Json<TestInput>| async move {
1206            Ok(CallToolResult::text(format!(
1207                "{}: {}",
1208                input.name, input.count
1209            )))
1210        };
1211
1212        let ctx = RequestContext::new(RequestId::Number(1));
1213        let args = serde_json::json!({"name": "test", "count": 5});
1214
1215        // Use explicit trait to avoid ambiguity
1216        let result: Result<CallToolResult> =
1217            ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1218        assert!(result.is_ok());
1219    }
1220
1221    #[tokio::test]
1222    async fn test_two_extractor_handler() {
1223        let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1224            Ok(CallToolResult::text(format!(
1225                "{}: {} - {}",
1226                state, input.name, input.count
1227            )))
1228        };
1229
1230        let ctx = RequestContext::new(RequestId::Number(1));
1231        let state = Arc::new("prefix".to_string());
1232        let args = serde_json::json!({"name": "test", "count": 5});
1233
1234        // Use explicit trait to avoid ambiguity
1235        let result: Result<CallToolResult> = ExtractorHandler::<
1236            Arc<String>,
1237            (State<Arc<String>>, Json<TestInput>),
1238        >::call(handler, ctx, state, args)
1239        .await;
1240        assert!(result.is_ok());
1241    }
1242
1243    #[tokio::test]
1244    async fn test_three_extractor_handler() {
1245        let handler = |State(state): State<Arc<String>>,
1246                       ctx: Context,
1247                       Json(input): Json<TestInput>| async move {
1248            // Verify we can access all extractors
1249            assert!(!ctx.is_cancelled());
1250            Ok(CallToolResult::text(format!(
1251                "{}: {} - {}",
1252                state, input.name, input.count
1253            )))
1254        };
1255
1256        let ctx = RequestContext::new(RequestId::Number(1));
1257        let state = Arc::new("prefix".to_string());
1258        let args = serde_json::json!({"name": "test", "count": 5});
1259
1260        // Use explicit trait to avoid ambiguity
1261        let result: Result<CallToolResult> = ExtractorHandler::<
1262            Arc<String>,
1263            (State<Arc<String>>, Context, Json<TestInput>),
1264        >::call(handler, ctx, state, args)
1265        .await;
1266        assert!(result.is_ok());
1267    }
1268
1269    #[test]
1270    fn test_json_schema_generation() {
1271        let schema = Json::<TestInput>::schema();
1272        assert!(schema.is_some());
1273        let schema = schema.unwrap();
1274        assert!(schema.get("properties").is_some());
1275    }
1276
1277    #[test]
1278    fn test_rejection_into_error() {
1279        let rejection = Rejection::new("test error");
1280        let error: Error = rejection.into();
1281        assert!(error.to_string().contains("test error"));
1282    }
1283
1284    #[test]
1285    fn test_json_rejection() {
1286        // Test basic JsonRejection
1287        let rejection = JsonRejection::new("missing field `name`");
1288        assert_eq!(rejection.message(), "missing field `name`");
1289        assert!(rejection.path().is_none());
1290        assert!(rejection.to_string().contains("Invalid input"));
1291
1292        // Test JsonRejection with path
1293        let rejection = JsonRejection::with_path("expected string", "users[0].name");
1294        assert_eq!(rejection.message(), "expected string");
1295        assert_eq!(rejection.path(), Some("users[0].name"));
1296        assert!(rejection.to_string().contains("users[0].name"));
1297
1298        // Test conversion to Error
1299        let error: Error = rejection.into();
1300        assert!(error.to_string().contains("users[0].name"));
1301    }
1302
1303    #[test]
1304    fn test_json_rejection_from_serde_error() {
1305        // Create a real serde error by deserializing invalid JSON
1306        #[derive(Debug, serde::Deserialize)]
1307        struct TestStruct {
1308            #[allow(dead_code)]
1309            name: String,
1310        }
1311
1312        let result: std::result::Result<TestStruct, _> =
1313            serde_json::from_value(serde_json::json!({"count": 42}));
1314        assert!(result.is_err());
1315
1316        let rejection: JsonRejection = result.unwrap_err().into();
1317        assert!(rejection.message().contains("name"));
1318    }
1319
1320    #[test]
1321    fn test_extension_rejection() {
1322        // Test ExtensionRejection
1323        let rejection = ExtensionRejection::not_found::<String>();
1324        assert!(rejection.type_name().contains("String"));
1325        assert!(rejection.to_string().contains("not found"));
1326        assert!(rejection.to_string().contains("with_state"));
1327
1328        // Test conversion to Error
1329        let error: Error = rejection.into();
1330        assert!(error.to_string().contains("not found"));
1331    }
1332
1333    #[tokio::test]
1334    async fn test_tool_builder_extractor_handler() {
1335        use crate::ToolBuilder;
1336
1337        let state = Arc::new("shared-state".to_string());
1338
1339        let tool =
1340            ToolBuilder::new("test_extractor")
1341                .description("Test extractor handler")
1342                .extractor_handler(
1343                    state,
1344                    |State(state): State<Arc<String>>,
1345                     ctx: Context,
1346                     Json(input): Json<TestInput>| async move {
1347                        assert!(!ctx.is_cancelled());
1348                        Ok(CallToolResult::text(format!(
1349                            "{}: {} - {}",
1350                            state, input.name, input.count
1351                        )))
1352                    },
1353                )
1354                .build()
1355                .expect("valid tool name");
1356
1357        assert_eq!(tool.name, "test_extractor");
1358        assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1359
1360        // Test calling the tool
1361        let result = tool
1362            .call(serde_json::json!({"name": "test", "count": 42}))
1363            .await;
1364        assert!(!result.is_error);
1365    }
1366
1367    #[tokio::test]
1368    async fn test_tool_builder_extractor_handler_typed() {
1369        use crate::ToolBuilder;
1370
1371        let state = Arc::new("typed-state".to_string());
1372
1373        let tool = ToolBuilder::new("test_typed")
1374            .description("Test typed extractor handler")
1375            .extractor_handler_typed::<_, _, _, TestInput>(
1376                state,
1377                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1378                    Ok(CallToolResult::text(format!(
1379                        "{}: {} - {}",
1380                        state, input.name, input.count
1381                    )))
1382                },
1383            )
1384            .build()
1385            .expect("valid tool name");
1386
1387        assert_eq!(tool.name, "test_typed");
1388
1389        // Verify schema is properly generated from TestInput
1390        let def = tool.definition();
1391        let schema = def.input_schema;
1392        assert!(schema.get("properties").is_some());
1393
1394        // Test calling the tool
1395        let result = tool
1396            .call(serde_json::json!({"name": "world", "count": 99}))
1397            .await;
1398        assert!(!result.is_error);
1399    }
1400}