Skip to main content

mcp_kit/server/
handler.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4    error::{McpError, McpResult},
5    types::{
6        messages::{
7            CallToolRequest, CompleteRequest, CompleteResult, GetPromptRequest, ReadResourceRequest,
8        },
9        prompt::GetPromptResult,
10        resource::ReadResourceResult,
11        tool::CallToolResult,
12    },
13};
14use serde::de::DeserializeOwned;
15
16/// A type-erased, boxed future
17pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18
19/// Boxed handler function for tool calls
20pub type HandlerFn<Req, Res> =
21    Arc<dyn Fn(Req) -> BoxFuture<'static, McpResult<Res>> + Send + Sync + 'static>;
22
23pub type ToolHandlerFn = HandlerFn<CallToolRequest, CallToolResult>;
24pub type ResourceHandlerFn = HandlerFn<ReadResourceRequest, ReadResourceResult>;
25pub type PromptHandlerFn = HandlerFn<GetPromptRequest, GetPromptResult>;
26pub type CompletionHandlerFn = HandlerFn<CompleteRequest, CompleteResult>;
27
28// ─── IntoToolResult ───────────────────────────────────────────────────────────
29
30/// Anything that can be returned from a tool handler
31pub trait IntoToolResult {
32    fn into_tool_result(self) -> CallToolResult;
33}
34
35impl IntoToolResult for CallToolResult {
36    fn into_tool_result(self) -> CallToolResult {
37        self
38    }
39}
40
41impl IntoToolResult for String {
42    fn into_tool_result(self) -> CallToolResult {
43        CallToolResult::text(self)
44    }
45}
46
47impl IntoToolResult for &str {
48    fn into_tool_result(self) -> CallToolResult {
49        CallToolResult::text(self)
50    }
51}
52
53impl IntoToolResult for serde_json::Value {
54    fn into_tool_result(self) -> CallToolResult {
55        CallToolResult::text(self.to_string())
56    }
57}
58
59impl<T: IntoToolResult, E: std::fmt::Display> IntoToolResult for Result<T, E> {
60    fn into_tool_result(self) -> CallToolResult {
61        match self {
62            Ok(v) => v.into_tool_result(),
63            Err(e) => CallToolResult::error(e.to_string()),
64        }
65    }
66}
67
68// ─── ToolHandler trait ────────────────────────────────────────────────────────
69
70/// Implemented for async functions that can serve as tool handlers.
71///
72/// Supports two calling conventions:
73///   1. `|args: serde_json::Value| async { ... }` – raw JSON args
74///   2. `|params: MyStruct| async { ... }` – typed, deserialized args
75pub trait ToolHandler<Marker>: Clone + Send + Sync + 'static {
76    fn into_handler_fn(self) -> ToolHandlerFn;
77}
78
79/// Marker for typed (deserialised) handlers.
80/// Works for both `|params: MyStruct|` and `|args: serde_json::Value|` since
81/// `Value` implements `DeserializeOwned`.
82pub struct TypedMarker<T>(std::marker::PhantomData<T>);
83
84impl<F, Fut, R, T> ToolHandler<TypedMarker<T>> for F
85where
86    F: Fn(T) -> Fut + Clone + Send + Sync + 'static,
87    Fut: Future<Output = R> + Send + 'static,
88    R: IntoToolResult + Send + 'static,
89    T: DeserializeOwned + Send + 'static,
90{
91    fn into_handler_fn(self) -> ToolHandlerFn {
92        Arc::new(move |req: CallToolRequest| {
93            let f = self.clone();
94            let args = req.arguments.clone();
95            Box::pin(async move {
96                let params: T = serde_json::from_value(args)
97                    .map_err(|e| McpError::InvalidParams(e.to_string()))?;
98                Ok(f(params).await.into_tool_result())
99            })
100        })
101    }
102}
103
104// ─── AuthenticatedMarker — handler receives (T, Auth) ─────────────────────────
105
106/// Marker for handlers that declare an [`Auth`] extractor as their second parameter.
107///
108/// The generated handler reads the current identity from the task-local auth
109/// context set by `core.rs` before each dispatch. If no identity is present,
110/// the handler returns [`McpError::Unauthorized`] before the user function is
111/// even called.
112///
113/// [`Auth`]: crate::server::extract::Auth
114/// [`McpError::Unauthorized`]: crate::error::McpError::Unauthorized
115#[cfg(feature = "auth")]
116pub struct AuthenticatedMarker<T>(std::marker::PhantomData<T>);
117
118#[cfg(feature = "auth")]
119impl<F, Fut, R, T> ToolHandler<AuthenticatedMarker<T>> for F
120where
121    F: Fn(T, crate::server::extract::Auth) -> Fut + Clone + Send + Sync + 'static,
122    Fut: Future<Output = R> + Send + 'static,
123    R: IntoToolResult + Send + 'static,
124    T: serde::de::DeserializeOwned + Send + 'static,
125{
126    fn into_handler_fn(self) -> ToolHandlerFn {
127        Arc::new(move |req: CallToolRequest| {
128            let f = self.clone();
129            let args = req.arguments.clone();
130            Box::pin(async move {
131                let auth = crate::server::extract::Auth::from_context()?;
132                let params: T = serde_json::from_value(args)
133                    .map_err(|e| McpError::InvalidParams(e.to_string()))?;
134                Ok(f(params, auth).await.into_tool_result())
135            })
136        })
137    }
138}
139
140// ─── PromptHandler ───────────────────────────────────────────────────────────
141
142pub trait PromptHandler<Marker>: Clone + Send + Sync + 'static {
143    fn into_handler_fn(self) -> PromptHandlerFn;
144}
145
146pub struct PromptRawMarker;
147
148impl<F, Fut> PromptHandler<PromptRawMarker> for F
149where
150    F: Fn(GetPromptRequest) -> Fut + Clone + Send + Sync + 'static,
151    Fut: Future<Output = McpResult<GetPromptResult>> + Send + 'static,
152{
153    fn into_handler_fn(self) -> PromptHandlerFn {
154        Arc::new(move |req| {
155            let f = self.clone();
156            Box::pin(async move { f(req).await })
157        })
158    }
159}
160
161// ─── ResourceHandler ─────────────────────────────────────────────────────────
162
163pub trait ResourceHandler<Marker>: Clone + Send + Sync + 'static {
164    fn into_handler_fn(self) -> ResourceHandlerFn;
165}
166
167pub struct ResourceRawMarker;
168
169impl<F, Fut> ResourceHandler<ResourceRawMarker> for F
170where
171    F: Fn(ReadResourceRequest) -> Fut + Clone + Send + Sync + 'static,
172    Fut: Future<Output = McpResult<ReadResourceResult>> + Send + 'static,
173{
174    fn into_handler_fn(self) -> ResourceHandlerFn {
175        Arc::new(move |req| {
176            let f = self.clone();
177            Box::pin(async move { f(req).await })
178        })
179    }
180}
181
182// ─── CompletionHandler ───────────────────────────────────────────────────────
183
184pub trait CompletionHandler<Marker>: Clone + Send + Sync + 'static {
185    fn into_handler_fn(self) -> CompletionHandlerFn;
186}
187
188pub struct CompletionRawMarker;
189
190impl<F, Fut> CompletionHandler<CompletionRawMarker> for F
191where
192    F: Fn(CompleteRequest) -> Fut + Clone + Send + Sync + 'static,
193    Fut: Future<Output = McpResult<CompleteResult>> + Send + 'static,
194{
195    fn into_handler_fn(self) -> CompletionHandlerFn {
196        Arc::new(move |req| {
197            let f = self.clone();
198            Box::pin(async move { f(req).await })
199        })
200    }
201}