mcpkit_rocket/
handler.rs

1//! HTTP handlers for MCP requests using Rocket.
2
3use crate::state::{HasServerInfo, McpState};
4use crate::{SUPPORTED_VERSIONS, is_supported_version};
5use mcpkit_core::capability::ClientCapabilities;
6use mcpkit_core::protocol::Message;
7use mcpkit_core::protocol_version::ProtocolVersion;
8use mcpkit_server::context::{Context, NoOpPeer};
9use mcpkit_server::{
10    PromptHandler, ResourceHandler, ServerHandler, ToolHandler, route_prompts, route_resources,
11    route_tools,
12};
13use rocket::http::{ContentType, Header, Status};
14use rocket::request::{FromRequest, Outcome, Request};
15use rocket::response::stream::{Event, EventStream};
16use rocket::response::{self, Responder, Response};
17use std::io::Cursor;
18use std::sync::Arc;
19use tracing::{debug, info, warn};
20
21/// MCP protocol version header.
22pub struct ProtocolVersionHeader(pub Option<String>);
23
24#[rocket::async_trait]
25impl<'r> FromRequest<'r> for ProtocolVersionHeader {
26    type Error = ();
27
28    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
29        let version = request
30            .headers()
31            .get_one("mcp-protocol-version")
32            .map(String::from);
33        Outcome::Success(ProtocolVersionHeader(version))
34    }
35}
36
37/// MCP session ID header.
38pub struct SessionIdHeader(pub Option<String>);
39
40#[rocket::async_trait]
41impl<'r> FromRequest<'r> for SessionIdHeader {
42    type Error = ();
43
44    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
45        let session_id = request
46            .headers()
47            .get_one("mcp-session-id")
48            .map(String::from);
49        Outcome::Success(SessionIdHeader(session_id))
50    }
51}
52
53/// Last-Event-ID header for SSE reconnection.
54pub struct LastEventIdHeader(pub Option<String>);
55
56#[rocket::async_trait]
57impl<'r> FromRequest<'r> for LastEventIdHeader {
58    type Error = ();
59
60    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
61        let last_event_id = request.headers().get_one("last-event-id").map(String::from);
62        Outcome::Success(LastEventIdHeader(last_event_id))
63    }
64}
65
66/// Response wrapper for MCP POST requests.
67pub struct McpResponse {
68    status: Status,
69    content_type: ContentType,
70    session_id: Option<String>,
71    body: String,
72}
73
74impl McpResponse {
75    /// Create a success response.
76    #[must_use]
77    pub fn success(body: String, session_id: String) -> Self {
78        Self {
79            status: Status::Ok,
80            content_type: ContentType::JSON,
81            session_id: Some(session_id),
82            body,
83        }
84    }
85
86    /// Create an accepted response (for notifications).
87    #[must_use]
88    pub fn accepted(session_id: String) -> Self {
89        Self {
90            status: Status::Accepted,
91            content_type: ContentType::JSON,
92            session_id: Some(session_id),
93            body: String::new(),
94        }
95    }
96
97    /// Create an error response.
98    #[must_use]
99    pub fn error(status: Status, message: String) -> Self {
100        Self {
101            status,
102            content_type: ContentType::JSON,
103            session_id: None,
104            body: serde_json::json!({
105                "error": {
106                    "code": -32600,
107                    "message": message
108                }
109            })
110            .to_string(),
111        }
112    }
113}
114
115impl<'r> Responder<'r, 'static> for McpResponse {
116    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
117        let mut builder = Response::build();
118        builder.status(self.status);
119        builder.header(self.content_type);
120
121        if let Some(session_id) = self.session_id {
122            builder.header(Header::new("mcp-session-id", session_id));
123        }
124
125        if !self.body.is_empty() {
126            builder.sized_body(self.body.len(), Cursor::new(self.body));
127        }
128
129        builder.ok()
130    }
131}
132
133/// Handler context wrapping the generic handler type.
134pub struct HandlerContext<H> {
135    inner: Arc<H>,
136}
137
138impl<H> Clone for HandlerContext<H> {
139    fn clone(&self) -> Self {
140        Self {
141            inner: Arc::clone(&self.inner),
142        }
143    }
144}
145
146impl<H> HandlerContext<H> {
147    /// Create a new handler context.
148    pub fn new(handler: H) -> Self {
149        Self {
150            inner: Arc::new(handler),
151        }
152    }
153
154    /// Get a reference to the inner handler.
155    #[must_use]
156    pub fn handler(&self) -> &H {
157        &self.inner
158    }
159}
160
161/// Handle MCP POST requests.
162///
163/// This is the core handler function that processes JSON-RPC messages.
164pub async fn handle_mcp_post<H>(
165    state: &McpState<H>,
166    version: Option<&str>,
167    session_id: Option<String>,
168    body: &str,
169) -> McpResponse
170where
171    H: ServerHandler
172        + ToolHandler
173        + ResourceHandler
174        + PromptHandler
175        + HasServerInfo
176        + Send
177        + Sync
178        + 'static,
179{
180    // Validate protocol version
181    if !is_supported_version(version) {
182        let provided = version.unwrap_or("none");
183        warn!(version = provided, "Unsupported protocol version");
184        return McpResponse::error(
185            Status::BadRequest,
186            format!(
187                "Unsupported protocol version: {} (supported: {})",
188                provided,
189                SUPPORTED_VERSIONS.join(", ")
190            ),
191        );
192    }
193
194    // Get or create session
195    let session_id = match session_id {
196        Some(id) => {
197            state.sessions.touch(&id);
198            id
199        }
200        None => state.sessions.create(),
201    };
202
203    debug!(session_id = %session_id, "Processing MCP request");
204
205    // Parse message
206    let msg: Message = match serde_json::from_str(body) {
207        Ok(m) => m,
208        Err(e) => {
209            warn!(error = %e, "Failed to parse JSON-RPC message");
210            return McpResponse::error(Status::BadRequest, format!("Invalid message: {e}"));
211        }
212    };
213
214    // Process message
215    match msg {
216        Message::Request(request) => {
217            info!(
218                method = %request.method,
219                id = ?request.id,
220                session_id = %session_id,
221                "Handling MCP request"
222            );
223
224            let response = create_response_for_request(state, &request).await;
225
226            match serde_json::to_string(&Message::Response(response)) {
227                Ok(body) => McpResponse::success(body, session_id),
228                Err(e) => McpResponse::error(
229                    Status::InternalServerError,
230                    format!("Serialization error: {e}"),
231                ),
232            }
233        }
234        Message::Notification(notification) => {
235            debug!(
236                method = %notification.method,
237                session_id = %session_id,
238                "Received notification"
239            );
240            McpResponse::accepted(session_id)
241        }
242        _ => {
243            warn!("Unexpected message type received");
244            McpResponse::error(
245                Status::BadRequest,
246                "Expected request or notification".to_string(),
247            )
248        }
249    }
250}
251
252/// Create a response for a request.
253async fn create_response_for_request<H>(
254    state: &McpState<H>,
255    request: &mcpkit_core::protocol::Request,
256) -> mcpkit_core::protocol::Response
257where
258    H: ServerHandler + ToolHandler + ResourceHandler + PromptHandler + Send + Sync + 'static,
259{
260    use mcpkit_core::error::JsonRpcError;
261    use mcpkit_core::protocol::Response;
262
263    let method = request.method.as_ref();
264    let params = request.params.as_ref();
265
266    // Create a context for the request
267    let req_id = request.id.clone();
268    let client_caps = ClientCapabilities::default();
269    let server_caps = state.handler.capabilities();
270    let protocol_version = ProtocolVersion::LATEST;
271    let peer = NoOpPeer;
272    let ctx = Context::new(
273        &req_id,
274        None,
275        &client_caps,
276        &server_caps,
277        protocol_version,
278        &peer,
279    );
280
281    match method {
282        "ping" => Response::success(request.id.clone(), serde_json::json!({})),
283        "initialize" => {
284            let init_result = serde_json::json!({
285                "protocolVersion": ProtocolVersion::LATEST.as_str(),
286                "serverInfo": state.server_info,
287                "capabilities": state.handler.capabilities(),
288            });
289            Response::success(request.id.clone(), init_result)
290        }
291        _ => {
292            // Try routing to tools
293            if let Some(result) = route_tools(state.handler.as_ref(), method, params, &ctx).await {
294                return match result {
295                    Ok(value) => Response::success(request.id.clone(), value),
296                    Err(e) => Response::error(request.id.clone(), e.into()),
297                };
298            }
299
300            // Try routing to resources
301            if let Some(result) =
302                route_resources(state.handler.as_ref(), method, params, &ctx).await
303            {
304                return match result {
305                    Ok(value) => Response::success(request.id.clone(), value),
306                    Err(e) => Response::error(request.id.clone(), e.into()),
307                };
308            }
309
310            // Try routing to prompts
311            if let Some(result) = route_prompts(state.handler.as_ref(), method, params, &ctx).await
312            {
313                return match result {
314                    Ok(value) => Response::success(request.id.clone(), value),
315                    Err(e) => Response::error(request.id.clone(), e.into()),
316                };
317            }
318
319            // Method not found
320            Response::error(
321                request.id.clone(),
322                JsonRpcError::method_not_found(format!("Method '{method}' not found")),
323            )
324        }
325    }
326}
327
328/// Handle SSE connections for server-to-client streaming.
329///
330/// This returns an `EventStream` for pushing notifications to clients.
331pub fn handle_sse<H>(state: &McpState<H>, session_id: Option<String>) -> EventStream![]
332where
333    H: HasServerInfo + Send + Sync + 'static,
334{
335    let (session_id, mut rx) = if let Some(id) = session_id {
336        if let Some(rx) = state.sse_sessions.get_receiver(&id) {
337            info!(session_id = %id, "Reconnected to SSE session");
338            (id, rx)
339        } else {
340            let (new_id, rx) = state.sse_sessions.create_session();
341            info!(session_id = %new_id, "Created new SSE session (requested not found)");
342            (new_id, rx)
343        }
344    } else {
345        let (id, rx) = state.sse_sessions.create_session();
346        info!(session_id = %id, "Created new SSE session");
347        (id, rx)
348    };
349
350    EventStream! {
351        // Send connected event with session ID
352        yield Event::data(session_id.clone()).event("connected").id("evt-connected");
353
354        // Stream new messages
355        loop {
356            match rx.recv().await {
357                Ok(msg) => {
358                    let event_id = format!("evt-{}", uuid::Uuid::new_v4());
359                    yield Event::data(msg).event("message").id(event_id);
360                }
361                Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
362                    warn!(skipped = n, "SSE client lagged, skipped messages");
363                }
364                Err(tokio::sync::broadcast::error::RecvError::Closed) => {
365                    debug!("SSE channel closed");
366                    break;
367                }
368            }
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    // Test HandlerContext
378    struct TestHandler {
379        name: String,
380    }
381
382    #[test]
383    fn test_handler_context_creation() {
384        let handler = TestHandler {
385            name: "test".to_string(),
386        };
387        let ctx = HandlerContext::new(handler);
388        assert_eq!(ctx.handler().name, "test");
389    }
390
391    #[test]
392    fn test_handler_context_clone() {
393        let handler = TestHandler {
394            name: "test".to_string(),
395        };
396        let ctx = HandlerContext::new(handler);
397        let cloned = ctx.clone();
398
399        // Both should reference the same Arc
400        assert_eq!(ctx.handler().name, cloned.handler().name);
401    }
402
403    // Test McpResponse
404    #[test]
405    fn test_mcp_response_success() {
406        let response =
407            McpResponse::success(r#"{"result":"ok"}"#.to_string(), "session-123".to_string());
408        assert_eq!(response.status, Status::Ok);
409        assert_eq!(response.content_type, ContentType::JSON);
410        assert_eq!(response.session_id, Some("session-123".to_string()));
411        assert_eq!(response.body, r#"{"result":"ok"}"#);
412    }
413
414    #[test]
415    fn test_mcp_response_accepted() {
416        let response = McpResponse::accepted("session-456".to_string());
417        assert_eq!(response.status, Status::Accepted);
418        assert_eq!(response.content_type, ContentType::JSON);
419        assert_eq!(response.session_id, Some("session-456".to_string()));
420        assert!(response.body.is_empty());
421    }
422
423    #[test]
424    fn test_mcp_response_error() {
425        let response = McpResponse::error(Status::BadRequest, "Invalid request".to_string());
426        assert_eq!(response.status, Status::BadRequest);
427        assert_eq!(response.content_type, ContentType::JSON);
428        assert!(response.session_id.is_none());
429        assert!(response.body.contains("Invalid request"));
430        assert!(response.body.contains("-32600"));
431    }
432
433    // Test header types
434    #[test]
435    fn test_protocol_version_header_with_value() {
436        let header = ProtocolVersionHeader(Some("2025-11-25".to_string()));
437        assert_eq!(header.0, Some("2025-11-25".to_string()));
438    }
439
440    #[test]
441    fn test_protocol_version_header_without_value() {
442        let header = ProtocolVersionHeader(None);
443        assert!(header.0.is_none());
444    }
445
446    #[test]
447    fn test_session_id_header_with_value() {
448        let header = SessionIdHeader(Some("abc-123".to_string()));
449        assert_eq!(header.0, Some("abc-123".to_string()));
450    }
451
452    #[test]
453    fn test_session_id_header_without_value() {
454        let header = SessionIdHeader(None);
455        assert!(header.0.is_none());
456    }
457
458    #[test]
459    fn test_last_event_id_header_with_value() {
460        let header = LastEventIdHeader(Some("evt-999".to_string()));
461        assert_eq!(header.0, Some("evt-999".to_string()));
462    }
463
464    #[test]
465    fn test_last_event_id_header_without_value() {
466        let header = LastEventIdHeader(None);
467        assert!(header.0.is_none());
468    }
469}