1use std::io::{BufReader, Read, Write};
42
43use asupersync::Cx;
44
45use crate::{Codec, Transport, TransportError};
46use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum WsFrameType {
51 Continuation,
53 Text,
55 Binary,
57 Close,
59 Ping,
61 Pong,
63}
64
65impl WsFrameType {
66 fn opcode(&self) -> u8 {
68 match self {
69 WsFrameType::Continuation => 0x00,
70 WsFrameType::Text => 0x01,
71 WsFrameType::Binary => 0x02,
72 WsFrameType::Close => 0x08,
73 WsFrameType::Ping => 0x09,
74 WsFrameType::Pong => 0x0A,
75 }
76 }
77
78 fn from_opcode(opcode: u8) -> Option<Self> {
80 match opcode {
81 0x00 => Some(WsFrameType::Continuation),
82 0x01 => Some(WsFrameType::Text),
83 0x02 => Some(WsFrameType::Binary),
84 0x08 => Some(WsFrameType::Close),
85 0x09 => Some(WsFrameType::Ping),
86 0x0A => Some(WsFrameType::Pong),
87 _ => None,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct WsFrame {
95 pub frame_type: WsFrameType,
97 pub payload: Vec<u8>,
99 pub fin: bool,
101}
102
103impl WsFrame {
104 #[must_use]
106 pub fn text(payload: impl Into<String>) -> Self {
107 Self {
108 frame_type: WsFrameType::Text,
109 payload: payload.into().into_bytes(),
110 fin: true,
111 }
112 }
113
114 #[must_use]
116 pub fn close() -> Self {
117 Self {
118 frame_type: WsFrameType::Close,
119 payload: Vec::new(),
120 fin: true,
121 }
122 }
123
124 #[must_use]
126 pub fn ping(payload: Vec<u8>) -> Self {
127 Self {
128 frame_type: WsFrameType::Ping,
129 payload,
130 fin: true,
131 }
132 }
133
134 #[must_use]
136 pub fn pong(payload: Vec<u8>) -> Self {
137 Self {
138 frame_type: WsFrameType::Pong,
139 payload,
140 fin: true,
141 }
142 }
143
144 pub fn as_text(&self) -> Result<&str, std::str::Utf8Error> {
146 std::str::from_utf8(&self.payload)
147 }
148}
149
150pub struct WsReader<R> {
155 reader: BufReader<R>,
156 max_frame_size: usize,
157 require_mask: bool,
160}
161
162impl<R: Read> WsReader<R> {
163 pub fn new(reader: R) -> Self {
167 Self::with_config(reader, true)
168 }
169
170 pub fn new_client(reader: R) -> Self {
174 Self::with_config(reader, false)
175 }
176
177 fn with_config(reader: R, require_mask: bool) -> Self {
179 Self {
180 reader: BufReader::new(reader),
181 max_frame_size: 10 * 1024 * 1024,
182 require_mask,
183 }
184 }
185
186 pub fn read_frame(&mut self) -> Result<WsFrame, TransportError> {
192 let mut header = [0u8; 2];
194 self.reader.read_exact(&mut header)?;
195
196 let fin = (header[0] & 0x80) != 0;
197 let rsv = header[0] & 0x70;
198 let opcode = header[0] & 0x0F;
199 let masked = (header[1] & 0x80) != 0;
200 let mut payload_len = (header[1] & 0x7F) as u64;
201
202 if rsv != 0 {
203 return Err(TransportError::Io(std::io::Error::new(
204 std::io::ErrorKind::InvalidData,
205 "WebSocket RSV bits set but no extensions are supported",
206 )));
207 }
208
209 if payload_len == 126 {
211 let mut ext = [0u8; 2];
212 self.reader.read_exact(&mut ext)?;
213 payload_len = u16::from_be_bytes(ext) as u64;
214 } else if payload_len == 127 {
215 let mut ext = [0u8; 8];
216 self.reader.read_exact(&mut ext)?;
217 payload_len = u64::from_be_bytes(ext);
218 }
219
220 let is_control = matches!(opcode, 0x08..=0x0A);
221 if is_control && !fin {
222 return Err(TransportError::Io(std::io::Error::new(
223 std::io::ErrorKind::InvalidData,
224 "Fragmented control frames are not allowed",
225 )));
226 }
227 if is_control && payload_len > 125 {
228 return Err(TransportError::Io(std::io::Error::new(
229 std::io::ErrorKind::InvalidData,
230 "Control frame payload too large",
231 )));
232 }
233
234 let max_frame_size = self.max_frame_size as u64;
235 if payload_len > max_frame_size {
236 return Err(TransportError::Io(std::io::Error::new(
237 std::io::ErrorKind::InvalidData,
238 format!("WebSocket frame too large: {payload_len} bytes"),
239 )));
240 }
241 if payload_len > usize::MAX as u64 {
242 return Err(TransportError::Io(std::io::Error::new(
243 std::io::ErrorKind::InvalidData,
244 "WebSocket frame length exceeds platform limits",
245 )));
246 }
247
248 if self.require_mask && !masked {
250 return Err(TransportError::Io(std::io::Error::new(
251 std::io::ErrorKind::InvalidData,
252 "Client frames MUST be masked per RFC 6455",
253 )));
254 }
255
256 let mask_key = if masked {
258 let mut key = [0u8; 4];
259 self.reader.read_exact(&mut key)?;
260 Some(key)
261 } else {
262 None
263 };
264
265 let mut payload = vec![0u8; payload_len as usize];
267 self.reader.read_exact(&mut payload)?;
268
269 if let Some(key) = mask_key {
271 for (i, byte) in payload.iter_mut().enumerate() {
272 *byte ^= key[i % 4];
273 }
274 }
275
276 let frame_type = WsFrameType::from_opcode(opcode).ok_or_else(|| {
277 TransportError::Io(std::io::Error::new(
278 std::io::ErrorKind::InvalidData,
279 format!("Unknown WebSocket opcode: {opcode}"),
280 ))
281 })?;
282
283 Ok(WsFrame {
284 frame_type,
285 payload,
286 fin,
287 })
288 }
289}
290
291pub struct WsWriter<W> {
296 writer: W,
297}
298
299impl<W: Write> WsWriter<W> {
300 pub fn new(writer: W) -> Self {
302 Self { writer }
303 }
304
305 pub fn write_frame(&mut self, frame: &WsFrame) -> Result<(), TransportError> {
311 let byte1 = if frame.fin { 0x80 } else { 0x00 } | frame.frame_type.opcode();
313
314 let payload_len = frame.payload.len();
316
317 if payload_len < 126 {
318 self.writer.write_all(&[byte1, payload_len as u8])?;
319 } else if payload_len < 65536 {
320 self.writer.write_all(&[byte1, 126])?;
321 self.writer.write_all(&(payload_len as u16).to_be_bytes())?;
322 } else {
323 self.writer.write_all(&[byte1, 127])?;
324 self.writer.write_all(&(payload_len as u64).to_be_bytes())?;
325 }
326
327 self.writer.write_all(&frame.payload)?;
329 self.writer.flush()?;
330
331 Ok(())
332 }
333}
334
335pub struct WsTransport<R, W> {
356 reader: WsReader<R>,
357 writer: WsWriter<W>,
358 codec: Codec,
359 fragment_buffer: Vec<u8>,
360 max_message_size: usize,
361}
362
363impl<R: Read, W: Write> WsTransport<R, W> {
364 pub fn new(reader: R, writer: W) -> Self {
366 Self {
367 reader: WsReader::new(reader),
368 writer: WsWriter::new(writer),
369 codec: Codec::new(),
370 fragment_buffer: Vec::new(),
371 max_message_size: 10 * 1024 * 1024,
372 }
373 }
374
375 pub fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
385 if cx.is_cancel_requested() {
387 return Err(TransportError::Cancelled);
388 }
389
390 let bytes = match message {
392 JsonRpcMessage::Request(req) => self.codec.encode_request(req)?,
393 JsonRpcMessage::Response(resp) => self.codec.encode_response(resp)?,
394 };
395
396 let text = String::from_utf8(bytes).map_err(|e| {
398 TransportError::Io(std::io::Error::new(
399 std::io::ErrorKind::InvalidData,
400 format!("Invalid UTF-8 in message: {e}"),
401 ))
402 })?;
403 let text = text.trim_end();
404
405 let frame = WsFrame::text(text);
407 self.writer.write_frame(&frame)?;
408
409 Ok(())
410 }
411
412 pub fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
425 loop {
426 if cx.is_cancel_requested() {
428 return Err(TransportError::Cancelled);
429 }
430
431 let frame = self.reader.read_frame()?;
433
434 match frame.frame_type {
435 WsFrameType::Text => {
436 if !self.fragment_buffer.is_empty() {
437 return Err(TransportError::Io(std::io::Error::new(
438 std::io::ErrorKind::InvalidData,
439 "Received Text frame while inside fragmented message",
440 )));
441 }
442
443 if frame.fin {
444 return self.decode_message(frame.payload);
446 }
447
448 let next_len = self
450 .fragment_buffer
451 .len()
452 .saturating_add(frame.payload.len());
453 if next_len > self.max_message_size {
454 self.fragment_buffer.clear();
455 return Err(TransportError::Io(std::io::Error::new(
456 std::io::ErrorKind::InvalidData,
457 "Fragmented message exceeds size limit",
458 )));
459 }
460 self.fragment_buffer.extend(frame.payload);
461 continue;
462 }
463 WsFrameType::Continuation => {
464 if self.fragment_buffer.is_empty() {
465 return Err(TransportError::Io(std::io::Error::new(
466 std::io::ErrorKind::InvalidData,
467 "Received Continuation frame without start frame",
468 )));
469 }
470
471 let next_len = self
472 .fragment_buffer
473 .len()
474 .saturating_add(frame.payload.len());
475 if next_len > self.max_message_size {
476 self.fragment_buffer.clear();
477 return Err(TransportError::Io(std::io::Error::new(
478 std::io::ErrorKind::InvalidData,
479 "Fragmented message exceeds size limit",
480 )));
481 }
482 self.fragment_buffer.extend(frame.payload);
483
484 if frame.fin {
485 let payload = std::mem::take(&mut self.fragment_buffer);
487 return self.decode_message(payload);
488 }
489
490 continue;
492 }
493 WsFrameType::Binary => {
494 if !self.fragment_buffer.is_empty() {
497 return Err(TransportError::Io(std::io::Error::new(
498 std::io::ErrorKind::InvalidData,
499 "Received Binary frame while inside fragmented message",
500 )));
501 }
502 continue;
504 }
505 WsFrameType::Close => {
506 return Err(TransportError::Closed);
507 }
508 WsFrameType::Ping => {
509 let pong = WsFrame::pong(frame.payload);
511 self.writer.write_frame(&pong)?;
512 continue;
513 }
514 WsFrameType::Pong => {
515 continue;
517 }
518 }
519 }
520 }
521
522 fn decode_message(&mut self, payload: Vec<u8>) -> Result<JsonRpcMessage, TransportError> {
524 let text = String::from_utf8(payload).map_err(|e| {
526 TransportError::Io(std::io::Error::new(
527 std::io::ErrorKind::InvalidData,
528 format!("Invalid UTF-8: {e}"),
529 ))
530 })?;
531
532 let mut input = text.as_bytes().to_vec();
534 input.push(b'\n');
535
536 let messages = self.codec.decode(&input)?;
537 if let Some(msg) = messages.into_iter().next() {
538 return Ok(msg);
539 }
540
541 Err(TransportError::Io(std::io::Error::new(
543 std::io::ErrorKind::InvalidData,
544 "Received empty message",
545 )))
546 }
547
548 pub fn close(&mut self) -> Result<(), TransportError> {
554 let frame = WsFrame::close();
555 self.writer.write_frame(&frame)?;
556 Ok(())
557 }
558
559 pub fn send_request(
563 &mut self,
564 cx: &Cx,
565 request: &JsonRpcRequest,
566 ) -> Result<(), TransportError> {
567 self.send(cx, &JsonRpcMessage::Request(request.clone()))
568 }
569
570 pub fn send_response(
574 &mut self,
575 cx: &Cx,
576 response: &JsonRpcResponse,
577 ) -> Result<(), TransportError> {
578 self.send(cx, &JsonRpcMessage::Response(response.clone()))
579 }
580
581 pub fn ping(&mut self) -> Result<(), TransportError> {
587 let frame = WsFrame::ping(Vec::new());
588 self.writer.write_frame(&frame)?;
589 Ok(())
590 }
591}
592
593impl<R: Read, W: Write> Transport for WsTransport<R, W> {
594 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
595 WsTransport::send(self, cx, message)
596 }
597
598 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
599 WsTransport::recv(self, cx)
600 }
601
602 fn close(&mut self) -> Result<(), TransportError> {
603 WsTransport::close(self)
604 }
605}
606
607pub struct WsClientWriter<W> {
613 writer: W,
614}
615
616impl<W: Write> WsClientWriter<W> {
617 pub fn new(writer: W) -> Self {
619 Self { writer }
620 }
621
622 fn generate_mask() -> Result<[u8; 4], TransportError> {
626 let mut mask = [0u8; 4];
627 getrandom::fill(&mut mask).map_err(|e| {
629 TransportError::Io(std::io::Error::new(
630 std::io::ErrorKind::Other,
631 format!("getrandom failed: {e}"),
632 ))
633 })?;
634 Ok(mask)
635 }
636
637 pub fn write_frame(&mut self, frame: &WsFrame) -> Result<(), TransportError> {
643 let byte1 = if frame.fin { 0x80 } else { 0x00 } | frame.frame_type.opcode();
645
646 let payload_len = frame.payload.len();
648 let mask_bit = 0x80u8;
649
650 if payload_len < 126 {
651 self.writer
652 .write_all(&[byte1, mask_bit | payload_len as u8])?;
653 } else if payload_len < 65536 {
654 self.writer.write_all(&[byte1, mask_bit | 126])?;
655 self.writer.write_all(&(payload_len as u16).to_be_bytes())?;
656 } else {
657 self.writer.write_all(&[byte1, mask_bit | 127])?;
658 self.writer.write_all(&(payload_len as u64).to_be_bytes())?;
659 }
660
661 let mask = Self::generate_mask()?;
663 self.writer.write_all(&mask)?;
664
665 let masked: Vec<u8> = frame
667 .payload
668 .iter()
669 .enumerate()
670 .map(|(i, b)| b ^ mask[i % 4])
671 .collect();
672 self.writer.write_all(&masked)?;
673 self.writer.flush()?;
674
675 Ok(())
676 }
677}
678
679pub struct WsClientTransport<R, W> {
684 reader: WsReader<R>,
685 writer: WsClientWriter<W>,
686 codec: Codec,
687 fragment_buffer: Vec<u8>,
688 max_message_size: usize,
689}
690
691impl<R: Read, W: Write> WsClientTransport<R, W> {
692 pub fn new(reader: R, writer: W) -> Self {
694 Self {
695 reader: WsReader::new_client(reader),
697 writer: WsClientWriter::new(writer),
698 codec: Codec::new(),
699 fragment_buffer: Vec::new(),
700 max_message_size: 10 * 1024 * 1024,
701 }
702 }
703
704 pub fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
716 if cx.is_cancel_requested() {
718 return Err(TransportError::Cancelled);
719 }
720
721 let bytes = match message {
723 JsonRpcMessage::Request(req) => self.codec.encode_request(req)?,
724 JsonRpcMessage::Response(resp) => self.codec.encode_response(resp)?,
725 };
726
727 let text = String::from_utf8(bytes).map_err(|e| {
729 TransportError::Io(std::io::Error::new(
730 std::io::ErrorKind::InvalidData,
731 format!("Invalid UTF-8 in message: {e}"),
732 ))
733 })?;
734 let text = text.trim_end();
735
736 let frame = WsFrame::text(text);
738 self.writer.write_frame(&frame)?;
739
740 Ok(())
741 }
742
743 pub fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
755 loop {
756 if cx.is_cancel_requested() {
758 return Err(TransportError::Cancelled);
759 }
760
761 let frame = self.reader.read_frame()?;
763
764 match frame.frame_type {
765 WsFrameType::Text => {
766 if !self.fragment_buffer.is_empty() {
767 return Err(TransportError::Io(std::io::Error::new(
768 std::io::ErrorKind::InvalidData,
769 "Received Text frame while inside fragmented message",
770 )));
771 }
772
773 if frame.fin {
774 return self.decode_message(frame.payload);
776 }
777
778 let next_len = self
780 .fragment_buffer
781 .len()
782 .saturating_add(frame.payload.len());
783 if next_len > self.max_message_size {
784 self.fragment_buffer.clear();
785 return Err(TransportError::Io(std::io::Error::new(
786 std::io::ErrorKind::InvalidData,
787 "Fragmented message exceeds size limit",
788 )));
789 }
790 self.fragment_buffer.extend(frame.payload);
791 continue;
792 }
793 WsFrameType::Continuation => {
794 if self.fragment_buffer.is_empty() {
795 return Err(TransportError::Io(std::io::Error::new(
796 std::io::ErrorKind::InvalidData,
797 "Received Continuation frame without start frame",
798 )));
799 }
800
801 let next_len = self
802 .fragment_buffer
803 .len()
804 .saturating_add(frame.payload.len());
805 if next_len > self.max_message_size {
806 self.fragment_buffer.clear();
807 return Err(TransportError::Io(std::io::Error::new(
808 std::io::ErrorKind::InvalidData,
809 "Fragmented message exceeds size limit",
810 )));
811 }
812 self.fragment_buffer.extend(frame.payload);
813
814 if frame.fin {
815 let payload = std::mem::take(&mut self.fragment_buffer);
817 return self.decode_message(payload);
818 }
819
820 continue;
822 }
823 WsFrameType::Binary => {
824 if !self.fragment_buffer.is_empty() {
827 return Err(TransportError::Io(std::io::Error::new(
828 std::io::ErrorKind::InvalidData,
829 "Received Binary frame while inside fragmented message",
830 )));
831 }
832 continue;
834 }
835 WsFrameType::Close => {
836 return Err(TransportError::Closed);
837 }
838 WsFrameType::Ping => {
839 let pong = WsFrame::pong(frame.payload);
841 self.writer.write_frame(&pong)?;
842 continue;
843 }
844 WsFrameType::Pong => {
845 continue;
846 }
847 }
848 }
849 }
850
851 fn decode_message(&mut self, payload: Vec<u8>) -> Result<JsonRpcMessage, TransportError> {
853 let text = String::from_utf8(payload).map_err(|e| {
854 TransportError::Io(std::io::Error::new(
855 std::io::ErrorKind::InvalidData,
856 format!("Invalid UTF-8: {e}"),
857 ))
858 })?;
859
860 let mut input = text.as_bytes().to_vec();
861 input.push(b'\n');
862
863 let messages = self.codec.decode(&input)?;
864 if let Some(msg) = messages.into_iter().next() {
865 return Ok(msg);
866 }
867
868 Err(TransportError::Io(std::io::Error::new(
869 std::io::ErrorKind::InvalidData,
870 "Received empty message",
871 )))
872 }
873
874 pub fn close(&mut self) -> Result<(), TransportError> {
880 let frame = WsFrame::close();
881 self.writer.write_frame(&frame)?;
882 Ok(())
883 }
884}
885
886impl<R: Read, W: Write> Transport for WsClientTransport<R, W> {
887 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
888 WsClientTransport::send(self, cx, message)
889 }
890
891 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
892 WsClientTransport::recv(self, cx)
893 }
894
895 fn close(&mut self) -> Result<(), TransportError> {
896 WsClientTransport::close(self)
897 }
898}
899
900#[cfg(test)]
901mod tests {
902 use super::*;
903 use std::io::Cursor;
904
905 #[test]
906 fn test_frame_type_opcode_roundtrip() {
907 for frame_type in [
908 WsFrameType::Text,
909 WsFrameType::Binary,
910 WsFrameType::Close,
911 WsFrameType::Ping,
912 WsFrameType::Pong,
913 ] {
914 let opcode = frame_type.opcode();
915 let parsed = WsFrameType::from_opcode(opcode);
916 assert_eq!(parsed, Some(frame_type));
917 }
918 }
919
920 #[test]
921 fn test_frame_text() {
922 let frame = WsFrame::text("hello");
923 assert_eq!(frame.frame_type, WsFrameType::Text);
924 assert_eq!(frame.as_text().unwrap(), "hello");
925 assert!(frame.fin);
926 }
927
928 #[test]
929 fn test_frame_close() {
930 let frame = WsFrame::close();
931 assert_eq!(frame.frame_type, WsFrameType::Close);
932 assert!(frame.payload.is_empty());
933 assert!(frame.fin);
934 }
935
936 #[test]
937 fn test_frame_ping_pong() {
938 let ping = WsFrame::ping(vec![1, 2, 3]);
939 assert_eq!(ping.frame_type, WsFrameType::Ping);
940 assert_eq!(ping.payload, vec![1, 2, 3]);
941
942 let pong = WsFrame::pong(vec![1, 2, 3]);
943 assert_eq!(pong.frame_type, WsFrameType::Pong);
944 assert_eq!(pong.payload, vec![1, 2, 3]);
945 }
946
947 #[test]
948 fn test_write_read_small_frame() {
949 let mut buffer = Vec::new();
950
951 {
953 let mut writer = WsWriter::new(&mut buffer);
954 let frame = WsFrame::text("hello");
955 writer.write_frame(&frame).unwrap();
956 }
957
958 let mut reader = WsReader::new_client(Cursor::new(buffer));
960 let frame = reader.read_frame().unwrap();
961
962 assert_eq!(frame.frame_type, WsFrameType::Text);
963 assert_eq!(frame.as_text().unwrap(), "hello");
964 assert!(frame.fin);
965 }
966
967 #[test]
968 fn test_write_read_medium_frame() {
969 let payload = "x".repeat(200);
971 let mut buffer = Vec::new();
972
973 {
974 let mut writer = WsWriter::new(&mut buffer);
975 let frame = WsFrame::text(&payload);
976 writer.write_frame(&frame).unwrap();
977 }
978
979 let mut reader = WsReader::new_client(Cursor::new(buffer));
981 let frame = reader.read_frame().unwrap();
982
983 assert_eq!(frame.as_text().unwrap(), payload);
984 }
985
986 #[test]
987 fn test_write_read_large_frame() {
988 let payload = "x".repeat(70000);
990 let mut buffer = Vec::new();
991
992 {
993 let mut writer = WsWriter::new(&mut buffer);
994 let frame = WsFrame::text(&payload);
995 writer.write_frame(&frame).unwrap();
996 }
997
998 let mut reader = WsReader::new_client(Cursor::new(buffer));
1000 let frame = reader.read_frame().unwrap();
1001
1002 assert_eq!(frame.as_text().unwrap(), payload);
1003 }
1004
1005 #[test]
1006 fn test_client_writer_masks_frames() {
1007 let mut buffer = Vec::new();
1008
1009 {
1010 let mut writer = WsClientWriter::new(&mut buffer);
1011 let frame = WsFrame::text("hi");
1012 writer.write_frame(&frame).unwrap();
1013 }
1014
1015 assert!(buffer.len() >= 2);
1017 assert_ne!(buffer[1] & 0x80, 0, "Mask bit should be set for client");
1018 }
1019
1020 #[test]
1021 fn test_read_masked_frame() {
1022 let payload = b"test";
1024 let mask = [0x12, 0x34, 0x56, 0x78];
1025 let masked_payload: Vec<u8> = payload
1026 .iter()
1027 .enumerate()
1028 .map(|(i, b)| b ^ mask[i % 4])
1029 .collect();
1030
1031 let mut buffer = Vec::new();
1032 buffer.push(0x81); buffer.push(0x80 | payload.len() as u8); buffer.extend_from_slice(&mask);
1035 buffer.extend_from_slice(&masked_payload);
1036
1037 let mut reader = WsReader::new(Cursor::new(buffer));
1038 let frame = reader.read_frame().unwrap();
1039
1040 assert_eq!(frame.as_text().unwrap(), "test");
1041 }
1042
1043 #[test]
1044 fn test_reader_rejects_oversized_frame() {
1045 let mask = [0x12, 0x34, 0x56, 0x78];
1047 let payload = b"hey";
1048 let masked: Vec<u8> = payload
1049 .iter()
1050 .enumerate()
1051 .map(|(i, b)| b ^ mask[i % 4])
1052 .collect();
1053
1054 let mut buffer = Vec::new();
1055 buffer.push(0x81); buffer.push(0x80 | 0x03); buffer.extend_from_slice(&mask);
1058 buffer.extend_from_slice(&masked);
1059
1060 let mut reader = WsReader::new(Cursor::new(buffer));
1061 reader.max_frame_size = 2;
1062
1063 let err = reader.read_frame().unwrap_err();
1064 assert!(matches!(
1065 err,
1066 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1067 ));
1068 }
1069
1070 #[test]
1071 fn test_reader_rejects_control_frame_over_125() {
1072 let mut buffer = Vec::new();
1073 buffer.push(0x89); buffer.push(0x80 | 126); buffer.extend_from_slice(&126u16.to_be_bytes());
1076 buffer.extend_from_slice(&[0, 0, 0, 0]); let mut reader = WsReader::new(Cursor::new(buffer));
1079 let err = reader.read_frame().unwrap_err();
1080 assert!(matches!(
1081 err,
1082 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1083 ));
1084 }
1085
1086 #[test]
1087 fn test_reader_rejects_fragmented_control_frame() {
1088 let mut buffer = Vec::new();
1089 buffer.push(0x09); buffer.push(0x80); buffer.extend_from_slice(&[0, 0, 0, 0]); let mut reader = WsReader::new(Cursor::new(buffer));
1094 let err = reader.read_frame().unwrap_err();
1095 assert!(matches!(
1096 err,
1097 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1098 ));
1099 }
1100
1101 #[test]
1102 fn test_reader_rejects_rsv_bits() {
1103 let mut buffer = Vec::new();
1104 buffer.push(0xC1); buffer.push(0x80); buffer.extend_from_slice(&[0, 0, 0, 0]); let mut reader = WsReader::new(Cursor::new(buffer));
1109 let err = reader.read_frame().unwrap_err();
1110 assert!(matches!(
1111 err,
1112 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1113 ));
1114 }
1115
1116 #[test]
1117 fn test_server_rejects_unmasked_client_frames() {
1118 let mut buffer = Vec::new();
1120 buffer.push(0x81); buffer.push(0x05); buffer.extend_from_slice(b"hello");
1123
1124 let mut reader = WsReader::new(Cursor::new(buffer));
1126 let err = reader.read_frame().unwrap_err();
1127 assert!(matches!(
1128 err,
1129 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1130 ));
1131 }
1132
1133 #[test]
1134 fn test_client_accepts_unmasked_server_frames() {
1135 let mut buffer = Vec::new();
1137 buffer.push(0x81); buffer.push(0x05); buffer.extend_from_slice(b"hello");
1140
1141 let mut reader = WsReader::new_client(Cursor::new(buffer));
1143 let frame = reader.read_frame().unwrap();
1144 assert_eq!(frame.as_text().unwrap(), "hello");
1145 }
1146
1147 fn build_masked_frame(opcode: u8, fin: bool, payload: &[u8]) -> Vec<u8> {
1150 let mask = [0x12, 0x34, 0x56, 0x78];
1151 let masked: Vec<u8> = payload
1152 .iter()
1153 .enumerate()
1154 .map(|(i, b)| b ^ mask[i % 4])
1155 .collect();
1156
1157 let mut frame = Vec::new();
1158 let byte1 = if fin { 0x80 } else { 0x00 } | opcode;
1159 frame.push(byte1);
1160
1161 let payload_len = payload.len();
1163 if payload_len < 126 {
1164 frame.push(0x80 | payload_len as u8);
1165 } else if payload_len < 65536 {
1166 frame.push(0x80 | 126);
1167 frame.extend_from_slice(&(payload_len as u16).to_be_bytes());
1168 } else {
1169 frame.push(0x80 | 127);
1170 frame.extend_from_slice(&(payload_len as u64).to_be_bytes());
1171 }
1172
1173 frame.extend_from_slice(&mask);
1174 frame.extend_from_slice(&masked);
1175 frame
1176 }
1177
1178 #[test]
1179 fn test_fragmented_message_size_limit() {
1180 let mut buffer = Vec::new();
1182 buffer.extend(build_masked_frame(0x01, false, b"hello"));
1184 buffer.extend(build_masked_frame(0x00, true, b"world"));
1186
1187 let cx = Cx::for_testing();
1188 let writer: Vec<u8> = Vec::new();
1189 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1190 transport.max_message_size = 8;
1191
1192 let err = transport.recv(&cx).unwrap_err();
1193 assert!(matches!(
1194 err,
1195 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1196 ));
1197 }
1198
1199 #[test]
1200 fn test_rejects_interleaved_binary_during_fragmentation() {
1201 let mut buffer = Vec::new();
1204 buffer.extend(build_masked_frame(0x01, false, b"hello"));
1206 buffer.extend(build_masked_frame(0x02, true, b"bad"));
1208
1209 let cx = Cx::for_testing();
1210 let writer: Vec<u8> = Vec::new();
1211 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1212
1213 let err = transport.recv(&cx).unwrap_err();
1214 assert!(matches!(
1215 err,
1216 TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1217 ));
1218 }
1219
1220 #[test]
1221 fn test_transport_roundtrip() {
1222 use fastmcp_protocol::RequestId;
1223
1224 let mut write_buf = Vec::new();
1227
1228 {
1230 let cx = Cx::for_testing();
1231 let reader: &[u8] = &[];
1232 let mut transport = WsTransport::new(reader, &mut write_buf);
1233
1234 let request = JsonRpcRequest {
1235 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1236 id: Some(RequestId::Number(1)),
1237 method: "test".to_string(),
1238 params: None,
1239 };
1240
1241 transport.send_request(&cx, &request).unwrap();
1242 }
1243
1244 {
1246 let cx = Cx::for_testing();
1247 let writer: Vec<u8> = Vec::new();
1248 let mut transport = WsClientTransport::new(Cursor::new(write_buf), writer);
1249
1250 let msg = transport.recv(&cx).unwrap();
1251 assert!(
1252 matches!(msg, JsonRpcMessage::Request(_)),
1253 "Expected request"
1254 );
1255 if let JsonRpcMessage::Request(req) = msg {
1256 assert_eq!(req.method, "test");
1257 assert_eq!(req.id, Some(RequestId::Number(1)));
1258 }
1259 }
1260 }
1261
1262 #[test]
1263 fn test_close_frame_returns_closed_error() {
1264 let mut buffer = Vec::new();
1266 buffer.push(0x88); buffer.push(0x80); buffer.extend_from_slice(&[0u8; 4]); let cx = Cx::for_testing();
1271 let writer: Vec<u8> = Vec::new();
1272 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1273
1274 let result = transport.recv(&cx);
1275 assert!(matches!(result, Err(TransportError::Closed)));
1276 }
1277
1278 #[test]
1279 fn test_ping_auto_pong() {
1280 let mut buffer = Vec::new();
1282
1283 buffer.extend(build_masked_frame(0x09, true, b"ping"));
1285
1286 let text = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
1288 buffer.extend(build_masked_frame(0x01, true, text.as_bytes()));
1289
1290 let mut response_buf = Vec::new();
1291
1292 let cx = Cx::for_testing();
1293 let mut transport = WsTransport::new(Cursor::new(buffer), &mut response_buf);
1294
1295 let msg = transport.recv(&cx).unwrap();
1297 assert!(
1298 matches!(msg, JsonRpcMessage::Request(_)),
1299 "Expected request"
1300 );
1301 if let JsonRpcMessage::Request(req) = msg {
1302 assert_eq!(req.method, "test");
1303 }
1304
1305 assert!(!response_buf.is_empty());
1307 assert_eq!(response_buf[0] & 0x0F, 0x0A); }
1309
1310 #[test]
1315 fn e2e_ws_bidirectional_message_flow() {
1316 use fastmcp_protocol::RequestId;
1317
1318 let mut request_buffer = Vec::new();
1322
1323 let req1 = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
1325 let req2 = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
1326 let req3 = r#"{"jsonrpc":"2.0","method":"tools/call","params":{"name":"test"},"id":3}"#;
1327
1328 request_buffer.extend(build_masked_frame(0x01, true, req1.as_bytes()));
1329 request_buffer.extend(build_masked_frame(0x01, true, req2.as_bytes()));
1330 request_buffer.extend(build_masked_frame(0x01, true, req3.as_bytes()));
1331
1332 let mut response_buffer = Vec::new();
1333 let cx = Cx::for_testing();
1334
1335 {
1336 let mut transport = WsTransport::new(Cursor::new(request_buffer), &mut response_buffer);
1337
1338 for expected_id in 1..=3 {
1340 let msg = transport.recv(&cx).unwrap();
1341 assert!(
1342 matches!(msg, JsonRpcMessage::Request(_)),
1343 "Expected request"
1344 );
1345 let JsonRpcMessage::Request(req) = msg else {
1346 return;
1347 };
1348
1349 assert_eq!(req.id, Some(RequestId::Number(expected_id)));
1350
1351 let response = JsonRpcResponse {
1353 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1354 result: Some(serde_json::json!({"ok": true})),
1355 error: None,
1356 id: req.id,
1357 };
1358 transport.send_response(&cx, &response).unwrap();
1359 }
1360 }
1361
1362 assert!(!response_buffer.is_empty());
1364 #[allow(clippy::naive_bytecount)]
1366 let frame_count = response_buffer
1367 .iter()
1368 .filter(|&&b| b == 0x81) .count();
1370 assert_eq!(frame_count, 3, "Expected 3 response frames");
1371 }
1372
1373 #[test]
1374 fn e2e_ws_fragmented_message_assembly() {
1375 let full_msg =
1377 r#"{"jsonrpc":"2.0","method":"test","params":{"data":"hello world"},"id":1}"#;
1378 let mid = full_msg.len() / 2;
1379
1380 let mut buffer = Vec::new();
1381 buffer.extend(build_masked_frame(0x01, false, &full_msg.as_bytes()[..mid]));
1383 buffer.extend(build_masked_frame(0x00, true, &full_msg.as_bytes()[mid..]));
1385
1386 let cx = Cx::for_testing();
1387 let writer: Vec<u8> = Vec::new();
1388 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1389
1390 let msg = transport.recv(&cx).unwrap();
1391 assert!(
1392 matches!(msg, JsonRpcMessage::Request(_)),
1393 "Expected request"
1394 );
1395 let JsonRpcMessage::Request(req) = msg else {
1396 return;
1397 };
1398 assert_eq!(req.method, "test");
1399 let params = req.params.unwrap();
1400 assert_eq!(params.get("data").unwrap(), "hello world");
1401 }
1402
1403 #[test]
1404 fn e2e_ws_interleaved_ping_during_operation() {
1405 let mut buffer = Vec::new();
1407
1408 buffer.extend(build_masked_frame(
1410 0x01,
1411 true,
1412 r#"{"jsonrpc":"2.0","method":"msg1","id":1}"#.as_bytes(),
1413 ));
1414 buffer.extend(build_masked_frame(0x09, true, b"keepalive"));
1416 buffer.extend(build_masked_frame(
1418 0x01,
1419 true,
1420 r#"{"jsonrpc":"2.0","method":"msg2","id":2}"#.as_bytes(),
1421 ));
1422 buffer.extend(build_masked_frame(0x09, true, b"alive"));
1424 buffer.extend(build_masked_frame(
1426 0x01,
1427 true,
1428 r#"{"jsonrpc":"2.0","method":"msg3","id":3}"#.as_bytes(),
1429 ));
1430
1431 let mut response_buffer = Vec::new();
1432 let cx = Cx::for_testing();
1433 let mut transport = WsTransport::new(Cursor::new(buffer), &mut response_buffer);
1434
1435 for i in 1..=3 {
1437 let msg = transport.recv(&cx).unwrap();
1438 assert!(
1439 matches!(msg, JsonRpcMessage::Request(_)),
1440 "Expected request"
1441 );
1442 let JsonRpcMessage::Request(req) = msg else {
1443 return;
1444 };
1445 assert_eq!(req.method, format!("msg{i}"));
1446 }
1447
1448 assert!(
1452 !response_buffer.is_empty(),
1453 "Expected pong responses to be written"
1454 );
1455 }
1456
1457 #[test]
1458 fn e2e_ws_graceful_close() {
1459 let mut buffer = Vec::new();
1461 buffer.extend(build_masked_frame(
1463 0x01,
1464 true,
1465 r#"{"jsonrpc":"2.0","method":"last","id":1}"#.as_bytes(),
1466 ));
1467 buffer.extend(build_masked_frame(0x08, true, &[])); let mut response_buffer = Vec::new();
1470 let cx = Cx::for_testing();
1471 let mut transport = WsTransport::new(Cursor::new(buffer), &mut response_buffer);
1472
1473 let msg = transport.recv(&cx).unwrap();
1475 assert!(matches!(msg, JsonRpcMessage::Request(_)));
1476
1477 let result = transport.recv(&cx);
1479 assert!(matches!(result, Err(TransportError::Closed)));
1480 }
1481
1482 #[test]
1483 fn e2e_ws_cancellation_respected() {
1484 let buffer = build_masked_frame(
1485 0x01,
1486 true,
1487 r#"{"jsonrpc":"2.0","method":"test","id":1}"#.as_bytes(),
1488 );
1489
1490 let cx = Cx::for_testing();
1491 cx.set_cancel_requested(true);
1492
1493 let writer: Vec<u8> = Vec::new();
1494 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1495
1496 let result = transport.recv(&cx);
1498 assert!(matches!(result, Err(TransportError::Cancelled)));
1499 }
1500
1501 #[test]
1502 fn e2e_ws_send_cancellation_respected() {
1503 let cx = Cx::for_testing();
1504 cx.set_cancel_requested(true);
1505
1506 let reader: &[u8] = &[];
1507 let mut writer = Vec::new();
1508 let mut transport = WsTransport::new(reader, &mut writer);
1509
1510 let request = JsonRpcRequest::new("test", None, 1i64);
1511 let result = transport.send_request(&cx, &request);
1512 assert!(matches!(result, Err(TransportError::Cancelled)));
1513
1514 assert!(writer.is_empty());
1516 }
1517
1518 #[test]
1519 fn e2e_ws_unicode_in_messages() {
1520 let unicode_msg =
1522 r#"{"jsonrpc":"2.0","method":"test","params":{"text":"Hello 世界 👋 éèê"},"id":1}"#;
1523 let buffer = build_masked_frame(0x01, true, unicode_msg.as_bytes());
1524
1525 let cx = Cx::for_testing();
1526 let writer: Vec<u8> = Vec::new();
1527 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1528
1529 let msg = transport.recv(&cx).unwrap();
1530 assert!(
1531 matches!(msg, JsonRpcMessage::Request(_)),
1532 "Expected request"
1533 );
1534 let JsonRpcMessage::Request(req) = msg else {
1535 return;
1536 };
1537 let params = req.params.unwrap();
1538 let text = params.get("text").unwrap().as_str().unwrap();
1539 assert!(text.contains("世界"));
1540 assert!(text.contains("👋"));
1541 assert!(text.contains("éèê"));
1542 }
1543
1544 #[test]
1545 fn e2e_ws_client_server_full_flow() {
1546 use fastmcp_protocol::RequestId;
1547
1548 let mut client_to_server = Vec::new();
1552 {
1553 let mut writer = WsClientWriter::new(&mut client_to_server);
1554 let request = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
1555 writer.write_frame(&WsFrame::text(request)).unwrap();
1556 }
1557
1558 let mut server_response = Vec::new();
1560 {
1561 let cx = Cx::for_testing();
1562 let mut transport =
1563 WsTransport::new(Cursor::new(client_to_server.clone()), &mut server_response);
1564
1565 let msg = transport.recv(&cx).unwrap();
1566 if let JsonRpcMessage::Request(req) = msg {
1567 assert_eq!(req.method, "initialize");
1568
1569 let response = JsonRpcResponse {
1571 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1572 result: Some(serde_json::json!({"capabilities": {}})),
1573 error: None,
1574 id: Some(RequestId::Number(1)),
1575 };
1576 transport.send_response(&cx, &response).unwrap();
1577 }
1578 }
1579
1580 {
1582 let cx = Cx::for_testing();
1583 let mut transport =
1584 WsClientTransport::new(Cursor::new(server_response), Vec::<u8>::new());
1585
1586 let msg = transport.recv(&cx).unwrap();
1587 assert!(
1588 matches!(msg, JsonRpcMessage::Response(_)),
1589 "Expected response"
1590 );
1591 let JsonRpcMessage::Response(resp) = msg else {
1592 return;
1593 };
1594 assert_eq!(resp.id, Some(RequestId::Number(1)));
1595 assert!(resp.result.is_some());
1596 }
1597 }
1598
1599 #[test]
1600 fn ws_continuation_opcode_roundtrip() {
1601 let ft = WsFrameType::Continuation;
1602 assert_eq!(WsFrameType::from_opcode(ft.opcode()), Some(ft));
1603 }
1604
1605 #[test]
1606 fn ws_unknown_opcode_returns_none() {
1607 assert_eq!(WsFrameType::from_opcode(0x03), None);
1608 assert_eq!(WsFrameType::from_opcode(0x0F), None);
1609 }
1610
1611 #[test]
1612 fn ws_frame_as_text_non_utf8_returns_error() {
1613 let frame = WsFrame {
1614 frame_type: WsFrameType::Text,
1615 payload: vec![0xFF, 0xFE],
1616 fin: true,
1617 };
1618 assert!(frame.as_text().is_err());
1619 }
1620
1621 #[test]
1622 fn ws_transport_close_sends_close_frame() {
1623 let reader: &[u8] = &[];
1624 let mut output = Vec::new();
1625 let mut transport = WsTransport::new(reader, &mut output);
1626 transport.close().unwrap();
1627
1628 assert!(output.len() >= 2);
1630 assert_eq!(output[0], 0x88);
1631 assert_eq!(output[1], 0x00);
1632 }
1633
1634 #[test]
1635 fn ws_transport_ping_sends_ping_frame() {
1636 let reader: &[u8] = &[];
1637 let mut output = Vec::new();
1638 let mut transport = WsTransport::new(reader, &mut output);
1639 transport.ping().unwrap();
1640
1641 assert!(output.len() >= 2);
1643 assert_eq!(output[0], 0x89);
1644 assert_eq!(output[1], 0x00);
1645 }
1646
1647 #[test]
1648 fn ws_client_transport_send_cancelled() {
1649 let cx = Cx::for_testing();
1650 cx.set_cancel_requested(true);
1651
1652 let reader: &[u8] = &[];
1653 let mut writer = Vec::new();
1654 let mut transport = WsClientTransport::new(reader, &mut writer);
1655
1656 let request = JsonRpcRequest::new("test", None, 1i64);
1657 let result = transport.send(&cx, &JsonRpcMessage::Request(request));
1658 assert!(matches!(result, Err(TransportError::Cancelled)));
1659 assert!(writer.is_empty());
1660 }
1661
1662 #[test]
1663 fn ws_binary_frame_skipped_outside_fragmentation() {
1664 let mut buffer = Vec::new();
1666 buffer.extend(build_masked_frame(0x02, true, b"binary-data"));
1667 buffer.extend(build_masked_frame(
1668 0x01,
1669 true,
1670 r#"{"jsonrpc":"2.0","method":"after_binary","id":1}"#.as_bytes(),
1671 ));
1672
1673 let cx = Cx::for_testing();
1674 let writer: Vec<u8> = Vec::new();
1675 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1676
1677 let msg = transport.recv(&cx).unwrap();
1678 let JsonRpcMessage::Request(req) = msg else {
1679 panic!("expected request");
1680 };
1681 assert_eq!(req.method, "after_binary");
1682 }
1683
1684 #[test]
1685 fn ws_pong_frame_skipped() {
1686 let mut buffer = Vec::new();
1688 buffer.extend(build_masked_frame(0x0A, true, b"pong-payload"));
1689 buffer.extend(build_masked_frame(
1690 0x01,
1691 true,
1692 r#"{"jsonrpc":"2.0","method":"after_pong","id":2}"#.as_bytes(),
1693 ));
1694
1695 let cx = Cx::for_testing();
1696 let writer: Vec<u8> = Vec::new();
1697 let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1698
1699 let msg = transport.recv(&cx).unwrap();
1700 let JsonRpcMessage::Request(req) = msg else {
1701 panic!("expected request");
1702 };
1703 assert_eq!(req.method, "after_pong");
1704 }
1705}