httpmcp_rust/
transport.rs

1use crate::context::RequestContext;
2use crate::error::{McpError, Result};
3use crate::handlers::lifecycle::{handle_initialize, handle_ping};
4use crate::jsonrpc::{JsonRpcRequest, JsonRpcResponse};
5use crate::protocol::*;
6use crate::server::HttpMcpServer;
7use actix_web::{
8    get, post,
9    web::{self, Data},
10    HttpRequest, HttpResponse, Responder,
11};
12use actix_web_lab::sse;
13use serde_json::Value;
14use std::sync::Arc;
15
16/// Configure actix-web application
17pub fn create_app(cfg: &mut web::ServiceConfig, server: Arc<HttpMcpServer>) {
18    if server.enable_cors {
19        cfg.default_service(web::to(|| async {
20            HttpResponse::Ok()
21                .insert_header(("Access-Control-Allow-Origin", "*"))
22                .insert_header(("Access-Control-Allow-Methods", "GET, POST, OPTIONS"))
23                .insert_header(("Access-Control-Allow-Headers", "*"))
24                .finish()
25        }));
26    }
27
28    cfg.app_data(Data::new(server.clone()))
29        .service(handle_post)
30        .service(handle_get);
31}
32
33/// POST /mcp - Handle JSON-RPC requests
34#[post("/mcp")]
35async fn handle_post(
36    req: HttpRequest,
37    body: web::Json<JsonRpcRequest>,
38    server: Data<Arc<HttpMcpServer>>,
39) -> Result<impl Responder> {
40    let ctx = create_request_context(&req);
41
42    // Validate OAuth if configured
43    if let Some(oauth) = &server.oauth_config {
44        oauth.validate_token(&ctx).await?;
45    }
46
47    // Validate JSON-RPC request
48    body.validate()?;
49
50    // Check if this is a notification (no id field)
51    let is_notification = body.id.is_none();
52
53    // Check if client accepts SSE (streaming mode)
54    let accept_sse = req
55        .headers()
56        .get("accept")
57        .and_then(|v| v.to_str().ok())
58        .map(|s| s.contains("text/event-stream"))
59        .unwrap_or(false);
60
61    // Route and execute the request
62    let response = route_request(&body, &ctx, &server).await?;
63
64    // Notifications MUST NOT receive a response per JSON-RPC 2.0 spec
65    if is_notification {
66        tracing::debug!("Notification received ({}), returning 204 No Content", body.method);
67        let mut resp = HttpResponse::NoContent();
68        if server.enable_cors {
69            resp.insert_header(("Access-Control-Allow-Origin", "*"));
70        }
71        return Ok(resp.finish());
72    }
73
74    // For SSE mode, broadcast response and return 202 Accepted
75    if accept_sse {
76        let subscriber_count = server.response_tx.receiver_count();
77        tracing::debug!("Broadcasting response to {} subscribers", subscriber_count);
78
79        // If there are active SSE subscribers, send via broadcast
80        if subscriber_count > 0 {
81            let _ = server.response_tx.send(response);
82            let mut resp = HttpResponse::Accepted();
83            if server.enable_cors {
84                resp.insert_header(("Access-Control-Allow-Origin", "*"));
85            }
86            return Ok(resp.finish());
87        }
88
89        // If no subscribers, fallback to direct response
90        tracing::warn!("No SSE subscribers, falling back to direct HTTP response");
91    }
92
93    // For non-SSE mode or fallback, return JSON response directly
94    let mut resp = HttpResponse::Ok();
95    if server.enable_cors {
96        resp.insert_header(("Access-Control-Allow-Origin", "*"));
97    }
98    Ok(resp.json(response))
99}
100
101/// GET /mcp - SSE stream for server-to-client messages
102#[get("/mcp")]
103async fn handle_get(req: HttpRequest, server: Data<Arc<HttpMcpServer>>) -> Result<impl Responder> {
104    let ctx = create_request_context(&req);
105
106    // Validate OAuth if configured
107    if let Some(oauth) = &server.oauth_config {
108        oauth.validate_token(&ctx).await?;
109    }
110
111    // Check for Last-Event-ID header for resumption
112    let _last_event_id = req
113        .headers()
114        .get("Last-Event-ID")
115        .and_then(|v| v.to_str().ok())
116        .map(|s| s.to_string());
117
118    // Subscribe to response broadcast channel
119    let mut rx = server.response_tx.subscribe();
120
121    tracing::debug!("SSE stream connected");
122
123    // Create SSE stream from broadcast channel
124    let event_stream = async_stream::stream! {
125        loop {
126            match rx.recv().await {
127                Ok(response) => {
128                    if let Ok(json) = serde_json::to_string(&response) {
129                        tracing::debug!("Sending response via SSE: {}", json);
130                        // Send as "message" event with the JSON-RPC response
131                        yield Ok::<_, actix_web::Error>(sse::Event::Data(
132                            sse::Data::new(json)
133                        ));
134                    }
135                }
136                Err(_) => break,
137            }
138        }
139    };
140
141    Ok(sse::Sse::from_stream(event_stream))
142}
143
144/// Route JSON-RPC request to appropriate handler
145async fn route_request(
146    req: &JsonRpcRequest,
147    ctx: &RequestContext,
148    server: &HttpMcpServer,
149) -> Result<JsonRpcResponse> {
150    tracing::debug!("Routing request: method={}", req.method);
151
152    match req.method.as_str() {
153        // Lifecycle
154        "initialize" => {
155            handle_initialize(req, server.server_info.clone(), server.capabilities.clone())
156        }
157        "ping" => handle_ping(req),
158
159        // Notifications
160        "notifications/initialized" => handle_notifications_initialized(req),
161
162        // Resources
163        "resources/list" => handle_resources_list(req, ctx, server).await,
164        "resources/read" => handle_resources_read(req, ctx, server).await,
165        "resources/templates/list" => handle_resources_templates(req, ctx, server).await,
166        "resources/subscribe" => handle_resources_subscribe(req, ctx, server).await,
167
168        // Tools
169        "tools/list" => handle_tools_list(req, ctx, server).await,
170        "tools/call" => handle_tools_call(req, ctx, server).await,
171
172        // Prompts
173        "prompts/list" => handle_prompts_list(req, ctx, server).await,
174        "prompts/get" => handle_prompts_get(req, ctx, server).await,
175
176        // Logging
177        "logging/setLevel" => handle_logging_set_level(req),
178
179        _ => Err(McpError::MethodNotFound(req.method.clone())),
180    }
181}
182
183// ============================================================================
184// Resource Handlers
185// ============================================================================
186
187async fn handle_resources_list(
188    req: &JsonRpcRequest,
189    ctx: &RequestContext,
190    server: &HttpMcpServer,
191) -> Result<JsonRpcResponse> {
192    let params: ResourcesListParams =
193        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
194            .unwrap_or(ResourcesListParams { cursor: None });
195
196    // Collect all resources from registered handlers
197    let mut all_resources = Vec::new();
198    for registered in server.resources.values() {
199        let (resources, _) = (registered.list_handler)(params.cursor.clone(), ctx.clone()).await?;
200        all_resources.extend(resources);
201    }
202
203    let result = ResourcesListResult {
204        resources: all_resources,
205        next_cursor: None,
206    };
207
208    Ok(JsonRpcResponse::success(
209        serde_json::to_value(result)?,
210        req.id.clone(),
211    ))
212}
213
214async fn handle_resources_read(
215    req: &JsonRpcRequest,
216    ctx: &RequestContext,
217    server: &HttpMcpServer,
218) -> Result<JsonRpcResponse> {
219    let params: ResourcesReadParams =
220        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
221            .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
222
223    // Try to find matching resource handler
224    let mut contents = Vec::new();
225    for registered in server.resources.values() {
226        let result = (registered.read_handler)(params.uri.clone(), ctx.clone()).await?;
227        contents.extend(result);
228    }
229
230    if contents.is_empty() {
231        return Err(McpError::ResourceNotFound(params.uri));
232    }
233
234    let result = ResourcesReadResult { contents };
235
236    Ok(JsonRpcResponse::success(
237        serde_json::to_value(result)?,
238        req.id.clone(),
239    ))
240}
241
242async fn handle_resources_templates(
243    req: &JsonRpcRequest,
244    _ctx: &RequestContext,
245    _server: &HttpMcpServer,
246) -> Result<JsonRpcResponse> {
247    // Resource templates are not supported in the new function-based API
248    Ok(JsonRpcResponse::success(
249        serde_json::json!({ "resourceTemplates": [] }),
250        req.id.clone(),
251    ))
252}
253
254async fn handle_resources_subscribe(
255    req: &JsonRpcRequest,
256    _ctx: &RequestContext,
257    _server: &HttpMcpServer,
258) -> Result<JsonRpcResponse> {
259    // Resource subscription is not supported in the new function-based API
260    Ok(JsonRpcResponse::success(Value::Null, req.id.clone()))
261}
262
263// ============================================================================
264// Tool Handlers
265// ============================================================================
266
267async fn handle_tools_list(
268    req: &JsonRpcRequest,
269    _ctx: &RequestContext,
270    server: &HttpMcpServer,
271) -> Result<JsonRpcResponse> {
272    // Collect all registered tools
273    let tools: Vec<Tool> = server
274        .tools
275        .values()
276        .map(|registered| registered.meta.clone())
277        .collect();
278
279    let result = ToolsListResult {
280        tools,
281        next_cursor: None,
282    };
283
284    Ok(JsonRpcResponse::success(
285        serde_json::to_value(result)?,
286        req.id.clone(),
287    ))
288}
289
290async fn handle_tools_call(
291    req: &JsonRpcRequest,
292    ctx: &RequestContext,
293    server: &HttpMcpServer,
294) -> Result<JsonRpcResponse> {
295    let params: ToolsCallParams = serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
296        .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
297
298    // Find the registered tool
299    let registered = server
300        .tools
301        .get(&params.name)
302        .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
303
304    // Call the tool handler
305    let result_value =
306        (registered.handler)(params.arguments.unwrap_or_default(), ctx.clone()).await?;
307
308    // Convert result to ToolContent
309    let content = vec![ToolContent::Text {
310        text: result_value.to_string(),
311    }];
312
313    let result = ToolsCallResult {
314        content,
315        is_error: None,
316    };
317
318    Ok(JsonRpcResponse::success(
319        serde_json::to_value(result)?,
320        req.id.clone(),
321    ))
322}
323
324// ============================================================================
325// Prompt Handlers
326// ============================================================================
327
328async fn handle_prompts_list(
329    req: &JsonRpcRequest,
330    _ctx: &RequestContext,
331    server: &HttpMcpServer,
332) -> Result<JsonRpcResponse> {
333    // Collect all registered prompts
334    let prompts: Vec<Prompt> = server
335        .prompts
336        .values()
337        .map(|registered| registered.meta.clone())
338        .collect();
339
340    let result = PromptsListResult {
341        prompts,
342        next_cursor: None,
343    };
344
345    Ok(JsonRpcResponse::success(
346        serde_json::to_value(result)?,
347        req.id.clone(),
348    ))
349}
350
351async fn handle_prompts_get(
352    req: &JsonRpcRequest,
353    ctx: &RequestContext,
354    server: &HttpMcpServer,
355) -> Result<JsonRpcResponse> {
356    let params: PromptsGetParams =
357        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
358            .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
359
360    // Find the registered prompt
361    let registered = server
362        .prompts
363        .get(&params.name)
364        .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
365
366    // Call the prompt handler
367    let (description, messages) =
368        (registered.handler)(params.name.clone(), params.arguments, ctx.clone()).await?;
369
370    let result = PromptsGetResult {
371        description,
372        messages,
373    };
374
375    Ok(JsonRpcResponse::success(
376        serde_json::to_value(result)?,
377        req.id.clone(),
378    ))
379}
380
381// ============================================================================
382// Notification Handlers
383// ============================================================================
384
385fn handle_notifications_initialized(req: &JsonRpcRequest) -> Result<JsonRpcResponse> {
386    tracing::debug!("Client initialized notification received");
387    // Return empty object instead of null
388    Ok(JsonRpcResponse::success(serde_json::json!({}), req.id.clone()))
389}
390
391// ============================================================================
392// Logging Handlers
393// ============================================================================
394
395fn handle_logging_set_level(req: &JsonRpcRequest) -> Result<JsonRpcResponse> {
396    let _params: LoggingSetLevelParams =
397        serde_json::from_value(req.params.clone().unwrap_or(Value::Null))
398            .map_err(|e| McpError::InvalidParams(format!("Invalid params: {}", e)))?;
399
400    // TODO: Implement actual log level setting
401    Ok(JsonRpcResponse::success(serde_json::json!({}), req.id.clone()))
402}
403
404// ============================================================================
405// Utilities
406// ============================================================================
407
408fn create_request_context(req: &HttpRequest) -> RequestContext {
409    RequestContext::new(
410        req.headers().clone(),
411        req.method().to_string(),
412        req.path().to_string(),
413        req.peer_addr(),
414    )
415}