1use std::collections::HashMap;
52use std::io::{Read, Write};
53use std::sync::{Arc, Mutex};
54use std::time::{Duration, Instant};
55
56use asupersync::Cx;
57use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
58
59use crate::{Codec, CodecError, Transport, TransportError};
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum HttpMethod {
68 Get,
69 Post,
70 Put,
71 Delete,
72 Options,
73 Head,
74 Patch,
75}
76
77impl HttpMethod {
78 #[must_use]
80 pub fn parse(s: &str) -> Option<Self> {
81 match s.to_uppercase().as_str() {
82 "GET" => Some(Self::Get),
83 "POST" => Some(Self::Post),
84 "PUT" => Some(Self::Put),
85 "DELETE" => Some(Self::Delete),
86 "OPTIONS" => Some(Self::Options),
87 "HEAD" => Some(Self::Head),
88 "PATCH" => Some(Self::Patch),
89 _ => None,
90 }
91 }
92
93 #[must_use]
95 pub fn as_str(&self) -> &'static str {
96 match self {
97 Self::Get => "GET",
98 Self::Post => "POST",
99 Self::Put => "PUT",
100 Self::Delete => "DELETE",
101 Self::Options => "OPTIONS",
102 Self::Head => "HEAD",
103 Self::Patch => "PATCH",
104 }
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub struct HttpStatus(pub u16);
111
112impl HttpStatus {
113 pub const OK: Self = Self(200);
114 pub const ACCEPTED: Self = Self(202);
115 pub const BAD_REQUEST: Self = Self(400);
116 pub const UNAUTHORIZED: Self = Self(401);
117 pub const FORBIDDEN: Self = Self(403);
118 pub const NOT_FOUND: Self = Self(404);
119 pub const METHOD_NOT_ALLOWED: Self = Self(405);
120 pub const INTERNAL_SERVER_ERROR: Self = Self(500);
121 pub const SERVICE_UNAVAILABLE: Self = Self(503);
122
123 #[must_use]
125 pub fn is_success(&self) -> bool {
126 (200..300).contains(&self.0)
127 }
128
129 #[must_use]
131 pub fn is_client_error(&self) -> bool {
132 (400..500).contains(&self.0)
133 }
134
135 #[must_use]
137 pub fn is_server_error(&self) -> bool {
138 (500..600).contains(&self.0)
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct HttpRequest {
145 pub method: HttpMethod,
147 pub path: String,
149 pub headers: HashMap<String, String>,
151 pub body: Vec<u8>,
153 pub query: HashMap<String, String>,
155}
156
157impl HttpRequest {
158 #[must_use]
160 pub fn new(method: HttpMethod, path: impl Into<String>) -> Self {
161 Self {
162 method,
163 path: path.into(),
164 headers: HashMap::new(),
165 body: Vec::new(),
166 query: HashMap::new(),
167 }
168 }
169
170 #[must_use]
172 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
173 self.headers
174 .insert(name.into().to_lowercase(), value.into());
175 self
176 }
177
178 #[must_use]
180 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
181 self.body = body.into();
182 self
183 }
184
185 #[must_use]
187 pub fn with_query(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
188 self.query.insert(name.into(), value.into());
189 self
190 }
191
192 #[must_use]
194 pub fn header(&self, name: &str) -> Option<&str> {
195 self.headers.get(&name.to_lowercase()).map(String::as_str)
196 }
197
198 #[must_use]
200 pub fn content_type(&self) -> Option<&str> {
201 self.header("content-type")
202 }
203
204 #[must_use]
206 pub fn authorization(&self) -> Option<&str> {
207 self.header("authorization")
208 }
209
210 pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
212 serde_json::from_slice(&self.body)
213 }
214}
215
216#[derive(Debug, Clone)]
218pub struct HttpResponse {
219 pub status: HttpStatus,
221 pub headers: HashMap<String, String>,
223 pub body: Vec<u8>,
225}
226
227impl HttpResponse {
228 #[must_use]
230 pub fn new(status: HttpStatus) -> Self {
231 let mut headers = HashMap::new();
232 headers.insert("content-type".to_string(), "application/json".to_string());
233 Self {
234 status,
235 headers,
236 body: Vec::new(),
237 }
238 }
239
240 #[must_use]
242 pub fn ok() -> Self {
243 Self::new(HttpStatus::OK)
244 }
245
246 #[must_use]
248 pub fn bad_request() -> Self {
249 Self::new(HttpStatus::BAD_REQUEST)
250 }
251
252 #[must_use]
254 pub fn internal_error() -> Self {
255 Self::new(HttpStatus::INTERNAL_SERVER_ERROR)
256 }
257
258 #[must_use]
260 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
261 self.headers
262 .insert(name.into().to_lowercase(), value.into());
263 self
264 }
265
266 #[must_use]
268 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
269 self.body = body.into();
270 self
271 }
272
273 #[must_use]
275 pub fn with_json<T: serde::Serialize>(mut self, value: &T) -> Self {
276 self.body = serde_json::to_vec(value).unwrap_or_default();
277 self.headers
278 .insert("content-type".to_string(), "application/json".to_string());
279 self
280 }
281
282 #[must_use]
284 pub fn with_cors(mut self, origin: &str) -> Self {
285 self.headers.insert(
286 "access-control-allow-origin".to_string(),
287 origin.to_string(),
288 );
289 self.headers.insert(
290 "access-control-allow-methods".to_string(),
291 "GET, POST, OPTIONS".to_string(),
292 );
293 self.headers.insert(
294 "access-control-allow-headers".to_string(),
295 "Content-Type, Authorization".to_string(),
296 );
297 self
298 }
299}
300
301#[derive(Debug)]
307pub enum HttpError {
308 InvalidMethod(String),
310 InvalidContentType(String),
312 JsonError(serde_json::Error),
314 CodecError(CodecError),
316 Timeout,
318 Closed,
320 Transport(TransportError),
322}
323
324impl std::fmt::Display for HttpError {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 match self {
327 Self::InvalidMethod(m) => write!(f, "invalid HTTP method: {}", m),
328 Self::InvalidContentType(ct) => write!(f, "invalid content type: {}", ct),
329 Self::JsonError(e) => write!(f, "JSON error: {}", e),
330 Self::CodecError(e) => write!(f, "codec error: {}", e),
331 Self::Timeout => write!(f, "request timeout"),
332 Self::Closed => write!(f, "connection closed"),
333 Self::Transport(e) => write!(f, "transport error: {}", e),
334 }
335 }
336}
337
338impl std::error::Error for HttpError {}
339
340impl From<serde_json::Error> for HttpError {
341 fn from(err: serde_json::Error) -> Self {
342 Self::JsonError(err)
343 }
344}
345
346impl From<CodecError> for HttpError {
347 fn from(err: CodecError) -> Self {
348 Self::CodecError(err)
349 }
350}
351
352impl From<TransportError> for HttpError {
353 fn from(err: TransportError) -> Self {
354 Self::Transport(err)
355 }
356}
357
358#[derive(Debug, Clone)]
364pub struct HttpHandlerConfig {
365 pub base_path: String,
367 pub allow_cors: bool,
369 pub cors_origins: Vec<String>,
371 pub timeout: Duration,
373 pub max_body_size: usize,
375}
376
377impl Default for HttpHandlerConfig {
378 fn default() -> Self {
379 Self {
380 base_path: "/mcp/v1".to_string(),
381 allow_cors: true,
382 cors_origins: vec!["*".to_string()],
383 timeout: Duration::from_secs(30),
384 max_body_size: 10 * 1024 * 1024, }
386 }
387}
388
389pub struct HttpRequestHandler {
395 config: HttpHandlerConfig,
396 codec: Codec,
397}
398
399impl HttpRequestHandler {
400 #[must_use]
402 pub fn new() -> Self {
403 Self::with_config(HttpHandlerConfig::default())
404 }
405
406 #[must_use]
408 pub fn with_config(config: HttpHandlerConfig) -> Self {
409 Self {
410 config,
411 codec: Codec::new(),
412 }
413 }
414
415 #[must_use]
417 pub fn config(&self) -> &HttpHandlerConfig {
418 &self.config
419 }
420
421 #[must_use]
423 pub fn handle_options(&self, request: &HttpRequest) -> HttpResponse {
424 if !self.config.allow_cors {
425 return HttpResponse::new(HttpStatus::METHOD_NOT_ALLOWED);
426 }
427
428 let origin = request.header("origin").unwrap_or("*");
429 let allowed = self.is_origin_allowed(origin);
430
431 if !allowed {
432 return HttpResponse::new(HttpStatus::FORBIDDEN);
433 }
434
435 HttpResponse::new(HttpStatus::OK)
436 .with_cors(origin)
437 .with_header("access-control-max-age", "86400")
438 }
439
440 #[must_use]
442 pub fn is_origin_allowed(&self, origin: &str) -> bool {
443 self.config
444 .cors_origins
445 .iter()
446 .any(|o| o == "*" || o == origin)
447 }
448
449 pub fn parse_request(&self, request: &HttpRequest) -> Result<JsonRpcRequest, HttpError> {
451 if request.method != HttpMethod::Post {
453 return Err(HttpError::InvalidMethod(
454 request.method.as_str().to_string(),
455 ));
456 }
457
458 let content_type = request.content_type().unwrap_or("");
460 if !content_type.starts_with("application/json") {
461 return Err(HttpError::InvalidContentType(content_type.to_string()));
462 }
463
464 if request.body.len() > self.config.max_body_size {
466 return Err(HttpError::InvalidContentType(format!(
467 "body size {} exceeds limit {}",
468 request.body.len(),
469 self.config.max_body_size
470 )));
471 }
472
473 let json_rpc: JsonRpcRequest = serde_json::from_slice(&request.body)?;
475 Ok(json_rpc)
476 }
477
478 #[must_use]
480 pub fn create_response(
481 &self,
482 response: &JsonRpcResponse,
483 origin: Option<&str>,
484 ) -> HttpResponse {
485 let body = self.codec.encode_response(response).unwrap_or_default();
486
487 let mut http_response = HttpResponse::ok()
488 .with_body(body)
489 .with_header("content-type", "application/json");
490
491 if self.config.allow_cors {
492 if let Some(origin) = origin {
493 if self.is_origin_allowed(origin) {
494 http_response = http_response.with_cors(origin);
495 }
496 }
497 }
498
499 http_response
500 }
501
502 #[must_use]
504 pub fn error_response(&self, status: HttpStatus, message: &str) -> HttpResponse {
505 let error = serde_json::json!({
506 "error": {
507 "code": -32600,
508 "message": message
509 }
510 });
511
512 HttpResponse::new(status).with_json(&error)
513 }
514}
515
516impl Default for HttpRequestHandler {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522pub struct HttpTransport<R, W> {
532 reader: R,
533 writer: W,
534 codec: Codec,
535 closed: bool,
536 pending_responses: Vec<JsonRpcResponse>,
537}
538
539impl<R: Read, W: Write> HttpTransport<R, W> {
540 #[must_use]
542 pub fn new(reader: R, writer: W) -> Self {
543 Self {
544 reader,
545 writer,
546 codec: Codec::new(),
547 closed: false,
548 pending_responses: Vec::new(),
549 }
550 }
551
552 pub fn read_request(&mut self) -> Result<HttpRequest, HttpError> {
557 let mut buffer = Vec::new();
558 let mut byte = [0u8; 1];
559
560 loop {
562 if self
563 .reader
564 .read(&mut byte)
565 .map_err(|e| HttpError::Transport(e.into()))?
566 == 0
567 {
568 return Err(HttpError::Closed);
569 }
570 buffer.push(byte[0]);
571
572 if buffer.ends_with(b"\r\n\r\n") {
573 break;
574 }
575
576 if buffer.len() > 64 * 1024 {
578 return Err(HttpError::InvalidContentType(
579 "headers too large".to_string(),
580 ));
581 }
582 }
583
584 let header_str = String::from_utf8_lossy(&buffer);
585 let mut lines = header_str.lines();
586
587 let request_line = lines
589 .next()
590 .ok_or_else(|| HttpError::InvalidMethod("missing request line".to_string()))?;
591
592 let parts: Vec<&str> = request_line.split_whitespace().collect();
593 if parts.len() < 2 {
594 return Err(HttpError::InvalidMethod("invalid request line".to_string()));
595 }
596
597 let method = HttpMethod::parse(parts[0])
598 .ok_or_else(|| HttpError::InvalidMethod(parts[0].to_string()))?;
599
600 let path = parts[1].to_string();
601
602 let mut headers = HashMap::new();
604 for line in lines {
605 if line.is_empty() {
606 break;
607 }
608 if let Some((name, value)) = line.split_once(':') {
609 headers.insert(name.trim().to_lowercase(), value.trim().to_string());
610 }
611 }
612
613 let content_length: usize = headers
615 .get("content-length")
616 .and_then(|s| s.parse().ok())
617 .unwrap_or(0);
618
619 let mut body = vec![0u8; content_length];
620 if content_length > 0 {
621 self.reader
622 .read_exact(&mut body)
623 .map_err(|e| HttpError::Transport(e.into()))?;
624 }
625
626 Ok(HttpRequest {
627 method,
628 path,
629 headers,
630 body,
631 query: HashMap::new(),
632 })
633 }
634
635 pub fn write_response(&mut self, response: &HttpResponse) -> Result<(), HttpError> {
637 let status_text = match response.status.0 {
638 200 => "OK",
639 400 => "Bad Request",
640 401 => "Unauthorized",
641 403 => "Forbidden",
642 404 => "Not Found",
643 500 => "Internal Server Error",
644 _ => "Unknown",
645 };
646
647 write!(
649 self.writer,
650 "HTTP/1.1 {} {}\r\n",
651 response.status.0, status_text
652 )
653 .map_err(|e| HttpError::Transport(e.into()))?;
654
655 for (name, value) in &response.headers {
657 write!(self.writer, "{}: {}\r\n", name, value)
658 .map_err(|e| HttpError::Transport(e.into()))?;
659 }
660
661 if !response.headers.contains_key("content-length") {
663 write!(self.writer, "content-length: {}\r\n", response.body.len())
664 .map_err(|e| HttpError::Transport(e.into()))?;
665 }
666
667 write!(self.writer, "\r\n").map_err(|e| HttpError::Transport(e.into()))?;
669
670 self.writer
672 .write_all(&response.body)
673 .map_err(|e| HttpError::Transport(e.into()))?;
674 self.writer
675 .flush()
676 .map_err(|e| HttpError::Transport(e.into()))?;
677
678 Ok(())
679 }
680
681 pub fn queue_response(&mut self, response: JsonRpcResponse) {
683 self.pending_responses.push(response);
684 }
685}
686
687impl<R: Read, W: Write> Transport for HttpTransport<R, W> {
688 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
689 if cx.is_cancel_requested() {
690 return Err(TransportError::Cancelled);
691 }
692
693 if self.closed {
694 return Err(TransportError::Closed);
695 }
696
697 let response = match message {
698 JsonRpcMessage::Response(r) => r.clone(),
699 JsonRpcMessage::Request(r) => {
700 let _ = r;
704 return Ok(());
705 }
706 };
707
708 let http_response = HttpResponse::ok().with_json(&response);
709
710 self.write_response(&http_response)
711 .map_err(|_| TransportError::Io(std::io::Error::other("write error")))?;
712
713 Ok(())
714 }
715
716 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
717 if cx.is_cancel_requested() {
718 return Err(TransportError::Cancelled);
719 }
720
721 if self.closed {
722 return Err(TransportError::Closed);
723 }
724
725 let http_request = self.read_request().map_err(|e| match e {
726 HttpError::Closed => TransportError::Closed,
727 HttpError::Timeout => TransportError::Timeout,
728 _ => TransportError::Io(std::io::Error::other(e.to_string())),
729 })?;
730
731 let json_rpc: JsonRpcRequest = serde_json::from_slice(&http_request.body)
733 .map_err(|e| TransportError::Codec(CodecError::Json(e)))?;
734
735 Ok(JsonRpcMessage::Request(json_rpc))
736 }
737
738 fn close(&mut self) -> Result<(), TransportError> {
739 self.closed = true;
740 Ok(())
741 }
742}
743
744pub struct StreamableHttpTransport {
754 requests: Arc<Mutex<Vec<JsonRpcRequest>>>,
756 responses: Arc<Mutex<Vec<JsonRpcResponse>>>,
758 codec: Codec,
760 closed: bool,
762 poll_interval: Duration,
764}
765
766impl StreamableHttpTransport {
767 #[must_use]
769 pub fn new() -> Self {
770 Self {
771 requests: Arc::new(Mutex::new(Vec::new())),
772 responses: Arc::new(Mutex::new(Vec::new())),
773 codec: Codec::new(),
774 closed: false,
775 poll_interval: Duration::from_millis(10),
776 }
777 }
778
779 pub fn push_request(&self, request: JsonRpcRequest) {
781 if let Ok(mut guard) = self.requests.lock() {
782 guard.push(request);
783 }
784 }
785
786 #[must_use]
788 pub fn pop_response(&self) -> Option<JsonRpcResponse> {
789 self.responses.lock().ok()?.pop()
790 }
791
792 #[must_use]
794 pub fn has_responses(&self) -> bool {
795 self.responses
796 .lock()
797 .map(|g| !g.is_empty())
798 .unwrap_or(false)
799 }
800
801 #[must_use]
803 pub fn request_queue(&self) -> Arc<Mutex<Vec<JsonRpcRequest>>> {
804 Arc::clone(&self.requests)
805 }
806
807 #[must_use]
809 pub fn response_queue(&self) -> Arc<Mutex<Vec<JsonRpcResponse>>> {
810 Arc::clone(&self.responses)
811 }
812}
813
814impl Default for StreamableHttpTransport {
815 fn default() -> Self {
816 Self::new()
817 }
818}
819
820impl Transport for StreamableHttpTransport {
821 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
822 if cx.is_cancel_requested() {
823 return Err(TransportError::Cancelled);
824 }
825
826 if self.closed {
827 return Err(TransportError::Closed);
828 }
829
830 if let JsonRpcMessage::Response(response) = message {
831 if let Ok(mut guard) = self.responses.lock() {
832 guard.push(response.clone());
833 }
834 }
835
836 Ok(())
837 }
838
839 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
840 if cx.is_cancel_requested() {
841 return Err(TransportError::Cancelled);
842 }
843
844 if self.closed {
845 return Err(TransportError::Closed);
846 }
847
848 loop {
850 if cx.is_cancel_requested() {
851 return Err(TransportError::Cancelled);
852 }
853
854 if let Ok(mut guard) = self.requests.lock() {
855 if let Some(request) = guard.pop() {
856 return Ok(JsonRpcMessage::Request(request));
857 }
858 }
859
860 std::thread::sleep(self.poll_interval);
862 }
863 }
864
865 fn close(&mut self) -> Result<(), TransportError> {
866 self.closed = true;
867 Ok(())
868 }
869}
870
871#[derive(Debug, Clone)]
877pub struct HttpSession {
878 pub id: String,
880 pub created_at: Instant,
882 pub last_activity: Instant,
884 pub data: HashMap<String, serde_json::Value>,
886}
887
888impl HttpSession {
889 #[must_use]
891 pub fn new(id: impl Into<String>) -> Self {
892 let now = Instant::now();
893 Self {
894 id: id.into(),
895 created_at: now,
896 last_activity: now,
897 data: HashMap::new(),
898 }
899 }
900
901 pub fn touch(&mut self) {
903 self.last_activity = Instant::now();
904 }
905
906 #[must_use]
908 pub fn is_expired(&self, timeout: Duration) -> bool {
909 self.last_activity.elapsed() > timeout
910 }
911
912 #[must_use]
914 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
915 self.data.get(key)
916 }
917
918 pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
920 self.data.insert(key.into(), value);
921 self.touch();
922 }
923
924 pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
926 self.touch();
927 self.data.remove(key)
928 }
929}
930
931#[derive(Debug, Default)]
933pub struct SessionStore {
934 sessions: Mutex<HashMap<String, HttpSession>>,
935 timeout: Duration,
936}
937
938impl SessionStore {
939 #[must_use]
941 pub fn new(timeout: Duration) -> Self {
942 Self {
943 sessions: Mutex::new(HashMap::new()),
944 timeout,
945 }
946 }
947
948 #[must_use]
950 pub fn with_defaults() -> Self {
951 Self::new(Duration::from_secs(3600))
952 }
953
954 #[must_use]
956 pub fn create(&self) -> String {
957 let id = generate_session_id();
958 let session = HttpSession::new(&id);
959
960 if let Ok(mut guard) = self.sessions.lock() {
961 guard.insert(id.clone(), session);
962 }
963
964 id
965 }
966
967 #[must_use]
969 pub fn get(&self, id: &str) -> Option<HttpSession> {
970 let mut guard = self.sessions.lock().ok()?;
971 let session = guard.get_mut(id)?;
972
973 if session.is_expired(self.timeout) {
974 guard.remove(id);
975 return None;
976 }
977
978 session.touch();
979 Some(session.clone())
980 }
981
982 pub fn update(&self, session: HttpSession) {
984 if let Ok(mut guard) = self.sessions.lock() {
985 guard.insert(session.id.clone(), session);
986 }
987 }
988
989 pub fn remove(&self, id: &str) {
991 if let Ok(mut guard) = self.sessions.lock() {
992 guard.remove(id);
993 }
994 }
995
996 pub fn cleanup(&self) {
998 if let Ok(mut guard) = self.sessions.lock() {
999 guard.retain(|_, s| !s.is_expired(self.timeout));
1000 }
1001 }
1002
1003 #[must_use]
1005 pub fn count(&self) -> usize {
1006 self.sessions.lock().map(|g| g.len()).unwrap_or(0)
1007 }
1008}
1009
1010fn generate_session_id() -> String {
1012 use std::collections::hash_map::RandomState;
1013 use std::hash::{BuildHasher, Hasher};
1014 use std::time::{SystemTime, UNIX_EPOCH};
1015
1016 let state = RandomState::new();
1017 let mut hasher = state.build_hasher();
1018 hasher.write_u128(
1019 SystemTime::now()
1020 .duration_since(UNIX_EPOCH)
1021 .unwrap_or_default()
1022 .as_nanos(),
1023 );
1024
1025 format!("{:016x}", hasher.finish())
1026}
1027
1028#[cfg(test)]
1033mod tests {
1034 use super::*;
1035
1036 #[test]
1037 fn test_http_method_parse() {
1038 assert_eq!(HttpMethod::parse("GET"), Some(HttpMethod::Get));
1039 assert_eq!(HttpMethod::parse("POST"), Some(HttpMethod::Post));
1040 assert_eq!(HttpMethod::parse("get"), Some(HttpMethod::Get));
1041 assert_eq!(HttpMethod::parse("INVALID"), None);
1042 }
1043
1044 #[test]
1045 fn test_http_status() {
1046 assert!(HttpStatus::OK.is_success());
1047 assert!(HttpStatus::BAD_REQUEST.is_client_error());
1048 assert!(HttpStatus::INTERNAL_SERVER_ERROR.is_server_error());
1049 }
1050
1051 #[test]
1052 fn test_http_request_builder() {
1053 let request = HttpRequest::new(HttpMethod::Post, "/api/mcp")
1054 .with_header("Content-Type", "application/json")
1055 .with_body(b"{\"test\": true}".to_vec())
1056 .with_query("version", "1");
1057
1058 assert_eq!(request.method, HttpMethod::Post);
1059 assert_eq!(request.path, "/api/mcp");
1060 assert_eq!(request.header("content-type"), Some("application/json"));
1061 assert_eq!(request.query.get("version"), Some(&"1".to_string()));
1062 }
1063
1064 #[test]
1065 fn test_http_response_builder() {
1066 let response = HttpResponse::ok()
1067 .with_header("X-Custom", "value")
1068 .with_body(b"Hello".to_vec());
1069
1070 assert_eq!(response.status, HttpStatus::OK);
1071 assert_eq!(response.headers.get("x-custom"), Some(&"value".to_string()));
1072 assert_eq!(response.body, b"Hello");
1073 }
1074
1075 #[test]
1076 fn test_http_response_json() {
1077 let data = serde_json::json!({"result": "ok"});
1078 let response = HttpResponse::ok().with_json(&data);
1079
1080 assert!(!response.body.is_empty());
1081 assert_eq!(
1082 response.headers.get("content-type"),
1083 Some(&"application/json".to_string())
1084 );
1085 }
1086
1087 #[test]
1088 fn test_http_response_cors() {
1089 let response = HttpResponse::ok().with_cors("https://example.com");
1090
1091 assert_eq!(
1092 response.headers.get("access-control-allow-origin"),
1093 Some(&"https://example.com".to_string())
1094 );
1095 }
1096
1097 #[test]
1098 fn test_http_handler_options() {
1099 let handler = HttpRequestHandler::new();
1100 let request = HttpRequest::new(HttpMethod::Options, "/mcp/v1")
1101 .with_header("Origin", "https://example.com");
1102
1103 let response = handler.handle_options(&request);
1104 assert_eq!(response.status, HttpStatus::OK);
1105 }
1106
1107 #[test]
1108 fn test_http_handler_parse_request() {
1109 let handler = HttpRequestHandler::new();
1110
1111 let json_rpc = serde_json::json!({
1113 "jsonrpc": "2.0",
1114 "method": "test",
1115 "id": 1
1116 });
1117 let request = HttpRequest::new(HttpMethod::Post, "/mcp/v1")
1118 .with_header("Content-Type", "application/json")
1119 .with_body(serde_json::to_vec(&json_rpc).unwrap());
1120
1121 let result = handler.parse_request(&request);
1122 assert!(result.is_ok());
1123
1124 let request = HttpRequest::new(HttpMethod::Get, "/mcp/v1");
1126 assert!(handler.parse_request(&request).is_err());
1127
1128 let request =
1130 HttpRequest::new(HttpMethod::Post, "/mcp/v1").with_header("Content-Type", "text/plain");
1131 assert!(handler.parse_request(&request).is_err());
1132 }
1133
1134 #[test]
1135 fn test_http_session() {
1136 let mut session = HttpSession::new("test-session");
1137 assert_eq!(session.id, "test-session");
1138
1139 session.set("key", serde_json::json!("value"));
1140 assert_eq!(session.get("key"), Some(&serde_json::json!("value")));
1141
1142 session.remove("key");
1143 assert!(session.get("key").is_none());
1144
1145 assert!(!session.is_expired(Duration::from_secs(3600)));
1146 }
1147
1148 #[test]
1149 fn test_session_store() {
1150 let store = SessionStore::with_defaults();
1151
1152 let id = store.create();
1153 assert!(!id.is_empty());
1154
1155 let session = store.get(&id);
1156 assert!(session.is_some());
1157
1158 store.remove(&id);
1159 assert!(store.get(&id).is_none());
1160 }
1161
1162 #[test]
1163 fn test_streamable_transport() {
1164 let transport = StreamableHttpTransport::new();
1165
1166 let request = JsonRpcRequest::new("test", None, 1i64);
1168 transport.push_request(request);
1169
1170 let guard = transport.requests.lock().unwrap();
1172 assert_eq!(guard.len(), 1);
1173 }
1174
1175 #[test]
1176 fn test_http_error_display() {
1177 let err = HttpError::InvalidMethod("PATCH".to_string());
1178 assert!(err.to_string().contains("PATCH"));
1179
1180 let err = HttpError::Timeout;
1181 assert!(err.to_string().contains("timeout"));
1182 }
1183
1184 #[test]
1185 fn test_generate_session_id() {
1186 let id1 = generate_session_id();
1187 let id2 = generate_session_id();
1188
1189 assert_ne!(id1, id2);
1190 assert_eq!(id1.len(), 16);
1191 }
1192
1193 #[test]
1198 fn e2e_http_request_response_flow() {
1199 use fastmcp_protocol::RequestId;
1200 use std::io::Cursor;
1201
1202 let json_rpc_request = serde_json::json!({
1204 "jsonrpc": "2.0",
1205 "method": "tools/list",
1206 "id": 1
1207 });
1208 let body = serde_json::to_vec(&json_rpc_request).unwrap();
1209
1210 let http_request = format!(
1211 "POST /mcp/v1 HTTP/1.1\r\n\
1212 Content-Type: application/json\r\n\
1213 Content-Length: {}\r\n\
1214 \r\n",
1215 body.len()
1216 );
1217
1218 let mut input = http_request.into_bytes();
1219 input.extend(body);
1220
1221 let reader = Cursor::new(input);
1222 let mut output = Vec::new();
1223
1224 let cx = Cx::for_testing();
1225
1226 {
1227 let mut transport = HttpTransport::new(reader, &mut output);
1228
1229 let msg = transport.recv(&cx).unwrap();
1231 match msg {
1232 JsonRpcMessage::Request(req) => {
1233 assert_eq!(req.method, "tools/list");
1234 assert_eq!(req.id, Some(RequestId::Number(1)));
1235
1236 let response = JsonRpcResponse {
1238 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1239 result: Some(serde_json::json!({"tools": []})),
1240 error: None,
1241 id: Some(RequestId::Number(1)),
1242 };
1243 transport
1244 .send(&cx, &JsonRpcMessage::Response(response))
1245 .unwrap();
1246 }
1247 _ => panic!("Expected request"),
1248 }
1249 }
1250
1251 let response_str = String::from_utf8(output).unwrap();
1253 assert!(response_str.starts_with("HTTP/1.1 200 OK\r\n"));
1254 assert!(response_str.contains("content-type: application/json"));
1255 assert!(response_str.contains("\"tools\":[]"));
1256 }
1257
1258 #[test]
1259 fn e2e_http_error_status_codes() {
1260 let handler = HttpRequestHandler::new();
1261
1262 let request = HttpRequest::new(HttpMethod::Get, "/mcp/v1")
1264 .with_header("Content-Type", "application/json");
1265 let result = handler.parse_request(&request);
1266 assert!(matches!(result, Err(HttpError::InvalidMethod(_))));
1267
1268 let request =
1270 HttpRequest::new(HttpMethod::Post, "/mcp/v1").with_header("Content-Type", "text/xml");
1271 let result = handler.parse_request(&request);
1272 assert!(matches!(result, Err(HttpError::InvalidContentType(_))));
1273
1274 let response = handler.error_response(HttpStatus::BAD_REQUEST, "Invalid request format");
1276 assert_eq!(response.status, HttpStatus::BAD_REQUEST);
1277 let body_str = String::from_utf8(response.body).unwrap();
1278 assert!(body_str.contains("\"error\""));
1279 }
1280
1281 #[test]
1282 fn e2e_http_content_type_handling() {
1283 let handler = HttpRequestHandler::new();
1284
1285 let request = HttpRequest::new(HttpMethod::Post, "/mcp/v1")
1287 .with_header("Content-Type", "application/json")
1288 .with_body(r#"{"jsonrpc":"2.0","method":"test","id":1}"#);
1289 assert!(handler.parse_request(&request).is_ok());
1290
1291 let request = HttpRequest::new(HttpMethod::Post, "/mcp/v1")
1293 .with_header("Content-Type", "application/json; charset=utf-8")
1294 .with_body(r#"{"jsonrpc":"2.0","method":"test","id":1}"#);
1295 assert!(handler.parse_request(&request).is_ok());
1296
1297 let response = JsonRpcResponse {
1299 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1300 result: Some(serde_json::json!({})),
1301 error: None,
1302 id: Some(fastmcp_protocol::RequestId::Number(1)),
1303 };
1304 let http_response = handler.create_response(&response, None);
1305 assert_eq!(
1306 http_response.headers.get("content-type"),
1307 Some(&"application/json".to_string())
1308 );
1309 }
1310
1311 #[test]
1312 fn e2e_http_cors_handling() {
1313 let config = HttpHandlerConfig {
1314 allow_cors: true,
1315 cors_origins: vec!["https://allowed.com".to_string()],
1316 ..Default::default()
1317 };
1318 let handler = HttpRequestHandler::with_config(config);
1319
1320 assert!(handler.is_origin_allowed("https://allowed.com"));
1322
1323 assert!(!handler.is_origin_allowed("https://evil.com"));
1325
1326 let request = HttpRequest::new(HttpMethod::Options, "/mcp/v1")
1328 .with_header("Origin", "https://allowed.com");
1329 let response = handler.handle_options(&request);
1330 assert_eq!(response.status, HttpStatus::OK);
1331 assert_eq!(
1332 response.headers.get("access-control-allow-origin"),
1333 Some(&"https://allowed.com".to_string())
1334 );
1335
1336 let request = HttpRequest::new(HttpMethod::Options, "/mcp/v1")
1338 .with_header("Origin", "https://evil.com");
1339 let response = handler.handle_options(&request);
1340 assert_eq!(response.status, HttpStatus::FORBIDDEN);
1341 }
1342
1343 #[test]
1344 fn e2e_http_streaming_transport() {
1345 use fastmcp_protocol::RequestId;
1346
1347 let mut transport = StreamableHttpTransport::new();
1348 let cx = Cx::for_testing();
1349
1350 let req1 = JsonRpcRequest::new("method1", None, 1i64);
1352 let req2 = JsonRpcRequest::new("method2", None, 2i64);
1353 transport.push_request(req1);
1354 transport.push_request(req2);
1355
1356 let msg = transport.recv(&cx).unwrap();
1358 if let JsonRpcMessage::Request(req) = msg {
1359 assert_eq!(req.method, "method2"); }
1361
1362 let response = JsonRpcResponse {
1364 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1365 result: Some(serde_json::json!({})),
1366 error: None,
1367 id: Some(RequestId::Number(2)),
1368 };
1369 transport
1370 .send(&cx, &JsonRpcMessage::Response(response))
1371 .unwrap();
1372
1373 assert!(transport.has_responses());
1375 let resp = transport.pop_response().unwrap();
1376 assert_eq!(resp.id, Some(RequestId::Number(2)));
1377 }
1378
1379 #[test]
1380 fn e2e_http_session_lifecycle() {
1381 let store = SessionStore::new(Duration::from_millis(100));
1382
1383 let id = store.create();
1385 assert_eq!(store.count(), 1);
1386
1387 let mut session = store.get(&id).unwrap();
1389 session.set("user_id", serde_json::json!(42));
1390 store.update(session);
1391
1392 let session = store.get(&id).unwrap();
1394 assert_eq!(session.get("user_id"), Some(&serde_json::json!(42)));
1395
1396 std::thread::sleep(Duration::from_millis(150));
1398
1399 assert!(store.get(&id).is_none());
1401 }
1402
1403 #[test]
1404 fn e2e_http_transport_cancellation() {
1405 use std::io::Cursor;
1406
1407 let reader = Cursor::new(Vec::<u8>::new());
1408 let mut output = Vec::new();
1409
1410 let cx = Cx::for_testing();
1411 cx.set_cancel_requested(true);
1412
1413 let mut transport = HttpTransport::new(reader, &mut output);
1414
1415 let response = JsonRpcResponse {
1417 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1418 result: None,
1419 error: None,
1420 id: None,
1421 };
1422 let result = transport.send(&cx, &JsonRpcMessage::Response(response));
1423 assert!(matches!(result, Err(TransportError::Cancelled)));
1424
1425 assert!(output.is_empty());
1427 }
1428
1429 #[test]
1430 fn e2e_http_transport_close() {
1431 use std::io::Cursor;
1432
1433 let reader = Cursor::new(Vec::<u8>::new());
1434 let mut output = Vec::new();
1435
1436 let cx = Cx::for_testing();
1437 let mut transport = HttpTransport::new(reader, &mut output);
1438
1439 transport.close().unwrap();
1441
1442 let response = JsonRpcResponse {
1444 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1445 result: None,
1446 error: None,
1447 id: None,
1448 };
1449 let result = transport.send(&cx, &JsonRpcMessage::Response(response));
1450 assert!(matches!(result, Err(TransportError::Closed)));
1451 }
1452
1453 #[test]
1454 fn e2e_http_body_size_limit() {
1455 let config = HttpHandlerConfig {
1456 max_body_size: 100,
1457 ..Default::default()
1458 };
1459 let handler = HttpRequestHandler::with_config(config);
1460
1461 let large_body = vec![b'x'; 200];
1463 let request = HttpRequest::new(HttpMethod::Post, "/mcp/v1")
1464 .with_header("Content-Type", "application/json")
1465 .with_body(large_body);
1466
1467 let result = handler.parse_request(&request);
1468 assert!(matches!(result, Err(HttpError::InvalidContentType(_))));
1470 }
1471}