mcp_protocol_sdk/transport/
http.rs

1//! HTTP transport implementation for MCP
2//!
3//! This module provides HTTP-based transport for MCP communication,
4//! including Server-Sent Events (SSE) for real-time communication.
5
6use async_trait::async_trait;
7use axum::{
8    extract::State,
9    http::{HeaderMap, StatusCode},
10    response::{sse::Event, Sse},
11    routing::{get, post},
12    Json, Router,
13};
14use reqwest::Client;
15use serde_json::Value;
16use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
17use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
18
19#[cfg(all(feature = "futures", feature = "tokio-stream"))]
20use futures::stream::Stream;
21
22#[cfg(feature = "tokio-stream")]
23use tokio_stream::{wrappers::BroadcastStream, StreamExt};
24
25use tower::ServiceBuilder;
26use tower_http::cors::{Any, CorsLayer};
27
28use crate::core::error::{McpError, McpResult};
29use crate::protocol::types::{
30    error_codes, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
31};
32use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
33
34// ============================================================================
35// HTTP Client Transport
36// ============================================================================
37
38/// HTTP transport for MCP clients
39///
40/// This transport communicates with an MCP server via HTTP requests and
41/// optionally uses Server-Sent Events for real-time notifications.
42pub struct HttpClientTransport {
43    client: Client,
44    base_url: String,
45    sse_url: Option<String>,
46    headers: HeaderMap,
47    /// For tracking active requests (currently used for metrics/debugging)
48    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
49    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
50    config: TransportConfig,
51    state: ConnectionState,
52    request_id_counter: Arc<Mutex<u64>>,
53}
54
55impl HttpClientTransport {
56    /// Create a new HTTP client transport
57    ///
58    /// # Arguments
59    /// * `base_url` - Base URL for the MCP server
60    /// * `sse_url` - Optional URL for Server-Sent Events (for notifications)
61    ///
62    /// # Returns
63    /// Result containing the transport or an error
64    pub async fn new<S: AsRef<str>>(base_url: S, sse_url: Option<S>) -> McpResult<Self> {
65        Self::with_config(base_url, sse_url, TransportConfig::default()).await
66    }
67
68    /// Create a new HTTP client transport with custom configuration
69    ///
70    /// # Arguments
71    /// * `base_url` - Base URL for the MCP server
72    /// * `sse_url` - Optional URL for Server-Sent Events
73    /// * `config` - Transport configuration
74    ///
75    /// # Returns
76    /// Result containing the transport or an error
77    pub async fn with_config<S: AsRef<str>>(
78        base_url: S,
79        sse_url: Option<S>,
80        config: TransportConfig,
81    ) -> McpResult<Self> {
82        let client_builder = Client::builder()
83            .timeout(Duration::from_millis(
84                config.read_timeout_ms.unwrap_or(60_000),
85            ))
86            .connect_timeout(Duration::from_millis(
87                config.connect_timeout_ms.unwrap_or(30_000),
88            ));
89
90        // Note: reqwest doesn't have a gzip() method, it's enabled by default with features
91
92        let client = client_builder
93            .build()
94            .map_err(|e| McpError::Http(format!("Failed to create HTTP client: {}", e)))?;
95
96        let mut headers = HeaderMap::new();
97        headers.insert("Content-Type", "application/json".parse().unwrap());
98        headers.insert("Accept", "application/json".parse().unwrap());
99
100        // Add custom headers from config
101        for (key, value) in &config.headers {
102            if let (Ok(header_name), Ok(header_value)) = (
103                key.parse::<axum::http::HeaderName>(),
104                value.parse::<axum::http::HeaderValue>(),
105            ) {
106                headers.insert(header_name, header_value);
107            }
108        }
109
110        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
111
112        // Set up SSE connection for notifications if URL provided
113        if let Some(sse_url) = &sse_url {
114            let sse_url = sse_url.as_ref().to_string();
115            let client_clone = client.clone();
116            let headers_clone = headers.clone();
117
118            tokio::spawn(async move {
119                if let Err(e) = Self::handle_sse_stream(
120                    client_clone,
121                    sse_url,
122                    headers_clone,
123                    notification_sender,
124                )
125                .await
126                {
127                    tracing::error!("SSE stream error: {}", e);
128                }
129            });
130        }
131
132        Ok(Self {
133            client,
134            base_url: base_url.as_ref().to_string(),
135            sse_url: sse_url.map(|s| s.as_ref().to_string()),
136            headers,
137            pending_requests: Arc::new(Mutex::new(HashMap::new())),
138            notification_receiver: Some(notification_receiver),
139            config,
140            state: ConnectionState::Connected,
141            request_id_counter: Arc::new(Mutex::new(0)),
142        })
143    }
144
145    async fn handle_sse_stream(
146        client: Client,
147        sse_url: String,
148        headers: HeaderMap,
149        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
150    ) -> McpResult<()> {
151        let mut request = client.get(&sse_url);
152        for (name, value) in headers.iter() {
153            // Convert axum headers to reqwest headers
154            let name_str = name.as_str();
155            let value_bytes = value.as_bytes();
156            request = request.header(name_str, value_bytes);
157        }
158
159        let response = request
160            .send()
161            .await
162            .map_err(|e| McpError::Http(format!("SSE connection failed: {}", e)))?;
163
164        let mut stream = response.bytes_stream();
165
166        #[cfg(feature = "tokio-stream")]
167        {
168            while let Some(chunk) = stream.next().await {
169                match chunk {
170                    Ok(bytes) => {
171                        let text = String::from_utf8_lossy(&bytes);
172                        for line in text.lines() {
173                            if line.starts_with("data: ") {
174                                let data = &line[6..]; // Remove "data: " prefix
175                                if let Ok(notification) =
176                                    serde_json::from_str::<JsonRpcNotification>(data)
177                                {
178                                    if notification_sender.send(notification).is_err() {
179                                        tracing::debug!("Notification receiver dropped");
180                                        return Ok(());
181                                    }
182                                }
183                            }
184                        }
185                    }
186                    Err(e) => {
187                        tracing::error!("SSE stream error: {}", e);
188                        break;
189                    }
190                }
191            }
192        }
193
194        #[cfg(not(feature = "tokio-stream"))]
195        {
196            tracing::warn!("SSE streaming requires tokio-stream feature");
197        }
198
199        Ok(())
200    }
201
202    async fn next_request_id(&self) -> u64 {
203        let mut counter = self.request_id_counter.lock().await;
204        *counter += 1;
205        *counter
206    }
207
208    /// Track request for metrics/debugging purposes
209    async fn track_request(&self, request_id: &Value) {
210        // For HTTP transport, we mainly use this for debugging and metrics
211        // Since HTTP is synchronous request/response, we don't need the async
212        // tracking that WebSocket uses, but we keep the interface for consistency
213        let mut pending = self.pending_requests.lock().await;
214        let (sender, _receiver) = tokio::sync::oneshot::channel();
215        pending.insert(request_id.clone(), sender);
216    }
217
218    /// Remove tracked request
219    async fn untrack_request(&self, request_id: &Value) {
220        let mut pending = self.pending_requests.lock().await;
221        pending.remove(request_id);
222    }
223
224    /// Get count of active requests (for debugging/metrics)
225    pub async fn active_request_count(&self) -> usize {
226        let pending = self.pending_requests.lock().await;
227        pending.len()
228    }
229}
230
231#[async_trait]
232impl Transport for HttpClientTransport {
233    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
234        // Generate request ID if not present or ensure we have a valid ID
235        let request_with_id = if request.id == Value::Null {
236            let request_id = self.next_request_id().await;
237            JsonRpcRequest {
238                id: Value::from(request_id),
239                ..request
240            }
241        } else {
242            request
243        };
244
245        // Track the request for debugging/metrics
246        self.track_request(&request_with_id.id).await;
247
248        let url = format!("{}/mcp", self.base_url);
249
250        let mut http_request = self.client.post(&url);
251
252        // Apply headers from config and defaults
253        for (name, value) in self.headers.iter() {
254            let name_str = name.as_str();
255            let value_bytes = value.as_bytes();
256            http_request = http_request.header(name_str, value_bytes);
257        }
258
259        // Apply timeout from config if specified
260        if let Some(timeout_ms) = self.config.read_timeout_ms {
261            http_request = http_request.timeout(Duration::from_millis(timeout_ms));
262        }
263
264        let response = http_request
265            .json(&request_with_id)
266            .send()
267            .await
268            .map_err(|e| {
269                // Untrack request on error
270                let request_id = request_with_id.id.clone();
271                let pending_requests = self.pending_requests.clone();
272                tokio::spawn(async move {
273                    let mut pending = pending_requests.lock().await;
274                    pending.remove(&request_id);
275                });
276                McpError::Http(format!("HTTP request failed: {}", e))
277            })?;
278
279        if !response.status().is_success() {
280            // Untrack request on HTTP error
281            self.untrack_request(&request_with_id.id).await;
282            return Err(McpError::Http(format!(
283                "HTTP error: {} {}",
284                response.status().as_u16(),
285                response.status().canonical_reason().unwrap_or("Unknown")
286            )));
287        }
288
289        let json_response: JsonRpcResponse = response.json().await.map_err(|e| {
290            // Untrack request on parse error
291            let request_id = request_with_id.id.clone();
292            let pending_requests = self.pending_requests.clone();
293            tokio::spawn(async move {
294                let mut pending = pending_requests.lock().await;
295                pending.remove(&request_id);
296            });
297            McpError::Http(format!("Failed to parse response: {}", e))
298        })?;
299
300        // Validate response ID matches request ID
301        if json_response.id != request_with_id.id {
302            self.untrack_request(&request_with_id.id).await;
303            return Err(McpError::Http(format!(
304                "Response ID {:?} does not match request ID {:?}",
305                json_response.id, request_with_id.id
306            )));
307        }
308
309        // Untrack successful request
310        self.untrack_request(&request_with_id.id).await;
311
312        Ok(json_response)
313    }
314
315    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
316        let url = format!("{}/mcp/notify", self.base_url);
317
318        let mut http_request = self.client.post(&url);
319
320        // Apply headers from config and defaults
321        for (name, value) in self.headers.iter() {
322            let name_str = name.as_str();
323            let value_bytes = value.as_bytes();
324            http_request = http_request.header(name_str, value_bytes);
325        }
326
327        // Apply write timeout from config if specified
328        if let Some(timeout_ms) = self.config.write_timeout_ms {
329            http_request = http_request.timeout(Duration::from_millis(timeout_ms));
330        }
331
332        let response = http_request
333            .json(&notification)
334            .send()
335            .await
336            .map_err(|e| McpError::Http(format!("HTTP notification failed: {}", e)))?;
337
338        if !response.status().is_success() {
339            return Err(McpError::Http(format!(
340                "HTTP notification error: {} {}",
341                response.status().as_u16(),
342                response.status().canonical_reason().unwrap_or("Unknown")
343            )));
344        }
345
346        Ok(())
347    }
348
349    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
350        if let Some(ref mut receiver) = self.notification_receiver {
351            match receiver.try_recv() {
352                Ok(notification) => Ok(Some(notification)),
353                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
354                Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::Http(
355                    "Notification channel disconnected".to_string(),
356                )),
357            }
358        } else {
359            Ok(None)
360        }
361    }
362
363    async fn close(&mut self) -> McpResult<()> {
364        self.state = ConnectionState::Disconnected;
365        self.notification_receiver = None;
366        Ok(())
367    }
368
369    fn is_connected(&self) -> bool {
370        matches!(self.state, ConnectionState::Connected)
371    }
372
373    fn connection_info(&self) -> String {
374        format!(
375            "HTTP transport (base: {}, sse: {:?}, state: {:?})",
376            self.base_url, self.sse_url, self.state
377        )
378    }
379}
380
381// ============================================================================
382// HTTP Server Transport
383// ============================================================================
384
385/// Shared state for HTTP server transport
386#[derive(Clone)]
387struct HttpServerState {
388    notification_sender: broadcast::Sender<JsonRpcNotification>,
389    request_handler: Option<
390        Arc<
391            dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
392        >,
393    >,
394}
395
396/// HTTP transport for MCP servers
397///
398/// This transport serves MCP requests over HTTP and provides Server-Sent Events
399/// for real-time notifications to clients.
400pub struct HttpServerTransport {
401    bind_addr: String,
402    config: TransportConfig,
403    state: Arc<RwLock<HttpServerState>>,
404    server_handle: Option<tokio::task::JoinHandle<()>>,
405    running: Arc<RwLock<bool>>,
406}
407
408impl HttpServerTransport {
409    /// Create a new HTTP server transport
410    ///
411    /// # Arguments
412    /// * `bind_addr` - Address to bind the HTTP server to (e.g., "0.0.0.0:3000")
413    ///
414    /// # Returns
415    /// New HTTP server transport instance
416    pub fn new<S: Into<String>>(bind_addr: S) -> Self {
417        Self::with_config(bind_addr, TransportConfig::default())
418    }
419
420    /// Create a new HTTP server transport with custom configuration
421    ///
422    /// # Arguments
423    /// * `bind_addr` - Address to bind the HTTP server to
424    /// * `config` - Transport configuration
425    ///
426    /// # Returns
427    /// New HTTP server transport instance
428    pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
429        let (notification_sender, _) = broadcast::channel(1000);
430
431        Self {
432            bind_addr: bind_addr.into(),
433            config,
434            state: Arc::new(RwLock::new(HttpServerState {
435                notification_sender,
436                request_handler: None,
437            })),
438            server_handle: None,
439            running: Arc::new(RwLock::new(false)),
440        }
441    }
442
443    /// Set the request handler function
444    ///
445    /// # Arguments
446    /// * `handler` - Function that processes incoming requests
447    pub async fn set_request_handler<F>(&mut self, handler: F)
448    where
449        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
450            + Send
451            + Sync
452            + 'static,
453    {
454        let mut state = self.state.write().await;
455        state.request_handler = Some(Arc::new(handler));
456    }
457}
458
459#[async_trait]
460impl ServerTransport for HttpServerTransport {
461    async fn start(&mut self) -> McpResult<()> {
462        tracing::info!("Starting HTTP server on {}", self.bind_addr);
463
464        let state = self.state.clone();
465        let bind_addr = self.bind_addr.clone();
466        let running = self.running.clone();
467        let _config = self.config.clone(); // TODO: Use config for timeouts/limits
468
469        // Create the Axum app with configuration-based settings
470        let mut app = Router::new()
471            .route("/mcp", post(handle_mcp_request))
472            .route("/mcp/notify", post(handle_mcp_notification))
473            .route("/mcp/events", get(handle_sse_events))
474            .route("/health", get(handle_health_check))
475            .with_state(state);
476
477        // Apply CORS configuration
478        let cors_layer = CorsLayer::new()
479            .allow_origin(Any)
480            .allow_methods(Any)
481            .allow_headers(Any);
482
483        app = app.layer(ServiceBuilder::new().layer(cors_layer).into_inner());
484
485        // Start the server
486        let listener = tokio::net::TcpListener::bind(&bind_addr)
487            .await
488            .map_err(|e| McpError::Http(format!("Failed to bind to {}: {}", bind_addr, e)))?;
489
490        *running.write().await = true;
491
492        let server_handle = tokio::spawn(async move {
493            if let Err(e) = axum::serve(listener, app).await {
494                tracing::error!("HTTP server error: {}", e);
495            }
496        });
497
498        self.server_handle = Some(server_handle);
499
500        tracing::info!("HTTP server started successfully on {}", self.bind_addr);
501        Ok(())
502    }
503
504    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
505        // This is now handled by the HTTP server itself and should not be called directly
506        // The HTTP transport handles requests through the HTTP server routes
507        tracing::warn!("handle_request called directly on HTTP transport - this may indicate a configuration issue");
508
509        let state = self.state.read().await;
510
511        if let Some(ref handler) = state.request_handler {
512            let response_rx = handler(request);
513            drop(state); // Release the lock
514
515            match response_rx.await {
516                Ok(response) => Ok(response),
517                Err(_) => Err(McpError::Http("Request handler channel closed".to_string())),
518            }
519        } else {
520            // Return an error indicating no handler is configured
521            Err(McpError::Http(
522                "No request handler configured for HTTP transport".to_string(),
523            ))
524        }
525    }
526
527    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
528        let state = self.state.read().await;
529
530        if let Err(_) = state.notification_sender.send(notification) {
531            tracing::warn!("No SSE clients connected to receive notification");
532        }
533
534        Ok(())
535    }
536
537    async fn stop(&mut self) -> McpResult<()> {
538        tracing::info!("Stopping HTTP server");
539
540        *self.running.write().await = false;
541
542        if let Some(handle) = self.server_handle.take() {
543            handle.abort();
544        }
545
546        Ok(())
547    }
548
549    fn is_running(&self) -> bool {
550        // Check if we have an active server handle
551        self.server_handle.is_some()
552    }
553
554    fn server_info(&self) -> String {
555        format!("HTTP server transport (bind: {})", self.bind_addr)
556    }
557}
558
559// ============================================================================
560// HTTP Route Handlers
561// ============================================================================
562
563/// Handle MCP JSON-RPC requests
564async fn handle_mcp_request(
565    State(state): State<Arc<RwLock<HttpServerState>>>,
566    Json(request): Json<JsonRpcRequest>,
567) -> Result<Json<JsonRpcMessage>, StatusCode> {
568    let state_guard = state.read().await;
569
570    if let Some(ref handler) = state_guard.request_handler {
571        let response_rx = handler(request);
572        drop(state_guard); // Release the lock
573
574        match response_rx.await {
575            Ok(response) => Ok(Json(JsonRpcMessage::Response(response))),
576            Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
577        }
578    } else {
579        let error_response = JsonRpcError::error(
580            request.id,
581            error_codes::METHOD_NOT_FOUND,
582            "No request handler configured".to_string(),
583            None,
584        );
585        Ok(Json(JsonRpcMessage::Error(error_response)))
586    }
587}
588
589/// Handle MCP notification requests
590async fn handle_mcp_notification(Json(_notification): Json<JsonRpcNotification>) -> StatusCode {
591    // Notifications don't require a response
592    StatusCode::OK
593}
594
595/// Handle Server-Sent Events for real-time notifications
596#[cfg(all(feature = "tokio-stream", feature = "futures"))]
597async fn handle_sse_events(
598    State(state): State<Arc<RwLock<HttpServerState>>>,
599) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
600    let state_guard = state.read().await;
601    let receiver = state_guard.notification_sender.subscribe();
602    drop(state_guard);
603
604    let stream = BroadcastStream::new(receiver).map(|result| {
605        match result {
606            Ok(notification) => match serde_json::to_string(&notification) {
607                Ok(json) => Ok(Event::default().data(json)),
608                Err(e) => {
609                    tracing::error!("Failed to serialize notification: {}", e);
610                    Ok(Event::default().data("{}"))
611                }
612            },
613            Err(_) => Ok(Event::default().data("{}")), // Lagged or closed
614        }
615    });
616
617    Sse::new(stream).keep_alive(
618        axum::response::sse::KeepAlive::new()
619            .interval(Duration::from_secs(30))
620            .text("keep-alive"),
621    )
622}
623
624/// Handle Server-Sent Events (fallback when features not available)
625#[cfg(not(all(feature = "tokio-stream", feature = "futures")))]
626async fn handle_sse_events(_state: State<Arc<RwLock<HttpServerState>>>) -> StatusCode {
627    StatusCode::NOT_IMPLEMENTED
628}
629
630/// Handle health check requests
631async fn handle_health_check() -> Json<Value> {
632    #[cfg(feature = "chrono")]
633    let timestamp = chrono::Utc::now().to_rfc3339();
634    #[cfg(not(feature = "chrono"))]
635    let timestamp = "unavailable";
636
637    Json(serde_json::json!({
638        "status": "healthy",
639        "transport": "http",
640        "timestamp": timestamp
641    }))
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[tokio::test]
649    async fn test_http_client_creation() {
650        let transport = HttpClientTransport::new("http://localhost:3000", None).await;
651        assert!(transport.is_ok());
652
653        let transport = transport.unwrap();
654        assert!(transport.is_connected());
655        assert_eq!(transport.base_url, "http://localhost:3000");
656    }
657
658    #[tokio::test]
659    async fn test_http_server_creation() {
660        let transport = HttpServerTransport::new("127.0.0.1:0");
661        assert_eq!(transport.bind_addr, "127.0.0.1:0");
662        assert!(!transport.is_running());
663    }
664
665    #[test]
666    fn test_http_server_with_config() {
667        let mut config = TransportConfig::default();
668        config.compression = true;
669
670        let transport = HttpServerTransport::with_config("0.0.0.0:8080", config);
671        assert_eq!(transport.bind_addr, "0.0.0.0:8080");
672        assert!(transport.config.compression);
673    }
674
675    #[tokio::test]
676    async fn test_http_client_with_sse() {
677        let transport = HttpClientTransport::new(
678            "http://localhost:3000",
679            Some("http://localhost:3000/events"),
680        )
681        .await;
682
683        assert!(transport.is_ok());
684        let transport = transport.unwrap();
685        assert!(transport.sse_url.is_some());
686        assert_eq!(transport.sse_url.unwrap(), "http://localhost:3000/events");
687    }
688}