rama_ws/protocol/
mod.rs

1//! Generic WebSocket message stream.
2
3use rama_core::extensions::{Extensions, ExtensionsMut, ExtensionsRef};
4use rama_core::telemetry::tracing;
5use rama_core::{
6    error::OpaqueError,
7    telemetry::tracing::{debug, trace},
8};
9use std::{
10    fmt,
11    io::{self, Read, Write},
12};
13
14#[cfg(feature = "compression")]
15use rama_http::headers::sec_websocket_extensions;
16
17pub mod frame;
18
19mod error;
20mod message;
21
22#[cfg(feature = "compression")]
23mod per_message_deflate;
24
25pub use error::ProtocolError;
26
27#[cfg(test)]
28mod tests;
29
30use crate::protocol::{
31    frame::{
32        Frame, FrameCodec, Utf8Bytes,
33        coding::{CloseCode, OpCode, OpCodeControl, OpCodeData},
34    },
35    message::{IncompleteMessage, IncompleteMessageType},
36};
37
38#[cfg(feature = "compression")]
39use self::per_message_deflate::PerMessageDeflateState;
40
41pub use self::{frame::CloseFrame, message::Message};
42
43/// Indicates a Client or Server role of the websocket
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum Role {
46    /// This socket is a server
47    Server,
48    /// This socket is a client
49    Client,
50}
51
52/// The configuration for WebSocket connection.
53///
54/// # Example
55/// ```
56/// # use rama_ws::protocol::WebSocketConfig;
57///
58/// let conf = WebSocketConfig::default()
59///     .with_read_buffer_size(256 * 1024)
60///     .with_write_buffer_size(256 * 1024);
61/// ```
62#[derive(Debug, Clone, Copy)]
63#[non_exhaustive]
64pub struct WebSocketConfig {
65    /// Read buffer capacity. This buffer is eagerly allocated and used for receiving
66    /// messages.
67    ///
68    /// For high read load scenarios a larger buffer, e.g. 128 KiB, improves performance.
69    ///
70    /// For scenarios where you expect a lot of connections and don't need high read load
71    /// performance a smaller buffer, e.g. 4 KiB, would be appropriate to lower total
72    /// memory usage.
73    ///
74    /// The default value is 128 KiB.
75    pub read_buffer_size: usize,
76
77    /// The target minimum size of the write buffer to reach before writing the data
78    /// to the underlying stream.
79    /// The default value is 128 KiB.
80    ///
81    /// If set to `0` each message will be eagerly written to the underlying stream.
82    /// It is often more optimal to allow them to buffer a little, hence the default value.
83    ///
84    /// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
85    pub write_buffer_size: usize,
86
87    /// The max size of the write buffer in bytes. Setting this can provide backpressure
88    /// in the case the write buffer is filling up due to write errors.
89    /// The default value is unlimited.
90    ///
91    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
92    /// when writes to the underlying stream are failing. So the **write buffer can not
93    /// fill up if you are not observing write errors even if not flushing**.
94    ///
95    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
96    /// and probably a little more depending on error handling strategy.
97    pub max_write_buffer_size: usize,
98
99    /// The maximum size of an incoming message. `None` means no size limit. The default value is 64 MiB
100    /// which should be reasonably big for all normal use-cases but small enough to prevent
101    /// memory eating by a malicious user.
102    pub max_message_size: Option<usize>,
103
104    /// The maximum size of a single incoming message frame. `None` means no size limit. The limit is for
105    /// frame payload NOT including the frame header. The default value is 16 MiB which should
106    /// be reasonably big for all normal use-cases but small enough to prevent memory eating
107    /// by a malicious user.
108    pub max_frame_size: Option<usize>,
109
110    /// When set to `true`, the server will accept and handle unmasked frames
111    /// from the client. According to the RFC 6455, the server must close the
112    /// connection to the client in such cases, however it seems like there are
113    /// some popular libraries that are sending unmasked frames, ignoring the RFC.
114    /// By default this option is set to `false`, i.e. according to RFC 6455.
115    pub accept_unmasked_frames: bool,
116
117    #[cfg(feature = "compression")]
118    #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
119    /// Per-message-deflate configuration, specify it
120    /// to enable per-message (de)compression using the Deflate algorithm
121    /// as specified by [`RFC7692`].
122    ///
123    /// [`RFC7692`]: https://datatracker.ietf.org/doc/html/rfc7692
124    pub per_message_deflate: Option<PerMessageDeflateConfig>,
125}
126
127#[cfg(feature = "compression")]
128#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
129/// Per-message-deflate configuration as specified in [`RFC7692`]
130///
131/// [`RFC7692`]: https://datatracker.ietf.org/doc/html/rfc7692
132#[derive(Debug, Clone, Copy)]
133pub struct PerMessageDeflateConfig {
134    /// Prevents Server Context Takeover
135    ///
136    /// This extension parameter enables a client to request that
137    /// the server forgo context takeover, thereby eliminating
138    /// the client's need to retain memory for the LZ77 sliding window between messages.
139    ///
140    /// A client's omission of this parameter indicates its capability to decompress messages
141    /// even if the server utilizes context takeover.
142    ///
143    /// Servers should support this parameter and confirm acceptance by
144    /// including it in their response;
145    /// they may even include it if not explicitly requested by the client.
146    pub server_no_context_takeover: bool,
147
148    /// Manages Client Context Takeover
149    ///
150    /// This extension parameter allows a client to indicate to
151    /// the server its intent not to use context takeover,
152    /// even if the server doesn't explicitly respond with the same parameter.
153    ///
154    /// When a server receives this, it can either ignore it or include
155    /// `client_no_context_takeover` in its response,
156    /// which prevents the client from using context
157    /// takeover and helps the server conserve memory.
158    /// If the server's response omits this parameter,
159    /// it signals its ability to decompress messages where
160    /// the client does use context takeover.
161    ///
162    /// Clients are required to support this parameter in a server's response.
163    pub client_no_context_takeover: bool,
164
165    /// Limits Server Window Size
166    ///
167    /// This extension parameter allows a client to propose
168    /// a maximum LZ77 sliding window size for the server
169    /// to use when compressing messages, specified as a base-2 logarithm (8-15).
170    ///
171    /// This helps the client reduce its memory requirements.
172    /// If a client omits this parameter,
173    /// it signals its capacity to handle messages compressed with a window up to 32,768 bytes.
174    ///
175    /// A server accepts by echoing the parameter with an equal or smaller value;
176    /// otherwise, it declines. Notably, a server may suggest a window size
177    /// even if the client didn't initially propose one.
178    pub server_max_window_bits: Option<u8>,
179
180    /// Adjusts Client Window Size
181    ///
182    /// This extension parameter allows a client to propose,
183    /// optionally with a value between 8 and 15 (base-2 logarithm),
184    /// the maximum LZ77 sliding window size it will use for compression.
185    ///
186    /// This signals to the server that the client supports this parameter in responses and,
187    /// if a value is provided, hints that the client won't exceed that window size
188    /// for its own compression, regardless of the server's response.
189    ///
190    /// A server can then include client_max_window_bits in its response
191    /// with an equal or smaller value, thereby limiting the client's window size
192    /// and reducing its own memory overhead for decompression.
193    ///
194    /// If the server's response omits this parameter,
195    /// it signifies its ability to decompress messages compressed with a client window
196    /// up to 32,768 bytes.
197    ///
198    /// Servers must not include this parameter in their response
199    /// if the client's initial offer didn't contain it.
200    pub client_max_window_bits: Option<u8>,
201}
202
203#[cfg(feature = "compression")]
204impl From<&sec_websocket_extensions::PerMessageDeflateConfig> for PerMessageDeflateConfig {
205    fn from(value: &sec_websocket_extensions::PerMessageDeflateConfig) -> Self {
206        Self {
207            server_no_context_takeover: value.server_no_context_takeover,
208            client_no_context_takeover: value.client_no_context_takeover,
209            server_max_window_bits: value.server_max_window_bits,
210            client_max_window_bits: value.client_max_window_bits,
211        }
212    }
213}
214
215#[cfg(feature = "compression")]
216impl From<sec_websocket_extensions::PerMessageDeflateConfig> for PerMessageDeflateConfig {
217    #[inline]
218    fn from(value: sec_websocket_extensions::PerMessageDeflateConfig) -> Self {
219        Self::from(&value)
220    }
221}
222
223#[cfg(feature = "compression")]
224impl From<&PerMessageDeflateConfig> for sec_websocket_extensions::PerMessageDeflateConfig {
225    fn from(value: &PerMessageDeflateConfig) -> Self {
226        Self {
227            server_no_context_takeover: value.server_no_context_takeover,
228            client_no_context_takeover: value.client_no_context_takeover,
229            server_max_window_bits: value.server_max_window_bits,
230            client_max_window_bits: value.client_max_window_bits,
231            ..Default::default()
232        }
233    }
234}
235
236#[cfg(feature = "compression")]
237impl From<PerMessageDeflateConfig> for sec_websocket_extensions::PerMessageDeflateConfig {
238    #[inline]
239    fn from(value: PerMessageDeflateConfig) -> Self {
240        Self::from(&value)
241    }
242}
243
244#[cfg(feature = "compression")]
245#[allow(clippy::derivable_impls)]
246impl Default for PerMessageDeflateConfig {
247    fn default() -> Self {
248        Self {
249            // By default, allow context takeover in both directions
250            server_no_context_takeover: false,
251            client_no_context_takeover: false,
252
253            // No limit: means default 15-bit window (32768 bytes)
254            server_max_window_bits: None,
255            client_max_window_bits: None,
256        }
257    }
258}
259
260impl Default for WebSocketConfig {
261    fn default() -> Self {
262        Self {
263            read_buffer_size: 128 * 1024,
264            write_buffer_size: 128 * 1024,
265            max_write_buffer_size: usize::MAX,
266            max_message_size: Some(64 << 20),
267            max_frame_size: Some(16 << 20),
268            accept_unmasked_frames: false,
269            #[cfg(feature = "compression")]
270            per_message_deflate: None,
271        }
272    }
273}
274
275impl WebSocketConfig {
276    rama_utils::macros::generate_set_and_with! {
277        /// Set [`Self::read_buffer_size`].
278        #[must_use]
279        pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
280            self.read_buffer_size = read_buffer_size;
281            self
282        }
283    }
284
285    rama_utils::macros::generate_set_and_with! {
286        /// Set [`Self::write_buffer_size`].
287        #[must_use]
288        pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
289            self.write_buffer_size = write_buffer_size;
290            self
291        }
292    }
293
294    rama_utils::macros::generate_set_and_with! {
295        /// Set [`Self::max_write_buffer_size`].
296        #[must_use]
297        pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
298            self.max_write_buffer_size = max_write_buffer_size;
299            self
300        }
301    }
302
303    rama_utils::macros::generate_set_and_with! {
304        /// Set [`Self::max_message_size`].
305        #[must_use]
306        pub fn max_message_size(mut self, max_message_size: Option<usize>) -> Self {
307            self.max_message_size = max_message_size;
308            self
309        }
310    }
311
312    rama_utils::macros::generate_set_and_with! {
313        /// Set [`Self::max_frame_size`].
314        #[must_use]
315        pub fn max_frame_size(mut self, max_frame_size: Option<usize>) -> Self {
316            self.max_frame_size = max_frame_size;
317            self
318        }
319    }
320
321    rama_utils::macros::generate_set_and_with! {
322        /// Set [`Self::accept_unmasked_frames`].
323        #[must_use]
324        pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
325            self.accept_unmasked_frames = accept_unmasked_frames;
326            self
327        }
328    }
329
330    #[cfg(feature = "compression")]
331    rama_utils::macros::generate_set_and_with! {
332        /// Set [`Self::per_message_deflate`] with the default config..
333        #[must_use]
334        #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
335        pub fn per_message_deflate_default(mut self) -> Self {
336            self.per_message_deflate = Some(Default::default());
337            self
338        }
339    }
340
341    #[cfg(feature = "compression")]
342    rama_utils::macros::generate_set_and_with! {
343        /// Set [`Self::per_message_deflate`].
344        #[must_use]
345        #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
346        pub fn per_message_deflate(mut self, per_message_deflate: Option<PerMessageDeflateConfig>) -> Self {
347            self.per_message_deflate = per_message_deflate;
348            self
349        }
350    }
351
352    /// Panic if values are invalid.
353    pub(crate) fn assert_valid(&self) {
354        assert!(
355            self.max_write_buffer_size > self.write_buffer_size,
356            "WebSocketConfig::max_write_buffer_size must be greater than write_buffer_size, \
357            see WebSocketConfig docs`"
358        );
359    }
360}
361
362/// WebSocket input-output stream.
363///
364/// This is THE structure you want to create to be able to speak the WebSocket protocol.
365/// It may be created by calling `connect`, `accept` or `client` functions.
366///
367/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
368pub struct WebSocket<Stream> {
369    /// The underlying socket.
370    socket: Stream,
371    /// The context for managing a WebSocket.
372    context: WebSocketContext,
373}
374
375impl<Stream: fmt::Debug> fmt::Debug for WebSocket<Stream> {
376    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377        f.debug_struct("WebSocket")
378            .field("socket", &self.socket)
379            .field("context", &self.context)
380            .finish()
381    }
382}
383
384impl<Stream> WebSocket<Stream> {
385    /// Convert a raw socket into a WebSocket without performing a handshake.
386    ///
387    /// # Panics
388    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
389    pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
390        Self {
391            socket: stream,
392            context: WebSocketContext::new(role, config),
393        }
394    }
395
396    /// Convert a raw socket into a WebSocket without performing a handshake.
397    ///
398    /// # Panics
399    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
400    pub fn from_partially_read(
401        stream: Stream,
402        part: Vec<u8>,
403        role: Role,
404        config: Option<WebSocketConfig>,
405    ) -> Self {
406        Self {
407            socket: stream,
408            context: WebSocketContext::from_partially_read(part, role, config),
409        }
410    }
411
412    /// Consumes the `WebSocket` and returns the underlying stream.
413    pub(crate) fn into_inner(self) -> Stream {
414        self.socket
415    }
416
417    /// Returns a shared reference to the inner stream.
418    pub fn get_ref(&self) -> &Stream {
419        &self.socket
420    }
421    /// Returns a mutable reference to the inner stream.
422    pub fn get_mut(&mut self) -> &mut Stream {
423        &mut self.socket
424    }
425
426    /// Change the configuration.
427    ///
428    /// # Panics
429    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
430    pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
431        self.context.set_config(set_func);
432    }
433
434    /// Read the configuration.
435    pub fn get_config(&self) -> &WebSocketConfig {
436        self.context.get_config()
437    }
438
439    /// Check if it is possible to read messages.
440    ///
441    /// Reading is impossible after receiving `Message::Close`. It is still possible after
442    /// sending close frame since the peer still may send some data before confirming close.
443    pub fn can_read(&self) -> bool {
444        self.context.can_read()
445    }
446
447    /// Check if it is possible to write messages.
448    ///
449    /// Writing gets impossible immediately after sending or receiving `Message::Close`.
450    pub fn can_write(&self) -> bool {
451        self.context.can_write()
452    }
453}
454
455impl<Stream: Read + Write> WebSocket<Stream> {
456    /// Read a message from stream, if possible.
457    ///
458    /// This will also queue responses to ping and close messages. These responses
459    /// will be written and flushed on the next call to [`read`](Self::read),
460    /// [`write`](Self::write) or [`flush`](Self::flush).
461    ///
462    /// # Closing the connection
463    /// When the remote endpoint decides to close the connection this will return
464    /// the close message with an optional close frame.
465    ///
466    /// You should continue calling [`read`](Self::read), [`write`](Self::write) or
467    /// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`]
468    /// is returned. Once that happens it is safe to drop the underlying connection.
469    pub fn read(&mut self) -> Result<Message, ProtocolError> {
470        self.context.read(&mut self.socket)
471    }
472
473    /// Writes and immediately flushes a message.
474    /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
475    pub fn send(&mut self, message: Message) -> Result<(), ProtocolError> {
476        self.write(message)?;
477        self.flush()
478    }
479
480    /// Write a message to the provided stream, if possible.
481    ///
482    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
483    ///
484    /// In the event of stream write failure the message frame will be stored
485    /// in the write buffer and will try again on the next call to [`write`](Self::write)
486    /// or [`flush`](Self::flush).
487    ///
488    /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
489    /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
490    ///
491    /// This call will generally not flush. However, if there are queued automatic messages
492    /// they will be written and eagerly flushed.
493    ///
494    /// For example, upon receiving ping messages this crate queues pong replies automatically.
495    /// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
496    /// will write & flush the pong reply. This means you should not respond to ping frames manually.
497    ///
498    /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
499    /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
500    /// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing
501    /// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the
502    /// ping will not be sent as it will be replaced by your custom pong message.
503    ///
504    /// # Errors
505    /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned
506    ///   along with the equivalent passed message frame.
507    /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`].
508    /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from
509    ///   [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program
510    ///   error on your part.
511    /// - [`Error::Io`] is returned if the underlying connection returns an error
512    ///   (consider these fatal except for WouldBlock).
513    /// - [`Error::Capacity`] if your message size is bigger than the configured max message size.
514    pub fn write(&mut self, message: Message) -> Result<(), ProtocolError> {
515        self.context.write(&mut self.socket, message)
516    }
517
518    /// Flush writes.
519    ///
520    /// Ensures all messages previously passed to [`write`](Self::write) and automatic
521    /// queued pong responses are written & flushed into the underlying stream.
522    pub fn flush(&mut self) -> Result<(), ProtocolError> {
523        self.context.flush(&mut self.socket)
524    }
525
526    /// Close the connection.
527    ///
528    /// This function guarantees that the close frame will be queued.
529    /// There is no need to call it again. Calling this function is
530    /// the same as calling `write(Message::Close(..))`.
531    ///
532    /// After queuing the close frame you should continue calling [`read`](Self::read) or
533    /// [`flush`](Self::flush) to drive the close handshake to completion.
534    ///
535    /// The websocket RFC defines that the underlying connection should be closed
536    /// by the server. This crate takes care of this asymmetry for you.
537    ///
538    /// When the close handshake is finished (we have both sent and received
539    /// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return
540    /// [Error::ConnectionClosed] if this endpoint is the server.
541    ///
542    /// If this endpoint is a client, [Error::ConnectionClosed] will only be
543    /// returned after the server has closed the underlying connection.
544    ///
545    /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
546    /// is returned from [`read`](Self::read) or [`flush`](Self::flush).
547    pub fn close(&mut self, code: Option<CloseFrame>) -> Result<(), ProtocolError> {
548        self.context.close(&mut self.socket, code)
549    }
550}
551
552impl<Stream: ExtensionsRef> ExtensionsRef for WebSocket<Stream> {
553    fn extensions(&self) -> &Extensions {
554        self.socket.extensions()
555    }
556}
557
558impl<Stream: ExtensionsMut> ExtensionsMut for WebSocket<Stream> {
559    fn extensions_mut(&mut self) -> &mut Extensions {
560        self.socket.extensions_mut()
561    }
562}
563
564/// A context for managing WebSocket stream.
565#[derive(Debug)]
566pub struct WebSocketContext {
567    /// Server or client?
568    role: Role,
569    /// encoder/decoder of frame.
570    frame: FrameCodec,
571    /// The state of processing, either "active" or "closing".
572    state: WebSocketState,
573    #[cfg(feature = "compression")]
574    /// The state used in function per-message compression,
575    /// only set in case the extension is enabled.
576    per_message_deflate_state: Option<PerMessageDeflateState>,
577    /// Receive: an incomplete message being processed.
578    incomplete: Option<IncompleteMessage>,
579    /// Send in addition to regular messages E.g. "pong" or "close".
580    additional_send: Option<Frame>,
581    /// True indicates there is an additional message (like a pong)
582    /// that failed to flush previously and we should try again.
583    unflushed_additional: bool,
584    /// The configuration for the websocket session.
585    config: WebSocketConfig,
586}
587
588impl WebSocketContext {
589    /// Create a WebSocket context that manages a post-handshake stream.
590    ///
591    /// # Panics
592    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
593    #[must_use]
594    pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
595        let conf = config.unwrap_or_default();
596        Self::_new(role, FrameCodec::new(conf.read_buffer_size), conf)
597    }
598
599    /// Create a WebSocket context that manages an post-handshake stream.
600    ///
601    /// # Panics
602    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
603    #[must_use]
604    pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
605        let conf = config.unwrap_or_default();
606        Self::_new(
607            role,
608            FrameCodec::from_partially_read(part, conf.read_buffer_size),
609            conf,
610        )
611    }
612
613    fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
614        config.assert_valid();
615        frame.set_max_out_buffer_len(config.max_write_buffer_size);
616        frame.set_out_buffer_write_len(config.write_buffer_size);
617        Self {
618            role,
619            frame,
620            state: WebSocketState::Active,
621            #[cfg(feature = "compression")]
622            per_message_deflate_state: config
623                .per_message_deflate
624                .map(|cfg| PerMessageDeflateState::new(role, cfg)),
625            incomplete: None,
626            additional_send: None,
627            unflushed_additional: false,
628            config,
629        }
630    }
631
632    /// Change the configuration.
633    ///
634    /// # Panics
635    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
636    pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
637        set_func(&mut self.config);
638        self.config.assert_valid();
639        self.frame
640            .set_max_out_buffer_len(self.config.max_write_buffer_size);
641        self.frame
642            .set_out_buffer_write_len(self.config.write_buffer_size);
643    }
644
645    /// Read the configuration.
646    pub fn get_config(&self) -> &WebSocketConfig {
647        &self.config
648    }
649
650    /// Check if it is possible to read messages.
651    ///
652    /// Reading is impossible after receiving `Message::Close`. It is still possible after
653    /// sending close frame since the peer still may send some data before confirming close.
654    pub fn can_read(&self) -> bool {
655        self.state.can_read()
656    }
657
658    /// Check if it is possible to write messages.
659    ///
660    /// Writing gets impossible immediately after sending or receiving `Message::Close`.
661    pub fn can_write(&self) -> bool {
662        self.state.is_active()
663    }
664
665    /// Read a message from the provided stream, if possible.
666    ///
667    /// This function sends pong and close responses automatically.
668    /// However, it never blocks on write.
669    pub fn read<Stream>(&mut self, stream: &mut Stream) -> Result<Message, ProtocolError>
670    where
671        Stream: Read + Write,
672    {
673        // Do not read from already closed connections.
674        self.state.check_not_terminated()?;
675
676        loop {
677            if self.additional_send.is_some() || self.unflushed_additional {
678                // Since we may get ping or close, we need to reply to the messages even during read.
679                match self.flush(stream) {
680                    Ok(_) => {}
681                    Err(ProtocolError::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
682                        // If blocked continue reading, but try again later
683                        self.unflushed_additional = true;
684                    }
685                    Err(err) => return Err(err),
686                }
687            } else if self.role == Role::Server && !self.state.can_read() {
688                self.state = WebSocketState::Terminated;
689                return Err(ProtocolError::Io(io::Error::new(
690                    io::ErrorKind::ConnectionAborted,
691                    OpaqueError::from_display("Connection closed normally by me-the-server"),
692                )));
693            }
694
695            // If we get here, either write blocks or we have nothing to write.
696            // Thus if read blocks, just let it return WouldBlock.
697            if let Some(message) = self.read_message_frame(stream)? {
698                trace!("Received message {message}");
699                return Ok(message);
700            }
701        }
702    }
703
704    /// Write a message to the provided stream.
705    ///
706    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
707    ///
708    /// In the event of stream write failure the message frame will be stored
709    /// in the write buffer and will try again on the next call to [`write`](Self::write)
710    /// or [`flush`](Self::flush).
711    ///
712    /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
713    /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
714    pub fn write<Stream>(
715        &mut self,
716        stream: &mut Stream,
717        message: Message,
718    ) -> Result<(), ProtocolError>
719    where
720        Stream: Read + Write,
721    {
722        // When terminated, return AlreadyClosed.
723        self.state.check_not_terminated()?;
724
725        // Do not write after sending a close frame.
726        if !self.state.is_active() {
727            return Err(ProtocolError::SendAfterClosing);
728        }
729
730        let frame = match message {
731            Message::Text(data) => {
732                #[cfg(feature = "compression")]
733                match self.per_message_deflate_state.as_mut() {
734                    Some(deflate_state) => {
735                        let data = match deflate_state.encoder.encode(data.as_bytes()) {
736                            Ok(data) => data,
737                            Err(err) => return Err(ProtocolError::DeflateError(err)),
738                        };
739                        let mut msg = Frame::message(data, OpCode::Data(OpCodeData::Text), true);
740                        msg.header_mut().rsv1 = true;
741                        msg
742                    }
743                    None => Frame::message(data, OpCode::Data(OpCodeData::Text), true),
744                }
745                #[cfg(not(feature = "compression"))]
746                Frame::message(data, OpCode::Data(OpCodeData::Text), true)
747            }
748            Message::Binary(data) => {
749                #[cfg(feature = "compression")]
750                match self.per_message_deflate_state.as_mut() {
751                    Some(deflate_state) => {
752                        let data = match deflate_state.encoder.encode(data.as_ref()) {
753                            Ok(data) => data,
754                            Err(err) => return Err(ProtocolError::DeflateError(err)),
755                        };
756                        let mut msg = Frame::message(data, OpCode::Data(OpCodeData::Binary), true);
757                        msg.header_mut().rsv1 = true;
758                        msg
759                    }
760                    None => Frame::message(data, OpCode::Data(OpCodeData::Binary), true),
761                }
762                #[cfg(not(feature = "compression"))]
763                Frame::message(data, OpCode::Data(OpCodeData::Binary), true)
764            }
765            Message::Ping(data) => Frame::ping(data),
766            Message::Pong(data) => {
767                self.set_additional(Frame::pong(data));
768                // Note: user pongs can be user flushed so no need to flush here
769                return self._write(stream, None).map(|_| ());
770            }
771            Message::Close(code) => return self.close(stream, code),
772            Message::Frame(f) => f,
773        };
774
775        let should_flush = self._write(stream, Some(frame))?;
776        if should_flush {
777            self.flush(stream)?;
778        }
779        Ok(())
780    }
781
782    /// Flush writes.
783    ///
784    /// Ensures all messages previously passed to [`write`](Self::write) and automatically
785    /// queued pong responses are written & flushed into the `stream`.
786    #[inline]
787    pub fn flush<Stream>(&mut self, stream: &mut Stream) -> Result<(), ProtocolError>
788    where
789        Stream: Read + Write,
790    {
791        self._write(stream, None)?;
792        self.frame.write_out_buffer(stream)?;
793        stream.flush()?;
794        self.unflushed_additional = false;
795        Ok(())
796    }
797
798    /// Writes any data in the out_buffer, `additional_send` and given `data`.
799    ///
800    /// Does **not** flush.
801    ///
802    /// Returns true if the write contents indicate we should flush immediately.
803    fn _write<Stream>(
804        &mut self,
805        stream: &mut Stream,
806        data: Option<Frame>,
807    ) -> Result<bool, ProtocolError>
808    where
809        Stream: Read + Write,
810    {
811        if let Some(data) = data {
812            self.buffer_frame(stream, data)?;
813        }
814
815        // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
816        // response, unless it already received a Close frame. It SHOULD
817        // respond with Pong frame as soon as is practical. (RFC 6455)
818        let should_flush = if let Some(msg) = self.additional_send.take() {
819            trace!("Sending pong/close");
820            match self.buffer_frame(stream, msg) {
821                Err(ProtocolError::WriteBufferFull(msg)) => {
822                    // if an system message would exceed the buffer put it back in
823                    // `additional_send` for retry. Otherwise returning this error
824                    // may not make sense to the user, e.g. calling `flush`.
825                    if let Message::Frame(msg) = msg {
826                        self.set_additional(msg);
827                        false
828                    } else {
829                        unreachable!();
830                    }
831                }
832                Err(err) => return Err(err),
833                Ok(_) => true,
834            }
835        } else {
836            self.unflushed_additional
837        };
838
839        // If we're closing and there is nothing to send anymore, we should close the connection.
840        if self.role == Role::Server && !self.state.can_read() {
841            // The underlying TCP connection, in most normal cases, SHOULD be closed
842            // first by the server, so that it holds the TIME_WAIT state and not the
843            // client (as this would prevent it from re-opening the connection for 2
844            // maximum segment lifetimes (2MSL), while there is no corresponding
845            // server impact as a TIME_WAIT connection is immediately reopened upon
846            // a new SYN with a higher seq number). (RFC 6455)
847            self.frame.write_out_buffer(stream)?;
848            self.state = WebSocketState::Terminated;
849            Err(ProtocolError::Io(io::Error::new(
850                io::ErrorKind::ConnectionAborted,
851                OpaqueError::from_display("Connection closed normally by me-the-server (EOF)"),
852            )))
853        } else {
854            Ok(should_flush)
855        }
856    }
857
858    /// Close the connection.
859    ///
860    /// This function guarantees that the close frame will be queued.
861    /// There is no need to call it again. Calling this function is
862    /// the same as calling `send(Message::Close(..))`.
863    pub fn close<Stream>(
864        &mut self,
865        stream: &mut Stream,
866        code: Option<CloseFrame>,
867    ) -> Result<(), ProtocolError>
868    where
869        Stream: Read + Write,
870    {
871        if self.state == WebSocketState::Active {
872            self.state = WebSocketState::ClosedByUs;
873            let frame = Frame::close(code);
874            self._write(stream, Some(frame))?;
875        }
876        self.flush(stream)
877    }
878
879    /// Try to decode one message frame. May return None.
880    fn read_message_frame(
881        &mut self,
882        stream: &mut impl Read,
883    ) -> Result<Option<Message>, ProtocolError> {
884        let Some(frame) = self.frame.read_frame(
885            stream,
886            self.config.max_frame_size,
887            matches!(self.role, Role::Server),
888            self.config.accept_unmasked_frames,
889        )?
890        else {
891            // Connection closed by peer
892            return match std::mem::replace(&mut self.state, WebSocketState::Terminated) {
893                WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
894                    Err(ProtocolError::Io(io::Error::new(
895                        io::ErrorKind::ConnectionAborted,
896                        OpaqueError::from_display("Connection closed normally by peer"),
897                    )))
898                }
899                WebSocketState::Active
900                | WebSocketState::ClosedByUs
901                | WebSocketState::Terminated => Err(ProtocolError::ResetWithoutClosingHandshake),
902            };
903        };
904
905        if !self.state.can_read() {
906            return Err(ProtocolError::ReceivedAfterClosing);
907        }
908
909        #[cfg(feature = "compression")]
910        // to ensure that this is valid in later branches,
911        // as this is not always true despite an extension active that supports it
912        let mut rsv1_set = false;
913
914        // MUST be 0 unless an extension is negotiated that defines meanings
915        // for non-zero values.  If a nonzero value is received and none of
916        // the negotiated extensions defines the meaning of such a nonzero
917        // value, the receiving endpoint MUST _Fail the WebSocket
918        // Connection_.
919        {
920            let hdr = frame.header();
921            if hdr.rsv1 {
922                #[cfg(feature = "compression")]
923                {
924                    rsv1_set = true;
925                    if self.per_message_deflate_state.is_none() {
926                        tracing::debug!(
927                            "rsv1 bit is set but PMD state is none: no use case for it"
928                        );
929                        return Err(ProtocolError::NonZeroReservedBits);
930                    }
931                }
932                #[cfg(not(feature = "compression"))]
933                {
934                    tracing::debug!("rsv1 bit is set but compression feature no enabled");
935                    return Err(ProtocolError::NonZeroReservedBits);
936                }
937            } else if hdr.rsv2 || hdr.rsv3 {
938                tracing::debug!("rsv2 or rsv3 bit set: not expected ever");
939                return Err(ProtocolError::NonZeroReservedBits);
940            }
941        }
942
943        if self.role == Role::Client && frame.is_masked() {
944            // A client MUST close a connection if it detects a masked frame. (RFC 6455)
945            return Err(ProtocolError::MaskedFrameFromServer);
946        }
947
948        match frame.header().opcode {
949            OpCode::Control(ctl) => {
950                #[cfg(feature = "compression")]
951                if rsv1_set {
952                    tracing::debug!("rsv1 bit set in control frame: not expected");
953                    return Err(ProtocolError::NonZeroReservedBits);
954                }
955
956                match ctl {
957                    // All control frames MUST have a payload length of 125 bytes or less
958                    // and MUST NOT be fragmented. (RFC 6455)
959                    _ if !frame.header().is_final => Err(ProtocolError::FragmentedControlFrame),
960                    _ if frame.payload().len() > 125 => Err(ProtocolError::ControlFrameTooBig),
961                    OpCodeControl::Close => {
962                        Ok(self.do_close(frame.into_close()?).map(Message::Close))
963                    }
964                    OpCodeControl::Reserved(i) => Err(ProtocolError::UnknownControlFrameType(i)),
965                    OpCodeControl::Ping => {
966                        let data = frame.into_payload();
967                        // No ping processing after we sent a close frame.
968                        if self.state.is_active() {
969                            self.set_additional(Frame::pong(data.clone()));
970                        }
971                        Ok(Some(Message::Ping(data)))
972                    }
973                    OpCodeControl::Pong => Ok(Some(Message::Pong(frame.into_payload()))),
974                }
975            }
976
977            OpCode::Data(data) => {
978                let fin = frame.header().is_final;
979
980                #[cfg(feature = "compression")]
981                if matches!(data, OpCodeData::Continue) && rsv1_set {
982                    tracing::debug!("rsv1 bit set in CONTINUE frame: not expected");
983                    return Err(ProtocolError::NonZeroReservedBits);
984                }
985
986                let payload = match (data, self.incomplete.as_mut()) {
987                    (OpCodeData::Continue, None) => {
988                        #[cfg(feature = "compression")]
989                        if let Some(deflate_state) = self.per_message_deflate_state.as_mut() {
990                            if fin {
991                                let (compressed_data, msg_type) =
992                                    deflate_state.decompress_incomplete_msg.fin_buffer(
993                                        frame.into_payload(),
994                                        self.config.max_message_size,
995                                    )?;
996                                return match deflate_state.decoder.decode(compressed_data.as_ref())
997                                {
998                                    Ok(raw_data) => match msg_type {
999                                        IncompleteMessageType::Text => {
1000                                            Ok(Some(Message::Text(Utf8Bytes::try_from(raw_data)?)))
1001                                        }
1002                                        IncompleteMessageType::Binary => {
1003                                            Ok(Some(Message::Binary(raw_data.into())))
1004                                        }
1005                                    },
1006                                    Err(err) => Err(ProtocolError::DeflateError(err)),
1007                                };
1008                            }
1009
1010                            deflate_state
1011                                .decompress_incomplete_msg
1012                                .extend(frame.into_payload(), self.config.max_message_size)?;
1013                            Ok(None)
1014                        } else {
1015                            return Err(ProtocolError::UnexpectedContinueFrame);
1016                        }
1017
1018                        #[cfg(not(feature = "compression"))]
1019                        return Err(ProtocolError::UnexpectedContinueFrame);
1020                    }
1021                    (OpCodeData::Continue, Some(incomplete)) => {
1022                        incomplete.extend(frame.into_payload(), self.config.max_message_size)?;
1023                        Ok(None)
1024                    }
1025                    (_, Some(_)) => Err(ProtocolError::ExpectedFragment(data)),
1026                    (OpCodeData::Text, _) => {
1027                        Ok(Some((frame.into_payload(), IncompleteMessageType::Text)))
1028                    }
1029                    (OpCodeData::Binary, _) => {
1030                        Ok(Some((frame.into_payload(), IncompleteMessageType::Binary)))
1031                    }
1032                    (OpCodeData::Reserved(i), _) => Err(ProtocolError::UnknownDataFrameType(i)),
1033                }?;
1034
1035                match (payload, fin) {
1036                    (None, true) =>
1037                    {
1038                        #[allow(
1039                            clippy::expect_used,
1040                            reason = "we can only reach here if incomplete is Some"
1041                        )]
1042                        Ok(Some(
1043                            self.incomplete
1044                                .take()
1045                                .expect("incomplete to be there")
1046                                .complete()?,
1047                        ))
1048                    }
1049                    (None, false) => Ok(None),
1050                    (Some((payload, t)), true) => {
1051                        check_max_size(payload.len(), self.config.max_message_size)?;
1052
1053                        #[cfg(feature = "compression")]
1054                        if rsv1_set {
1055                            if let Some(deflate_state) = self.per_message_deflate_state.as_mut() {
1056                                let compressed_data = payload;
1057                                let raw_data = deflate_state
1058                                    .decoder
1059                                    .decode(&compressed_data)
1060                                    .map_err(ProtocolError::DeflateError)?;
1061                                match t {
1062                                    IncompleteMessageType::Text => {
1063                                        Ok(Some(Message::Text(Utf8Bytes::try_from(raw_data)?)))
1064                                    }
1065                                    IncompleteMessageType::Binary => {
1066                                        Ok(Some(Message::Binary(raw_data.into())))
1067                                    }
1068                                }
1069                            } else {
1070                                tracing::debug!(
1071                                    "rsv1 bit set in text frame but deflate state is none"
1072                                );
1073                                Err(ProtocolError::NonZeroReservedBits)
1074                            }
1075                        } else {
1076                            match t {
1077                                IncompleteMessageType::Text => {
1078                                    Ok(Some(Message::Text(payload.try_into()?)))
1079                                }
1080                                IncompleteMessageType::Binary => Ok(Some(Message::Binary(payload))),
1081                            }
1082                        }
1083
1084                        #[cfg(not(feature = "compression"))]
1085                        match t {
1086                            IncompleteMessageType::Text => {
1087                                Ok(Some(Message::Text(payload.try_into()?)))
1088                            }
1089                            IncompleteMessageType::Binary => Ok(Some(Message::Binary(payload))),
1090                        }
1091                    }
1092                    (Some((payload, t)), false) => {
1093                        #[cfg(feature = "compression")]
1094                        if rsv1_set {
1095                            if let Some(deflate_state) = self.per_message_deflate_state.as_mut() {
1096                                deflate_state.decompress_incomplete_msg.reset(t);
1097                                deflate_state
1098                                    .decompress_incomplete_msg
1099                                    .extend(payload, self.config.max_message_size)?;
1100                                Ok(None)
1101                            } else {
1102                                tracing::debug!(
1103                                    "rsv1 bit set in non-fin bin/text frame but deflate state is none"
1104                                );
1105                                Err(ProtocolError::NonZeroReservedBits)
1106                            }
1107                        } else {
1108                            let mut incomplete = IncompleteMessage::new(t);
1109                            incomplete.extend(payload, self.config.max_message_size)?;
1110                            self.incomplete = Some(incomplete);
1111                            Ok(None)
1112                        }
1113                        #[cfg(not(feature = "compression"))]
1114                        {
1115                            let mut incomplete = IncompleteMessage::new(t);
1116                            incomplete.extend(payload, self.config.max_message_size)?;
1117                            self.incomplete = Some(incomplete);
1118                            Ok(None)
1119                        }
1120                    }
1121                }
1122            }
1123        } // match opcode
1124    }
1125
1126    /// Received a close frame. Tells if we need to return a close frame to the user.
1127    #[allow(clippy::option_option)]
1128    fn do_close(&mut self, close: Option<CloseFrame>) -> Option<Option<CloseFrame>> {
1129        rama_core::telemetry::tracing::trace!("Received close frame: {close:?}");
1130        match self.state {
1131            WebSocketState::Active => {
1132                self.state = WebSocketState::ClosedByPeer;
1133
1134                let close = close.map(|frame| {
1135                    if !frame.code.is_allowed() {
1136                        CloseFrame {
1137                            code: CloseCode::Protocol,
1138                            reason: Utf8Bytes::from_static("Protocol violation"),
1139                        }
1140                    } else {
1141                        frame
1142                    }
1143                });
1144
1145                let reply = Frame::close(close.clone());
1146                debug!("Replying to close with {reply:?}");
1147                self.set_additional(reply);
1148
1149                Some(close)
1150            }
1151            WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
1152                // It is already closed, just ignore.
1153                None
1154            }
1155            WebSocketState::ClosedByUs => {
1156                // We received a reply.
1157                self.state = WebSocketState::CloseAcknowledged;
1158                Some(close)
1159            }
1160            WebSocketState::Terminated => unreachable!(),
1161        }
1162    }
1163
1164    /// Write a single frame into the write-buffer.
1165    fn buffer_frame<Stream>(
1166        &mut self,
1167        stream: &mut Stream,
1168        mut frame: Frame,
1169    ) -> Result<(), ProtocolError>
1170    where
1171        Stream: Read + Write,
1172    {
1173        match self.role {
1174            Role::Server => {}
1175            Role::Client => {
1176                // 5.  If the data is being sent by the client, the frame(s) MUST be
1177                // masked as defined in Section 5.3. (RFC 6455)
1178                frame.set_random_mask();
1179            }
1180        }
1181
1182        trace!("Sending frame: {frame:?}");
1183        self.frame.buffer_frame(stream, frame)
1184    }
1185
1186    /// Replace `additional_send` if it is currently a `Pong` message.
1187    fn set_additional(&mut self, add: Frame) {
1188        let empty_or_pong = self
1189            .additional_send
1190            .as_ref()
1191            .is_none_or(|f| f.header().opcode == OpCode::Control(OpCodeControl::Pong));
1192        if empty_or_pong {
1193            self.additional_send.replace(add);
1194        }
1195    }
1196}
1197
1198fn check_max_size(size: usize, max_size: Option<usize>) -> Result<(), ProtocolError> {
1199    if let Some(max_size) = max_size
1200        && size > max_size
1201    {
1202        return Err(ProtocolError::MessageTooLong { size, max_size });
1203    }
1204    Ok(())
1205}
1206
1207/// The current connection state.
1208#[derive(Debug, PartialEq, Eq, Clone, Copy)]
1209enum WebSocketState {
1210    /// The connection is active.
1211    Active,
1212    /// We initiated a close handshake.
1213    ClosedByUs,
1214    /// The peer initiated a close handshake.
1215    ClosedByPeer,
1216    /// The peer replied to our close handshake.
1217    CloseAcknowledged,
1218    /// The connection does not exist anymore.
1219    Terminated,
1220}
1221
1222impl WebSocketState {
1223    /// Tell if we're allowed to process normal messages.
1224    fn is_active(self) -> bool {
1225        matches!(self, Self::Active)
1226    }
1227
1228    /// Tell if we should process incoming data. Note that if we send a close frame
1229    /// but the remote hasn't confirmed, they might have sent data before they receive our
1230    /// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
1231    fn can_read(self) -> bool {
1232        matches!(self, Self::Active | Self::ClosedByUs)
1233    }
1234
1235    /// Check if the state is active, return error if not.
1236    fn check_not_terminated(self) -> Result<(), ProtocolError> {
1237        match self {
1238            Self::Terminated => Err(ProtocolError::Io(io::Error::new(
1239                io::ErrorKind::NotConnected,
1240                OpaqueError::from_display("Trying to work with closed connection"),
1241            ))),
1242            Self::Active | Self::CloseAcknowledged | Self::ClosedByPeer | Self::ClosedByUs => {
1243                Ok(())
1244            }
1245        }
1246    }
1247}