1use async_trait::async_trait;
7use axum::{
8 extract::State,
9 http::{HeaderMap, StatusCode},
10 response::{sse::Event, Sse},
11 routing::{get, post},
12 Json, Router,
13};
14use reqwest::Client;
15use serde_json::Value;
16use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
17use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
18
19#[cfg(all(feature = "futures", feature = "tokio-stream"))]
20use futures::stream::Stream;
21
22#[cfg(feature = "tokio-stream")]
23use tokio_stream::{wrappers::BroadcastStream, StreamExt};
24
25use tower::ServiceBuilder;
26use tower_http::cors::{Any, CorsLayer};
27
28use crate::core::error::{McpError, McpResult};
29use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
30use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
31
32pub struct HttpClientTransport {
41 client: Client,
42 base_url: String,
43 sse_url: Option<String>,
44 headers: HeaderMap,
45 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
46 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
47 config: TransportConfig,
48 state: ConnectionState,
49 request_id_counter: Arc<Mutex<u64>>,
50}
51
52impl HttpClientTransport {
53 pub async fn new<S: AsRef<str>>(base_url: S, sse_url: Option<S>) -> McpResult<Self> {
62 Self::with_config(base_url, sse_url, TransportConfig::default()).await
63 }
64
65 pub async fn with_config<S: AsRef<str>>(
75 base_url: S,
76 sse_url: Option<S>,
77 config: TransportConfig,
78 ) -> McpResult<Self> {
79 let client_builder = Client::builder()
80 .timeout(Duration::from_millis(
81 config.read_timeout_ms.unwrap_or(60_000),
82 ))
83 .connect_timeout(Duration::from_millis(
84 config.connect_timeout_ms.unwrap_or(30_000),
85 ));
86
87 let client = client_builder
90 .build()
91 .map_err(|e| McpError::Http(format!("Failed to create HTTP client: {}", e)))?;
92
93 let mut headers = HeaderMap::new();
94 headers.insert("Content-Type", "application/json".parse().unwrap());
95 headers.insert("Accept", "application/json".parse().unwrap());
96
97 for (key, value) in &config.headers {
99 if let (Ok(header_name), Ok(header_value)) = (
100 key.parse::<axum::http::HeaderName>(),
101 value.parse::<axum::http::HeaderValue>(),
102 ) {
103 headers.insert(header_name, header_value);
104 }
105 }
106
107 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
108
109 if let Some(sse_url) = &sse_url {
111 let sse_url = sse_url.as_ref().to_string();
112 let client_clone = client.clone();
113 let headers_clone = headers.clone();
114
115 tokio::spawn(async move {
116 if let Err(e) = Self::handle_sse_stream(
117 client_clone,
118 sse_url,
119 headers_clone,
120 notification_sender,
121 )
122 .await
123 {
124 tracing::error!("SSE stream error: {}", e);
125 }
126 });
127 }
128
129 Ok(Self {
130 client,
131 base_url: base_url.as_ref().to_string(),
132 sse_url: sse_url.map(|s| s.as_ref().to_string()),
133 headers,
134 pending_requests: Arc::new(Mutex::new(HashMap::new())),
135 notification_receiver: Some(notification_receiver),
136 config,
137 state: ConnectionState::Connected,
138 request_id_counter: Arc::new(Mutex::new(0)),
139 })
140 }
141
142 async fn handle_sse_stream(
143 client: Client,
144 sse_url: String,
145 headers: HeaderMap,
146 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
147 ) -> McpResult<()> {
148 let mut request = client.get(&sse_url);
149 for (name, value) in headers.iter() {
150 let name_str = name.as_str();
152 let value_bytes = value.as_bytes();
153 request = request.header(name_str, value_bytes);
154 }
155
156 let response = request
157 .send()
158 .await
159 .map_err(|e| McpError::Http(format!("SSE connection failed: {}", e)))?;
160
161 let mut stream = response.bytes_stream();
162
163 #[cfg(feature = "tokio-stream")]
164 {
165 while let Some(chunk) = stream.next().await {
166 match chunk {
167 Ok(bytes) => {
168 let text = String::from_utf8_lossy(&bytes);
169 for line in text.lines() {
170 if line.starts_with("data: ") {
171 let data = &line[6..]; if let Ok(notification) =
173 serde_json::from_str::<JsonRpcNotification>(data)
174 {
175 if notification_sender.send(notification).is_err() {
176 tracing::debug!("Notification receiver dropped");
177 return Ok(());
178 }
179 }
180 }
181 }
182 }
183 Err(e) => {
184 tracing::error!("SSE stream error: {}", e);
185 break;
186 }
187 }
188 }
189 }
190
191 #[cfg(not(feature = "tokio-stream"))]
192 {
193 tracing::warn!("SSE streaming requires tokio-stream feature");
194 }
195
196 Ok(())
197 }
198
199 async fn next_request_id(&self) -> u64 {
200 let mut counter = self.request_id_counter.lock().await;
201 *counter += 1;
202 *counter
203 }
204}
205
206#[async_trait]
207impl Transport for HttpClientTransport {
208 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
209 let url = format!("{}/mcp", self.base_url);
210
211 let mut http_request = self.client.post(&url);
212 for (name, value) in self.headers.iter() {
213 let name_str = name.as_str();
214 let value_bytes = value.as_bytes();
215 http_request = http_request.header(name_str, value_bytes);
216 }
217
218 let response = http_request
219 .json(&request)
220 .send()
221 .await
222 .map_err(|e| McpError::Http(format!("HTTP request failed: {}", e)))?;
223
224 if !response.status().is_success() {
225 return Err(McpError::Http(format!(
226 "HTTP error: {} {}",
227 response.status().as_u16(),
228 response.status().canonical_reason().unwrap_or("Unknown")
229 )));
230 }
231
232 let json_response: JsonRpcResponse = response
233 .json()
234 .await
235 .map_err(|e| McpError::Http(format!("Failed to parse response: {}", e)))?;
236
237 Ok(json_response)
238 }
239
240 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
241 let url = format!("{}/mcp/notify", self.base_url);
242
243 let mut http_request = self.client.post(&url);
244 for (name, value) in self.headers.iter() {
245 let name_str = name.as_str();
246 let value_bytes = value.as_bytes();
247 http_request = http_request.header(name_str, value_bytes);
248 }
249
250 let response = http_request
251 .json(¬ification)
252 .send()
253 .await
254 .map_err(|e| McpError::Http(format!("HTTP notification failed: {}", e)))?;
255
256 if !response.status().is_success() {
257 return Err(McpError::Http(format!(
258 "HTTP notification error: {} {}",
259 response.status().as_u16(),
260 response.status().canonical_reason().unwrap_or("Unknown")
261 )));
262 }
263
264 Ok(())
265 }
266
267 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
268 if let Some(ref mut receiver) = self.notification_receiver {
269 match receiver.try_recv() {
270 Ok(notification) => Ok(Some(notification)),
271 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
272 Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::Http(
273 "Notification channel disconnected".to_string(),
274 )),
275 }
276 } else {
277 Ok(None)
278 }
279 }
280
281 async fn close(&mut self) -> McpResult<()> {
282 self.state = ConnectionState::Disconnected;
283 self.notification_receiver = None;
284 Ok(())
285 }
286
287 fn is_connected(&self) -> bool {
288 matches!(self.state, ConnectionState::Connected)
289 }
290
291 fn connection_info(&self) -> String {
292 format!(
293 "HTTP transport (base: {}, sse: {:?}, state: {:?})",
294 self.base_url, self.sse_url, self.state
295 )
296 }
297}
298
299#[derive(Clone)]
305struct HttpServerState {
306 notification_sender: broadcast::Sender<JsonRpcNotification>,
307 request_handler: Option<
308 Arc<
309 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
310 >,
311 >,
312}
313
314pub struct HttpServerTransport {
319 bind_addr: String,
320 config: TransportConfig,
321 state: Arc<RwLock<HttpServerState>>,
322 server_handle: Option<tokio::task::JoinHandle<()>>,
323 running: Arc<RwLock<bool>>,
324}
325
326impl HttpServerTransport {
327 pub fn new<S: Into<String>>(bind_addr: S) -> Self {
335 Self::with_config(bind_addr, TransportConfig::default())
336 }
337
338 pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
347 let (notification_sender, _) = broadcast::channel(1000);
348
349 Self {
350 bind_addr: bind_addr.into(),
351 config,
352 state: Arc::new(RwLock::new(HttpServerState {
353 notification_sender,
354 request_handler: None,
355 })),
356 server_handle: None,
357 running: Arc::new(RwLock::new(false)),
358 }
359 }
360
361 pub async fn set_request_handler<F>(&mut self, handler: F)
366 where
367 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
368 + Send
369 + Sync
370 + 'static,
371 {
372 let mut state = self.state.write().await;
373 state.request_handler = Some(Arc::new(handler));
374 }
375}
376
377#[async_trait]
378impl ServerTransport for HttpServerTransport {
379 async fn start(&mut self) -> McpResult<()> {
380 tracing::info!("Starting HTTP server on {}", self.bind_addr);
381
382 let state = self.state.clone();
383 let bind_addr = self.bind_addr.clone();
384 let running = self.running.clone();
385
386 let app = Router::new()
388 .route("/mcp", post(handle_mcp_request))
389 .route("/mcp/notify", post(handle_mcp_notification))
390 .route("/mcp/events", get(handle_sse_events))
391 .route("/health", get(handle_health_check))
392 .layer(
393 ServiceBuilder::new()
394 .layer(
395 CorsLayer::new()
396 .allow_origin(Any)
397 .allow_methods(Any)
398 .allow_headers(Any),
399 )
400 .into_inner(),
401 )
402 .with_state(state);
403
404 let listener = tokio::net::TcpListener::bind(&bind_addr)
406 .await
407 .map_err(|e| McpError::Http(format!("Failed to bind to {}: {}", bind_addr, e)))?;
408
409 *running.write().await = true;
410
411 let server_handle = tokio::spawn(async move {
412 if let Err(e) = axum::serve(listener, app).await {
413 tracing::error!("HTTP server error: {}", e);
414 }
415 });
416
417 self.server_handle = Some(server_handle);
418
419 tracing::info!("HTTP server started successfully on {}", self.bind_addr);
420 Ok(())
421 }
422
423 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
424 let state = self.state.read().await;
425
426 if let Some(ref handler) = state.request_handler {
427 let response_rx = handler(request);
428 drop(state); match response_rx.await {
431 Ok(response) => Ok(response),
432 Err(_) => Err(McpError::Http("Request handler channel closed".to_string())),
433 }
434 } else {
435 Ok(JsonRpcResponse {
436 jsonrpc: "2.0".to_string(),
437 id: request.id,
438 result: None,
439 error: Some(crate::protocol::types::JsonRpcError {
440 code: crate::protocol::types::METHOD_NOT_FOUND,
441 message: "No request handler configured".to_string(),
442 data: None,
443 }),
444 })
445 }
446 }
447
448 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
449 let state = self.state.read().await;
450
451 if let Err(_) = state.notification_sender.send(notification) {
452 tracing::warn!("No SSE clients connected to receive notification");
453 }
454
455 Ok(())
456 }
457
458 async fn stop(&mut self) -> McpResult<()> {
459 tracing::info!("Stopping HTTP server");
460
461 *self.running.write().await = false;
462
463 if let Some(handle) = self.server_handle.take() {
464 handle.abort();
465 }
466
467 Ok(())
468 }
469
470 fn is_running(&self) -> bool {
471 self.server_handle.is_some()
473 }
474
475 fn server_info(&self) -> String {
476 format!("HTTP server transport (bind: {})", self.bind_addr)
477 }
478}
479
480async fn handle_mcp_request(
486 State(state): State<Arc<RwLock<HttpServerState>>>,
487 Json(request): Json<JsonRpcRequest>,
488) -> Result<Json<JsonRpcResponse>, StatusCode> {
489 let state_guard = state.read().await;
490
491 if let Some(ref handler) = state_guard.request_handler {
492 let response_rx = handler(request);
493 drop(state_guard); match response_rx.await {
496 Ok(response) => Ok(Json(response)),
497 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
498 }
499 } else {
500 let error_response = JsonRpcResponse {
501 jsonrpc: "2.0".to_string(),
502 id: request.id,
503 result: None,
504 error: Some(crate::protocol::types::JsonRpcError {
505 code: crate::protocol::types::METHOD_NOT_FOUND,
506 message: "No request handler configured".to_string(),
507 data: None,
508 }),
509 };
510 Ok(Json(error_response))
511 }
512}
513
514async fn handle_mcp_notification(Json(_notification): Json<JsonRpcNotification>) -> StatusCode {
516 StatusCode::OK
518}
519
520#[cfg(all(feature = "tokio-stream", feature = "futures"))]
522async fn handle_sse_events(
523 State(state): State<Arc<RwLock<HttpServerState>>>,
524) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
525 let state_guard = state.read().await;
526 let receiver = state_guard.notification_sender.subscribe();
527 drop(state_guard);
528
529 let stream = BroadcastStream::new(receiver).map(|result| {
530 match result {
531 Ok(notification) => match serde_json::to_string(¬ification) {
532 Ok(json) => Ok(Event::default().data(json)),
533 Err(e) => {
534 tracing::error!("Failed to serialize notification: {}", e);
535 Ok(Event::default().data("{}"))
536 }
537 },
538 Err(_) => Ok(Event::default().data("{}")), }
540 });
541
542 Sse::new(stream).keep_alive(
543 axum::response::sse::KeepAlive::new()
544 .interval(Duration::from_secs(30))
545 .text("keep-alive"),
546 )
547}
548
549#[cfg(not(all(feature = "tokio-stream", feature = "futures")))]
551async fn handle_sse_events(_state: State<Arc<RwLock<HttpServerState>>>) -> StatusCode {
552 StatusCode::NOT_IMPLEMENTED
553}
554
555async fn handle_health_check() -> Json<Value> {
557 #[cfg(feature = "chrono")]
558 let timestamp = chrono::Utc::now().to_rfc3339();
559 #[cfg(not(feature = "chrono"))]
560 let timestamp = "unavailable";
561
562 Json(serde_json::json!({
563 "status": "healthy",
564 "transport": "http",
565 "timestamp": timestamp
566 }))
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use serde_json::json;
573
574 #[tokio::test]
575 async fn test_http_client_creation() {
576 let transport = HttpClientTransport::new("http://localhost:3000", None).await;
577 assert!(transport.is_ok());
578
579 let transport = transport.unwrap();
580 assert!(transport.is_connected());
581 assert_eq!(transport.base_url, "http://localhost:3000");
582 }
583
584 #[tokio::test]
585 async fn test_http_server_creation() {
586 let transport = HttpServerTransport::new("127.0.0.1:0");
587 assert_eq!(transport.bind_addr, "127.0.0.1:0");
588 assert!(!transport.is_running());
589 }
590
591 #[test]
592 fn test_http_server_with_config() {
593 let mut config = TransportConfig::default();
594 config.compression = true;
595
596 let transport = HttpServerTransport::with_config("0.0.0.0:8080", config);
597 assert_eq!(transport.bind_addr, "0.0.0.0:8080");
598 assert!(transport.config.compression);
599 }
600
601 #[tokio::test]
602 async fn test_http_client_with_sse() {
603 let transport = HttpClientTransport::new(
604 "http://localhost:3000",
605 Some("http://localhost:3000/events"),
606 )
607 .await;
608
609 assert!(transport.is_ok());
610 let transport = transport.unwrap();
611 assert!(transport.sse_url.is_some());
612 assert_eq!(transport.sse_url.unwrap(), "http://localhost:3000/events");
613 }
614}