httpmcp_rust/
transport.rs

1use crate::context::RequestContext;
2use crate::error::{McpError, Result};
3use crate::handlers::lifecycle::{handle_initialize, handle_ping};
4use crate::jsonrpc::{JsonRpcRequest, JsonRpcResponse};
5use crate::protocol::*;
6use crate::server::HttpMcpServer;
7use actix_multipart::Multipart;
8use actix_web::{
9    get, post,
10    web::{self, Data},
11    HttpRequest, HttpResponse, Responder,
12};
13use actix_web_lab::sse;
14use serde_json::Value;
15use std::sync::Arc;
16
17/// Configure actix-web application
18pub fn create_app(cfg: &mut web::ServiceConfig, server: Arc<HttpMcpServer>) {
19    if server.enable_cors {
20        cfg.default_service(web::to(|| async {
21            HttpResponse::Ok()
22                .insert_header(("Access-Control-Allow-Origin", "*"))
23                .insert_header((
24                    "Access-Control-Allow-Methods",
25                    "GET, POST, PUT, DELETE, PATCH, OPTIONS",
26                ))
27                .insert_header(("Access-Control-Allow-Headers", "*"))
28                .finish()
29        }));
30    }
31
32    cfg.app_data(Data::new(server.clone()))
33        .service(handle_post)
34        .service(handle_get);
35
36    // Register custom endpoints dynamically
37    for endpoint in &server.endpoints {
38        let route = endpoint.route.clone();
39        let method = endpoint.method.clone();
40        let handler = endpoint.handler.clone();
41        let server_clone = server.clone();
42
43        cfg.route(
44            &route,
45            web::method(parse_http_method(&method)).to(
46                move |req: HttpRequest, body: Option<web::Json<Value>>| {
47                    let handler = handler.clone();
48                    let server_clone = server_clone.clone();
49                    async move {
50                        let ctx = create_request_context(&req);
51
52                        // Validate OAuth if configured
53                        if let Some(oauth) = &server_clone.oauth_config {
54                            if let Err(e) = oauth.validate_token(&ctx).await {
55                                return Ok::<HttpResponse, actix_web::Error>(
56                                    HttpResponse::Unauthorized().json(serde_json::json!({
57                                        "error": e.to_string()
58                                    })),
59                                );
60                            }
61                        }
62
63                        let body_value = body.map(|json| json.into_inner());
64                        match handler(ctx, body_value).await {
65                            Ok(response) => Ok(response),
66                            Err(e) => {
67                                Ok(HttpResponse::InternalServerError().json(serde_json::json!({
68                                    "error": e.to_string()
69                                })))
70                            }
71                        }
72                    }
73                },
74            ),
75        );
76    }
77
78    // Register multipart endpoints dynamically
79    for endpoint in &server.multipart_endpoints {
80        let route = endpoint.route.clone();
81        let method = endpoint.method.clone();
82        let handler = endpoint.handler.clone();
83        let server_clone = server.clone();
84
85        cfg.route(
86            &route,
87            web::method(parse_http_method(&method)).to(
88                move |req: HttpRequest, multipart: Multipart| {
89                    let handler = handler.clone();
90                    let server_clone = server_clone.clone();
91                    let ctx = create_request_context(&req);
92
93                    async move {
94                        // Validate OAuth if configured
95                        if let Some(oauth) = &server_clone.oauth_config {
96                            if let Err(e) = oauth.validate_token(&ctx).await {
97                                return Ok::<HttpResponse, actix_web::Error>(
98                                    HttpResponse::Unauthorized().json(serde_json::json!({
99                                        "error": e.to_string()
100                                    })),
101                                );
102                            }
103                        }
104
105                        // Call handler directly - multipart processing happens on the same task
106                        match handler(ctx, multipart).await {
107                            Ok(response) => Ok(response),
108                            Err(e) => {
109                                Ok(HttpResponse::InternalServerError().json(serde_json::json!({
110                                    "error": e.to_string()
111                                })))
112                            }
113                        }
114                    }
115                },
116            ),
117        );
118    }
119}
120
121/// Parse HTTP method string to actix-web Method
122fn parse_http_method(method: &str) -> actix_web::http::Method {
123    match method.to_uppercase().as_str() {
124        "GET" => actix_web::http::Method::GET,
125        "POST" => actix_web::http::Method::POST,
126        "PUT" => actix_web::http::Method::PUT,
127        "DELETE" => actix_web::http::Method::DELETE,
128        "PATCH" => actix_web::http::Method::PATCH,
129        "HEAD" => actix_web::http::Method::HEAD,
130        "OPTIONS" => actix_web::http::Method::OPTIONS,
131        _ => actix_web::http::Method::GET,
132    }
133}
134
135/// POST /mcp - Handle JSON-RPC requests
136#[post("/mcp")]
137async fn handle_post(
138    req: HttpRequest,
139    body: web::Json<JsonRpcRequest>,
140    server: Data<Arc<HttpMcpServer>>,
141) -> Result<impl Responder> {
142    let ctx = create_request_context(&req);
143
144    // Validate OAuth if configured
145    if let Some(oauth) = &server.oauth_config {
146        oauth.validate_token(&ctx).await?;
147    }
148
149    // Validate JSON-RPC request
150    body.validate()?;
151
152    // Check if this is a notification (no id field)
153    let is_notification = body.id.is_none();
154
155    // Check if client accepts SSE (streaming mode)
156    let accept_sse = req
157        .headers()
158        .get("accept")
159        .and_then(|v| v.to_str().ok())
160        .map(|s| s.contains("text/event-stream"))
161        .unwrap_or(false);
162
163    // Route and execute the request
164    let response = route_request(&body, &ctx, &server).await?;
165
166    // Notifications MUST NOT receive a response per JSON-RPC 2.0 spec
167    if is_notification {
168        tracing::debug!(
169            "Notification received ({}), returning 204 No Content",
170            body.method
171        );
172        let mut resp = HttpResponse::NoContent();
173        if server.enable_cors {
174            resp.insert_header(("Access-Control-Allow-Origin", "*"));
175        }
176        return Ok(resp.finish());
177    }
178
179    // For SSE mode, broadcast response and return 202 Accepted
180    if accept_sse {
181        let subscriber_count = server.response_tx.receiver_count();
182        tracing::debug!("Broadcasting response to {} subscribers", subscriber_count);
183
184        // If there are active SSE subscribers, send via broadcast
185        if subscriber_count > 0 {
186            let _ = server.response_tx.send(response);
187            let mut resp = HttpResponse::Accepted();
188            if server.enable_cors {
189                resp.insert_header(("Access-Control-Allow-Origin", "*"));
190            }
191            return Ok(resp.finish());
192        }
193
194        // If no subscribers, fallback to direct response
195        tracing::warn!("No SSE subscribers, falling back to direct HTTP response");
196    }
197
198    // For non-SSE mode or fallback, return JSON response directly
199    let mut resp = HttpResponse::Ok();
200    if server.enable_cors {
201        resp.insert_header(("Access-Control-Allow-Origin", "*"));
202    }
203    Ok(resp.json(response))
204}
205
206/// GET /mcp - SSE stream for server-to-client messages
207#[get("/mcp")]
208async fn handle_get(req: HttpRequest, server: Data<Arc<HttpMcpServer>>) -> Result<impl Responder> {
209    let ctx = create_request_context(&req);
210
211    // Validate OAuth if configured
212    if let Some(oauth) = &server.oauth_config {
213        oauth.validate_token(&ctx).await?;
214    }
215
216    // Check for Last-Event-ID header for resumption
217    let _last_event_id = req
218        .headers()
219        .get("Last-Event-ID")
220        .and_then(|v| v.to_str().ok())
221        .map(|s| s.to_string());
222
223    // Subscribe to response broadcast channel
224    let mut rx = server.response_tx.subscribe();
225
226    tracing::debug!("SSE stream connected");
227
228    // Create SSE stream from broadcast channel
229    let event_stream = async_stream::stream! {
230        while let Ok(response) = rx.recv().await {
231            if let Ok(json) = serde_json::to_string(&response) {
232                tracing::debug!("Sending response via SSE: {}", json);
233                // Send as "message" event with the JSON-RPC response
234                yield Ok::<_, actix_web::Error>(sse::Event::Data(
235                    sse::Data::new(json)
236                ));
237            }
238        }
239    };
240
241    Ok(sse::Sse::from_stream(event_stream))
242}
243
244/// Route JSON-RPC request to appropriate handler
245async fn route_request(
246    req: &JsonRpcRequest,
247    ctx: &RequestContext,
248    server: &HttpMcpServer,
249) -> Result<JsonRpcResponse> {
250    tracing::debug!("Routing request: method={}", req.method);
251
252    match req.method.as_str() {
253        // Lifecycle
254        "initialize" => {
255            handle_initialize(req, server.server_info.clone(), server.capabilities.clone())
256        }
257        "ping" => handle_ping(req),
258
259        // Notifications
260        "notifications/initialized" => handle_notifications_initialized(req),
261
262        // Resources
263        "resources/list" => handle_resources_list(req, ctx, server).await,
264        "resources/read" => handle_resources_read(req, ctx, server).await,
265        "resources/templates/list" => handle_resources_templates(req, ctx, server).await,
266        "resources/subscribe" => handle_resources_subscribe(req, ctx, server).await,
267
268        // Tools
269        "tools/list" => handle_tools_list(req, ctx, server).await,
270        "tools/call" => handle_tools_call(req, ctx, server).await,
271
272        // Prompts
273        "prompts/list" => handle_prompts_list(req, ctx, server).await,
274        "prompts/get" => handle_prompts_get(req, ctx, server).await,
275
276        // Logging
277        "logging/setLevel" => handle_logging_set_level(req),
278
279        _ => Err(McpError::MethodNotFound(req.method.clone())),
280    }
281}
282
283// ============================================================================
284// Resource Handlers
285// ============================================================================
286
287async fn handle_resources_list(
288    req: &JsonRpcRequest,
289    ctx: &RequestContext,
290    server: &HttpMcpServer,
291) -> Result<JsonRpcResponse> {
292    let params: ResourcesListParams =
293        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
294            .unwrap_or(ResourcesListParams { cursor: None });
295
296    // Collect all resources from registered handlers
297    let mut all_resources = Vec::new();
298    for registered in server.resources.values() {
299        let (resources, _) = (registered.list_handler)(params.cursor.clone(), ctx.clone()).await?;
300        all_resources.extend(resources);
301    }
302
303    let result = ResourcesListResult {
304        resources: all_resources,
305        next_cursor: None,
306    };
307
308    Ok(JsonRpcResponse::success(
309        serde_json::to_value(result)?,
310        req.id.clone(),
311    ))
312}
313
314async fn handle_resources_read(
315    req: &JsonRpcRequest,
316    ctx: &RequestContext,
317    server: &HttpMcpServer,
318) -> Result<JsonRpcResponse> {
319    let params: ResourcesReadParams =
320        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
321            .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
322
323    // Try to find matching resource handler
324    let mut contents = Vec::new();
325    for registered in server.resources.values() {
326        let result = (registered.read_handler)(params.uri.clone(), ctx.clone()).await?;
327        contents.extend(result);
328    }
329
330    if contents.is_empty() {
331        return Err(McpError::ResourceNotFound(params.uri));
332    }
333
334    let result = ResourcesReadResult { contents };
335
336    Ok(JsonRpcResponse::success(
337        serde_json::to_value(result)?,
338        req.id.clone(),
339    ))
340}
341
342async fn handle_resources_templates(
343    req: &JsonRpcRequest,
344    _ctx: &RequestContext,
345    _server: &HttpMcpServer,
346) -> Result<JsonRpcResponse> {
347    // Resource templates are not supported in the new function-based API
348    Ok(JsonRpcResponse::success(
349        serde_json::json!({ "resourceTemplates": [] }),
350        req.id.clone(),
351    ))
352}
353
354async fn handle_resources_subscribe(
355    req: &JsonRpcRequest,
356    _ctx: &RequestContext,
357    _server: &HttpMcpServer,
358) -> Result<JsonRpcResponse> {
359    // Resource subscription is not supported in the new function-based API
360    Ok(JsonRpcResponse::success(Value::Null, req.id.clone()))
361}
362
363// ============================================================================
364// Tool Handlers
365// ============================================================================
366
367async fn handle_tools_list(
368    req: &JsonRpcRequest,
369    _ctx: &RequestContext,
370    server: &HttpMcpServer,
371) -> Result<JsonRpcResponse> {
372    // Collect all registered tools
373    let tools: Vec<Tool> = server
374        .tools
375        .values()
376        .map(|registered| registered.meta.clone())
377        .collect();
378
379    let result = ToolsListResult {
380        tools,
381        next_cursor: None,
382    };
383
384    Ok(JsonRpcResponse::success(
385        serde_json::to_value(result)?,
386        req.id.clone(),
387    ))
388}
389
390async fn handle_tools_call(
391    req: &JsonRpcRequest,
392    ctx: &RequestContext,
393    server: &HttpMcpServer,
394) -> Result<JsonRpcResponse> {
395    let params: ToolsCallParams = serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
396        .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
397
398    // Find the registered tool
399    let registered = server
400        .tools
401        .get(&params.name)
402        .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
403
404    // Call the tool handler
405    let result_value =
406        (registered.handler)(params.arguments.unwrap_or_default(), ctx.clone()).await?;
407
408    // Convert result to ToolContent
409    let content = vec![ToolContent::Text {
410        text: result_value.to_string(),
411    }];
412
413    let result = ToolsCallResult {
414        content,
415        is_error: None,
416    };
417
418    Ok(JsonRpcResponse::success(
419        serde_json::to_value(result)?,
420        req.id.clone(),
421    ))
422}
423
424// ============================================================================
425// Prompt Handlers
426// ============================================================================
427
428async fn handle_prompts_list(
429    req: &JsonRpcRequest,
430    _ctx: &RequestContext,
431    server: &HttpMcpServer,
432) -> Result<JsonRpcResponse> {
433    // Collect all registered prompts
434    let prompts: Vec<Prompt> = server
435        .prompts
436        .values()
437        .map(|registered| registered.meta.clone())
438        .collect();
439
440    let result = PromptsListResult {
441        prompts,
442        next_cursor: None,
443    };
444
445    Ok(JsonRpcResponse::success(
446        serde_json::to_value(result)?,
447        req.id.clone(),
448    ))
449}
450
451async fn handle_prompts_get(
452    req: &JsonRpcRequest,
453    ctx: &RequestContext,
454    server: &HttpMcpServer,
455) -> Result<JsonRpcResponse> {
456    let params: PromptsGetParams =
457        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
458            .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
459
460    // Find the registered prompt
461    let registered = server
462        .prompts
463        .get(&params.name)
464        .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
465
466    // Call the prompt handler
467    let (description, messages) =
468        (registered.handler)(params.name.clone(), params.arguments, ctx.clone()).await?;
469
470    let result = PromptsGetResult {
471        description,
472        messages,
473    };
474
475    Ok(JsonRpcResponse::success(
476        serde_json::to_value(result)?,
477        req.id.clone(),
478    ))
479}
480
481// ============================================================================
482// Notification Handlers
483// ============================================================================
484
485fn handle_notifications_initialized(req: &JsonRpcRequest) -> Result<JsonRpcResponse> {
486    tracing::debug!("Client initialized notification received");
487    // Return empty object instead of null
488    Ok(JsonRpcResponse::success(
489        serde_json::json!({}),
490        req.id.clone(),
491    ))
492}
493
494// ============================================================================
495// Logging Handlers
496// ============================================================================
497
498fn handle_logging_set_level(req: &JsonRpcRequest) -> Result<JsonRpcResponse> {
499    let _params: LoggingSetLevelParams =
500        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
501            .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
502
503    // TODO: Implement actual log level setting
504    Ok(JsonRpcResponse::success(
505        serde_json::json!({}),
506        req.id.clone(),
507    ))
508}
509
510// ============================================================================
511// Utilities
512// ============================================================================
513
514fn create_request_context(req: &HttpRequest) -> RequestContext {
515    RequestContext::new(
516        req.headers().clone(),
517        req.method().to_string(),
518        req.path().to_string(),
519        req.peer_addr(),
520    )
521}