1use crate::{
4 RequestHandler, Transport, TransportError,
5 batch::{JsonRpcMessage, process_batch},
6 validation::validate_message_string,
7};
8use async_trait::async_trait;
9use axum::response::sse::{Event, KeepAlive};
10use axum::{
11 Router,
12 extract::{Query, State},
13 http::{
14 HeaderMap, StatusCode,
15 header::{AUTHORIZATION, ORIGIN},
16 },
17 response::{IntoResponse, Response as AxumResponse, Sse},
18 routing::{get, post},
19};
20use serde::Deserialize;
23use serde_json;
24use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
25use tokio::sync::{Mutex, RwLock, broadcast};
26use tower::ServiceBuilder;
27use tower_http::cors::CorsLayer;
28use tracing::{debug, error, info, warn};
29use uuid::Uuid;
30
31#[derive(Debug, Clone)]
33pub struct HttpConfig {
34 pub port: u16,
36 pub host: String,
38 pub max_message_size: usize,
40 pub enable_cors: bool,
42 pub allowed_origins: Option<Vec<String>>,
44 pub validate_messages: bool,
46 pub session_timeout_secs: u64,
48 pub require_auth: bool,
50 pub valid_tokens: Vec<String>,
52}
53
54impl Default for HttpConfig {
55 fn default() -> Self {
56 Self {
57 port: 3000,
58 host: "127.0.0.1".to_string(),
59 max_message_size: 10 * 1024 * 1024, enable_cors: true,
61 allowed_origins: None,
62 validate_messages: true,
63 session_timeout_secs: 300, require_auth: false,
65 valid_tokens: vec![],
66 }
67 }
68}
69
70#[derive(Clone)]
72struct SessionInfo {
73 #[allow(dead_code)]
74 id: String,
75 #[allow(dead_code)]
76 created_at: std::time::Instant,
77 last_activity: std::time::Instant,
78 event_sender: broadcast::Sender<String>,
79 #[allow(dead_code)]
81 _keepalive_receiver: Arc<Mutex<broadcast::Receiver<String>>>,
82}
83
84#[derive(Clone)]
86struct HttpState {
87 handler: Arc<RequestHandler>,
88 config: HttpConfig,
89 sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
90}
91
92#[derive(Debug, Deserialize)]
94struct SseQuery {
95 #[serde(rename = "sessionId")]
96 session_id: Option<String>,
97 #[serde(rename = "lastEventId")]
98 #[allow(dead_code)]
99 last_event_id: Option<String>,
100 #[serde(rename = "transportType")]
101 #[allow(dead_code)]
102 transport_type: Option<String>,
103 #[allow(dead_code)]
104 url: Option<String>,
105}
106
107pub struct HttpTransport {
116 config: HttpConfig,
117 state: Option<HttpState>,
118 server_handle: Option<tokio::task::JoinHandle<()>>,
119}
120
121impl HttpTransport {
122 pub fn new(port: u16) -> Self {
124 let config = HttpConfig {
125 port,
126 ..Default::default()
127 };
128
129 Self {
130 config,
131 state: None,
132 server_handle: None,
133 }
134 }
135
136 pub fn config(&self) -> &HttpConfig {
138 &self.config
139 }
140
141 pub fn is_initialized(&self) -> bool {
143 self.state.is_some()
144 }
145
146 pub fn is_running(&self) -> bool {
148 self.server_handle.is_some()
149 }
150
151 pub async fn broadcast_message(&self, message: &str) -> Result<(), TransportError> {
153 if let Some(ref state) = self.state {
154 let sessions = state.sessions.read().await;
155 for (session_id, session) in sessions.iter() {
156 if let Err(e) = session.event_sender.send(message.to_string()) {
157 debug!("Failed to send to session {}: {}", session_id, e);
158 }
159 }
160 Ok(())
161 } else {
162 Err(TransportError::Connection(
163 "Transport not started".to_string(),
164 ))
165 }
166 }
167
168 pub fn with_config(config: HttpConfig) -> Self {
170 Self {
171 config,
172 state: None,
173 server_handle: None,
174 }
175 }
176
177 async fn ensure_session(state: Arc<HttpState>, session_id: Option<String>) -> String {
179 if let Some(id) = session_id {
180 let sessions = state.sessions.read().await;
182 if sessions.contains_key(&id) {
183 return id;
184 }
185 drop(sessions);
187 let (tx, keepalive_rx) = broadcast::channel(1024);
188 let session_info = SessionInfo {
189 id: id.clone(),
190 created_at: std::time::Instant::now(),
191 last_activity: std::time::Instant::now(),
192 event_sender: tx,
193 _keepalive_receiver: Arc::new(Mutex::new(keepalive_rx)),
194 };
195 let mut sessions = state.sessions.write().await;
196 sessions.insert(id.clone(), session_info);
197 info!("Created session with provided ID: {}", id);
198 return id;
199 }
200
201 let session_id = Uuid::new_v4().to_string();
203 let (tx, keepalive_rx) = broadcast::channel(1024);
204 let session_info = SessionInfo {
205 id: session_id.clone(),
206 created_at: std::time::Instant::now(),
207 last_activity: std::time::Instant::now(),
208 event_sender: tx,
209 _keepalive_receiver: Arc::new(Mutex::new(keepalive_rx)),
210 };
211
212 {
213 let mut sessions = state.sessions.write().await;
214 sessions.insert(session_id.clone(), session_info);
215 }
216
217 debug!("Created new session: {}", session_id);
218 session_id
219 }
220
221 async fn update_session_activity(state: Arc<HttpState>, session_id: &str) {
223 let mut sessions = state.sessions.write().await;
224 if let Some(session) = sessions.get_mut(session_id) {
225 session.last_activity = std::time::Instant::now();
226 }
227 }
228
229 async fn cleanup_sessions(state: Arc<HttpState>) {
231 let timeout = Duration::from_secs(state.config.session_timeout_secs);
232 let now = std::time::Instant::now();
233
234 let mut sessions = state.sessions.write().await;
235 sessions.retain(|id, session| {
236 let expired = now.duration_since(session.last_activity) > timeout;
237 if expired {
238 debug!("Removing expired session: {}", id);
239 }
240 !expired
241 });
242 }
243
244 pub fn validate_origin(config: &HttpConfig, headers: &HeaderMap) -> Result<(), TransportError> {
246 if let Some(allowed_origins) = &config.allowed_origins {
247 if let Some(origin) = headers.get(ORIGIN) {
248 let origin_str = origin
249 .to_str()
250 .map_err(|_| TransportError::Protocol("Invalid Origin header".to_string()))?;
251
252 if !allowed_origins.contains(&origin_str.to_string()) {
253 return Err(TransportError::Protocol(format!(
254 "Origin not allowed: {origin_str}"
255 )));
256 }
257 } else {
258 return Err(TransportError::Protocol(
259 "Missing Origin header".to_string(),
260 ));
261 }
262 }
263
264 Ok(())
265 }
266
267 pub fn validate_auth(config: &HttpConfig, headers: &HeaderMap) -> Result<(), TransportError> {
269 if !config.require_auth {
270 return Ok(());
271 }
272
273 let auth_header = headers
274 .get(AUTHORIZATION)
275 .ok_or_else(|| TransportError::Protocol("Missing Authorization header".to_string()))?;
276
277 let auth_str = auth_header
278 .to_str()
279 .map_err(|_| TransportError::Protocol("Invalid Authorization header".to_string()))?;
280
281 if let Some(token) = auth_str.strip_prefix("Bearer ") {
282 if config.valid_tokens.contains(&token.to_string()) {
283 Ok(())
284 } else {
285 Err(TransportError::Protocol("Invalid bearer token".to_string()))
286 }
287 } else {
288 Err(TransportError::Protocol(
289 "Invalid Authorization format, expected Bearer token".to_string(),
290 ))
291 }
292 }
293}
294
295#[derive(Debug, Clone, Deserialize)]
297struct PostQuery {
298 #[serde(alias = "sessionId")]
299 session_id: Option<String>,
300}
301
302async fn handle_post(
304 State(state): State<Arc<HttpState>>,
305 Query(query): Query<PostQuery>,
306 headers: HeaderMap,
307 body: String,
308) -> Result<AxumResponse<String>, StatusCode> {
309 info!("Received POST request with session query: {:?}", query);
310 debug!("Raw request body: {}", body);
311
312 let request_value: serde_json::Value = match serde_json::from_str(&body) {
314 Ok(v) => v,
315 Err(e) => {
316 warn!("Failed to parse JSON: {}", e);
317 return Err(StatusCode::BAD_REQUEST);
318 }
319 };
320
321 let message = if let Some(wrapped_message) = request_value.get("message") {
323 wrapped_message.clone()
325 } else if request_value.get("jsonrpc").is_some() {
326 request_value
328 } else {
329 warn!("Invalid request format - no 'message' field and no 'jsonrpc' field");
330 return Err(StatusCode::BAD_REQUEST);
331 };
332
333 info!("Request message: {:?}", message);
334
335 if let Err(e) = HttpTransport::validate_origin(&state.config, &headers) {
337 warn!("Origin validation failed: {}", e);
338 return Err(StatusCode::FORBIDDEN);
339 }
340
341 if let Err(e) = HttpTransport::validate_auth(&state.config, &headers) {
343 warn!("Authentication failed: {}", e);
344 return Err(StatusCode::UNAUTHORIZED);
345 }
346
347 let session_id_from_request = query.session_id.or_else(|| {
349 headers
350 .get("Mcp-Session-Id")
351 .and_then(|v| v.to_str().ok())
352 .map(|s| s.to_string())
353 });
354
355 let session_id = HttpTransport::ensure_session(state.clone(), session_id_from_request).await;
357
358 let message_json = serde_json::to_string(&message).map_err(|_| StatusCode::BAD_REQUEST)?;
360
361 if state.config.validate_messages {
362 if let Err(e) = validate_message_string(&message_json, Some(state.config.max_message_size))
363 {
364 warn!("Message validation failed: {}", e);
365 return Err(StatusCode::BAD_REQUEST);
366 }
367 }
368
369 let message = JsonRpcMessage::parse(&message_json).map_err(|_| StatusCode::BAD_REQUEST)?;
371
372 if let Err(e) = message.validate() {
374 warn!("JSON-RPC validation failed: {}", e);
375 return Err(StatusCode::BAD_REQUEST);
376 }
377
378 {
380 HttpTransport::update_session_activity(state.clone(), &session_id).await;
381 }
382
383 match process_batch(message, &state.handler).await {
385 Ok(Some(response_message)) => {
386 let response_json = response_message
387 .to_string()
388 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
389
390 let accept_header = headers
392 .get("accept")
393 .and_then(|v| v.to_str().ok())
394 .unwrap_or("");
395
396 debug!("Received Accept header: '{}'", accept_header);
401
402 let wants_json_response = if accept_header.contains("application/json") {
403 true
406 } else {
407 false
409 };
410
411 debug!(
412 "Transport mode selected: {}",
413 if wants_json_response {
414 "streamable-http"
415 } else {
416 "sse"
417 }
418 );
419
420 if wants_json_response {
421 info!(
423 "Using Streamable HTTP transport, returning response directly for session: {}, Accept: {}",
424 session_id, accept_header
425 );
426 debug!("Direct response: {}", response_json);
427 Ok(AxumResponse::builder()
428 .status(StatusCode::OK)
429 .header("Content-Type", "application/json")
430 .header("Mcp-Session-Id", session_id)
431 .body(response_json)
432 .unwrap())
433 } else {
434 info!(
436 "Using legacy HTTP+SSE transport for session: {}, Accept: {}",
437 session_id, accept_header
438 );
439 debug!("Response to send through SSE: {}", response_json);
440
441 let sessions = state.sessions.read().await;
442 info!("Active sessions: {}", sessions.len());
443
444 if let Some(session) = sessions.get(&session_id) {
445 info!("Found session {}, sending response", session_id);
446 match session.event_sender.send(response_json.clone()) {
447 Ok(num_receivers) => {
448 info!(
449 "Response sent successfully to {} receivers on session: {}",
450 num_receivers, session_id
451 );
452 }
453 Err(e) => {
454 warn!("Failed to send response through SSE: {}", e);
455 }
456 }
457 } else {
458 warn!(
459 "Session {} not found for response, trying any active session",
460 session_id
461 );
462 let mut sent = false;
464 for (sid, session) in sessions.iter() {
465 match session.event_sender.send(response_json.clone()) {
466 Ok(num_receivers) => {
467 info!(
468 "Response sent successfully to {} receivers on fallback session: {}",
469 num_receivers, sid
470 );
471 sent = true;
472 break;
473 }
474 Err(e) => {
475 debug!("Failed to send to session {}: {}", sid, e);
476 }
477 }
478 }
479 if !sent {
480 warn!("No active sessions available to send response");
481 }
482 }
483
484 Ok(AxumResponse::builder()
486 .status(StatusCode::NO_CONTENT)
487 .header("Mcp-Session-Id", session_id)
488 .body("".to_string())
489 .unwrap())
490 }
491 }
492 Ok(None) => {
493 Ok(AxumResponse::builder()
495 .status(StatusCode::NO_CONTENT)
496 .body("".to_string())
497 .unwrap())
498 }
499 Err(e) => {
500 error!("Failed to process message: {}", e);
501
502 let error_response = pulseengine_mcp_protocol::Response {
504 jsonrpc: "2.0".to_string(),
505 id: None,
506 result: None,
507 error: Some(pulseengine_mcp_protocol::Error::internal_error(
508 e.to_string(),
509 )),
510 };
511
512 if let Ok(error_json) = serde_json::to_string(&error_response) {
513 let accept_header = headers
515 .get("accept")
516 .and_then(|v| v.to_str().ok())
517 .unwrap_or("");
518 let wants_json_response = accept_header.contains("application/json");
519
520 if wants_json_response {
521 debug!(
523 "Using Streamable HTTP transport, returning error directly: {}",
524 error_json
525 );
526 Ok(AxumResponse::builder()
527 .status(StatusCode::OK)
528 .header("Content-Type", "application/json")
529 .header("Mcp-Session-Id", session_id)
530 .body(error_json)
531 .unwrap())
532 } else {
533 debug!(
535 "Using legacy HTTP+SSE transport, sending error through SSE: {}",
536 error_json
537 );
538 let sessions = state.sessions.read().await;
539 if let Some(session) = sessions.get(&session_id) {
540 if let Err(e) = session.event_sender.send(error_json.clone()) {
541 warn!("Failed to send error through SSE: {}", e);
542 } else {
543 debug!(
544 "Error response sent successfully to session: {}",
545 session_id
546 );
547 }
548 } else {
549 warn!(
550 "Session {} not found for error response, trying any active session",
551 session_id
552 );
553 let mut sent = false;
555 for (sid, session) in sessions.iter() {
556 if session.event_sender.send(error_json.clone()).is_ok() {
557 debug!(
558 "Error response sent successfully to fallback session: {}",
559 sid
560 );
561 sent = true;
562 break;
563 }
564 }
565 if !sent {
566 warn!("No active sessions available to send error response");
567 }
568 }
569
570 Ok(AxumResponse::builder()
572 .status(StatusCode::NO_CONTENT)
573 .body("".to_string())
574 .unwrap())
575 }
576 } else {
577 Err(StatusCode::INTERNAL_SERVER_ERROR)
579 }
580 }
581 }
582}
583
584async fn handle_sse(
586 uri: axum::http::Uri,
587 State(state): State<Arc<HttpState>>,
588 headers: HeaderMap,
589 Query(query): Query<SseQuery>,
590) -> Result<axum::response::Response, StatusCode> {
591 info!(
592 "Received SSE request - URI: {}, query string: {:?}, parsed query: {:?}",
593 uri,
594 uri.query(),
595 query
596 );
597 info!("Headers: {:?}", headers);
598
599 if let Err(e) = HttpTransport::validate_origin(&state.config, &headers) {
601 warn!("Origin validation failed: {}", e);
602 return Err(StatusCode::FORBIDDEN);
603 }
604
605 if let Err(e) = HttpTransport::validate_auth(&state.config, &headers) {
607 warn!("Authentication failed: {}", e);
608 return Err(StatusCode::UNAUTHORIZED);
609 }
610
611 let session_id = HttpTransport::ensure_session(state.clone(), query.session_id).await;
613
614 info!("Creating MCP-compliant SSE stream with endpoint event");
616
617 let receiver = {
619 let sessions = state.sessions.read().await;
620 sessions
621 .get(&session_id)
622 .map(|session| session.event_sender.subscribe())
623 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
624 };
625
626 info!("Starting SSE stream for session: {}", session_id);
627
628 let session_id_for_header = session_id.clone();
630
631 let stream = async_stream::stream! {
633 let mut event_counter = 0u64;
634
635 let endpoint_url = format!("/messages?sessionId={session_id}");
638 info!("Sending 'endpoint' event for session: {} with URL: {}", session_id, endpoint_url);
639 event_counter += 1;
640 yield Ok::<_, axum::Error>(Event::default()
641 .id(event_counter.to_string())
642 .event("endpoint")
643 .data(endpoint_url));
644
645 let mut receiver = receiver;
647 loop {
648 tokio::select! {
649 Ok(data) = receiver.recv() => {
650 event_counter += 1;
651 yield Ok::<_, axum::Error>(Event::default()
652 .id(event_counter.to_string())
653 .event("message")
654 .data(data));
655 }
656 _ = tokio::time::sleep(Duration::from_secs(30)) => {
657 event_counter += 1;
659 yield Ok::<_, axum::Error>(Event::default()
660 .id(event_counter.to_string())
661 .event("ping")
662 .data(serde_json::json!({
663 "type": "ping",
664 "timestamp": chrono::Utc::now().to_rfc3339()
665 }).to_string()));
666 }
667 }
668 }
669 };
670
671 let sse = Sse::new(stream).keep_alive(
673 KeepAlive::new()
674 .interval(Duration::from_secs(15))
675 .text("keep-alive"),
676 );
677
678 let mut response = sse.into_response();
680 response.headers_mut().insert(
681 axum::http::header::CACHE_CONTROL,
682 "no-cache".parse().unwrap(),
683 );
684 response.headers_mut().insert(
685 axum::http::header::CONNECTION,
686 "keep-alive".parse().unwrap(),
687 );
688 response
689 .headers_mut()
690 .insert("X-Accel-Buffering", "no".parse().unwrap());
691
692 response
694 .headers_mut()
695 .insert("Mcp-Session-Id", session_id_for_header.parse().unwrap());
696
697 Ok(response)
698}
699
700async fn handle_health() -> &'static str {
702 "OK"
703}
704
705#[async_trait]
706impl Transport for HttpTransport {
707 async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
708 info!(
709 "Starting HTTP transport on {}:{}",
710 self.config.host, self.config.port
711 );
712
713 let state = Arc::new(HttpState {
714 handler: Arc::new(handler),
715 config: self.config.clone(),
716 sessions: Arc::new(RwLock::new(HashMap::new())),
717 });
718
719 let cors = CorsLayer::very_permissive().expose_headers(vec![
721 axum::http::header::HeaderName::from_static("mcp-session-id"),
722 axum::http::header::HeaderName::from_static("content-type"),
723 ]);
724
725 let app = Router::new()
727 .route("/messages", post(handle_post))
728 .route("/sse", get(handle_sse))
729 .route("/health", get(handle_health))
730 .layer(ServiceBuilder::new().layer(cors))
731 .with_state(state.clone());
732
733 let cleanup_state = state.clone();
735 tokio::spawn(async move {
736 let mut interval = tokio::time::interval(Duration::from_secs(60));
737 loop {
738 interval.tick().await;
739 HttpTransport::cleanup_sessions(cleanup_state.clone()).await;
740 }
741 });
742
743 let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
745 .parse()
746 .map_err(|e| TransportError::Config(format!("Invalid address: {e}")))?;
747
748 let listener = tokio::net::TcpListener::bind(addr)
749 .await
750 .map_err(|e| TransportError::Connection(format!("Failed to bind to {addr}: {e}")))?;
751
752 info!("HTTP transport listening on {}", addr);
753 info!("Endpoints:");
754 info!(" POST http://{}/messages - MCP messages", addr);
755 info!(" GET http://{}/sse - Server-Sent Events", addr);
756 info!(" GET http://{}/health - Health check", addr);
757
758 let server_handle = tokio::spawn(async move {
759 if let Err(e) = axum::serve(listener, app).await {
760 error!("HTTP server error: {}", e);
761 }
762 });
763
764 self.state = Some(HttpState {
765 handler: state.handler.clone(),
766 config: state.config.clone(),
767 sessions: state.sessions.clone(),
768 });
769 self.server_handle = Some(server_handle);
770
771 Ok(())
772 }
773
774 async fn stop(&mut self) -> Result<(), TransportError> {
775 info!("Stopping HTTP transport");
776
777 if let Some(handle) = self.server_handle.take() {
778 handle.abort();
779 }
780
781 self.state = None;
782 Ok(())
783 }
784
785 async fn health_check(&self) -> Result<(), TransportError> {
786 if self.state.is_some() {
787 Ok(())
788 } else {
789 Err(TransportError::Connection(
790 "HTTP transport not running".to_string(),
791 ))
792 }
793 }
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799 use axum::extract::{Query, State};
800 use axum::http::{HeaderMap, HeaderValue, StatusCode};
801 use pulseengine_mcp_protocol::{Error as McpError, Response};
802 use serde_json::json;
803 use std::sync::Arc;
804 use tokio::sync::RwLock;
805
806 fn mock_handler(
808 request: pulseengine_mcp_protocol::Request,
809 ) -> std::pin::Pin<
810 Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
811 > {
812 Box::pin(async move {
813 Response {
814 jsonrpc: "2.0".to_string(),
815 id: request.id,
816 result: Some(json!({"echo": request.method})),
817 error: None,
818 }
819 })
820 }
821
822 fn mock_error_handler(
824 request: pulseengine_mcp_protocol::Request,
825 ) -> std::pin::Pin<
826 Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
827 > {
828 Box::pin(async move {
829 Response {
830 jsonrpc: "2.0".to_string(),
831 id: request.id,
832 result: None,
833 error: Some(McpError::method_not_found(format!(
834 "Method '{}' not supported",
835 request.method
836 ))),
837 }
838 })
839 }
840
841 fn mock_notification_handler(
843 _request: pulseengine_mcp_protocol::Request,
844 ) -> std::pin::Pin<
845 Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
846 > {
847 Box::pin(async move {
848 Response {
849 jsonrpc: "2.0".to_string(),
850 id: None,
851 result: None,
852 error: None,
853 }
854 })
855 }
856
857 fn create_test_state() -> Arc<HttpState> {
858 let config = HttpConfig::default();
859 Arc::new(HttpState {
860 handler: Arc::new(Box::new(mock_handler)),
861 config,
862 sessions: Arc::new(RwLock::new(HashMap::new())),
863 })
864 }
865
866 fn create_test_headers() -> HeaderMap {
867 let mut headers = HeaderMap::new();
868 headers.insert("Content-Type", "application/json".parse().unwrap());
869 headers
870 }
871
872 #[test]
875 fn test_http_config_default() {
876 let config = HttpConfig::default();
877 assert_eq!(config.port, 3000);
878 assert_eq!(config.host, "127.0.0.1");
879 assert_eq!(config.max_message_size, 10 * 1024 * 1024);
880 assert!(config.enable_cors);
881 assert!(config.allowed_origins.is_none());
882 assert!(config.validate_messages);
883 assert_eq!(config.session_timeout_secs, 300);
884 assert!(!config.require_auth);
885 assert!(config.valid_tokens.is_empty());
886 }
887
888 #[test]
889 fn test_http_config_custom() {
890 let config = HttpConfig {
891 port: 8080,
892 host: "0.0.0.0".to_string(),
893 max_message_size: 1024,
894 enable_cors: false,
895 allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
896 validate_messages: true,
897 session_timeout_secs: 600,
898 require_auth: true,
899 valid_tokens: vec!["test-token".to_string()],
900 };
901
902 let transport = HttpTransport::with_config(config.clone());
903 assert_eq!(transport.config.port, 8080);
904 assert_eq!(transport.config.host, "0.0.0.0");
905 assert_eq!(transport.config.max_message_size, 1024);
906 assert!(!transport.config.enable_cors);
907 assert_eq!(
908 transport.config.allowed_origins,
909 Some(vec!["http://localhost:3000".to_string()])
910 );
911 assert!(transport.config.validate_messages);
912 assert_eq!(transport.config.session_timeout_secs, 600);
913 assert!(transport.config.require_auth);
914 assert_eq!(transport.config.valid_tokens, vec!["test-token"]);
915 }
916
917 #[test]
920 fn test_http_transport_new() {
921 let transport = HttpTransport::new(8080);
922 assert_eq!(transport.config.port, 8080);
923 assert_eq!(transport.config.host, "127.0.0.1");
924 assert!(!transport.is_initialized());
925 assert!(!transport.is_running());
926 }
927
928 #[test]
929 fn test_http_transport_with_config() {
930 let config = HttpConfig {
931 port: 9000,
932 host: "192.168.1.1".to_string(),
933 ..Default::default()
934 };
935 let transport = HttpTransport::with_config(config);
936 assert_eq!(transport.config.port, 9000);
937 assert_eq!(transport.config.host, "192.168.1.1");
938 assert!(!transport.is_initialized());
939 assert!(!transport.is_running());
940 }
941
942 #[test]
943 fn test_http_transport_config_access() {
944 let transport = HttpTransport::new(4000);
945 let config = transport.config();
946 assert_eq!(config.port, 4000);
947 }
948
949 #[test]
952 fn test_session_info_creation() {
953 let (tx, rx) = broadcast::channel(1024);
954 let session = SessionInfo {
955 id: "test-session".to_string(),
956 created_at: std::time::Instant::now(),
957 last_activity: std::time::Instant::now(),
958 event_sender: tx,
959 _keepalive_receiver: Arc::new(Mutex::new(rx)),
960 };
961
962 assert_eq!(session.id, "test-session");
963 }
964
965 #[test]
968 fn test_query_deserialization() {
969 let query = PostQuery {
971 session_id: Some("test123".to_string()),
972 };
973 assert_eq!(query.session_id, Some("test123".to_string()));
974 }
975
976 #[test]
979 fn test_validate_origin_no_restrictions() {
980 let config = HttpConfig {
981 allowed_origins: None,
982 ..Default::default()
983 };
984
985 let headers = HeaderMap::new();
986 assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
987
988 let mut headers = HeaderMap::new();
989 headers.insert(ORIGIN, "http://any-origin.com".parse().unwrap());
990 assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
991 }
992
993 #[test]
994 fn test_validate_origin_with_allowed_origins() {
995 let config = HttpConfig {
996 allowed_origins: Some(vec![
997 "http://localhost:3000".to_string(),
998 "https://example.com".to_string(),
999 ]),
1000 ..Default::default()
1001 };
1002
1003 let mut headers = HeaderMap::new();
1004 headers.insert(ORIGIN, "http://localhost:3000".parse().unwrap());
1005 assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
1006
1007 headers.insert(ORIGIN, "https://example.com".parse().unwrap());
1008 assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
1009
1010 headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
1011 assert!(HttpTransport::validate_origin(&config, &headers).is_err());
1012 }
1013
1014 #[test]
1015 fn test_validate_origin_missing_header() {
1016 let config = HttpConfig {
1017 allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
1018 ..Default::default()
1019 };
1020
1021 let headers = HeaderMap::new();
1022 let result = HttpTransport::validate_origin(&config, &headers);
1023 assert!(result.is_err());
1024 assert!(
1025 result
1026 .unwrap_err()
1027 .to_string()
1028 .contains("Missing Origin header")
1029 );
1030 }
1031
1032 #[test]
1033 fn test_validate_origin_invalid_header() {
1034 let config = HttpConfig {
1035 allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
1036 ..Default::default()
1037 };
1038
1039 let mut headers = HeaderMap::new();
1040 headers.insert(ORIGIN, HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap());
1042 let result = HttpTransport::validate_origin(&config, &headers);
1043 assert!(result.is_err());
1044 assert!(
1045 result
1046 .unwrap_err()
1047 .to_string()
1048 .contains("Invalid Origin header")
1049 );
1050 }
1051
1052 #[test]
1055 fn test_validate_auth_disabled() {
1056 let config = HttpConfig {
1057 require_auth: false,
1058 ..Default::default()
1059 };
1060
1061 let headers = HeaderMap::new();
1062 assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
1063 }
1064
1065 #[test]
1066 fn test_validate_auth_missing_header() {
1067 let config = HttpConfig {
1068 require_auth: true,
1069 valid_tokens: vec!["valid-token".to_string()],
1070 ..Default::default()
1071 };
1072
1073 let headers = HeaderMap::new();
1074 let result = HttpTransport::validate_auth(&config, &headers);
1075 assert!(result.is_err());
1076 assert!(
1077 result
1078 .unwrap_err()
1079 .to_string()
1080 .contains("Missing Authorization header")
1081 );
1082 }
1083
1084 #[test]
1085 fn test_validate_auth_invalid_header() {
1086 let config = HttpConfig {
1087 require_auth: true,
1088 valid_tokens: vec!["valid-token".to_string()],
1089 ..Default::default()
1090 };
1091
1092 let mut headers = HeaderMap::new();
1093 headers.insert(
1094 AUTHORIZATION,
1095 HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
1096 );
1097 let result = HttpTransport::validate_auth(&config, &headers);
1098 assert!(result.is_err());
1099 assert!(
1100 result
1101 .unwrap_err()
1102 .to_string()
1103 .contains("Invalid Authorization header")
1104 );
1105 }
1106
1107 #[test]
1108 fn test_validate_auth_valid_bearer_token() {
1109 let config = HttpConfig {
1110 require_auth: true,
1111 valid_tokens: vec!["valid-token".to_string(), "another-token".to_string()],
1112 ..Default::default()
1113 };
1114
1115 let mut headers = HeaderMap::new();
1116 headers.insert(AUTHORIZATION, "Bearer valid-token".parse().unwrap());
1117 assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
1118
1119 headers.insert(AUTHORIZATION, "Bearer another-token".parse().unwrap());
1120 assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
1121 }
1122
1123 #[test]
1124 fn test_validate_auth_invalid_bearer_token() {
1125 let config = HttpConfig {
1126 require_auth: true,
1127 valid_tokens: vec!["valid-token".to_string()],
1128 ..Default::default()
1129 };
1130
1131 let mut headers = HeaderMap::new();
1132 headers.insert(AUTHORIZATION, "Bearer invalid-token".parse().unwrap());
1133 let result = HttpTransport::validate_auth(&config, &headers);
1134 assert!(result.is_err());
1135 assert!(
1136 result
1137 .unwrap_err()
1138 .to_string()
1139 .contains("Invalid bearer token")
1140 );
1141 }
1142
1143 #[test]
1144 fn test_validate_auth_invalid_format() {
1145 let config = HttpConfig {
1146 require_auth: true,
1147 valid_tokens: vec!["valid-token".to_string()],
1148 ..Default::default()
1149 };
1150
1151 let mut headers = HeaderMap::new();
1152 headers.insert(AUTHORIZATION, "Basic dXNlcjpwYXNz".parse().unwrap());
1153 let result = HttpTransport::validate_auth(&config, &headers);
1154 assert!(result.is_err());
1155 assert!(
1156 result
1157 .unwrap_err()
1158 .to_string()
1159 .contains("Invalid Authorization format")
1160 );
1161
1162 headers.insert(AUTHORIZATION, "just-a-token".parse().unwrap());
1163 let result = HttpTransport::validate_auth(&config, &headers);
1164 assert!(result.is_err());
1165 assert!(
1166 result
1167 .unwrap_err()
1168 .to_string()
1169 .contains("Invalid Authorization format")
1170 );
1171 }
1172
1173 #[tokio::test]
1176 async fn test_ensure_session_new() {
1177 let state = create_test_state();
1178
1179 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1181 assert!(!session_id.is_empty());
1182
1183 let sessions = state.sessions.read().await;
1185 assert!(sessions.contains_key(&session_id));
1186 assert_eq!(sessions.len(), 1);
1187 }
1188
1189 #[tokio::test]
1190 async fn test_ensure_session_with_provided_id() {
1191 let state = create_test_state();
1192
1193 let provided_id = "my-session-123".to_string();
1195 let session_id =
1196 HttpTransport::ensure_session(state.clone(), Some(provided_id.clone())).await;
1197 assert_eq!(session_id, provided_id);
1198
1199 let sessions = state.sessions.read().await;
1201 assert!(sessions.contains_key(&session_id));
1202 assert_eq!(sessions.len(), 1);
1203 }
1204
1205 #[tokio::test]
1206 async fn test_ensure_session_existing() {
1207 let state = create_test_state();
1208
1209 let session_id = "existing-session".to_string();
1211 let result1 = HttpTransport::ensure_session(state.clone(), Some(session_id.clone())).await;
1212 assert_eq!(result1, session_id);
1213
1214 let result2 = HttpTransport::ensure_session(state.clone(), Some(session_id.clone())).await;
1216 assert_eq!(result2, session_id);
1217
1218 let sessions = state.sessions.read().await;
1220 assert_eq!(sessions.len(), 1);
1221 }
1222
1223 #[tokio::test]
1224 async fn test_update_session_activity() {
1225 let state = create_test_state();
1226
1227 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1229
1230 let initial_activity = {
1232 let sessions = state.sessions.read().await;
1233 sessions.get(&session_id).unwrap().last_activity
1234 };
1235
1236 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
1238 HttpTransport::update_session_activity(state.clone(), &session_id).await;
1239
1240 let updated_activity = {
1242 let sessions = state.sessions.read().await;
1243 sessions.get(&session_id).unwrap().last_activity
1244 };
1245
1246 assert!(updated_activity > initial_activity);
1247 }
1248
1249 #[tokio::test]
1250 async fn test_update_session_activity_nonexistent() {
1251 let state = create_test_state();
1252
1253 HttpTransport::update_session_activity(state.clone(), "nonexistent").await;
1255
1256 let sessions = state.sessions.read().await;
1258 assert_eq!(sessions.len(), 0);
1259 }
1260
1261 #[tokio::test]
1262 async fn test_cleanup_sessions() {
1263 let config = HttpConfig {
1264 session_timeout_secs: 1, ..Default::default()
1266 };
1267
1268 let state = Arc::new(HttpState {
1269 handler: Arc::new(Box::new(mock_handler)),
1270 config,
1271 sessions: Arc::new(RwLock::new(HashMap::new())),
1272 });
1273
1274 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1276
1277 {
1279 let mut sessions = state.sessions.write().await;
1280 if let Some(session) = sessions.get_mut(&session_id) {
1281 session.last_activity =
1282 std::time::Instant::now() - std::time::Duration::from_secs(2);
1283 }
1284 }
1285
1286 HttpTransport::cleanup_sessions(state.clone()).await;
1288
1289 let sessions = state.sessions.read().await;
1291 assert_eq!(sessions.len(), 0);
1292 }
1293
1294 #[tokio::test]
1295 async fn test_cleanup_sessions_keeps_active() {
1296 let state = create_test_state();
1297
1298 let session_id1 = HttpTransport::ensure_session(state.clone(), None).await;
1300 let session_id2 = HttpTransport::ensure_session(state.clone(), None).await;
1301
1302 HttpTransport::cleanup_sessions(state.clone()).await;
1304
1305 let sessions = state.sessions.read().await;
1307 assert_eq!(sessions.len(), 2);
1308 assert!(sessions.contains_key(&session_id1));
1309 assert!(sessions.contains_key(&session_id2));
1310 }
1311
1312 #[tokio::test]
1315 async fn test_broadcast_message_not_initialized() {
1316 let transport = HttpTransport::new(3000);
1317 let result = transport.broadcast_message("test message").await;
1318 assert!(result.is_err());
1319 assert!(
1320 result
1321 .unwrap_err()
1322 .to_string()
1323 .contains("Transport not started")
1324 );
1325 }
1326
1327 #[tokio::test]
1328 async fn test_broadcast_message_with_sessions() {
1329 let state = create_test_state();
1330
1331 let transport = HttpTransport {
1333 config: HttpConfig::default(),
1334 state: Some((*state).clone()),
1335 server_handle: None,
1336 };
1337
1338 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1340
1341 let mut receiver = {
1343 let sessions = state.sessions.read().await;
1344 sessions.get(&session_id).unwrap().event_sender.subscribe()
1345 };
1346
1347 let result = transport.broadcast_message("test broadcast").await;
1349 assert!(result.is_ok());
1350
1351 let received = receiver.recv().await.unwrap();
1353 assert_eq!(received, "test broadcast");
1354 }
1355
1356 #[tokio::test]
1357 async fn test_broadcast_message_no_sessions() {
1358 let state = create_test_state();
1359
1360 let transport = HttpTransport {
1362 config: HttpConfig::default(),
1363 state: Some((*state).clone()),
1364 server_handle: None,
1365 };
1366
1367 let result = transport.broadcast_message("test broadcast").await;
1369 assert!(result.is_ok());
1370 }
1371
1372 #[tokio::test]
1375 async fn test_handle_post_valid_wrapped_message() {
1376 let state = create_test_state();
1377 let query = PostQuery {
1378 session_id: Some("test-session".to_string()),
1379 };
1380 let headers = create_test_headers();
1381 let body = json!({
1382 "message": {
1383 "jsonrpc": "2.0",
1384 "method": "ping",
1385 "params": {},
1386 "id": 1
1387 }
1388 })
1389 .to_string();
1390
1391 let result = handle_post(State(state), Query(query), headers, body).await;
1392 assert!(result.is_ok());
1393 }
1394
1395 #[tokio::test]
1396 async fn test_handle_post_valid_direct_message() {
1397 let state = create_test_state();
1398 let query = PostQuery { session_id: None };
1399 let headers = create_test_headers();
1400 let body = json!({
1401 "jsonrpc": "2.0",
1402 "method": "ping",
1403 "params": {},
1404 "id": 1
1405 })
1406 .to_string();
1407
1408 let result = handle_post(State(state), Query(query), headers, body).await;
1409 assert!(result.is_ok());
1410 }
1411
1412 #[tokio::test]
1413 async fn test_handle_post_invalid_json() {
1414 let state = create_test_state();
1415 let query = PostQuery { session_id: None };
1416 let headers = create_test_headers();
1417 let body = "invalid json".to_string();
1418
1419 let result = handle_post(State(state), Query(query), headers, body).await;
1420 assert!(result.is_err());
1421 assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1422 }
1423
1424 #[tokio::test]
1425 async fn test_handle_post_invalid_format() {
1426 let state = create_test_state();
1427 let query = PostQuery { session_id: None };
1428 let headers = create_test_headers();
1429 let body = json!({
1430 "not_jsonrpc": "data"
1431 })
1432 .to_string();
1433
1434 let result = handle_post(State(state), Query(query), headers, body).await;
1435 assert!(result.is_err());
1436 assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1437 }
1438
1439 #[tokio::test]
1440 async fn test_handle_post_origin_validation_failure() {
1441 let config = HttpConfig {
1442 allowed_origins: Some(vec!["http://allowed.com".to_string()]),
1443 ..Default::default()
1444 };
1445
1446 let state = Arc::new(HttpState {
1447 handler: Arc::new(Box::new(mock_handler)),
1448 config,
1449 sessions: Arc::new(RwLock::new(HashMap::new())),
1450 });
1451
1452 let query = PostQuery { session_id: None };
1453 let mut headers = create_test_headers();
1454 headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
1455 let body = json!({
1456 "jsonrpc": "2.0",
1457 "method": "ping",
1458 "params": {},
1459 "id": 1
1460 })
1461 .to_string();
1462
1463 let result = handle_post(State(state), Query(query), headers, body).await;
1464 assert!(result.is_err());
1465 assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
1466 }
1467
1468 #[tokio::test]
1469 async fn test_handle_post_auth_failure() {
1470 let config = HttpConfig {
1471 require_auth: true,
1472 valid_tokens: vec!["valid-token".to_string()],
1473 ..Default::default()
1474 };
1475
1476 let state = Arc::new(HttpState {
1477 handler: Arc::new(Box::new(mock_handler)),
1478 config,
1479 sessions: Arc::new(RwLock::new(HashMap::new())),
1480 });
1481
1482 let query = PostQuery { session_id: None };
1483 let mut headers = create_test_headers();
1484 headers.insert(AUTHORIZATION, "Bearer invalid-token".parse().unwrap());
1485 let body = json!({
1486 "jsonrpc": "2.0",
1487 "method": "ping",
1488 "params": {},
1489 "id": 1
1490 })
1491 .to_string();
1492
1493 let result = handle_post(State(state), Query(query), headers, body).await;
1494 assert!(result.is_err());
1495 assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
1496 }
1497
1498 #[tokio::test]
1499 async fn test_handle_post_message_validation_failure() {
1500 let config = HttpConfig {
1501 validate_messages: true,
1502 max_message_size: 10, ..Default::default()
1504 };
1505
1506 let state = Arc::new(HttpState {
1507 handler: Arc::new(Box::new(mock_handler)),
1508 config,
1509 sessions: Arc::new(RwLock::new(HashMap::new())),
1510 });
1511
1512 let query = PostQuery { session_id: None };
1513 let headers = create_test_headers();
1514 let body = json!({
1515 "jsonrpc": "2.0",
1516 "method": "this_is_a_very_long_method_name_that_exceeds_the_size_limit",
1517 "params": {},
1518 "id": 1
1519 })
1520 .to_string();
1521
1522 let result = handle_post(State(state), Query(query), headers, body).await;
1523 assert!(result.is_err());
1524 assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1525 }
1526
1527 #[tokio::test]
1528 async fn test_handle_post_streamable_http_mode() {
1529 let state = create_test_state();
1530 let query = PostQuery { session_id: None };
1531 let mut headers = create_test_headers();
1532 headers.insert(
1533 "accept",
1534 "text/event-stream, application/json".parse().unwrap(),
1535 );
1536 let body = json!({
1537 "jsonrpc": "2.0",
1538 "method": "ping",
1539 "params": {},
1540 "id": 1
1541 })
1542 .to_string();
1543
1544 let result = handle_post(State(state), Query(query), headers, body).await;
1545 assert!(result.is_ok());
1546
1547 let response = result.unwrap();
1548 assert_eq!(response.status(), StatusCode::OK);
1549 assert!(
1550 response
1551 .headers()
1552 .get("Content-Type")
1553 .unwrap()
1554 .to_str()
1555 .unwrap()
1556 .contains("application/json")
1557 );
1558 assert!(response.headers().contains_key("Mcp-Session-Id"));
1559 }
1560
1561 #[tokio::test]
1562 async fn test_handle_post_sse_mode() {
1563 let state = create_test_state();
1564 let query = PostQuery { session_id: None };
1565 let mut headers = create_test_headers();
1566 headers.insert("accept", "text/event-stream".parse().unwrap());
1567 let body = json!({
1568 "jsonrpc": "2.0",
1569 "method": "ping",
1570 "params": {},
1571 "id": 1
1572 })
1573 .to_string();
1574
1575 let result = handle_post(State(state), Query(query), headers, body).await;
1576 assert!(result.is_ok());
1577
1578 let response = result.unwrap();
1579 assert_eq!(response.status(), StatusCode::NO_CONTENT);
1580 assert!(response.headers().contains_key("Mcp-Session-Id"));
1581 }
1582
1583 #[tokio::test]
1584 async fn test_handle_post_notification_response() {
1585 let state = Arc::new(HttpState {
1586 handler: Arc::new(Box::new(mock_notification_handler)),
1587 config: HttpConfig::default(),
1588 sessions: Arc::new(RwLock::new(HashMap::new())),
1589 });
1590
1591 let query = PostQuery { session_id: None };
1592 let headers = create_test_headers();
1593 let body = json!({
1594 "jsonrpc": "2.0",
1595 "method": "notification",
1596 "params": {}
1597 })
1598 .to_string();
1599
1600 let result = handle_post(State(state), Query(query), headers, body).await;
1601 assert!(result.is_ok());
1602
1603 let response = result.unwrap();
1604 assert_eq!(response.status(), StatusCode::NO_CONTENT);
1605 }
1606
1607 #[tokio::test]
1608 async fn test_handle_post_processing_error() {
1609 let state = Arc::new(HttpState {
1610 handler: Arc::new(Box::new(mock_error_handler)),
1611 config: HttpConfig::default(),
1612 sessions: Arc::new(RwLock::new(HashMap::new())),
1613 });
1614
1615 let query = PostQuery { session_id: None };
1616 let mut headers = create_test_headers();
1617 headers.insert("accept", "application/json".parse().unwrap());
1618 let body = json!({
1619 "jsonrpc": "2.0",
1620 "method": "unknown_method",
1621 "params": {},
1622 "id": 1
1623 })
1624 .to_string();
1625
1626 let result = handle_post(State(state), Query(query), headers, body).await;
1627 assert!(result.is_ok());
1628
1629 let response = result.unwrap();
1630 assert_eq!(response.status(), StatusCode::OK);
1631
1632 let body_str = response.body();
1634 assert!(body_str.contains("error"));
1635 }
1636
1637 #[tokio::test]
1638 async fn test_handle_post_session_id_from_header() {
1639 let state = create_test_state();
1640 let query = PostQuery { session_id: None };
1641 let mut headers = create_test_headers();
1642 headers.insert("Mcp-Session-Id", "header-session-123".parse().unwrap());
1643 let body = json!({
1644 "jsonrpc": "2.0",
1645 "method": "ping",
1646 "params": {},
1647 "id": 1
1648 })
1649 .to_string();
1650
1651 let result = handle_post(State(state.clone()), Query(query), headers, body).await;
1652 assert!(result.is_ok());
1653
1654 let sessions = state.sessions.read().await;
1656 assert!(sessions.contains_key("header-session-123"));
1657 }
1658
1659 #[tokio::test]
1662 async fn test_handle_sse_basic() {
1663 let state = create_test_state();
1664 let query = SseQuery {
1665 session_id: Some("sse-test-session".to_string()),
1666 last_event_id: None,
1667 transport_type: None,
1668 url: None,
1669 };
1670 let headers = create_test_headers();
1671 let uri = "http://localhost:3000/sse?sessionId=sse-test-session"
1672 .parse()
1673 .unwrap();
1674
1675 let result = handle_sse(uri, State(state.clone()), headers, Query(query)).await;
1676 assert!(result.is_ok());
1677
1678 let response = result.unwrap();
1679 assert_eq!(response.status(), StatusCode::OK);
1680 assert!(response.headers().contains_key("Mcp-Session-Id"));
1681 assert_eq!(
1682 response
1683 .headers()
1684 .get("content-type")
1685 .unwrap()
1686 .to_str()
1687 .unwrap(),
1688 "text/event-stream"
1689 );
1690
1691 let sessions = state.sessions.read().await;
1693 assert!(sessions.contains_key("sse-test-session"));
1694 }
1695
1696 #[tokio::test]
1697 async fn test_handle_sse_origin_validation_failure() {
1698 let config = HttpConfig {
1699 allowed_origins: Some(vec!["http://allowed.com".to_string()]),
1700 ..Default::default()
1701 };
1702
1703 let state = Arc::new(HttpState {
1704 handler: Arc::new(Box::new(mock_handler)),
1705 config,
1706 sessions: Arc::new(RwLock::new(HashMap::new())),
1707 });
1708
1709 let query = SseQuery {
1710 session_id: None,
1711 last_event_id: None,
1712 transport_type: None,
1713 url: None,
1714 };
1715 let mut headers = create_test_headers();
1716 headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
1717 let uri = "http://localhost:3000/sse".parse().unwrap();
1718
1719 let result = handle_sse(uri, State(state), headers, Query(query)).await;
1720 assert!(result.is_err());
1721 assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
1722 }
1723
1724 #[tokio::test]
1725 async fn test_handle_sse_auth_failure() {
1726 let config = HttpConfig {
1727 require_auth: true,
1728 valid_tokens: vec!["valid-token".to_string()],
1729 ..Default::default()
1730 };
1731
1732 let state = Arc::new(HttpState {
1733 handler: Arc::new(Box::new(mock_handler)),
1734 config,
1735 sessions: Arc::new(RwLock::new(HashMap::new())),
1736 });
1737
1738 let query = SseQuery {
1739 session_id: None,
1740 last_event_id: None,
1741 transport_type: None,
1742 url: None,
1743 };
1744 let headers = create_test_headers();
1745 let uri = "http://localhost:3000/sse".parse().unwrap();
1746
1747 let result = handle_sse(uri, State(state), headers, Query(query)).await;
1748 assert!(result.is_err());
1749 assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
1750 }
1751
1752 #[tokio::test]
1755 async fn test_handle_health() {
1756 let result = handle_health().await;
1757 assert_eq!(result, "OK");
1758 }
1759
1760 #[tokio::test]
1763 async fn test_transport_start_invalid_address() {
1764 let config = HttpConfig {
1765 host: "invalid-host-name-that-does-not-exist".to_string(),
1766 port: 0,
1767 ..Default::default()
1768 };
1769 let mut transport = HttpTransport::with_config(config);
1770
1771 let result = transport.start(Box::new(mock_handler)).await;
1772 assert!(result.is_err());
1773 assert!(result.unwrap_err().to_string().contains("Invalid address"));
1774 }
1775
1776 #[tokio::test]
1777 async fn test_transport_start_and_stop() {
1778 let mut transport = HttpTransport::new(0); assert!(!transport.is_initialized());
1782 assert!(!transport.is_running());
1783
1784 let result = transport.start(Box::new(mock_handler)).await;
1786 assert!(result.is_ok());
1787 assert!(transport.is_initialized());
1788 assert!(transport.is_running());
1789
1790 assert!(transport.health_check().await.is_ok());
1792
1793 let result = transport.stop().await;
1795 assert!(result.is_ok());
1796 assert!(!transport.is_running());
1797 }
1798
1799 #[tokio::test]
1800 async fn test_transport_health_check_not_running() {
1801 let transport = HttpTransport::new(3000);
1802 let result = transport.health_check().await;
1803 assert!(result.is_err());
1804 assert!(
1805 result
1806 .unwrap_err()
1807 .to_string()
1808 .contains("HTTP transport not running")
1809 );
1810 }
1811
1812 #[tokio::test]
1815 async fn test_full_session_lifecycle() {
1816 let state = create_test_state();
1817
1818 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1820 assert!(!session_id.is_empty());
1821
1822 HttpTransport::update_session_activity(state.clone(), &session_id).await;
1824
1825 let message = "test message";
1827 {
1828 let sessions = state.sessions.read().await;
1829 let session = sessions.get(&session_id).unwrap();
1830 let result = session.event_sender.send(message.to_string());
1831 assert!(result.is_ok());
1832 }
1833
1834 HttpTransport::cleanup_sessions(state.clone()).await;
1836 {
1837 let sessions = state.sessions.read().await;
1838 assert!(sessions.contains_key(&session_id));
1839 }
1840 }
1841
1842 #[tokio::test]
1843 async fn test_multiple_sessions() {
1844 let state = create_test_state();
1845
1846 let session1 =
1848 HttpTransport::ensure_session(state.clone(), Some("session-1".to_string())).await;
1849 let session2 =
1850 HttpTransport::ensure_session(state.clone(), Some("session-2".to_string())).await;
1851 let session3 = HttpTransport::ensure_session(state.clone(), None).await;
1852
1853 assert_eq!(session1, "session-1");
1854 assert_eq!(session2, "session-2");
1855 assert!(!session3.is_empty());
1856 assert_ne!(session3, session1);
1857 assert_ne!(session3, session2);
1858
1859 let sessions = state.sessions.read().await;
1861 assert_eq!(sessions.len(), 3);
1862 assert!(sessions.contains_key(&session1));
1863 assert!(sessions.contains_key(&session2));
1864 assert!(sessions.contains_key(&session3));
1865 }
1866
1867 #[tokio::test]
1868 async fn test_message_format_variations() {
1869 let state = create_test_state();
1870 let query = PostQuery { session_id: None };
1871 let headers = create_test_headers();
1872
1873 let wrapped_body = json!({
1875 "message": {
1876 "jsonrpc": "2.0",
1877 "method": "test",
1878 "params": {"key": "value"},
1879 "id": 1
1880 }
1881 })
1882 .to_string();
1883
1884 let result = handle_post(
1885 State(state.clone()),
1886 Query(query.clone()),
1887 headers.clone(),
1888 wrapped_body,
1889 )
1890 .await;
1891 assert!(result.is_ok());
1892
1893 let direct_body = json!({
1895 "jsonrpc": "2.0",
1896 "method": "test",
1897 "params": {"key": "value"},
1898 "id": 2
1899 })
1900 .to_string();
1901
1902 let result = handle_post(State(state), Query(query), headers, direct_body).await;
1903 assert!(result.is_ok());
1904 }
1905
1906 #[tokio::test]
1907 async fn test_error_handling_edge_cases() {
1908 let state = create_test_state();
1909 let query = PostQuery { session_id: None };
1910 let headers = create_test_headers();
1911
1912 let invalid_jsonrpc = json!({
1914 "jsonrpc": "1.0", "method": "test"
1916 })
1918 .to_string();
1919
1920 let result = handle_post(State(state), Query(query), headers, invalid_jsonrpc).await;
1921 assert!(result.is_err());
1922 assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1923 }
1924
1925 #[test]
1928 fn test_config_extreme_values() {
1929 let config = HttpConfig {
1930 port: 65535,
1931 host: "0.0.0.0".to_string(),
1932 max_message_size: 0,
1933 enable_cors: true,
1934 allowed_origins: Some(vec![]),
1935 validate_messages: false,
1936 session_timeout_secs: 0,
1937 require_auth: false,
1938 valid_tokens: vec![],
1939 };
1940
1941 let transport = HttpTransport::with_config(config);
1942 assert_eq!(transport.config.port, 65535);
1943 assert_eq!(transport.config.max_message_size, 0);
1944 assert_eq!(transport.config.session_timeout_secs, 0);
1945 assert!(
1946 transport
1947 .config
1948 .allowed_origins
1949 .as_ref()
1950 .unwrap()
1951 .is_empty()
1952 );
1953 }
1954
1955 #[test]
1956 fn test_session_info_timing() {
1957 let now = std::time::Instant::now();
1958 let (tx, rx) = broadcast::channel(1024);
1959
1960 let session = SessionInfo {
1961 id: "timing-test".to_string(),
1962 created_at: now,
1963 last_activity: now,
1964 event_sender: tx,
1965 _keepalive_receiver: Arc::new(Mutex::new(rx)),
1966 };
1967
1968 assert!(session.created_at <= std::time::Instant::now());
1969 assert!(session.last_activity <= std::time::Instant::now());
1970 }
1971
1972 #[tokio::test]
1975 async fn test_broadcast_channel_receiver_drop() {
1976 let state = create_test_state();
1977
1978 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1980 let receiver = {
1981 let sessions = state.sessions.read().await;
1982 sessions.get(&session_id).unwrap().event_sender.subscribe()
1983 };
1984
1985 drop(receiver);
1987
1988 let transport = HttpTransport {
1990 config: HttpConfig::default(),
1991 state: Some((*state).clone()),
1992 server_handle: None,
1993 };
1994
1995 let result = transport.broadcast_message("test after drop").await;
1997 assert!(result.is_ok());
1998 }
1999
2000 #[tokio::test]
2001 async fn test_session_channel_capacity() {
2002 let state = create_test_state();
2003 let session_id = HttpTransport::ensure_session(state.clone(), None).await;
2004
2005 let sender = {
2007 let sessions = state.sessions.read().await;
2008 sessions.get(&session_id).unwrap().event_sender.clone()
2009 };
2010
2011 for i in 0..2000 {
2013 let _ = sender.send(format!("message-{i}"));
2015 }
2016
2017 }
2020}