1use async_trait::async_trait;
7use axum::{
8 extract::State,
9 http::{HeaderMap, StatusCode},
10 response::{sse::Event, Sse},
11 routing::{get, post},
12 Json, Router,
13};
14use reqwest::Client;
15use serde_json::Value;
16use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
17use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
18
19#[cfg(all(feature = "futures", feature = "tokio-stream"))]
20use futures::stream::Stream;
21
22#[cfg(feature = "tokio-stream")]
23use tokio_stream::{wrappers::BroadcastStream, StreamExt};
24
25use tower::ServiceBuilder;
26use tower_http::cors::{Any, CorsLayer};
27
28use crate::core::error::{McpError, McpResult};
29use crate::protocol::types::{
30 error_codes, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
31};
32use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
33
34pub struct HttpClientTransport {
43 client: Client,
44 base_url: String,
45 sse_url: Option<String>,
46 headers: HeaderMap,
47 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
49 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
50 config: TransportConfig,
51 state: ConnectionState,
52 request_id_counter: Arc<Mutex<u64>>,
53}
54
55impl HttpClientTransport {
56 pub async fn new<S: AsRef<str>>(base_url: S, sse_url: Option<S>) -> McpResult<Self> {
65 Self::with_config(base_url, sse_url, TransportConfig::default()).await
66 }
67
68 pub async fn with_config<S: AsRef<str>>(
78 base_url: S,
79 sse_url: Option<S>,
80 config: TransportConfig,
81 ) -> McpResult<Self> {
82 let client_builder = Client::builder()
83 .timeout(Duration::from_millis(
84 config.read_timeout_ms.unwrap_or(60_000),
85 ))
86 .connect_timeout(Duration::from_millis(
87 config.connect_timeout_ms.unwrap_or(30_000),
88 ));
89
90 let client = client_builder
93 .build()
94 .map_err(|e| McpError::Http(format!("Failed to create HTTP client: {}", e)))?;
95
96 let mut headers = HeaderMap::new();
97 headers.insert("Content-Type", "application/json".parse().unwrap());
98 headers.insert("Accept", "application/json".parse().unwrap());
99
100 for (key, value) in &config.headers {
102 if let (Ok(header_name), Ok(header_value)) = (
103 key.parse::<axum::http::HeaderName>(),
104 value.parse::<axum::http::HeaderValue>(),
105 ) {
106 headers.insert(header_name, header_value);
107 }
108 }
109
110 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
111
112 if let Some(sse_url) = &sse_url {
114 let sse_url = sse_url.as_ref().to_string();
115 let client_clone = client.clone();
116 let headers_clone = headers.clone();
117
118 tokio::spawn(async move {
119 if let Err(e) = Self::handle_sse_stream(
120 client_clone,
121 sse_url,
122 headers_clone,
123 notification_sender,
124 )
125 .await
126 {
127 tracing::error!("SSE stream error: {}", e);
128 }
129 });
130 }
131
132 Ok(Self {
133 client,
134 base_url: base_url.as_ref().to_string(),
135 sse_url: sse_url.map(|s| s.as_ref().to_string()),
136 headers,
137 pending_requests: Arc::new(Mutex::new(HashMap::new())),
138 notification_receiver: Some(notification_receiver),
139 config,
140 state: ConnectionState::Connected,
141 request_id_counter: Arc::new(Mutex::new(0)),
142 })
143 }
144
145 async fn handle_sse_stream(
146 client: Client,
147 sse_url: String,
148 headers: HeaderMap,
149 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
150 ) -> McpResult<()> {
151 let mut request = client.get(&sse_url);
152 for (name, value) in headers.iter() {
153 let name_str = name.as_str();
155 let value_bytes = value.as_bytes();
156 request = request.header(name_str, value_bytes);
157 }
158
159 let response = request
160 .send()
161 .await
162 .map_err(|e| McpError::Http(format!("SSE connection failed: {}", e)))?;
163
164 let mut stream = response.bytes_stream();
165
166 #[cfg(feature = "tokio-stream")]
167 {
168 while let Some(chunk) = stream.next().await {
169 match chunk {
170 Ok(bytes) => {
171 let text = String::from_utf8_lossy(&bytes);
172 for line in text.lines() {
173 if line.starts_with("data: ") {
174 let data = &line[6..]; if let Ok(notification) =
176 serde_json::from_str::<JsonRpcNotification>(data)
177 {
178 if notification_sender.send(notification).is_err() {
179 tracing::debug!("Notification receiver dropped");
180 return Ok(());
181 }
182 }
183 }
184 }
185 }
186 Err(e) => {
187 tracing::error!("SSE stream error: {}", e);
188 break;
189 }
190 }
191 }
192 }
193
194 #[cfg(not(feature = "tokio-stream"))]
195 {
196 tracing::warn!("SSE streaming requires tokio-stream feature");
197 }
198
199 Ok(())
200 }
201
202 async fn next_request_id(&self) -> u64 {
203 let mut counter = self.request_id_counter.lock().await;
204 *counter += 1;
205 *counter
206 }
207
208 async fn track_request(&self, request_id: &Value) {
210 let mut pending = self.pending_requests.lock().await;
214 let (sender, _receiver) = tokio::sync::oneshot::channel();
215 pending.insert(request_id.clone(), sender);
216 }
217
218 async fn untrack_request(&self, request_id: &Value) {
220 let mut pending = self.pending_requests.lock().await;
221 pending.remove(request_id);
222 }
223
224 pub async fn active_request_count(&self) -> usize {
226 let pending = self.pending_requests.lock().await;
227 pending.len()
228 }
229}
230
231#[async_trait]
232impl Transport for HttpClientTransport {
233 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
234 let request_with_id = if request.id == Value::Null {
236 let request_id = self.next_request_id().await;
237 JsonRpcRequest {
238 id: Value::from(request_id),
239 ..request
240 }
241 } else {
242 request
243 };
244
245 self.track_request(&request_with_id.id).await;
247
248 let url = format!("{}/mcp", self.base_url);
249
250 let mut http_request = self.client.post(&url);
251
252 for (name, value) in self.headers.iter() {
254 let name_str = name.as_str();
255 let value_bytes = value.as_bytes();
256 http_request = http_request.header(name_str, value_bytes);
257 }
258
259 if let Some(timeout_ms) = self.config.read_timeout_ms {
261 http_request = http_request.timeout(Duration::from_millis(timeout_ms));
262 }
263
264 let response = http_request
265 .json(&request_with_id)
266 .send()
267 .await
268 .map_err(|e| {
269 let request_id = request_with_id.id.clone();
271 let pending_requests = self.pending_requests.clone();
272 tokio::spawn(async move {
273 let mut pending = pending_requests.lock().await;
274 pending.remove(&request_id);
275 });
276 McpError::Http(format!("HTTP request failed: {}", e))
277 })?;
278
279 if !response.status().is_success() {
280 self.untrack_request(&request_with_id.id).await;
282 return Err(McpError::Http(format!(
283 "HTTP error: {} {}",
284 response.status().as_u16(),
285 response.status().canonical_reason().unwrap_or("Unknown")
286 )));
287 }
288
289 let json_response: JsonRpcResponse = response.json().await.map_err(|e| {
290 let request_id = request_with_id.id.clone();
292 let pending_requests = self.pending_requests.clone();
293 tokio::spawn(async move {
294 let mut pending = pending_requests.lock().await;
295 pending.remove(&request_id);
296 });
297 McpError::Http(format!("Failed to parse response: {}", e))
298 })?;
299
300 if json_response.id != request_with_id.id {
302 self.untrack_request(&request_with_id.id).await;
303 return Err(McpError::Http(format!(
304 "Response ID {:?} does not match request ID {:?}",
305 json_response.id, request_with_id.id
306 )));
307 }
308
309 self.untrack_request(&request_with_id.id).await;
311
312 Ok(json_response)
313 }
314
315 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
316 let url = format!("{}/mcp/notify", self.base_url);
317
318 let mut http_request = self.client.post(&url);
319
320 for (name, value) in self.headers.iter() {
322 let name_str = name.as_str();
323 let value_bytes = value.as_bytes();
324 http_request = http_request.header(name_str, value_bytes);
325 }
326
327 if let Some(timeout_ms) = self.config.write_timeout_ms {
329 http_request = http_request.timeout(Duration::from_millis(timeout_ms));
330 }
331
332 let response = http_request
333 .json(¬ification)
334 .send()
335 .await
336 .map_err(|e| McpError::Http(format!("HTTP notification failed: {}", e)))?;
337
338 if !response.status().is_success() {
339 return Err(McpError::Http(format!(
340 "HTTP notification error: {} {}",
341 response.status().as_u16(),
342 response.status().canonical_reason().unwrap_or("Unknown")
343 )));
344 }
345
346 Ok(())
347 }
348
349 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
350 if let Some(ref mut receiver) = self.notification_receiver {
351 match receiver.try_recv() {
352 Ok(notification) => Ok(Some(notification)),
353 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
354 Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::Http(
355 "Notification channel disconnected".to_string(),
356 )),
357 }
358 } else {
359 Ok(None)
360 }
361 }
362
363 async fn close(&mut self) -> McpResult<()> {
364 self.state = ConnectionState::Disconnected;
365 self.notification_receiver = None;
366 Ok(())
367 }
368
369 fn is_connected(&self) -> bool {
370 matches!(self.state, ConnectionState::Connected)
371 }
372
373 fn connection_info(&self) -> String {
374 format!(
375 "HTTP transport (base: {}, sse: {:?}, state: {:?})",
376 self.base_url, self.sse_url, self.state
377 )
378 }
379}
380
381#[derive(Clone)]
387struct HttpServerState {
388 notification_sender: broadcast::Sender<JsonRpcNotification>,
389 request_handler: Option<
390 Arc<
391 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
392 >,
393 >,
394}
395
396pub struct HttpServerTransport {
401 bind_addr: String,
402 config: TransportConfig,
403 state: Arc<RwLock<HttpServerState>>,
404 server_handle: Option<tokio::task::JoinHandle<()>>,
405 running: Arc<RwLock<bool>>,
406}
407
408impl HttpServerTransport {
409 pub fn new<S: Into<String>>(bind_addr: S) -> Self {
417 Self::with_config(bind_addr, TransportConfig::default())
418 }
419
420 pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
429 let (notification_sender, _) = broadcast::channel(1000);
430
431 Self {
432 bind_addr: bind_addr.into(),
433 config,
434 state: Arc::new(RwLock::new(HttpServerState {
435 notification_sender,
436 request_handler: None,
437 })),
438 server_handle: None,
439 running: Arc::new(RwLock::new(false)),
440 }
441 }
442
443 pub async fn set_request_handler<F>(&mut self, handler: F)
448 where
449 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
450 + Send
451 + Sync
452 + 'static,
453 {
454 let mut state = self.state.write().await;
455 state.request_handler = Some(Arc::new(handler));
456 }
457}
458
459#[async_trait]
460impl ServerTransport for HttpServerTransport {
461 async fn start(&mut self) -> McpResult<()> {
462 tracing::info!("Starting HTTP server on {}", self.bind_addr);
463
464 let state = self.state.clone();
465 let bind_addr = self.bind_addr.clone();
466 let running = self.running.clone();
467 let _config = self.config.clone(); let mut app = Router::new()
471 .route("/mcp", post(handle_mcp_request))
472 .route("/mcp/notify", post(handle_mcp_notification))
473 .route("/mcp/events", get(handle_sse_events))
474 .route("/health", get(handle_health_check))
475 .with_state(state);
476
477 let cors_layer = CorsLayer::new()
479 .allow_origin(Any)
480 .allow_methods(Any)
481 .allow_headers(Any);
482
483 app = app.layer(ServiceBuilder::new().layer(cors_layer).into_inner());
484
485 let listener = tokio::net::TcpListener::bind(&bind_addr)
487 .await
488 .map_err(|e| McpError::Http(format!("Failed to bind to {}: {}", bind_addr, e)))?;
489
490 *running.write().await = true;
491
492 let server_handle = tokio::spawn(async move {
493 if let Err(e) = axum::serve(listener, app).await {
494 tracing::error!("HTTP server error: {}", e);
495 }
496 });
497
498 self.server_handle = Some(server_handle);
499
500 tracing::info!("HTTP server started successfully on {}", self.bind_addr);
501 Ok(())
502 }
503
504 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
505 tracing::warn!("handle_request called directly on HTTP transport - this may indicate a configuration issue");
508
509 let state = self.state.read().await;
510
511 if let Some(ref handler) = state.request_handler {
512 let response_rx = handler(request);
513 drop(state); match response_rx.await {
516 Ok(response) => Ok(response),
517 Err(_) => Err(McpError::Http("Request handler channel closed".to_string())),
518 }
519 } else {
520 Err(McpError::Http(
522 "No request handler configured for HTTP transport".to_string(),
523 ))
524 }
525 }
526
527 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
528 let state = self.state.read().await;
529
530 if let Err(_) = state.notification_sender.send(notification) {
531 tracing::warn!("No SSE clients connected to receive notification");
532 }
533
534 Ok(())
535 }
536
537 async fn stop(&mut self) -> McpResult<()> {
538 tracing::info!("Stopping HTTP server");
539
540 *self.running.write().await = false;
541
542 if let Some(handle) = self.server_handle.take() {
543 handle.abort();
544 }
545
546 Ok(())
547 }
548
549 fn is_running(&self) -> bool {
550 self.server_handle.is_some()
552 }
553
554 fn server_info(&self) -> String {
555 format!("HTTP server transport (bind: {})", self.bind_addr)
556 }
557}
558
559async fn handle_mcp_request(
565 State(state): State<Arc<RwLock<HttpServerState>>>,
566 Json(request): Json<JsonRpcRequest>,
567) -> Result<Json<JsonRpcMessage>, StatusCode> {
568 let state_guard = state.read().await;
569
570 if let Some(ref handler) = state_guard.request_handler {
571 let response_rx = handler(request);
572 drop(state_guard); match response_rx.await {
575 Ok(response) => Ok(Json(JsonRpcMessage::Response(response))),
576 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
577 }
578 } else {
579 let error_response = JsonRpcError::error(
580 request.id,
581 error_codes::METHOD_NOT_FOUND,
582 "No request handler configured".to_string(),
583 None,
584 );
585 Ok(Json(JsonRpcMessage::Error(error_response)))
586 }
587}
588
589async fn handle_mcp_notification(Json(_notification): Json<JsonRpcNotification>) -> StatusCode {
591 StatusCode::OK
593}
594
595#[cfg(all(feature = "tokio-stream", feature = "futures"))]
597async fn handle_sse_events(
598 State(state): State<Arc<RwLock<HttpServerState>>>,
599) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
600 let state_guard = state.read().await;
601 let receiver = state_guard.notification_sender.subscribe();
602 drop(state_guard);
603
604 let stream = BroadcastStream::new(receiver).map(|result| {
605 match result {
606 Ok(notification) => match serde_json::to_string(¬ification) {
607 Ok(json) => Ok(Event::default().data(json)),
608 Err(e) => {
609 tracing::error!("Failed to serialize notification: {}", e);
610 Ok(Event::default().data("{}"))
611 }
612 },
613 Err(_) => Ok(Event::default().data("{}")), }
615 });
616
617 Sse::new(stream).keep_alive(
618 axum::response::sse::KeepAlive::new()
619 .interval(Duration::from_secs(30))
620 .text("keep-alive"),
621 )
622}
623
624#[cfg(not(all(feature = "tokio-stream", feature = "futures")))]
626async fn handle_sse_events(_state: State<Arc<RwLock<HttpServerState>>>) -> StatusCode {
627 StatusCode::NOT_IMPLEMENTED
628}
629
630async fn handle_health_check() -> Json<Value> {
632 #[cfg(feature = "chrono")]
633 let timestamp = chrono::Utc::now().to_rfc3339();
634 #[cfg(not(feature = "chrono"))]
635 let timestamp = "unavailable";
636
637 Json(serde_json::json!({
638 "status": "healthy",
639 "transport": "http",
640 "timestamp": timestamp
641 }))
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[tokio::test]
649 async fn test_http_client_creation() {
650 let transport = HttpClientTransport::new("http://localhost:3000", None).await;
651 assert!(transport.is_ok());
652
653 let transport = transport.unwrap();
654 assert!(transport.is_connected());
655 assert_eq!(transport.base_url, "http://localhost:3000");
656 }
657
658 #[tokio::test]
659 async fn test_http_server_creation() {
660 let transport = HttpServerTransport::new("127.0.0.1:0");
661 assert_eq!(transport.bind_addr, "127.0.0.1:0");
662 assert!(!transport.is_running());
663 }
664
665 #[test]
666 fn test_http_server_with_config() {
667 let mut config = TransportConfig::default();
668 config.compression = true;
669
670 let transport = HttpServerTransport::with_config("0.0.0.0:8080", config);
671 assert_eq!(transport.bind_addr, "0.0.0.0:8080");
672 assert!(transport.config.compression);
673 }
674
675 #[tokio::test]
676 async fn test_http_client_with_sse() {
677 let transport = HttpClientTransport::new(
678 "http://localhost:3000",
679 Some("http://localhost:3000/events"),
680 )
681 .await;
682
683 assert!(transport.is_ok());
684 let transport = transport.unwrap();
685 assert!(transport.sse_url.is_some());
686 assert_eq!(transport.sse_url.unwrap(), "http://localhost:3000/events");
687 }
688}