mcp_protocol_sdk/transport/
http.rs

1//! HTTP transport implementation for MCP
2//!
3//! This module provides HTTP-based transport for MCP communication,
4//! including Server-Sent Events (SSE) for real-time communication.
5
6use async_trait::async_trait;
7use axum::{
8    extract::State,
9    http::{HeaderMap, StatusCode},
10    response::{sse::Event, Sse},
11    routing::{get, post},
12    Json, Router,
13};
14use reqwest::Client;
15use serde_json::Value;
16use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
17use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
18
19#[cfg(all(feature = "futures", feature = "tokio-stream"))]
20use futures::stream::Stream;
21
22#[cfg(feature = "tokio-stream")]
23use tokio_stream::{wrappers::BroadcastStream, StreamExt};
24
25use tower::ServiceBuilder;
26use tower_http::cors::{Any, CorsLayer};
27
28use crate::core::error::{McpError, McpResult};
29use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
30use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
31
32// ============================================================================
33// HTTP Client Transport
34// ============================================================================
35
36/// HTTP transport for MCP clients
37///
38/// This transport communicates with an MCP server via HTTP requests and
39/// optionally uses Server-Sent Events for real-time notifications.
40pub struct HttpClientTransport {
41    client: Client,
42    base_url: String,
43    sse_url: Option<String>,
44    headers: HeaderMap,
45    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
46    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
47    config: TransportConfig,
48    state: ConnectionState,
49    request_id_counter: Arc<Mutex<u64>>,
50}
51
52impl HttpClientTransport {
53    /// Create a new HTTP client transport
54    ///
55    /// # Arguments
56    /// * `base_url` - Base URL for the MCP server
57    /// * `sse_url` - Optional URL for Server-Sent Events (for notifications)
58    ///
59    /// # Returns
60    /// Result containing the transport or an error
61    pub async fn new<S: AsRef<str>>(base_url: S, sse_url: Option<S>) -> McpResult<Self> {
62        Self::with_config(base_url, sse_url, TransportConfig::default()).await
63    }
64
65    /// Create a new HTTP client transport with custom configuration
66    ///
67    /// # Arguments
68    /// * `base_url` - Base URL for the MCP server
69    /// * `sse_url` - Optional URL for Server-Sent Events
70    /// * `config` - Transport configuration
71    ///
72    /// # Returns
73    /// Result containing the transport or an error
74    pub async fn with_config<S: AsRef<str>>(
75        base_url: S,
76        sse_url: Option<S>,
77        config: TransportConfig,
78    ) -> McpResult<Self> {
79        let client_builder = Client::builder()
80            .timeout(Duration::from_millis(
81                config.read_timeout_ms.unwrap_or(60_000),
82            ))
83            .connect_timeout(Duration::from_millis(
84                config.connect_timeout_ms.unwrap_or(30_000),
85            ));
86
87        // Note: reqwest doesn't have a gzip() method, it's enabled by default with features
88
89        let client = client_builder
90            .build()
91            .map_err(|e| McpError::Http(format!("Failed to create HTTP client: {}", e)))?;
92
93        let mut headers = HeaderMap::new();
94        headers.insert("Content-Type", "application/json".parse().unwrap());
95        headers.insert("Accept", "application/json".parse().unwrap());
96
97        // Add custom headers from config
98        for (key, value) in &config.headers {
99            if let (Ok(header_name), Ok(header_value)) = (
100                key.parse::<axum::http::HeaderName>(),
101                value.parse::<axum::http::HeaderValue>(),
102            ) {
103                headers.insert(header_name, header_value);
104            }
105        }
106
107        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
108
109        // Set up SSE connection for notifications if URL provided
110        if let Some(sse_url) = &sse_url {
111            let sse_url = sse_url.as_ref().to_string();
112            let client_clone = client.clone();
113            let headers_clone = headers.clone();
114
115            tokio::spawn(async move {
116                if let Err(e) = Self::handle_sse_stream(
117                    client_clone,
118                    sse_url,
119                    headers_clone,
120                    notification_sender,
121                )
122                .await
123                {
124                    tracing::error!("SSE stream error: {}", e);
125                }
126            });
127        }
128
129        Ok(Self {
130            client,
131            base_url: base_url.as_ref().to_string(),
132            sse_url: sse_url.map(|s| s.as_ref().to_string()),
133            headers,
134            pending_requests: Arc::new(Mutex::new(HashMap::new())),
135            notification_receiver: Some(notification_receiver),
136            config,
137            state: ConnectionState::Connected,
138            request_id_counter: Arc::new(Mutex::new(0)),
139        })
140    }
141
142    async fn handle_sse_stream(
143        client: Client,
144        sse_url: String,
145        headers: HeaderMap,
146        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
147    ) -> McpResult<()> {
148        let mut request = client.get(&sse_url);
149        for (name, value) in headers.iter() {
150            // Convert axum headers to reqwest headers
151            let name_str = name.as_str();
152            let value_bytes = value.as_bytes();
153            request = request.header(name_str, value_bytes);
154        }
155
156        let response = request
157            .send()
158            .await
159            .map_err(|e| McpError::Http(format!("SSE connection failed: {}", e)))?;
160
161        let mut stream = response.bytes_stream();
162
163        #[cfg(feature = "tokio-stream")]
164        {
165            while let Some(chunk) = stream.next().await {
166                match chunk {
167                    Ok(bytes) => {
168                        let text = String::from_utf8_lossy(&bytes);
169                        for line in text.lines() {
170                            if line.starts_with("data: ") {
171                                let data = &line[6..]; // Remove "data: " prefix
172                                if let Ok(notification) =
173                                    serde_json::from_str::<JsonRpcNotification>(data)
174                                {
175                                    if notification_sender.send(notification).is_err() {
176                                        tracing::debug!("Notification receiver dropped");
177                                        return Ok(());
178                                    }
179                                }
180                            }
181                        }
182                    }
183                    Err(e) => {
184                        tracing::error!("SSE stream error: {}", e);
185                        break;
186                    }
187                }
188            }
189        }
190
191        #[cfg(not(feature = "tokio-stream"))]
192        {
193            tracing::warn!("SSE streaming requires tokio-stream feature");
194        }
195
196        Ok(())
197    }
198
199    async fn next_request_id(&self) -> u64 {
200        let mut counter = self.request_id_counter.lock().await;
201        *counter += 1;
202        *counter
203    }
204}
205
206#[async_trait]
207impl Transport for HttpClientTransport {
208    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
209        let url = format!("{}/mcp", self.base_url);
210
211        let mut http_request = self.client.post(&url);
212        for (name, value) in self.headers.iter() {
213            let name_str = name.as_str();
214            let value_bytes = value.as_bytes();
215            http_request = http_request.header(name_str, value_bytes);
216        }
217
218        let response = http_request
219            .json(&request)
220            .send()
221            .await
222            .map_err(|e| McpError::Http(format!("HTTP request failed: {}", e)))?;
223
224        if !response.status().is_success() {
225            return Err(McpError::Http(format!(
226                "HTTP error: {} {}",
227                response.status().as_u16(),
228                response.status().canonical_reason().unwrap_or("Unknown")
229            )));
230        }
231
232        let json_response: JsonRpcResponse = response
233            .json()
234            .await
235            .map_err(|e| McpError::Http(format!("Failed to parse response: {}", e)))?;
236
237        Ok(json_response)
238    }
239
240    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
241        let url = format!("{}/mcp/notify", self.base_url);
242
243        let mut http_request = self.client.post(&url);
244        for (name, value) in self.headers.iter() {
245            let name_str = name.as_str();
246            let value_bytes = value.as_bytes();
247            http_request = http_request.header(name_str, value_bytes);
248        }
249
250        let response = http_request
251            .json(&notification)
252            .send()
253            .await
254            .map_err(|e| McpError::Http(format!("HTTP notification failed: {}", e)))?;
255
256        if !response.status().is_success() {
257            return Err(McpError::Http(format!(
258                "HTTP notification error: {} {}",
259                response.status().as_u16(),
260                response.status().canonical_reason().unwrap_or("Unknown")
261            )));
262        }
263
264        Ok(())
265    }
266
267    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
268        if let Some(ref mut receiver) = self.notification_receiver {
269            match receiver.try_recv() {
270                Ok(notification) => Ok(Some(notification)),
271                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
272                Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::Http(
273                    "Notification channel disconnected".to_string(),
274                )),
275            }
276        } else {
277            Ok(None)
278        }
279    }
280
281    async fn close(&mut self) -> McpResult<()> {
282        self.state = ConnectionState::Disconnected;
283        self.notification_receiver = None;
284        Ok(())
285    }
286
287    fn is_connected(&self) -> bool {
288        matches!(self.state, ConnectionState::Connected)
289    }
290
291    fn connection_info(&self) -> String {
292        format!(
293            "HTTP transport (base: {}, sse: {:?}, state: {:?})",
294            self.base_url, self.sse_url, self.state
295        )
296    }
297}
298
299// ============================================================================
300// HTTP Server Transport
301// ============================================================================
302
303/// Shared state for HTTP server transport
304#[derive(Clone)]
305struct HttpServerState {
306    notification_sender: broadcast::Sender<JsonRpcNotification>,
307    request_handler: Option<
308        Arc<
309            dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
310        >,
311    >,
312}
313
314/// HTTP transport for MCP servers
315///
316/// This transport serves MCP requests over HTTP and provides Server-Sent Events
317/// for real-time notifications to clients.
318pub struct HttpServerTransport {
319    bind_addr: String,
320    config: TransportConfig,
321    state: Arc<RwLock<HttpServerState>>,
322    server_handle: Option<tokio::task::JoinHandle<()>>,
323    running: Arc<RwLock<bool>>,
324}
325
326impl HttpServerTransport {
327    /// Create a new HTTP server transport
328    ///
329    /// # Arguments
330    /// * `bind_addr` - Address to bind the HTTP server to (e.g., "0.0.0.0:3000")
331    ///
332    /// # Returns
333    /// New HTTP server transport instance
334    pub fn new<S: Into<String>>(bind_addr: S) -> Self {
335        Self::with_config(bind_addr, TransportConfig::default())
336    }
337
338    /// Create a new HTTP server transport with custom configuration
339    ///
340    /// # Arguments
341    /// * `bind_addr` - Address to bind the HTTP server to
342    /// * `config` - Transport configuration
343    ///
344    /// # Returns
345    /// New HTTP server transport instance
346    pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
347        let (notification_sender, _) = broadcast::channel(1000);
348
349        Self {
350            bind_addr: bind_addr.into(),
351            config,
352            state: Arc::new(RwLock::new(HttpServerState {
353                notification_sender,
354                request_handler: None,
355            })),
356            server_handle: None,
357            running: Arc::new(RwLock::new(false)),
358        }
359    }
360
361    /// Set the request handler function
362    ///
363    /// # Arguments
364    /// * `handler` - Function that processes incoming requests
365    pub async fn set_request_handler<F>(&mut self, handler: F)
366    where
367        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
368            + Send
369            + Sync
370            + 'static,
371    {
372        let mut state = self.state.write().await;
373        state.request_handler = Some(Arc::new(handler));
374    }
375}
376
377#[async_trait]
378impl ServerTransport for HttpServerTransport {
379    async fn start(&mut self) -> McpResult<()> {
380        tracing::info!("Starting HTTP server on {}", self.bind_addr);
381
382        let state = self.state.clone();
383        let bind_addr = self.bind_addr.clone();
384        let running = self.running.clone();
385
386        // Create the Axum app
387        let app = Router::new()
388            .route("/mcp", post(handle_mcp_request))
389            .route("/mcp/notify", post(handle_mcp_notification))
390            .route("/mcp/events", get(handle_sse_events))
391            .route("/health", get(handle_health_check))
392            .layer(
393                ServiceBuilder::new()
394                    .layer(
395                        CorsLayer::new()
396                            .allow_origin(Any)
397                            .allow_methods(Any)
398                            .allow_headers(Any),
399                    )
400                    .into_inner(),
401            )
402            .with_state(state);
403
404        // Start the server
405        let listener = tokio::net::TcpListener::bind(&bind_addr)
406            .await
407            .map_err(|e| McpError::Http(format!("Failed to bind to {}: {}", bind_addr, e)))?;
408
409        *running.write().await = true;
410
411        let server_handle = tokio::spawn(async move {
412            if let Err(e) = axum::serve(listener, app).await {
413                tracing::error!("HTTP server error: {}", e);
414            }
415        });
416
417        self.server_handle = Some(server_handle);
418
419        tracing::info!("HTTP server started successfully on {}", self.bind_addr);
420        Ok(())
421    }
422
423    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
424        let state = self.state.read().await;
425
426        if let Some(ref handler) = state.request_handler {
427            let response_rx = handler(request);
428            drop(state); // Release the lock
429
430            match response_rx.await {
431                Ok(response) => Ok(response),
432                Err(_) => Err(McpError::Http("Request handler channel closed".to_string())),
433            }
434        } else {
435            Ok(JsonRpcResponse {
436                jsonrpc: "2.0".to_string(),
437                id: request.id,
438                result: None,
439                error: Some(crate::protocol::types::JsonRpcError {
440                    code: crate::protocol::types::METHOD_NOT_FOUND,
441                    message: "No request handler configured".to_string(),
442                    data: None,
443                }),
444            })
445        }
446    }
447
448    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
449        let state = self.state.read().await;
450
451        if let Err(_) = state.notification_sender.send(notification) {
452            tracing::warn!("No SSE clients connected to receive notification");
453        }
454
455        Ok(())
456    }
457
458    async fn stop(&mut self) -> McpResult<()> {
459        tracing::info!("Stopping HTTP server");
460
461        *self.running.write().await = false;
462
463        if let Some(handle) = self.server_handle.take() {
464            handle.abort();
465        }
466
467        Ok(())
468    }
469
470    fn is_running(&self) -> bool {
471        // Check if we have an active server handle
472        self.server_handle.is_some()
473    }
474
475    fn server_info(&self) -> String {
476        format!("HTTP server transport (bind: {})", self.bind_addr)
477    }
478}
479
480// ============================================================================
481// HTTP Route Handlers
482// ============================================================================
483
484/// Handle MCP JSON-RPC requests
485async fn handle_mcp_request(
486    State(state): State<Arc<RwLock<HttpServerState>>>,
487    Json(request): Json<JsonRpcRequest>,
488) -> Result<Json<JsonRpcResponse>, StatusCode> {
489    let state_guard = state.read().await;
490
491    if let Some(ref handler) = state_guard.request_handler {
492        let response_rx = handler(request);
493        drop(state_guard); // Release the lock
494
495        match response_rx.await {
496            Ok(response) => Ok(Json(response)),
497            Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
498        }
499    } else {
500        let error_response = JsonRpcResponse {
501            jsonrpc: "2.0".to_string(),
502            id: request.id,
503            result: None,
504            error: Some(crate::protocol::types::JsonRpcError {
505                code: crate::protocol::types::METHOD_NOT_FOUND,
506                message: "No request handler configured".to_string(),
507                data: None,
508            }),
509        };
510        Ok(Json(error_response))
511    }
512}
513
514/// Handle MCP notification requests
515async fn handle_mcp_notification(Json(_notification): Json<JsonRpcNotification>) -> StatusCode {
516    // Notifications don't require a response
517    StatusCode::OK
518}
519
520/// Handle Server-Sent Events for real-time notifications
521#[cfg(all(feature = "tokio-stream", feature = "futures"))]
522async fn handle_sse_events(
523    State(state): State<Arc<RwLock<HttpServerState>>>,
524) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
525    let state_guard = state.read().await;
526    let receiver = state_guard.notification_sender.subscribe();
527    drop(state_guard);
528
529    let stream = BroadcastStream::new(receiver).map(|result| {
530        match result {
531            Ok(notification) => match serde_json::to_string(&notification) {
532                Ok(json) => Ok(Event::default().data(json)),
533                Err(e) => {
534                    tracing::error!("Failed to serialize notification: {}", e);
535                    Ok(Event::default().data("{}"))
536                }
537            },
538            Err(_) => Ok(Event::default().data("{}")), // Lagged or closed
539        }
540    });
541
542    Sse::new(stream).keep_alive(
543        axum::response::sse::KeepAlive::new()
544            .interval(Duration::from_secs(30))
545            .text("keep-alive"),
546    )
547}
548
549/// Handle Server-Sent Events (fallback when features not available)
550#[cfg(not(all(feature = "tokio-stream", feature = "futures")))]
551async fn handle_sse_events(_state: State<Arc<RwLock<HttpServerState>>>) -> StatusCode {
552    StatusCode::NOT_IMPLEMENTED
553}
554
555/// Handle health check requests
556async fn handle_health_check() -> Json<Value> {
557    #[cfg(feature = "chrono")]
558    let timestamp = chrono::Utc::now().to_rfc3339();
559    #[cfg(not(feature = "chrono"))]
560    let timestamp = "unavailable";
561
562    Json(serde_json::json!({
563        "status": "healthy",
564        "transport": "http",
565        "timestamp": timestamp
566    }))
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use serde_json::json;
573
574    #[tokio::test]
575    async fn test_http_client_creation() {
576        let transport = HttpClientTransport::new("http://localhost:3000", None).await;
577        assert!(transport.is_ok());
578
579        let transport = transport.unwrap();
580        assert!(transport.is_connected());
581        assert_eq!(transport.base_url, "http://localhost:3000");
582    }
583
584    #[tokio::test]
585    async fn test_http_server_creation() {
586        let transport = HttpServerTransport::new("127.0.0.1:0");
587        assert_eq!(transport.bind_addr, "127.0.0.1:0");
588        assert!(!transport.is_running());
589    }
590
591    #[test]
592    fn test_http_server_with_config() {
593        let mut config = TransportConfig::default();
594        config.compression = true;
595
596        let transport = HttpServerTransport::with_config("0.0.0.0:8080", config);
597        assert_eq!(transport.bind_addr, "0.0.0.0:8080");
598        assert!(transport.config.compression);
599    }
600
601    #[tokio::test]
602    async fn test_http_client_with_sse() {
603        let transport = HttpClientTransport::new(
604            "http://localhost:3000",
605            Some("http://localhost:3000/events"),
606        )
607        .await;
608
609        assert!(transport.is_ok());
610        let transport = transport.unwrap();
611        assert!(transport.sse_url.is_some());
612        assert_eq!(transport.sse_url.unwrap(), "http://localhost:3000/events");
613    }
614}