1use crate::state::{HasServerInfo, McpState};
4use crate::{SUPPORTED_VERSIONS, is_supported_version};
5use mcpkit_core::capability::ClientCapabilities;
6use mcpkit_core::protocol::Message;
7use mcpkit_core::protocol_version::ProtocolVersion;
8use mcpkit_server::context::{Context, NoOpPeer};
9use mcpkit_server::{
10 PromptHandler, ResourceHandler, ServerHandler, ToolHandler, route_prompts, route_resources,
11 route_tools,
12};
13use rocket::http::{ContentType, Header, Status};
14use rocket::request::{FromRequest, Outcome, Request};
15use rocket::response::stream::{Event, EventStream};
16use rocket::response::{self, Responder, Response};
17use std::io::Cursor;
18use std::sync::Arc;
19use tracing::{debug, info, warn};
20
21pub struct ProtocolVersionHeader(pub Option<String>);
23
24#[rocket::async_trait]
25impl<'r> FromRequest<'r> for ProtocolVersionHeader {
26 type Error = ();
27
28 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
29 let version = request
30 .headers()
31 .get_one("mcp-protocol-version")
32 .map(String::from);
33 Outcome::Success(ProtocolVersionHeader(version))
34 }
35}
36
37pub struct SessionIdHeader(pub Option<String>);
39
40#[rocket::async_trait]
41impl<'r> FromRequest<'r> for SessionIdHeader {
42 type Error = ();
43
44 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
45 let session_id = request
46 .headers()
47 .get_one("mcp-session-id")
48 .map(String::from);
49 Outcome::Success(SessionIdHeader(session_id))
50 }
51}
52
53pub struct LastEventIdHeader(pub Option<String>);
55
56#[rocket::async_trait]
57impl<'r> FromRequest<'r> for LastEventIdHeader {
58 type Error = ();
59
60 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
61 let last_event_id = request.headers().get_one("last-event-id").map(String::from);
62 Outcome::Success(LastEventIdHeader(last_event_id))
63 }
64}
65
66pub struct McpResponse {
68 status: Status,
69 content_type: ContentType,
70 session_id: Option<String>,
71 body: String,
72}
73
74impl McpResponse {
75 #[must_use]
77 pub fn success(body: String, session_id: String) -> Self {
78 Self {
79 status: Status::Ok,
80 content_type: ContentType::JSON,
81 session_id: Some(session_id),
82 body,
83 }
84 }
85
86 #[must_use]
88 pub fn accepted(session_id: String) -> Self {
89 Self {
90 status: Status::Accepted,
91 content_type: ContentType::JSON,
92 session_id: Some(session_id),
93 body: String::new(),
94 }
95 }
96
97 #[must_use]
99 pub fn error(status: Status, message: String) -> Self {
100 Self {
101 status,
102 content_type: ContentType::JSON,
103 session_id: None,
104 body: serde_json::json!({
105 "error": {
106 "code": -32600,
107 "message": message
108 }
109 })
110 .to_string(),
111 }
112 }
113}
114
115impl<'r> Responder<'r, 'static> for McpResponse {
116 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
117 let mut builder = Response::build();
118 builder.status(self.status);
119 builder.header(self.content_type);
120
121 if let Some(session_id) = self.session_id {
122 builder.header(Header::new("mcp-session-id", session_id));
123 }
124
125 if !self.body.is_empty() {
126 builder.sized_body(self.body.len(), Cursor::new(self.body));
127 }
128
129 builder.ok()
130 }
131}
132
133pub struct HandlerContext<H> {
135 inner: Arc<H>,
136}
137
138impl<H> Clone for HandlerContext<H> {
139 fn clone(&self) -> Self {
140 Self {
141 inner: Arc::clone(&self.inner),
142 }
143 }
144}
145
146impl<H> HandlerContext<H> {
147 pub fn new(handler: H) -> Self {
149 Self {
150 inner: Arc::new(handler),
151 }
152 }
153
154 #[must_use]
156 pub fn handler(&self) -> &H {
157 &self.inner
158 }
159}
160
161pub async fn handle_mcp_post<H>(
165 state: &McpState<H>,
166 version: Option<&str>,
167 session_id: Option<String>,
168 body: &str,
169) -> McpResponse
170where
171 H: ServerHandler
172 + ToolHandler
173 + ResourceHandler
174 + PromptHandler
175 + HasServerInfo
176 + Send
177 + Sync
178 + 'static,
179{
180 if !is_supported_version(version) {
182 let provided = version.unwrap_or("none");
183 warn!(version = provided, "Unsupported protocol version");
184 return McpResponse::error(
185 Status::BadRequest,
186 format!(
187 "Unsupported protocol version: {} (supported: {})",
188 provided,
189 SUPPORTED_VERSIONS.join(", ")
190 ),
191 );
192 }
193
194 let session_id = match session_id {
196 Some(id) => {
197 state.sessions.touch(&id);
198 id
199 }
200 None => state.sessions.create(),
201 };
202
203 debug!(session_id = %session_id, "Processing MCP request");
204
205 let msg: Message = match serde_json::from_str(body) {
207 Ok(m) => m,
208 Err(e) => {
209 warn!(error = %e, "Failed to parse JSON-RPC message");
210 return McpResponse::error(Status::BadRequest, format!("Invalid message: {e}"));
211 }
212 };
213
214 match msg {
216 Message::Request(request) => {
217 info!(
218 method = %request.method,
219 id = ?request.id,
220 session_id = %session_id,
221 "Handling MCP request"
222 );
223
224 let response = create_response_for_request(state, &request).await;
225
226 match serde_json::to_string(&Message::Response(response)) {
227 Ok(body) => McpResponse::success(body, session_id),
228 Err(e) => McpResponse::error(
229 Status::InternalServerError,
230 format!("Serialization error: {e}"),
231 ),
232 }
233 }
234 Message::Notification(notification) => {
235 debug!(
236 method = %notification.method,
237 session_id = %session_id,
238 "Received notification"
239 );
240 McpResponse::accepted(session_id)
241 }
242 _ => {
243 warn!("Unexpected message type received");
244 McpResponse::error(
245 Status::BadRequest,
246 "Expected request or notification".to_string(),
247 )
248 }
249 }
250}
251
252async fn create_response_for_request<H>(
254 state: &McpState<H>,
255 request: &mcpkit_core::protocol::Request,
256) -> mcpkit_core::protocol::Response
257where
258 H: ServerHandler + ToolHandler + ResourceHandler + PromptHandler + Send + Sync + 'static,
259{
260 use mcpkit_core::error::JsonRpcError;
261 use mcpkit_core::protocol::Response;
262
263 let method = request.method.as_ref();
264 let params = request.params.as_ref();
265
266 let req_id = request.id.clone();
268 let client_caps = ClientCapabilities::default();
269 let server_caps = state.handler.capabilities();
270 let protocol_version = ProtocolVersion::LATEST;
271 let peer = NoOpPeer;
272 let ctx = Context::new(
273 &req_id,
274 None,
275 &client_caps,
276 &server_caps,
277 protocol_version,
278 &peer,
279 );
280
281 match method {
282 "ping" => Response::success(request.id.clone(), serde_json::json!({})),
283 "initialize" => {
284 let init_result = serde_json::json!({
285 "protocolVersion": ProtocolVersion::LATEST.as_str(),
286 "serverInfo": state.server_info,
287 "capabilities": state.handler.capabilities(),
288 });
289 Response::success(request.id.clone(), init_result)
290 }
291 _ => {
292 if let Some(result) = route_tools(state.handler.as_ref(), method, params, &ctx).await {
294 return match result {
295 Ok(value) => Response::success(request.id.clone(), value),
296 Err(e) => Response::error(request.id.clone(), e.into()),
297 };
298 }
299
300 if let Some(result) =
302 route_resources(state.handler.as_ref(), method, params, &ctx).await
303 {
304 return match result {
305 Ok(value) => Response::success(request.id.clone(), value),
306 Err(e) => Response::error(request.id.clone(), e.into()),
307 };
308 }
309
310 if let Some(result) = route_prompts(state.handler.as_ref(), method, params, &ctx).await
312 {
313 return match result {
314 Ok(value) => Response::success(request.id.clone(), value),
315 Err(e) => Response::error(request.id.clone(), e.into()),
316 };
317 }
318
319 Response::error(
321 request.id.clone(),
322 JsonRpcError::method_not_found(format!("Method '{method}' not found")),
323 )
324 }
325 }
326}
327
328pub fn handle_sse<H>(state: &McpState<H>, session_id: Option<String>) -> EventStream![]
332where
333 H: HasServerInfo + Send + Sync + 'static,
334{
335 let (session_id, mut rx) = if let Some(id) = session_id {
336 if let Some(rx) = state.sse_sessions.get_receiver(&id) {
337 info!(session_id = %id, "Reconnected to SSE session");
338 (id, rx)
339 } else {
340 let (new_id, rx) = state.sse_sessions.create_session();
341 info!(session_id = %new_id, "Created new SSE session (requested not found)");
342 (new_id, rx)
343 }
344 } else {
345 let (id, rx) = state.sse_sessions.create_session();
346 info!(session_id = %id, "Created new SSE session");
347 (id, rx)
348 };
349
350 EventStream! {
351 yield Event::data(session_id.clone()).event("connected").id("evt-connected");
353
354 loop {
356 match rx.recv().await {
357 Ok(msg) => {
358 let event_id = format!("evt-{}", uuid::Uuid::new_v4());
359 yield Event::data(msg).event("message").id(event_id);
360 }
361 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
362 warn!(skipped = n, "SSE client lagged, skipped messages");
363 }
364 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
365 debug!("SSE channel closed");
366 break;
367 }
368 }
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 struct TestHandler {
379 name: String,
380 }
381
382 #[test]
383 fn test_handler_context_creation() {
384 let handler = TestHandler {
385 name: "test".to_string(),
386 };
387 let ctx = HandlerContext::new(handler);
388 assert_eq!(ctx.handler().name, "test");
389 }
390
391 #[test]
392 fn test_handler_context_clone() {
393 let handler = TestHandler {
394 name: "test".to_string(),
395 };
396 let ctx = HandlerContext::new(handler);
397 let cloned = ctx.clone();
398
399 assert_eq!(ctx.handler().name, cloned.handler().name);
401 }
402
403 #[test]
405 fn test_mcp_response_success() {
406 let response =
407 McpResponse::success(r#"{"result":"ok"}"#.to_string(), "session-123".to_string());
408 assert_eq!(response.status, Status::Ok);
409 assert_eq!(response.content_type, ContentType::JSON);
410 assert_eq!(response.session_id, Some("session-123".to_string()));
411 assert_eq!(response.body, r#"{"result":"ok"}"#);
412 }
413
414 #[test]
415 fn test_mcp_response_accepted() {
416 let response = McpResponse::accepted("session-456".to_string());
417 assert_eq!(response.status, Status::Accepted);
418 assert_eq!(response.content_type, ContentType::JSON);
419 assert_eq!(response.session_id, Some("session-456".to_string()));
420 assert!(response.body.is_empty());
421 }
422
423 #[test]
424 fn test_mcp_response_error() {
425 let response = McpResponse::error(Status::BadRequest, "Invalid request".to_string());
426 assert_eq!(response.status, Status::BadRequest);
427 assert_eq!(response.content_type, ContentType::JSON);
428 assert!(response.session_id.is_none());
429 assert!(response.body.contains("Invalid request"));
430 assert!(response.body.contains("-32600"));
431 }
432
433 #[test]
435 fn test_protocol_version_header_with_value() {
436 let header = ProtocolVersionHeader(Some("2025-11-25".to_string()));
437 assert_eq!(header.0, Some("2025-11-25".to_string()));
438 }
439
440 #[test]
441 fn test_protocol_version_header_without_value() {
442 let header = ProtocolVersionHeader(None);
443 assert!(header.0.is_none());
444 }
445
446 #[test]
447 fn test_session_id_header_with_value() {
448 let header = SessionIdHeader(Some("abc-123".to_string()));
449 assert_eq!(header.0, Some("abc-123".to_string()));
450 }
451
452 #[test]
453 fn test_session_id_header_without_value() {
454 let header = SessionIdHeader(None);
455 assert!(header.0.is_none());
456 }
457
458 #[test]
459 fn test_last_event_id_header_with_value() {
460 let header = LastEventIdHeader(Some("evt-999".to_string()));
461 assert_eq!(header.0, Some("evt-999".to_string()));
462 }
463
464 #[test]
465 fn test_last_event_id_header_without_value() {
466 let header = LastEventIdHeader(None);
467 assert!(header.0.is_none());
468 }
469}