Skip to main content

fastmcp_transport/
websocket.rs

1//! WebSocket transport for MCP.
2//!
3//! This module provides WebSocket-based transport for bidirectional MCP
4//! communication. Unlike SSE (server-push only), WebSocket allows both
5//! client and server to send messages at any time.
6//!
7//! # Wire Format
8//!
9//! MCP over WebSocket uses:
10//! - Text frames for JSON-RPC messages (one message per frame)
11//! - Standard JSON-RPC request/response format
12//! - Optional ping/pong for keep-alive
13//!
14//! # Architecture
15//!
16//! This implementation provides low-level WebSocket message framing.
17//! It does NOT include HTTP upgrade handling - that should be done by
18//! your HTTP server (e.g., hyper, axum, warp) before handing off the
19//! upgraded connection to this transport.
20//!
21//! # Example
22//!
23//! ```ignore
24//! use fastmcp_transport::websocket::{WsTransport, WsFrame};
25//!
26//! // After HTTP upgrade, you have a bidirectional byte stream
27//! let transport = WsTransport::new(reader, writer);
28//!
29//! // Receive a message
30//! let msg = transport.recv(&cx)?;
31//!
32//! // Send a response
33//! transport.send(&cx, &response)?;
34//! ```
35//!
36//! # Cancel-Safety
37//!
38//! All operations check `cx.checkpoint()` before blocking I/O.
39//! The transport integrates with asupersync's structured concurrency.
40
41use std::io::{BufReader, Read, Write};
42
43use asupersync::Cx;
44
45use crate::{Codec, Transport, TransportError};
46use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
47
48/// WebSocket frame types.
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum WsFrameType {
51    /// Continuation frame (for fragmented messages).
52    Continuation,
53    /// Text frame containing UTF-8 data (used for JSON-RPC).
54    Text,
55    /// Binary frame.
56    Binary,
57    /// Close frame.
58    Close,
59    /// Ping frame (keep-alive).
60    Ping,
61    /// Pong frame (keep-alive response).
62    Pong,
63}
64
65impl WsFrameType {
66    /// Returns the opcode for this frame type.
67    fn opcode(&self) -> u8 {
68        match self {
69            WsFrameType::Continuation => 0x00,
70            WsFrameType::Text => 0x01,
71            WsFrameType::Binary => 0x02,
72            WsFrameType::Close => 0x08,
73            WsFrameType::Ping => 0x09,
74            WsFrameType::Pong => 0x0A,
75        }
76    }
77
78    /// Parses a frame type from an opcode.
79    fn from_opcode(opcode: u8) -> Option<Self> {
80        match opcode {
81            0x00 => Some(WsFrameType::Continuation),
82            0x01 => Some(WsFrameType::Text),
83            0x02 => Some(WsFrameType::Binary),
84            0x08 => Some(WsFrameType::Close),
85            0x09 => Some(WsFrameType::Ping),
86            0x0A => Some(WsFrameType::Pong),
87            _ => None,
88        }
89    }
90}
91
92/// A WebSocket frame.
93#[derive(Debug, Clone)]
94pub struct WsFrame {
95    /// Frame type.
96    pub frame_type: WsFrameType,
97    /// Frame payload.
98    pub payload: Vec<u8>,
99    /// Whether this is the final frame in a message.
100    pub fin: bool,
101}
102
103impl WsFrame {
104    /// Creates a new text frame with the given payload.
105    #[must_use]
106    pub fn text(payload: impl Into<String>) -> Self {
107        Self {
108            frame_type: WsFrameType::Text,
109            payload: payload.into().into_bytes(),
110            fin: true,
111        }
112    }
113
114    /// Creates a new close frame.
115    #[must_use]
116    pub fn close() -> Self {
117        Self {
118            frame_type: WsFrameType::Close,
119            payload: Vec::new(),
120            fin: true,
121        }
122    }
123
124    /// Creates a new ping frame.
125    #[must_use]
126    pub fn ping(payload: Vec<u8>) -> Self {
127        Self {
128            frame_type: WsFrameType::Ping,
129            payload,
130            fin: true,
131        }
132    }
133
134    /// Creates a new pong frame.
135    #[must_use]
136    pub fn pong(payload: Vec<u8>) -> Self {
137        Self {
138            frame_type: WsFrameType::Pong,
139            payload,
140            fin: true,
141        }
142    }
143
144    /// Returns the payload as a UTF-8 string if this is a text frame.
145    pub fn as_text(&self) -> Result<&str, std::str::Utf8Error> {
146        std::str::from_utf8(&self.payload)
147    }
148}
149
150/// WebSocket frame reader.
151///
152/// Reads WebSocket frames from an underlying byte stream.
153/// Handles frame parsing according to RFC 6455.
154pub struct WsReader<R> {
155    reader: BufReader<R>,
156    max_frame_size: usize,
157    /// Whether to require masking (true for server-side, false for client-side).
158    /// Per RFC 6455: servers MUST reject unmasked client frames.
159    require_mask: bool,
160}
161
162impl<R: Read> WsReader<R> {
163    /// Creates a new WebSocket reader for server-side use.
164    ///
165    /// Per RFC 6455, servers MUST reject unmasked frames from clients.
166    pub fn new(reader: R) -> Self {
167        Self::with_config(reader, true)
168    }
169
170    /// Creates a new WebSocket reader for client-side use.
171    ///
172    /// Clients receive unmasked frames from servers per RFC 6455.
173    pub fn new_client(reader: R) -> Self {
174        Self::with_config(reader, false)
175    }
176
177    /// Creates a new WebSocket reader with explicit mask requirement.
178    fn with_config(reader: R, require_mask: bool) -> Self {
179        Self {
180            reader: BufReader::new(reader),
181            max_frame_size: 10 * 1024 * 1024,
182            require_mask,
183        }
184    }
185
186    /// Reads the next WebSocket frame.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if the frame is malformed or I/O fails.
191    pub fn read_frame(&mut self) -> Result<WsFrame, TransportError> {
192        // Read first two bytes (header)
193        let mut header = [0u8; 2];
194        self.reader.read_exact(&mut header)?;
195
196        let fin = (header[0] & 0x80) != 0;
197        let rsv = header[0] & 0x70;
198        let opcode = header[0] & 0x0F;
199        let masked = (header[1] & 0x80) != 0;
200        let mut payload_len = (header[1] & 0x7F) as u64;
201
202        if rsv != 0 {
203            return Err(TransportError::Io(std::io::Error::new(
204                std::io::ErrorKind::InvalidData,
205                "WebSocket RSV bits set but no extensions are supported",
206            )));
207        }
208
209        // Extended payload length
210        if payload_len == 126 {
211            let mut ext = [0u8; 2];
212            self.reader.read_exact(&mut ext)?;
213            payload_len = u16::from_be_bytes(ext) as u64;
214        } else if payload_len == 127 {
215            let mut ext = [0u8; 8];
216            self.reader.read_exact(&mut ext)?;
217            payload_len = u64::from_be_bytes(ext);
218        }
219
220        let is_control = matches!(opcode, 0x08..=0x0A);
221        if is_control && !fin {
222            return Err(TransportError::Io(std::io::Error::new(
223                std::io::ErrorKind::InvalidData,
224                "Fragmented control frames are not allowed",
225            )));
226        }
227        if is_control && payload_len > 125 {
228            return Err(TransportError::Io(std::io::Error::new(
229                std::io::ErrorKind::InvalidData,
230                "Control frame payload too large",
231            )));
232        }
233
234        let max_frame_size = self.max_frame_size as u64;
235        if payload_len > max_frame_size {
236            return Err(TransportError::Io(std::io::Error::new(
237                std::io::ErrorKind::InvalidData,
238                format!("WebSocket frame too large: {payload_len} bytes"),
239            )));
240        }
241        if payload_len > usize::MAX as u64 {
242            return Err(TransportError::Io(std::io::Error::new(
243                std::io::ErrorKind::InvalidData,
244                "WebSocket frame length exceeds platform limits",
245            )));
246        }
247
248        // RFC 6455 Section 5.1: Server MUST close connection if client frame is unmasked
249        if self.require_mask && !masked {
250            return Err(TransportError::Io(std::io::Error::new(
251                std::io::ErrorKind::InvalidData,
252                "Client frames MUST be masked per RFC 6455",
253            )));
254        }
255
256        // Read masking key if present (client -> server frames are masked)
257        let mask_key = if masked {
258            let mut key = [0u8; 4];
259            self.reader.read_exact(&mut key)?;
260            Some(key)
261        } else {
262            None
263        };
264
265        // Read payload
266        let mut payload = vec![0u8; payload_len as usize];
267        self.reader.read_exact(&mut payload)?;
268
269        // Unmask if necessary
270        if let Some(key) = mask_key {
271            for (i, byte) in payload.iter_mut().enumerate() {
272                *byte ^= key[i % 4];
273            }
274        }
275
276        let frame_type = WsFrameType::from_opcode(opcode).ok_or_else(|| {
277            TransportError::Io(std::io::Error::new(
278                std::io::ErrorKind::InvalidData,
279                format!("Unknown WebSocket opcode: {opcode}"),
280            ))
281        })?;
282
283        Ok(WsFrame {
284            frame_type,
285            payload,
286            fin,
287        })
288    }
289}
290
291/// WebSocket frame writer.
292///
293/// Writes WebSocket frames to an underlying byte stream.
294/// Server frames are unmasked per RFC 6455.
295pub struct WsWriter<W> {
296    writer: W,
297}
298
299impl<W: Write> WsWriter<W> {
300    /// Creates a new WebSocket writer.
301    pub fn new(writer: W) -> Self {
302        Self { writer }
303    }
304
305    /// Writes a WebSocket frame.
306    ///
307    /// # Errors
308    ///
309    /// Returns an error if I/O fails.
310    pub fn write_frame(&mut self, frame: &WsFrame) -> Result<(), TransportError> {
311        // First byte: FIN + opcode
312        let byte1 = if frame.fin { 0x80 } else { 0x00 } | frame.frame_type.opcode();
313
314        // Second byte: mask bit (0 for server) + payload length
315        let payload_len = frame.payload.len();
316
317        if payload_len < 126 {
318            self.writer.write_all(&[byte1, payload_len as u8])?;
319        } else if payload_len < 65536 {
320            self.writer.write_all(&[byte1, 126])?;
321            self.writer.write_all(&(payload_len as u16).to_be_bytes())?;
322        } else {
323            self.writer.write_all(&[byte1, 127])?;
324            self.writer.write_all(&(payload_len as u64).to_be_bytes())?;
325        }
326
327        // Write payload (unmasked for server -> client)
328        self.writer.write_all(&frame.payload)?;
329        self.writer.flush()?;
330
331        Ok(())
332    }
333}
334
335/// WebSocket transport for MCP.
336///
337/// Provides bidirectional message passing over WebSocket.
338/// Messages are JSON-RPC encoded as text frames.
339///
340/// # Example
341///
342/// ```ignore
343/// let transport = WsTransport::new(tcp_read, tcp_write);
344///
345/// // Receive a message
346/// match transport.recv(&cx)? {
347///     JsonRpcMessage::Request(req) => {
348///         // Handle request and send response
349///         let response = handle_request(req);
350///         transport.send(&cx, &JsonRpcMessage::Response(response))?;
351///     }
352///     _ => {}
353/// }
354/// ```
355pub struct WsTransport<R, W> {
356    reader: WsReader<R>,
357    writer: WsWriter<W>,
358    codec: Codec,
359    fragment_buffer: Vec<u8>,
360    max_message_size: usize,
361}
362
363impl<R: Read, W: Write> WsTransport<R, W> {
364    /// Creates a new WebSocket transport.
365    pub fn new(reader: R, writer: W) -> Self {
366        Self {
367            reader: WsReader::new(reader),
368            writer: WsWriter::new(writer),
369            codec: Codec::new(),
370            fragment_buffer: Vec::new(),
371            max_message_size: 10 * 1024 * 1024,
372        }
373    }
374
375    /// Sends a JSON-RPC message over the WebSocket.
376    ///
377    /// # Cancel-Safety
378    ///
379    /// Checks for cancellation before sending.
380    ///
381    /// # Errors
382    ///
383    /// Returns an error if cancelled, the connection is closed, or I/O fails.
384    pub fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
385        // Check cancellation
386        if cx.is_cancel_requested() {
387            return Err(TransportError::Cancelled);
388        }
389
390        // Encode message
391        let bytes = match message {
392            JsonRpcMessage::Request(req) => self.codec.encode_request(req)?,
393            JsonRpcMessage::Response(resp) => self.codec.encode_response(resp)?,
394        };
395
396        // Convert to string (strip trailing newline from NDJSON format)
397        let text = String::from_utf8(bytes).map_err(|e| {
398            TransportError::Io(std::io::Error::new(
399                std::io::ErrorKind::InvalidData,
400                format!("Invalid UTF-8 in message: {e}"),
401            ))
402        })?;
403        let text = text.trim_end();
404
405        // Send as text frame
406        let frame = WsFrame::text(text);
407        self.writer.write_frame(&frame)?;
408
409        Ok(())
410    }
411
412    /// Receives the next JSON-RPC message from the WebSocket.
413    ///
414    /// Handles control frames (ping/pong) automatically.
415    /// Handles message fragmentation (Continuation frames).
416    ///
417    /// # Cancel-Safety
418    ///
419    /// Checks for cancellation before blocking.
420    ///
421    /// # Errors
422    ///
423    /// Returns an error if cancelled, the connection is closed, or parsing fails.
424    pub fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
425        loop {
426            // Check cancellation
427            if cx.is_cancel_requested() {
428                return Err(TransportError::Cancelled);
429            }
430
431            // Read next frame
432            let frame = self.reader.read_frame()?;
433
434            match frame.frame_type {
435                WsFrameType::Text => {
436                    if !self.fragment_buffer.is_empty() {
437                        return Err(TransportError::Io(std::io::Error::new(
438                            std::io::ErrorKind::InvalidData,
439                            "Received Text frame while inside fragmented message",
440                        )));
441                    }
442
443                    if frame.fin {
444                        // Complete message in single frame
445                        return self.decode_message(frame.payload);
446                    }
447
448                    // Start of fragmented message
449                    let next_len = self
450                        .fragment_buffer
451                        .len()
452                        .saturating_add(frame.payload.len());
453                    if next_len > self.max_message_size {
454                        self.fragment_buffer.clear();
455                        return Err(TransportError::Io(std::io::Error::new(
456                            std::io::ErrorKind::InvalidData,
457                            "Fragmented message exceeds size limit",
458                        )));
459                    }
460                    self.fragment_buffer.extend(frame.payload);
461                    continue;
462                }
463                WsFrameType::Continuation => {
464                    if self.fragment_buffer.is_empty() {
465                        return Err(TransportError::Io(std::io::Error::new(
466                            std::io::ErrorKind::InvalidData,
467                            "Received Continuation frame without start frame",
468                        )));
469                    }
470
471                    let next_len = self
472                        .fragment_buffer
473                        .len()
474                        .saturating_add(frame.payload.len());
475                    if next_len > self.max_message_size {
476                        self.fragment_buffer.clear();
477                        return Err(TransportError::Io(std::io::Error::new(
478                            std::io::ErrorKind::InvalidData,
479                            "Fragmented message exceeds size limit",
480                        )));
481                    }
482                    self.fragment_buffer.extend(frame.payload);
483
484                    if frame.fin {
485                        // End of fragmented message
486                        let payload = std::mem::take(&mut self.fragment_buffer);
487                        return self.decode_message(payload);
488                    }
489
490                    // More fragments to come
491                    continue;
492                }
493                WsFrameType::Binary => {
494                    // Per RFC 6455 Section 5.4, data frames MUST NOT be interleaved
495                    // during fragmentation. Reject if we're inside a fragmented message.
496                    if !self.fragment_buffer.is_empty() {
497                        return Err(TransportError::Io(std::io::Error::new(
498                            std::io::ErrorKind::InvalidData,
499                            "Received Binary frame while inside fragmented message",
500                        )));
501                    }
502                    // Binary frames not used by MCP, skip otherwise
503                    continue;
504                }
505                WsFrameType::Close => {
506                    return Err(TransportError::Closed);
507                }
508                WsFrameType::Ping => {
509                    // Auto-respond with pong
510                    let pong = WsFrame::pong(frame.payload);
511                    self.writer.write_frame(&pong)?;
512                    continue;
513                }
514                WsFrameType::Pong => {
515                    // Ignore pong frames
516                    continue;
517                }
518            }
519        }
520    }
521
522    /// Decodes a payload into a JSON-RPC message.
523    fn decode_message(&mut self, payload: Vec<u8>) -> Result<JsonRpcMessage, TransportError> {
524        // Parse JSON-RPC message
525        let text = String::from_utf8(payload).map_err(|e| {
526            TransportError::Io(std::io::Error::new(
527                std::io::ErrorKind::InvalidData,
528                format!("Invalid UTF-8: {e}"),
529            ))
530        })?;
531
532        // Add newline for codec and decode
533        let mut input = text.as_bytes().to_vec();
534        input.push(b'\n');
535
536        let messages = self.codec.decode(&input)?;
537        if let Some(msg) = messages.into_iter().next() {
538            return Ok(msg);
539        }
540
541        // This shouldn't happen for a complete message unless it was empty or just whitespace
542        Err(TransportError::Io(std::io::Error::new(
543            std::io::ErrorKind::InvalidData,
544            "Received empty message",
545        )))
546    }
547
548    /// Sends a close frame and shuts down the connection.
549    ///
550    /// # Errors
551    ///
552    /// Returns an error if I/O fails.
553    pub fn close(&mut self) -> Result<(), TransportError> {
554        let frame = WsFrame::close();
555        self.writer.write_frame(&frame)?;
556        Ok(())
557    }
558
559    /// Sends a request through this transport.
560    ///
561    /// Convenience method that wraps a request in a message.
562    pub fn send_request(
563        &mut self,
564        cx: &Cx,
565        request: &JsonRpcRequest,
566    ) -> Result<(), TransportError> {
567        self.send(cx, &JsonRpcMessage::Request(request.clone()))
568    }
569
570    /// Sends a response through this transport.
571    ///
572    /// Convenience method that wraps a response in a message.
573    pub fn send_response(
574        &mut self,
575        cx: &Cx,
576        response: &JsonRpcResponse,
577    ) -> Result<(), TransportError> {
578        self.send(cx, &JsonRpcMessage::Response(response.clone()))
579    }
580
581    /// Sends a ping frame.
582    ///
583    /// # Errors
584    ///
585    /// Returns an error if I/O fails.
586    pub fn ping(&mut self) -> Result<(), TransportError> {
587        let frame = WsFrame::ping(Vec::new());
588        self.writer.write_frame(&frame)?;
589        Ok(())
590    }
591}
592
593impl<R: Read, W: Write> Transport for WsTransport<R, W> {
594    fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
595        WsTransport::send(self, cx, message)
596    }
597
598    fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
599        WsTransport::recv(self, cx)
600    }
601
602    fn close(&mut self) -> Result<(), TransportError> {
603        WsTransport::close(self)
604    }
605}
606
607/// Client-side WebSocket mask generation.
608///
609/// Clients must mask frames per RFC 6455. This struct provides
610/// frame writing with proper masking using cryptographically secure
611/// random mask keys.
612pub struct WsClientWriter<W> {
613    writer: W,
614}
615
616impl<W: Write> WsClientWriter<W> {
617    /// Creates a new client WebSocket writer.
618    pub fn new(writer: W) -> Self {
619        Self { writer }
620    }
621
622    /// Generates a cryptographically secure mask key.
623    ///
624    /// RFC 6455 Section 5.3: The masking key MUST be unpredictable.
625    fn generate_mask() -> Result<[u8; 4], TransportError> {
626        let mut mask = [0u8; 4];
627        // Use CSPRNG for unpredictable mask keys per RFC 6455
628        getrandom::fill(&mut mask).map_err(|e| {
629            TransportError::Io(std::io::Error::new(
630                std::io::ErrorKind::Other,
631                format!("getrandom failed: {e}"),
632            ))
633        })?;
634        Ok(mask)
635    }
636
637    /// Writes a WebSocket frame with client masking.
638    ///
639    /// # Errors
640    ///
641    /// Returns an error if I/O fails.
642    pub fn write_frame(&mut self, frame: &WsFrame) -> Result<(), TransportError> {
643        // First byte: FIN + opcode
644        let byte1 = if frame.fin { 0x80 } else { 0x00 } | frame.frame_type.opcode();
645
646        // Second byte: mask bit (1 for client) + payload length
647        let payload_len = frame.payload.len();
648        let mask_bit = 0x80u8;
649
650        if payload_len < 126 {
651            self.writer
652                .write_all(&[byte1, mask_bit | payload_len as u8])?;
653        } else if payload_len < 65536 {
654            self.writer.write_all(&[byte1, mask_bit | 126])?;
655            self.writer.write_all(&(payload_len as u16).to_be_bytes())?;
656        } else {
657            self.writer.write_all(&[byte1, mask_bit | 127])?;
658            self.writer.write_all(&(payload_len as u64).to_be_bytes())?;
659        }
660
661        // Write mask key (cryptographically random per RFC 6455)
662        let mask = Self::generate_mask()?;
663        self.writer.write_all(&mask)?;
664
665        // Write masked payload
666        let masked: Vec<u8> = frame
667            .payload
668            .iter()
669            .enumerate()
670            .map(|(i, b)| b ^ mask[i % 4])
671            .collect();
672        self.writer.write_all(&masked)?;
673        self.writer.flush()?;
674
675        Ok(())
676    }
677}
678
679/// Client-side WebSocket transport.
680///
681/// Similar to `WsTransport` but masks outgoing frames as required
682/// for client-to-server communication per RFC 6455.
683pub struct WsClientTransport<R, W> {
684    reader: WsReader<R>,
685    writer: WsClientWriter<W>,
686    codec: Codec,
687    fragment_buffer: Vec<u8>,
688    max_message_size: usize,
689}
690
691impl<R: Read, W: Write> WsClientTransport<R, W> {
692    /// Creates a new client WebSocket transport.
693    pub fn new(reader: R, writer: W) -> Self {
694        Self {
695            // Client receives unmasked frames from server
696            reader: WsReader::new_client(reader),
697            writer: WsClientWriter::new(writer),
698            codec: Codec::new(),
699            fragment_buffer: Vec::new(),
700            max_message_size: 10 * 1024 * 1024,
701        }
702    }
703
704    /// Sends a JSON-RPC message over the WebSocket.
705    ///
706    /// The frame will be masked as required for clients.
707    ///
708    /// # Cancel-Safety
709    ///
710    /// Checks for cancellation before sending.
711    ///
712    /// # Errors
713    ///
714    /// Returns an error if cancelled, the connection is closed, or I/O fails.
715    pub fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
716        // Check cancellation
717        if cx.is_cancel_requested() {
718            return Err(TransportError::Cancelled);
719        }
720
721        // Encode message
722        let bytes = match message {
723            JsonRpcMessage::Request(req) => self.codec.encode_request(req)?,
724            JsonRpcMessage::Response(resp) => self.codec.encode_response(resp)?,
725        };
726
727        // Convert to string (strip trailing newline from NDJSON format)
728        let text = String::from_utf8(bytes).map_err(|e| {
729            TransportError::Io(std::io::Error::new(
730                std::io::ErrorKind::InvalidData,
731                format!("Invalid UTF-8 in message: {e}"),
732            ))
733        })?;
734        let text = text.trim_end();
735
736        // Send as text frame (masked)
737        let frame = WsFrame::text(text);
738        self.writer.write_frame(&frame)?;
739
740        Ok(())
741    }
742
743    /// Receives the next JSON-RPC message from the WebSocket.
744    ///
745    /// Handles control frames (ping/pong) automatically.
746    ///
747    /// # Cancel-Safety
748    ///
749    /// Checks for cancellation before blocking.
750    ///
751    /// # Errors
752    ///
753    /// Returns an error if cancelled, the connection is closed, or parsing fails.
754    pub fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
755        loop {
756            // Check cancellation
757            if cx.is_cancel_requested() {
758                return Err(TransportError::Cancelled);
759            }
760
761            // Read next frame
762            let frame = self.reader.read_frame()?;
763
764            match frame.frame_type {
765                WsFrameType::Text => {
766                    if !self.fragment_buffer.is_empty() {
767                        return Err(TransportError::Io(std::io::Error::new(
768                            std::io::ErrorKind::InvalidData,
769                            "Received Text frame while inside fragmented message",
770                        )));
771                    }
772
773                    if frame.fin {
774                        // Complete message in single frame
775                        return self.decode_message(frame.payload);
776                    }
777
778                    // Start of fragmented message
779                    let next_len = self
780                        .fragment_buffer
781                        .len()
782                        .saturating_add(frame.payload.len());
783                    if next_len > self.max_message_size {
784                        self.fragment_buffer.clear();
785                        return Err(TransportError::Io(std::io::Error::new(
786                            std::io::ErrorKind::InvalidData,
787                            "Fragmented message exceeds size limit",
788                        )));
789                    }
790                    self.fragment_buffer.extend(frame.payload);
791                    continue;
792                }
793                WsFrameType::Continuation => {
794                    if self.fragment_buffer.is_empty() {
795                        return Err(TransportError::Io(std::io::Error::new(
796                            std::io::ErrorKind::InvalidData,
797                            "Received Continuation frame without start frame",
798                        )));
799                    }
800
801                    let next_len = self
802                        .fragment_buffer
803                        .len()
804                        .saturating_add(frame.payload.len());
805                    if next_len > self.max_message_size {
806                        self.fragment_buffer.clear();
807                        return Err(TransportError::Io(std::io::Error::new(
808                            std::io::ErrorKind::InvalidData,
809                            "Fragmented message exceeds size limit",
810                        )));
811                    }
812                    self.fragment_buffer.extend(frame.payload);
813
814                    if frame.fin {
815                        // End of fragmented message
816                        let payload = std::mem::take(&mut self.fragment_buffer);
817                        return self.decode_message(payload);
818                    }
819
820                    // More fragments to come
821                    continue;
822                }
823                WsFrameType::Binary => {
824                    // Per RFC 6455 Section 5.4, data frames MUST NOT be interleaved
825                    // during fragmentation. Reject if we're inside a fragmented message.
826                    if !self.fragment_buffer.is_empty() {
827                        return Err(TransportError::Io(std::io::Error::new(
828                            std::io::ErrorKind::InvalidData,
829                            "Received Binary frame while inside fragmented message",
830                        )));
831                    }
832                    // Binary frames not used by MCP, skip otherwise
833                    continue;
834                }
835                WsFrameType::Close => {
836                    return Err(TransportError::Closed);
837                }
838                WsFrameType::Ping => {
839                    // Respond with pong (masked)
840                    let pong = WsFrame::pong(frame.payload);
841                    self.writer.write_frame(&pong)?;
842                    continue;
843                }
844                WsFrameType::Pong => {
845                    continue;
846                }
847            }
848        }
849    }
850
851    /// Decodes a payload into a JSON-RPC message.
852    fn decode_message(&mut self, payload: Vec<u8>) -> Result<JsonRpcMessage, TransportError> {
853        let text = String::from_utf8(payload).map_err(|e| {
854            TransportError::Io(std::io::Error::new(
855                std::io::ErrorKind::InvalidData,
856                format!("Invalid UTF-8: {e}"),
857            ))
858        })?;
859
860        let mut input = text.as_bytes().to_vec();
861        input.push(b'\n');
862
863        let messages = self.codec.decode(&input)?;
864        if let Some(msg) = messages.into_iter().next() {
865            return Ok(msg);
866        }
867
868        Err(TransportError::Io(std::io::Error::new(
869            std::io::ErrorKind::InvalidData,
870            "Received empty message",
871        )))
872    }
873
874    /// Sends a close frame.
875    ///
876    /// # Errors
877    ///
878    /// Returns an error if I/O fails.
879    pub fn close(&mut self) -> Result<(), TransportError> {
880        let frame = WsFrame::close();
881        self.writer.write_frame(&frame)?;
882        Ok(())
883    }
884}
885
886impl<R: Read, W: Write> Transport for WsClientTransport<R, W> {
887    fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
888        WsClientTransport::send(self, cx, message)
889    }
890
891    fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
892        WsClientTransport::recv(self, cx)
893    }
894
895    fn close(&mut self) -> Result<(), TransportError> {
896        WsClientTransport::close(self)
897    }
898}
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903    use std::io::Cursor;
904
905    #[test]
906    fn test_frame_type_opcode_roundtrip() {
907        for frame_type in [
908            WsFrameType::Text,
909            WsFrameType::Binary,
910            WsFrameType::Close,
911            WsFrameType::Ping,
912            WsFrameType::Pong,
913        ] {
914            let opcode = frame_type.opcode();
915            let parsed = WsFrameType::from_opcode(opcode);
916            assert_eq!(parsed, Some(frame_type));
917        }
918    }
919
920    #[test]
921    fn test_frame_text() {
922        let frame = WsFrame::text("hello");
923        assert_eq!(frame.frame_type, WsFrameType::Text);
924        assert_eq!(frame.as_text().unwrap(), "hello");
925        assert!(frame.fin);
926    }
927
928    #[test]
929    fn test_frame_close() {
930        let frame = WsFrame::close();
931        assert_eq!(frame.frame_type, WsFrameType::Close);
932        assert!(frame.payload.is_empty());
933        assert!(frame.fin);
934    }
935
936    #[test]
937    fn test_frame_ping_pong() {
938        let ping = WsFrame::ping(vec![1, 2, 3]);
939        assert_eq!(ping.frame_type, WsFrameType::Ping);
940        assert_eq!(ping.payload, vec![1, 2, 3]);
941
942        let pong = WsFrame::pong(vec![1, 2, 3]);
943        assert_eq!(pong.frame_type, WsFrameType::Pong);
944        assert_eq!(pong.payload, vec![1, 2, 3]);
945    }
946
947    #[test]
948    fn test_write_read_small_frame() {
949        let mut buffer = Vec::new();
950
951        // Write frame (server-side, unmasked)
952        {
953            let mut writer = WsWriter::new(&mut buffer);
954            let frame = WsFrame::text("hello");
955            writer.write_frame(&frame).unwrap();
956        }
957
958        // Read frame back (client-side, accepts unmasked)
959        let mut reader = WsReader::new_client(Cursor::new(buffer));
960        let frame = reader.read_frame().unwrap();
961
962        assert_eq!(frame.frame_type, WsFrameType::Text);
963        assert_eq!(frame.as_text().unwrap(), "hello");
964        assert!(frame.fin);
965    }
966
967    #[test]
968    fn test_write_read_medium_frame() {
969        // 200 bytes - uses extended length (126)
970        let payload = "x".repeat(200);
971        let mut buffer = Vec::new();
972
973        {
974            let mut writer = WsWriter::new(&mut buffer);
975            let frame = WsFrame::text(&payload);
976            writer.write_frame(&frame).unwrap();
977        }
978
979        // Client-side reader accepts unmasked frames
980        let mut reader = WsReader::new_client(Cursor::new(buffer));
981        let frame = reader.read_frame().unwrap();
982
983        assert_eq!(frame.as_text().unwrap(), payload);
984    }
985
986    #[test]
987    fn test_write_read_large_frame() {
988        // 70000 bytes - uses extended length (127)
989        let payload = "x".repeat(70000);
990        let mut buffer = Vec::new();
991
992        {
993            let mut writer = WsWriter::new(&mut buffer);
994            let frame = WsFrame::text(&payload);
995            writer.write_frame(&frame).unwrap();
996        }
997
998        // Client-side reader accepts unmasked frames
999        let mut reader = WsReader::new_client(Cursor::new(buffer));
1000        let frame = reader.read_frame().unwrap();
1001
1002        assert_eq!(frame.as_text().unwrap(), payload);
1003    }
1004
1005    #[test]
1006    fn test_client_writer_masks_frames() {
1007        let mut buffer = Vec::new();
1008
1009        {
1010            let mut writer = WsClientWriter::new(&mut buffer);
1011            let frame = WsFrame::text("hi");
1012            writer.write_frame(&frame).unwrap();
1013        }
1014
1015        // Check that mask bit is set (second byte has 0x80 bit)
1016        assert!(buffer.len() >= 2);
1017        assert_ne!(buffer[1] & 0x80, 0, "Mask bit should be set for client");
1018    }
1019
1020    #[test]
1021    fn test_read_masked_frame() {
1022        // Build a masked frame manually
1023        let payload = b"test";
1024        let mask = [0x12, 0x34, 0x56, 0x78];
1025        let masked_payload: Vec<u8> = payload
1026            .iter()
1027            .enumerate()
1028            .map(|(i, b)| b ^ mask[i % 4])
1029            .collect();
1030
1031        let mut buffer = Vec::new();
1032        buffer.push(0x81); // FIN + Text opcode
1033        buffer.push(0x80 | payload.len() as u8); // Mask bit + length
1034        buffer.extend_from_slice(&mask);
1035        buffer.extend_from_slice(&masked_payload);
1036
1037        let mut reader = WsReader::new(Cursor::new(buffer));
1038        let frame = reader.read_frame().unwrap();
1039
1040        assert_eq!(frame.as_text().unwrap(), "test");
1041    }
1042
1043    #[test]
1044    fn test_reader_rejects_oversized_frame() {
1045        // Build a masked frame for server-side testing
1046        let mask = [0x12, 0x34, 0x56, 0x78];
1047        let payload = b"hey";
1048        let masked: Vec<u8> = payload
1049            .iter()
1050            .enumerate()
1051            .map(|(i, b)| b ^ mask[i % 4])
1052            .collect();
1053
1054        let mut buffer = Vec::new();
1055        buffer.push(0x81); // FIN + Text opcode
1056        buffer.push(0x80 | 0x03); // Mask bit + 3-byte payload
1057        buffer.extend_from_slice(&mask);
1058        buffer.extend_from_slice(&masked);
1059
1060        let mut reader = WsReader::new(Cursor::new(buffer));
1061        reader.max_frame_size = 2;
1062
1063        let err = reader.read_frame().unwrap_err();
1064        assert!(matches!(
1065            err,
1066            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1067        ));
1068    }
1069
1070    #[test]
1071    fn test_reader_rejects_control_frame_over_125() {
1072        let mut buffer = Vec::new();
1073        buffer.push(0x89); // FIN + Ping opcode
1074        buffer.push(0x80 | 126); // Mask bit + Extended length (not allowed for control frames)
1075        buffer.extend_from_slice(&126u16.to_be_bytes());
1076        buffer.extend_from_slice(&[0, 0, 0, 0]); // Mask key
1077
1078        let mut reader = WsReader::new(Cursor::new(buffer));
1079        let err = reader.read_frame().unwrap_err();
1080        assert!(matches!(
1081            err,
1082            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1083        ));
1084    }
1085
1086    #[test]
1087    fn test_reader_rejects_fragmented_control_frame() {
1088        let mut buffer = Vec::new();
1089        buffer.push(0x09); // FIN=0 + Ping opcode
1090        buffer.push(0x80); // Mask bit + Zero payload
1091        buffer.extend_from_slice(&[0, 0, 0, 0]); // Mask key
1092
1093        let mut reader = WsReader::new(Cursor::new(buffer));
1094        let err = reader.read_frame().unwrap_err();
1095        assert!(matches!(
1096            err,
1097            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1098        ));
1099    }
1100
1101    #[test]
1102    fn test_reader_rejects_rsv_bits() {
1103        let mut buffer = Vec::new();
1104        buffer.push(0xC1); // FIN + RSV1 + Text opcode
1105        buffer.push(0x80); // Mask bit set + Zero payload
1106        buffer.extend_from_slice(&[0, 0, 0, 0]); // Mask key
1107
1108        let mut reader = WsReader::new(Cursor::new(buffer));
1109        let err = reader.read_frame().unwrap_err();
1110        assert!(matches!(
1111            err,
1112            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1113        ));
1114    }
1115
1116    #[test]
1117    fn test_server_rejects_unmasked_client_frames() {
1118        // RFC 6455 Section 5.1: Server MUST reject unmasked frames
1119        let mut buffer = Vec::new();
1120        buffer.push(0x81); // FIN + Text opcode
1121        buffer.push(0x05); // NO mask bit + 5-byte payload
1122        buffer.extend_from_slice(b"hello");
1123
1124        // Server-side reader (requires masking)
1125        let mut reader = WsReader::new(Cursor::new(buffer));
1126        let err = reader.read_frame().unwrap_err();
1127        assert!(matches!(
1128            err,
1129            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1130        ));
1131    }
1132
1133    #[test]
1134    fn test_client_accepts_unmasked_server_frames() {
1135        // Clients receive unmasked frames from servers
1136        let mut buffer = Vec::new();
1137        buffer.push(0x81); // FIN + Text opcode
1138        buffer.push(0x05); // NO mask bit + 5-byte payload
1139        buffer.extend_from_slice(b"hello");
1140
1141        // Client-side reader (does not require masking)
1142        let mut reader = WsReader::new_client(Cursor::new(buffer));
1143        let frame = reader.read_frame().unwrap();
1144        assert_eq!(frame.as_text().unwrap(), "hello");
1145    }
1146
1147    /// Helper to build a masked WebSocket frame for testing server-side code.
1148    ///
1149    fn build_masked_frame(opcode: u8, fin: bool, payload: &[u8]) -> Vec<u8> {
1150        let mask = [0x12, 0x34, 0x56, 0x78];
1151        let masked: Vec<u8> = payload
1152            .iter()
1153            .enumerate()
1154            .map(|(i, b)| b ^ mask[i % 4])
1155            .collect();
1156
1157        let mut frame = Vec::new();
1158        let byte1 = if fin { 0x80 } else { 0x00 } | opcode;
1159        frame.push(byte1);
1160
1161        // Mask bit + payload length (including extended length encodings).
1162        let payload_len = payload.len();
1163        if payload_len < 126 {
1164            frame.push(0x80 | payload_len as u8);
1165        } else if payload_len < 65536 {
1166            frame.push(0x80 | 126);
1167            frame.extend_from_slice(&(payload_len as u16).to_be_bytes());
1168        } else {
1169            frame.push(0x80 | 127);
1170            frame.extend_from_slice(&(payload_len as u64).to_be_bytes());
1171        }
1172
1173        frame.extend_from_slice(&mask);
1174        frame.extend_from_slice(&masked);
1175        frame
1176    }
1177
1178    #[test]
1179    fn test_fragmented_message_size_limit() {
1180        // Build masked frames (client -> server)
1181        let mut buffer = Vec::new();
1182        // Text frame start (FIN=0, opcode=Text)
1183        buffer.extend(build_masked_frame(0x01, false, b"hello"));
1184        // Continuation frame end (FIN=1, opcode=Continuation)
1185        buffer.extend(build_masked_frame(0x00, true, b"world"));
1186
1187        let cx = Cx::for_testing();
1188        let writer: Vec<u8> = Vec::new();
1189        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1190        transport.max_message_size = 8;
1191
1192        let err = transport.recv(&cx).unwrap_err();
1193        assert!(matches!(
1194            err,
1195            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1196        ));
1197    }
1198
1199    #[test]
1200    fn test_rejects_interleaved_binary_during_fragmentation() {
1201        // RFC 6455 Section 5.4: Data frames MUST NOT be interleaved
1202        // Build masked frames (client -> server)
1203        let mut buffer = Vec::new();
1204        // Text frame start (FIN=0, opcode=Text)
1205        buffer.extend(build_masked_frame(0x01, false, b"hello"));
1206        // Binary frame (interleaved - MUST be rejected)
1207        buffer.extend(build_masked_frame(0x02, true, b"bad"));
1208
1209        let cx = Cx::for_testing();
1210        let writer: Vec<u8> = Vec::new();
1211        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1212
1213        let err = transport.recv(&cx).unwrap_err();
1214        assert!(matches!(
1215            err,
1216            TransportError::Io(ref e) if e.kind() == std::io::ErrorKind::InvalidData
1217        ));
1218    }
1219
1220    #[test]
1221    fn test_transport_roundtrip() {
1222        use fastmcp_protocol::RequestId;
1223
1224        // Create a pipe using in-memory buffers
1225        // Simulating server -> client: server writes unmasked, client reads unmasked
1226        let mut write_buf = Vec::new();
1227
1228        // Server writes a request (unmasked)
1229        {
1230            let cx = Cx::for_testing();
1231            let reader: &[u8] = &[];
1232            let mut transport = WsTransport::new(reader, &mut write_buf);
1233
1234            let request = JsonRpcRequest {
1235                jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1236                id: Some(RequestId::Number(1)),
1237                method: "test".to_string(),
1238                params: None,
1239            };
1240
1241            transport.send_request(&cx, &request).unwrap();
1242        }
1243
1244        // Client reads the request (accepts unmasked frames from server)
1245        {
1246            let cx = Cx::for_testing();
1247            let writer: Vec<u8> = Vec::new();
1248            let mut transport = WsClientTransport::new(Cursor::new(write_buf), writer);
1249
1250            let msg = transport.recv(&cx).unwrap();
1251            assert!(
1252                matches!(msg, JsonRpcMessage::Request(_)),
1253                "Expected request"
1254            );
1255            if let JsonRpcMessage::Request(req) = msg {
1256                assert_eq!(req.method, "test");
1257                assert_eq!(req.id, Some(RequestId::Number(1)));
1258            }
1259        }
1260    }
1261
1262    #[test]
1263    fn test_close_frame_returns_closed_error() {
1264        // Build a masked close frame (simulating client -> server)
1265        let mut buffer = Vec::new();
1266        buffer.push(0x88); // FIN + Close opcode
1267        buffer.push(0x80); // Masked, 0 payload length
1268        buffer.extend_from_slice(&[0u8; 4]); // Mask key (zeros work for empty payload)
1269
1270        let cx = Cx::for_testing();
1271        let writer: Vec<u8> = Vec::new();
1272        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1273
1274        let result = transport.recv(&cx);
1275        assert!(matches!(result, Err(TransportError::Closed)));
1276    }
1277
1278    #[test]
1279    fn test_ping_auto_pong() {
1280        // Build masked frames (simulating client -> server)
1281        let mut buffer = Vec::new();
1282
1283        // Ping frame (masked, opcode 0x09)
1284        buffer.extend(build_masked_frame(0x09, true, b"ping"));
1285
1286        // Text frame with JSON-RPC (masked, opcode 0x01)
1287        let text = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
1288        buffer.extend(build_masked_frame(0x01, true, text.as_bytes()));
1289
1290        let mut response_buf = Vec::new();
1291
1292        let cx = Cx::for_testing();
1293        let mut transport = WsTransport::new(Cursor::new(buffer), &mut response_buf);
1294
1295        // Should skip ping (auto-pong) and return the text message
1296        let msg = transport.recv(&cx).unwrap();
1297        assert!(
1298            matches!(msg, JsonRpcMessage::Request(_)),
1299            "Expected request"
1300        );
1301        if let JsonRpcMessage::Request(req) = msg {
1302            assert_eq!(req.method, "test");
1303        }
1304
1305        // Check that pong was written
1306        assert!(!response_buf.is_empty());
1307        assert_eq!(response_buf[0] & 0x0F, 0x0A); // Pong opcode
1308    }
1309
1310    // =========================================================================
1311    // E2E WebSocket Tests (bd-2kv / bd-2gyv)
1312    // =========================================================================
1313
1314    #[test]
1315    fn e2e_ws_bidirectional_message_flow() {
1316        use fastmcp_protocol::RequestId;
1317
1318        // Simulate a full bidirectional message flow
1319        // Server-side processing of multiple requests and responses
1320
1321        let mut request_buffer = Vec::new();
1322
1323        // Build multiple masked requests (client -> server)
1324        let req1 = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
1325        let req2 = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
1326        let req3 = r#"{"jsonrpc":"2.0","method":"tools/call","params":{"name":"test"},"id":3}"#;
1327
1328        request_buffer.extend(build_masked_frame(0x01, true, req1.as_bytes()));
1329        request_buffer.extend(build_masked_frame(0x01, true, req2.as_bytes()));
1330        request_buffer.extend(build_masked_frame(0x01, true, req3.as_bytes()));
1331
1332        let mut response_buffer = Vec::new();
1333        let cx = Cx::for_testing();
1334
1335        {
1336            let mut transport = WsTransport::new(Cursor::new(request_buffer), &mut response_buffer);
1337
1338            // Receive and process each request
1339            for expected_id in 1..=3 {
1340                let msg = transport.recv(&cx).unwrap();
1341                assert!(
1342                    matches!(msg, JsonRpcMessage::Request(_)),
1343                    "Expected request"
1344                );
1345                let JsonRpcMessage::Request(req) = msg else {
1346                    return;
1347                };
1348
1349                assert_eq!(req.id, Some(RequestId::Number(expected_id)));
1350
1351                // Send response
1352                let response = JsonRpcResponse {
1353                    jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1354                    result: Some(serde_json::json!({"ok": true})),
1355                    error: None,
1356                    id: req.id,
1357                };
1358                transport.send_response(&cx, &response).unwrap();
1359            }
1360        }
1361
1362        // Verify responses were written (unmasked for server -> client)
1363        assert!(!response_buffer.is_empty());
1364        // Each response should be a separate text frame
1365        #[allow(clippy::naive_bytecount)]
1366        let frame_count = response_buffer
1367            .iter()
1368            .filter(|&&b| b == 0x81) // FIN + Text opcode
1369            .count();
1370        assert_eq!(frame_count, 3, "Expected 3 response frames");
1371    }
1372
1373    #[test]
1374    fn e2e_ws_fragmented_message_assembly() {
1375        // Test receiving a fragmented JSON-RPC message
1376        let full_msg =
1377            r#"{"jsonrpc":"2.0","method":"test","params":{"data":"hello world"},"id":1}"#;
1378        let mid = full_msg.len() / 2;
1379
1380        let mut buffer = Vec::new();
1381        // First fragment (FIN=0, opcode=Text)
1382        buffer.extend(build_masked_frame(0x01, false, &full_msg.as_bytes()[..mid]));
1383        // Continuation fragment (FIN=1, opcode=Continuation)
1384        buffer.extend(build_masked_frame(0x00, true, &full_msg.as_bytes()[mid..]));
1385
1386        let cx = Cx::for_testing();
1387        let writer: Vec<u8> = Vec::new();
1388        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1389
1390        let msg = transport.recv(&cx).unwrap();
1391        assert!(
1392            matches!(msg, JsonRpcMessage::Request(_)),
1393            "Expected request"
1394        );
1395        let JsonRpcMessage::Request(req) = msg else {
1396            return;
1397        };
1398        assert_eq!(req.method, "test");
1399        let params = req.params.unwrap();
1400        assert_eq!(params.get("data").unwrap(), "hello world");
1401    }
1402
1403    #[test]
1404    fn e2e_ws_interleaved_ping_during_operation() {
1405        // Test that ping/pong doesn't disrupt normal message flow
1406        let mut buffer = Vec::new();
1407
1408        // Message 1
1409        buffer.extend(build_masked_frame(
1410            0x01,
1411            true,
1412            r#"{"jsonrpc":"2.0","method":"msg1","id":1}"#.as_bytes(),
1413        ));
1414        // Ping (should be handled automatically)
1415        buffer.extend(build_masked_frame(0x09, true, b"keepalive"));
1416        // Message 2
1417        buffer.extend(build_masked_frame(
1418            0x01,
1419            true,
1420            r#"{"jsonrpc":"2.0","method":"msg2","id":2}"#.as_bytes(),
1421        ));
1422        // Another ping
1423        buffer.extend(build_masked_frame(0x09, true, b"alive"));
1424        // Message 3
1425        buffer.extend(build_masked_frame(
1426            0x01,
1427            true,
1428            r#"{"jsonrpc":"2.0","method":"msg3","id":3}"#.as_bytes(),
1429        ));
1430
1431        let mut response_buffer = Vec::new();
1432        let cx = Cx::for_testing();
1433        let mut transport = WsTransport::new(Cursor::new(buffer), &mut response_buffer);
1434
1435        // Should receive all 3 messages, with pings handled automatically
1436        for i in 1..=3 {
1437            let msg = transport.recv(&cx).unwrap();
1438            assert!(
1439                matches!(msg, JsonRpcMessage::Request(_)),
1440                "Expected request"
1441            );
1442            let JsonRpcMessage::Request(req) = msg else {
1443                return;
1444            };
1445            assert_eq!(req.method, format!("msg{i}"));
1446        }
1447
1448        // Verify pongs were sent - the response buffer should contain pong frames
1449        // Pong frames have opcode 0x0A and FIN bit set (0x8A)
1450        // Just verify we have some response data (pongs are there)
1451        assert!(
1452            !response_buffer.is_empty(),
1453            "Expected pong responses to be written"
1454        );
1455    }
1456
1457    #[test]
1458    fn e2e_ws_graceful_close() {
1459        // Test graceful close handshake
1460        let mut buffer = Vec::new();
1461        // Message followed by close
1462        buffer.extend(build_masked_frame(
1463            0x01,
1464            true,
1465            r#"{"jsonrpc":"2.0","method":"last","id":1}"#.as_bytes(),
1466        ));
1467        buffer.extend(build_masked_frame(0x08, true, &[])); // Close frame
1468
1469        let mut response_buffer = Vec::new();
1470        let cx = Cx::for_testing();
1471        let mut transport = WsTransport::new(Cursor::new(buffer), &mut response_buffer);
1472
1473        // Receive the message
1474        let msg = transport.recv(&cx).unwrap();
1475        assert!(matches!(msg, JsonRpcMessage::Request(_)));
1476
1477        // Next recv should return Closed
1478        let result = transport.recv(&cx);
1479        assert!(matches!(result, Err(TransportError::Closed)));
1480    }
1481
1482    #[test]
1483    fn e2e_ws_cancellation_respected() {
1484        let buffer = build_masked_frame(
1485            0x01,
1486            true,
1487            r#"{"jsonrpc":"2.0","method":"test","id":1}"#.as_bytes(),
1488        );
1489
1490        let cx = Cx::for_testing();
1491        cx.set_cancel_requested(true);
1492
1493        let writer: Vec<u8> = Vec::new();
1494        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1495
1496        // Recv should respect cancellation
1497        let result = transport.recv(&cx);
1498        assert!(matches!(result, Err(TransportError::Cancelled)));
1499    }
1500
1501    #[test]
1502    fn e2e_ws_send_cancellation_respected() {
1503        let cx = Cx::for_testing();
1504        cx.set_cancel_requested(true);
1505
1506        let reader: &[u8] = &[];
1507        let mut writer = Vec::new();
1508        let mut transport = WsTransport::new(reader, &mut writer);
1509
1510        let request = JsonRpcRequest::new("test", None, 1i64);
1511        let result = transport.send_request(&cx, &request);
1512        assert!(matches!(result, Err(TransportError::Cancelled)));
1513
1514        // Nothing should be written
1515        assert!(writer.is_empty());
1516    }
1517
1518    #[test]
1519    fn e2e_ws_unicode_in_messages() {
1520        // Test Unicode handling in WebSocket text frames
1521        let unicode_msg =
1522            r#"{"jsonrpc":"2.0","method":"test","params":{"text":"Hello 世界 👋 éèê"},"id":1}"#;
1523        let buffer = build_masked_frame(0x01, true, unicode_msg.as_bytes());
1524
1525        let cx = Cx::for_testing();
1526        let writer: Vec<u8> = Vec::new();
1527        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1528
1529        let msg = transport.recv(&cx).unwrap();
1530        assert!(
1531            matches!(msg, JsonRpcMessage::Request(_)),
1532            "Expected request"
1533        );
1534        let JsonRpcMessage::Request(req) = msg else {
1535            return;
1536        };
1537        let params = req.params.unwrap();
1538        let text = params.get("text").unwrap().as_str().unwrap();
1539        assert!(text.contains("世界"));
1540        assert!(text.contains("👋"));
1541        assert!(text.contains("éèê"));
1542    }
1543
1544    #[test]
1545    fn e2e_ws_client_server_full_flow() {
1546        use fastmcp_protocol::RequestId;
1547
1548        // Full client-server flow with proper masking
1549
1550        // 1. Client sends masked request
1551        let mut client_to_server = Vec::new();
1552        {
1553            let mut writer = WsClientWriter::new(&mut client_to_server);
1554            let request = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
1555            writer.write_frame(&WsFrame::text(request)).unwrap();
1556        }
1557
1558        // 2. Server receives and processes
1559        let mut server_response = Vec::new();
1560        {
1561            let cx = Cx::for_testing();
1562            let mut transport =
1563                WsTransport::new(Cursor::new(client_to_server.clone()), &mut server_response);
1564
1565            let msg = transport.recv(&cx).unwrap();
1566            if let JsonRpcMessage::Request(req) = msg {
1567                assert_eq!(req.method, "initialize");
1568
1569                // Send response (unmasked)
1570                let response = JsonRpcResponse {
1571                    jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1572                    result: Some(serde_json::json!({"capabilities": {}})),
1573                    error: None,
1574                    id: Some(RequestId::Number(1)),
1575                };
1576                transport.send_response(&cx, &response).unwrap();
1577            }
1578        }
1579
1580        // 3. Client receives response
1581        {
1582            let cx = Cx::for_testing();
1583            let mut transport =
1584                WsClientTransport::new(Cursor::new(server_response), Vec::<u8>::new());
1585
1586            let msg = transport.recv(&cx).unwrap();
1587            assert!(
1588                matches!(msg, JsonRpcMessage::Response(_)),
1589                "Expected response"
1590            );
1591            let JsonRpcMessage::Response(resp) = msg else {
1592                return;
1593            };
1594            assert_eq!(resp.id, Some(RequestId::Number(1)));
1595            assert!(resp.result.is_some());
1596        }
1597    }
1598
1599    #[test]
1600    fn ws_continuation_opcode_roundtrip() {
1601        let ft = WsFrameType::Continuation;
1602        assert_eq!(WsFrameType::from_opcode(ft.opcode()), Some(ft));
1603    }
1604
1605    #[test]
1606    fn ws_unknown_opcode_returns_none() {
1607        assert_eq!(WsFrameType::from_opcode(0x03), None);
1608        assert_eq!(WsFrameType::from_opcode(0x0F), None);
1609    }
1610
1611    #[test]
1612    fn ws_frame_as_text_non_utf8_returns_error() {
1613        let frame = WsFrame {
1614            frame_type: WsFrameType::Text,
1615            payload: vec![0xFF, 0xFE],
1616            fin: true,
1617        };
1618        assert!(frame.as_text().is_err());
1619    }
1620
1621    #[test]
1622    fn ws_transport_close_sends_close_frame() {
1623        let reader: &[u8] = &[];
1624        let mut output = Vec::new();
1625        let mut transport = WsTransport::new(reader, &mut output);
1626        transport.close().unwrap();
1627
1628        // Close frame: FIN + opcode 0x08 = 0x88, payload length 0
1629        assert!(output.len() >= 2);
1630        assert_eq!(output[0], 0x88);
1631        assert_eq!(output[1], 0x00);
1632    }
1633
1634    #[test]
1635    fn ws_transport_ping_sends_ping_frame() {
1636        let reader: &[u8] = &[];
1637        let mut output = Vec::new();
1638        let mut transport = WsTransport::new(reader, &mut output);
1639        transport.ping().unwrap();
1640
1641        // Ping frame: FIN + opcode 0x09 = 0x89, payload length 0
1642        assert!(output.len() >= 2);
1643        assert_eq!(output[0], 0x89);
1644        assert_eq!(output[1], 0x00);
1645    }
1646
1647    #[test]
1648    fn ws_client_transport_send_cancelled() {
1649        let cx = Cx::for_testing();
1650        cx.set_cancel_requested(true);
1651
1652        let reader: &[u8] = &[];
1653        let mut writer = Vec::new();
1654        let mut transport = WsClientTransport::new(reader, &mut writer);
1655
1656        let request = JsonRpcRequest::new("test", None, 1i64);
1657        let result = transport.send(&cx, &JsonRpcMessage::Request(request));
1658        assert!(matches!(result, Err(TransportError::Cancelled)));
1659        assert!(writer.is_empty());
1660    }
1661
1662    #[test]
1663    fn ws_binary_frame_skipped_outside_fragmentation() {
1664        // Binary frame (non-fragmented) followed by a text message
1665        let mut buffer = Vec::new();
1666        buffer.extend(build_masked_frame(0x02, true, b"binary-data"));
1667        buffer.extend(build_masked_frame(
1668            0x01,
1669            true,
1670            r#"{"jsonrpc":"2.0","method":"after_binary","id":1}"#.as_bytes(),
1671        ));
1672
1673        let cx = Cx::for_testing();
1674        let writer: Vec<u8> = Vec::new();
1675        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1676
1677        let msg = transport.recv(&cx).unwrap();
1678        let JsonRpcMessage::Request(req) = msg else {
1679            panic!("expected request");
1680        };
1681        assert_eq!(req.method, "after_binary");
1682    }
1683
1684    #[test]
1685    fn ws_pong_frame_skipped() {
1686        // Pong frame followed by a text message
1687        let mut buffer = Vec::new();
1688        buffer.extend(build_masked_frame(0x0A, true, b"pong-payload"));
1689        buffer.extend(build_masked_frame(
1690            0x01,
1691            true,
1692            r#"{"jsonrpc":"2.0","method":"after_pong","id":2}"#.as_bytes(),
1693        ));
1694
1695        let cx = Cx::for_testing();
1696        let writer: Vec<u8> = Vec::new();
1697        let mut transport = WsTransport::new(Cursor::new(buffer), writer);
1698
1699        let msg = transport.recv(&cx).unwrap();
1700        let JsonRpcMessage::Request(req) = msg else {
1701            panic!("expected request");
1702        };
1703        assert_eq!(req.method, "after_pong");
1704    }
1705}