rama_ws/protocol/mod.rs
1//! Generic WebSocket message stream.
2
3use rama_core::extensions::{Extensions, ExtensionsMut, ExtensionsRef};
4use rama_core::telemetry::tracing;
5use rama_core::{
6 error::OpaqueError,
7 telemetry::tracing::{debug, trace},
8};
9use std::{
10 fmt,
11 io::{self, Read, Write},
12};
13
14#[cfg(feature = "compression")]
15use rama_http::headers::sec_websocket_extensions;
16
17pub mod frame;
18
19mod error;
20mod message;
21
22#[cfg(feature = "compression")]
23mod per_message_deflate;
24
25pub use error::ProtocolError;
26
27#[cfg(test)]
28mod tests;
29
30use crate::protocol::{
31 frame::{
32 Frame, FrameCodec, Utf8Bytes,
33 coding::{CloseCode, OpCode, OpCodeControl, OpCodeData},
34 },
35 message::{IncompleteMessage, IncompleteMessageType},
36};
37
38#[cfg(feature = "compression")]
39use self::per_message_deflate::PerMessageDeflateState;
40
41pub use self::{frame::CloseFrame, message::Message};
42
43/// Indicates a Client or Server role of the websocket
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum Role {
46 /// This socket is a server
47 Server,
48 /// This socket is a client
49 Client,
50}
51
52/// The configuration for WebSocket connection.
53///
54/// # Example
55/// ```
56/// # use rama_ws::protocol::WebSocketConfig;
57///
58/// let conf = WebSocketConfig::default()
59/// .with_read_buffer_size(256 * 1024)
60/// .with_write_buffer_size(256 * 1024);
61/// ```
62#[derive(Debug, Clone, Copy)]
63#[non_exhaustive]
64pub struct WebSocketConfig {
65 /// Read buffer capacity. This buffer is eagerly allocated and used for receiving
66 /// messages.
67 ///
68 /// For high read load scenarios a larger buffer, e.g. 128 KiB, improves performance.
69 ///
70 /// For scenarios where you expect a lot of connections and don't need high read load
71 /// performance a smaller buffer, e.g. 4 KiB, would be appropriate to lower total
72 /// memory usage.
73 ///
74 /// The default value is 128 KiB.
75 pub read_buffer_size: usize,
76
77 /// The target minimum size of the write buffer to reach before writing the data
78 /// to the underlying stream.
79 /// The default value is 128 KiB.
80 ///
81 /// If set to `0` each message will be eagerly written to the underlying stream.
82 /// It is often more optimal to allow them to buffer a little, hence the default value.
83 ///
84 /// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
85 pub write_buffer_size: usize,
86
87 /// The max size of the write buffer in bytes. Setting this can provide backpressure
88 /// in the case the write buffer is filling up due to write errors.
89 /// The default value is unlimited.
90 ///
91 /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
92 /// when writes to the underlying stream are failing. So the **write buffer can not
93 /// fill up if you are not observing write errors even if not flushing**.
94 ///
95 /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
96 /// and probably a little more depending on error handling strategy.
97 pub max_write_buffer_size: usize,
98
99 /// The maximum size of an incoming message. `None` means no size limit. The default value is 64 MiB
100 /// which should be reasonably big for all normal use-cases but small enough to prevent
101 /// memory eating by a malicious user.
102 pub max_message_size: Option<usize>,
103
104 /// The maximum size of a single incoming message frame. `None` means no size limit. The limit is for
105 /// frame payload NOT including the frame header. The default value is 16 MiB which should
106 /// be reasonably big for all normal use-cases but small enough to prevent memory eating
107 /// by a malicious user.
108 pub max_frame_size: Option<usize>,
109
110 /// When set to `true`, the server will accept and handle unmasked frames
111 /// from the client. According to the RFC 6455, the server must close the
112 /// connection to the client in such cases, however it seems like there are
113 /// some popular libraries that are sending unmasked frames, ignoring the RFC.
114 /// By default this option is set to `false`, i.e. according to RFC 6455.
115 pub accept_unmasked_frames: bool,
116
117 #[cfg(feature = "compression")]
118 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
119 /// Per-message-deflate configuration, specify it
120 /// to enable per-message (de)compression using the Deflate algorithm
121 /// as specified by [`RFC7692`].
122 ///
123 /// [`RFC7692`]: https://datatracker.ietf.org/doc/html/rfc7692
124 pub per_message_deflate: Option<PerMessageDeflateConfig>,
125}
126
127#[cfg(feature = "compression")]
128#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
129/// Per-message-deflate configuration as specified in [`RFC7692`]
130///
131/// [`RFC7692`]: https://datatracker.ietf.org/doc/html/rfc7692
132#[derive(Debug, Clone, Copy)]
133pub struct PerMessageDeflateConfig {
134 /// Prevents Server Context Takeover
135 ///
136 /// This extension parameter enables a client to request that
137 /// the server forgo context takeover, thereby eliminating
138 /// the client's need to retain memory for the LZ77 sliding window between messages.
139 ///
140 /// A client's omission of this parameter indicates its capability to decompress messages
141 /// even if the server utilizes context takeover.
142 ///
143 /// Servers should support this parameter and confirm acceptance by
144 /// including it in their response;
145 /// they may even include it if not explicitly requested by the client.
146 pub server_no_context_takeover: bool,
147
148 /// Manages Client Context Takeover
149 ///
150 /// This extension parameter allows a client to indicate to
151 /// the server its intent not to use context takeover,
152 /// even if the server doesn't explicitly respond with the same parameter.
153 ///
154 /// When a server receives this, it can either ignore it or include
155 /// `client_no_context_takeover` in its response,
156 /// which prevents the client from using context
157 /// takeover and helps the server conserve memory.
158 /// If the server's response omits this parameter,
159 /// it signals its ability to decompress messages where
160 /// the client does use context takeover.
161 ///
162 /// Clients are required to support this parameter in a server's response.
163 pub client_no_context_takeover: bool,
164
165 /// Limits Server Window Size
166 ///
167 /// This extension parameter allows a client to propose
168 /// a maximum LZ77 sliding window size for the server
169 /// to use when compressing messages, specified as a base-2 logarithm (8-15).
170 ///
171 /// This helps the client reduce its memory requirements.
172 /// If a client omits this parameter,
173 /// it signals its capacity to handle messages compressed with a window up to 32,768 bytes.
174 ///
175 /// A server accepts by echoing the parameter with an equal or smaller value;
176 /// otherwise, it declines. Notably, a server may suggest a window size
177 /// even if the client didn't initially propose one.
178 pub server_max_window_bits: Option<u8>,
179
180 /// Adjusts Client Window Size
181 ///
182 /// This extension parameter allows a client to propose,
183 /// optionally with a value between 8 and 15 (base-2 logarithm),
184 /// the maximum LZ77 sliding window size it will use for compression.
185 ///
186 /// This signals to the server that the client supports this parameter in responses and,
187 /// if a value is provided, hints that the client won't exceed that window size
188 /// for its own compression, regardless of the server's response.
189 ///
190 /// A server can then include client_max_window_bits in its response
191 /// with an equal or smaller value, thereby limiting the client's window size
192 /// and reducing its own memory overhead for decompression.
193 ///
194 /// If the server's response omits this parameter,
195 /// it signifies its ability to decompress messages compressed with a client window
196 /// up to 32,768 bytes.
197 ///
198 /// Servers must not include this parameter in their response
199 /// if the client's initial offer didn't contain it.
200 pub client_max_window_bits: Option<u8>,
201}
202
203#[cfg(feature = "compression")]
204impl From<&sec_websocket_extensions::PerMessageDeflateConfig> for PerMessageDeflateConfig {
205 fn from(value: &sec_websocket_extensions::PerMessageDeflateConfig) -> Self {
206 Self {
207 server_no_context_takeover: value.server_no_context_takeover,
208 client_no_context_takeover: value.client_no_context_takeover,
209 server_max_window_bits: value.server_max_window_bits,
210 client_max_window_bits: value.client_max_window_bits,
211 }
212 }
213}
214
215#[cfg(feature = "compression")]
216impl From<sec_websocket_extensions::PerMessageDeflateConfig> for PerMessageDeflateConfig {
217 #[inline]
218 fn from(value: sec_websocket_extensions::PerMessageDeflateConfig) -> Self {
219 Self::from(&value)
220 }
221}
222
223#[cfg(feature = "compression")]
224impl From<&PerMessageDeflateConfig> for sec_websocket_extensions::PerMessageDeflateConfig {
225 fn from(value: &PerMessageDeflateConfig) -> Self {
226 Self {
227 server_no_context_takeover: value.server_no_context_takeover,
228 client_no_context_takeover: value.client_no_context_takeover,
229 server_max_window_bits: value.server_max_window_bits,
230 client_max_window_bits: value.client_max_window_bits,
231 ..Default::default()
232 }
233 }
234}
235
236#[cfg(feature = "compression")]
237impl From<PerMessageDeflateConfig> for sec_websocket_extensions::PerMessageDeflateConfig {
238 #[inline]
239 fn from(value: PerMessageDeflateConfig) -> Self {
240 Self::from(&value)
241 }
242}
243
244#[cfg(feature = "compression")]
245#[allow(clippy::derivable_impls)]
246impl Default for PerMessageDeflateConfig {
247 fn default() -> Self {
248 Self {
249 // By default, allow context takeover in both directions
250 server_no_context_takeover: false,
251 client_no_context_takeover: false,
252
253 // No limit: means default 15-bit window (32768 bytes)
254 server_max_window_bits: None,
255 client_max_window_bits: None,
256 }
257 }
258}
259
260impl Default for WebSocketConfig {
261 fn default() -> Self {
262 Self {
263 read_buffer_size: 128 * 1024,
264 write_buffer_size: 128 * 1024,
265 max_write_buffer_size: usize::MAX,
266 max_message_size: Some(64 << 20),
267 max_frame_size: Some(16 << 20),
268 accept_unmasked_frames: false,
269 #[cfg(feature = "compression")]
270 per_message_deflate: None,
271 }
272 }
273}
274
275impl WebSocketConfig {
276 rama_utils::macros::generate_set_and_with! {
277 /// Set [`Self::read_buffer_size`].
278 #[must_use]
279 pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
280 self.read_buffer_size = read_buffer_size;
281 self
282 }
283 }
284
285 rama_utils::macros::generate_set_and_with! {
286 /// Set [`Self::write_buffer_size`].
287 #[must_use]
288 pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
289 self.write_buffer_size = write_buffer_size;
290 self
291 }
292 }
293
294 rama_utils::macros::generate_set_and_with! {
295 /// Set [`Self::max_write_buffer_size`].
296 #[must_use]
297 pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
298 self.max_write_buffer_size = max_write_buffer_size;
299 self
300 }
301 }
302
303 rama_utils::macros::generate_set_and_with! {
304 /// Set [`Self::max_message_size`].
305 #[must_use]
306 pub fn max_message_size(mut self, max_message_size: Option<usize>) -> Self {
307 self.max_message_size = max_message_size;
308 self
309 }
310 }
311
312 rama_utils::macros::generate_set_and_with! {
313 /// Set [`Self::max_frame_size`].
314 #[must_use]
315 pub fn max_frame_size(mut self, max_frame_size: Option<usize>) -> Self {
316 self.max_frame_size = max_frame_size;
317 self
318 }
319 }
320
321 rama_utils::macros::generate_set_and_with! {
322 /// Set [`Self::accept_unmasked_frames`].
323 #[must_use]
324 pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
325 self.accept_unmasked_frames = accept_unmasked_frames;
326 self
327 }
328 }
329
330 #[cfg(feature = "compression")]
331 rama_utils::macros::generate_set_and_with! {
332 /// Set [`Self::per_message_deflate`] with the default config..
333 #[must_use]
334 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
335 pub fn per_message_deflate_default(mut self) -> Self {
336 self.per_message_deflate = Some(Default::default());
337 self
338 }
339 }
340
341 #[cfg(feature = "compression")]
342 rama_utils::macros::generate_set_and_with! {
343 /// Set [`Self::per_message_deflate`].
344 #[must_use]
345 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
346 pub fn per_message_deflate(mut self, per_message_deflate: Option<PerMessageDeflateConfig>) -> Self {
347 self.per_message_deflate = per_message_deflate;
348 self
349 }
350 }
351
352 /// Panic if values are invalid.
353 pub(crate) fn assert_valid(&self) {
354 assert!(
355 self.max_write_buffer_size > self.write_buffer_size,
356 "WebSocketConfig::max_write_buffer_size must be greater than write_buffer_size, \
357 see WebSocketConfig docs`"
358 );
359 }
360}
361
362/// WebSocket input-output stream.
363///
364/// This is THE structure you want to create to be able to speak the WebSocket protocol.
365/// It may be created by calling `connect`, `accept` or `client` functions.
366///
367/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
368pub struct WebSocket<Stream> {
369 /// The underlying socket.
370 socket: Stream,
371 /// The context for managing a WebSocket.
372 context: WebSocketContext,
373}
374
375impl<Stream: fmt::Debug> fmt::Debug for WebSocket<Stream> {
376 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377 f.debug_struct("WebSocket")
378 .field("socket", &self.socket)
379 .field("context", &self.context)
380 .finish()
381 }
382}
383
384impl<Stream> WebSocket<Stream> {
385 /// Convert a raw socket into a WebSocket without performing a handshake.
386 ///
387 /// # Panics
388 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
389 pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
390 Self {
391 socket: stream,
392 context: WebSocketContext::new(role, config),
393 }
394 }
395
396 /// Convert a raw socket into a WebSocket without performing a handshake.
397 ///
398 /// # Panics
399 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
400 pub fn from_partially_read(
401 stream: Stream,
402 part: Vec<u8>,
403 role: Role,
404 config: Option<WebSocketConfig>,
405 ) -> Self {
406 Self {
407 socket: stream,
408 context: WebSocketContext::from_partially_read(part, role, config),
409 }
410 }
411
412 /// Consumes the `WebSocket` and returns the underlying stream.
413 pub(crate) fn into_inner(self) -> Stream {
414 self.socket
415 }
416
417 /// Returns a shared reference to the inner stream.
418 pub fn get_ref(&self) -> &Stream {
419 &self.socket
420 }
421 /// Returns a mutable reference to the inner stream.
422 pub fn get_mut(&mut self) -> &mut Stream {
423 &mut self.socket
424 }
425
426 /// Change the configuration.
427 ///
428 /// # Panics
429 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
430 pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
431 self.context.set_config(set_func);
432 }
433
434 /// Read the configuration.
435 pub fn get_config(&self) -> &WebSocketConfig {
436 self.context.get_config()
437 }
438
439 /// Check if it is possible to read messages.
440 ///
441 /// Reading is impossible after receiving `Message::Close`. It is still possible after
442 /// sending close frame since the peer still may send some data before confirming close.
443 pub fn can_read(&self) -> bool {
444 self.context.can_read()
445 }
446
447 /// Check if it is possible to write messages.
448 ///
449 /// Writing gets impossible immediately after sending or receiving `Message::Close`.
450 pub fn can_write(&self) -> bool {
451 self.context.can_write()
452 }
453}
454
455impl<Stream: Read + Write> WebSocket<Stream> {
456 /// Read a message from stream, if possible.
457 ///
458 /// This will also queue responses to ping and close messages. These responses
459 /// will be written and flushed on the next call to [`read`](Self::read),
460 /// [`write`](Self::write) or [`flush`](Self::flush).
461 ///
462 /// # Closing the connection
463 /// When the remote endpoint decides to close the connection this will return
464 /// the close message with an optional close frame.
465 ///
466 /// You should continue calling [`read`](Self::read), [`write`](Self::write) or
467 /// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`]
468 /// is returned. Once that happens it is safe to drop the underlying connection.
469 pub fn read(&mut self) -> Result<Message, ProtocolError> {
470 self.context.read(&mut self.socket)
471 }
472
473 /// Writes and immediately flushes a message.
474 /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
475 pub fn send(&mut self, message: Message) -> Result<(), ProtocolError> {
476 self.write(message)?;
477 self.flush()
478 }
479
480 /// Write a message to the provided stream, if possible.
481 ///
482 /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
483 ///
484 /// In the event of stream write failure the message frame will be stored
485 /// in the write buffer and will try again on the next call to [`write`](Self::write)
486 /// or [`flush`](Self::flush).
487 ///
488 /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
489 /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
490 ///
491 /// This call will generally not flush. However, if there are queued automatic messages
492 /// they will be written and eagerly flushed.
493 ///
494 /// For example, upon receiving ping messages this crate queues pong replies automatically.
495 /// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
496 /// will write & flush the pong reply. This means you should not respond to ping frames manually.
497 ///
498 /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
499 /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
500 /// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing
501 /// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the
502 /// ping will not be sent as it will be replaced by your custom pong message.
503 ///
504 /// # Errors
505 /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned
506 /// along with the equivalent passed message frame.
507 /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`].
508 /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from
509 /// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program
510 /// error on your part.
511 /// - [`Error::Io`] is returned if the underlying connection returns an error
512 /// (consider these fatal except for WouldBlock).
513 /// - [`Error::Capacity`] if your message size is bigger than the configured max message size.
514 pub fn write(&mut self, message: Message) -> Result<(), ProtocolError> {
515 self.context.write(&mut self.socket, message)
516 }
517
518 /// Flush writes.
519 ///
520 /// Ensures all messages previously passed to [`write`](Self::write) and automatic
521 /// queued pong responses are written & flushed into the underlying stream.
522 pub fn flush(&mut self) -> Result<(), ProtocolError> {
523 self.context.flush(&mut self.socket)
524 }
525
526 /// Close the connection.
527 ///
528 /// This function guarantees that the close frame will be queued.
529 /// There is no need to call it again. Calling this function is
530 /// the same as calling `write(Message::Close(..))`.
531 ///
532 /// After queuing the close frame you should continue calling [`read`](Self::read) or
533 /// [`flush`](Self::flush) to drive the close handshake to completion.
534 ///
535 /// The websocket RFC defines that the underlying connection should be closed
536 /// by the server. This crate takes care of this asymmetry for you.
537 ///
538 /// When the close handshake is finished (we have both sent and received
539 /// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return
540 /// [Error::ConnectionClosed] if this endpoint is the server.
541 ///
542 /// If this endpoint is a client, [Error::ConnectionClosed] will only be
543 /// returned after the server has closed the underlying connection.
544 ///
545 /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
546 /// is returned from [`read`](Self::read) or [`flush`](Self::flush).
547 pub fn close(&mut self, code: Option<CloseFrame>) -> Result<(), ProtocolError> {
548 self.context.close(&mut self.socket, code)
549 }
550}
551
552impl<Stream: ExtensionsRef> ExtensionsRef for WebSocket<Stream> {
553 fn extensions(&self) -> &Extensions {
554 self.socket.extensions()
555 }
556}
557
558impl<Stream: ExtensionsMut> ExtensionsMut for WebSocket<Stream> {
559 fn extensions_mut(&mut self) -> &mut Extensions {
560 self.socket.extensions_mut()
561 }
562}
563
564/// A context for managing WebSocket stream.
565#[derive(Debug)]
566pub struct WebSocketContext {
567 /// Server or client?
568 role: Role,
569 /// encoder/decoder of frame.
570 frame: FrameCodec,
571 /// The state of processing, either "active" or "closing".
572 state: WebSocketState,
573 #[cfg(feature = "compression")]
574 /// The state used in function per-message compression,
575 /// only set in case the extension is enabled.
576 per_message_deflate_state: Option<PerMessageDeflateState>,
577 /// Receive: an incomplete message being processed.
578 incomplete: Option<IncompleteMessage>,
579 /// Send in addition to regular messages E.g. "pong" or "close".
580 additional_send: Option<Frame>,
581 /// True indicates there is an additional message (like a pong)
582 /// that failed to flush previously and we should try again.
583 unflushed_additional: bool,
584 /// The configuration for the websocket session.
585 config: WebSocketConfig,
586}
587
588impl WebSocketContext {
589 /// Create a WebSocket context that manages a post-handshake stream.
590 ///
591 /// # Panics
592 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
593 #[must_use]
594 pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
595 let conf = config.unwrap_or_default();
596 Self::_new(role, FrameCodec::new(conf.read_buffer_size), conf)
597 }
598
599 /// Create a WebSocket context that manages an post-handshake stream.
600 ///
601 /// # Panics
602 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
603 #[must_use]
604 pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
605 let conf = config.unwrap_or_default();
606 Self::_new(
607 role,
608 FrameCodec::from_partially_read(part, conf.read_buffer_size),
609 conf,
610 )
611 }
612
613 fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
614 config.assert_valid();
615 frame.set_max_out_buffer_len(config.max_write_buffer_size);
616 frame.set_out_buffer_write_len(config.write_buffer_size);
617 Self {
618 role,
619 frame,
620 state: WebSocketState::Active,
621 #[cfg(feature = "compression")]
622 per_message_deflate_state: config
623 .per_message_deflate
624 .map(|cfg| PerMessageDeflateState::new(role, cfg)),
625 incomplete: None,
626 additional_send: None,
627 unflushed_additional: false,
628 config,
629 }
630 }
631
632 /// Change the configuration.
633 ///
634 /// # Panics
635 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
636 pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
637 set_func(&mut self.config);
638 self.config.assert_valid();
639 self.frame
640 .set_max_out_buffer_len(self.config.max_write_buffer_size);
641 self.frame
642 .set_out_buffer_write_len(self.config.write_buffer_size);
643 }
644
645 /// Read the configuration.
646 pub fn get_config(&self) -> &WebSocketConfig {
647 &self.config
648 }
649
650 /// Check if it is possible to read messages.
651 ///
652 /// Reading is impossible after receiving `Message::Close`. It is still possible after
653 /// sending close frame since the peer still may send some data before confirming close.
654 pub fn can_read(&self) -> bool {
655 self.state.can_read()
656 }
657
658 /// Check if it is possible to write messages.
659 ///
660 /// Writing gets impossible immediately after sending or receiving `Message::Close`.
661 pub fn can_write(&self) -> bool {
662 self.state.is_active()
663 }
664
665 /// Read a message from the provided stream, if possible.
666 ///
667 /// This function sends pong and close responses automatically.
668 /// However, it never blocks on write.
669 pub fn read<Stream>(&mut self, stream: &mut Stream) -> Result<Message, ProtocolError>
670 where
671 Stream: Read + Write,
672 {
673 // Do not read from already closed connections.
674 self.state.check_not_terminated()?;
675
676 loop {
677 if self.additional_send.is_some() || self.unflushed_additional {
678 // Since we may get ping or close, we need to reply to the messages even during read.
679 match self.flush(stream) {
680 Ok(_) => {}
681 Err(ProtocolError::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
682 // If blocked continue reading, but try again later
683 self.unflushed_additional = true;
684 }
685 Err(err) => return Err(err),
686 }
687 } else if self.role == Role::Server && !self.state.can_read() {
688 self.state = WebSocketState::Terminated;
689 return Err(ProtocolError::Io(io::Error::new(
690 io::ErrorKind::ConnectionAborted,
691 OpaqueError::from_display("Connection closed normally by me-the-server"),
692 )));
693 }
694
695 // If we get here, either write blocks or we have nothing to write.
696 // Thus if read blocks, just let it return WouldBlock.
697 if let Some(message) = self.read_message_frame(stream)? {
698 trace!("Received message {message}");
699 return Ok(message);
700 }
701 }
702 }
703
704 /// Write a message to the provided stream.
705 ///
706 /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
707 ///
708 /// In the event of stream write failure the message frame will be stored
709 /// in the write buffer and will try again on the next call to [`write`](Self::write)
710 /// or [`flush`](Self::flush).
711 ///
712 /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
713 /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
714 pub fn write<Stream>(
715 &mut self,
716 stream: &mut Stream,
717 message: Message,
718 ) -> Result<(), ProtocolError>
719 where
720 Stream: Read + Write,
721 {
722 // When terminated, return AlreadyClosed.
723 self.state.check_not_terminated()?;
724
725 // Do not write after sending a close frame.
726 if !self.state.is_active() {
727 return Err(ProtocolError::SendAfterClosing);
728 }
729
730 let frame = match message {
731 Message::Text(data) => {
732 #[cfg(feature = "compression")]
733 match self.per_message_deflate_state.as_mut() {
734 Some(deflate_state) => {
735 let data = match deflate_state.encoder.encode(data.as_bytes()) {
736 Ok(data) => data,
737 Err(err) => return Err(ProtocolError::DeflateError(err)),
738 };
739 let mut msg = Frame::message(data, OpCode::Data(OpCodeData::Text), true);
740 msg.header_mut().rsv1 = true;
741 msg
742 }
743 None => Frame::message(data, OpCode::Data(OpCodeData::Text), true),
744 }
745 #[cfg(not(feature = "compression"))]
746 Frame::message(data, OpCode::Data(OpCodeData::Text), true)
747 }
748 Message::Binary(data) => {
749 #[cfg(feature = "compression")]
750 match self.per_message_deflate_state.as_mut() {
751 Some(deflate_state) => {
752 let data = match deflate_state.encoder.encode(data.as_ref()) {
753 Ok(data) => data,
754 Err(err) => return Err(ProtocolError::DeflateError(err)),
755 };
756 let mut msg = Frame::message(data, OpCode::Data(OpCodeData::Binary), true);
757 msg.header_mut().rsv1 = true;
758 msg
759 }
760 None => Frame::message(data, OpCode::Data(OpCodeData::Binary), true),
761 }
762 #[cfg(not(feature = "compression"))]
763 Frame::message(data, OpCode::Data(OpCodeData::Binary), true)
764 }
765 Message::Ping(data) => Frame::ping(data),
766 Message::Pong(data) => {
767 self.set_additional(Frame::pong(data));
768 // Note: user pongs can be user flushed so no need to flush here
769 return self._write(stream, None).map(|_| ());
770 }
771 Message::Close(code) => return self.close(stream, code),
772 Message::Frame(f) => f,
773 };
774
775 let should_flush = self._write(stream, Some(frame))?;
776 if should_flush {
777 self.flush(stream)?;
778 }
779 Ok(())
780 }
781
782 /// Flush writes.
783 ///
784 /// Ensures all messages previously passed to [`write`](Self::write) and automatically
785 /// queued pong responses are written & flushed into the `stream`.
786 #[inline]
787 pub fn flush<Stream>(&mut self, stream: &mut Stream) -> Result<(), ProtocolError>
788 where
789 Stream: Read + Write,
790 {
791 self._write(stream, None)?;
792 self.frame.write_out_buffer(stream)?;
793 stream.flush()?;
794 self.unflushed_additional = false;
795 Ok(())
796 }
797
798 /// Writes any data in the out_buffer, `additional_send` and given `data`.
799 ///
800 /// Does **not** flush.
801 ///
802 /// Returns true if the write contents indicate we should flush immediately.
803 fn _write<Stream>(
804 &mut self,
805 stream: &mut Stream,
806 data: Option<Frame>,
807 ) -> Result<bool, ProtocolError>
808 where
809 Stream: Read + Write,
810 {
811 if let Some(data) = data {
812 self.buffer_frame(stream, data)?;
813 }
814
815 // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
816 // response, unless it already received a Close frame. It SHOULD
817 // respond with Pong frame as soon as is practical. (RFC 6455)
818 let should_flush = if let Some(msg) = self.additional_send.take() {
819 trace!("Sending pong/close");
820 match self.buffer_frame(stream, msg) {
821 Err(ProtocolError::WriteBufferFull(msg)) => {
822 // if an system message would exceed the buffer put it back in
823 // `additional_send` for retry. Otherwise returning this error
824 // may not make sense to the user, e.g. calling `flush`.
825 if let Message::Frame(msg) = msg {
826 self.set_additional(msg);
827 false
828 } else {
829 unreachable!();
830 }
831 }
832 Err(err) => return Err(err),
833 Ok(_) => true,
834 }
835 } else {
836 self.unflushed_additional
837 };
838
839 // If we're closing and there is nothing to send anymore, we should close the connection.
840 if self.role == Role::Server && !self.state.can_read() {
841 // The underlying TCP connection, in most normal cases, SHOULD be closed
842 // first by the server, so that it holds the TIME_WAIT state and not the
843 // client (as this would prevent it from re-opening the connection for 2
844 // maximum segment lifetimes (2MSL), while there is no corresponding
845 // server impact as a TIME_WAIT connection is immediately reopened upon
846 // a new SYN with a higher seq number). (RFC 6455)
847 self.frame.write_out_buffer(stream)?;
848 self.state = WebSocketState::Terminated;
849 Err(ProtocolError::Io(io::Error::new(
850 io::ErrorKind::ConnectionAborted,
851 OpaqueError::from_display("Connection closed normally by me-the-server (EOF)"),
852 )))
853 } else {
854 Ok(should_flush)
855 }
856 }
857
858 /// Close the connection.
859 ///
860 /// This function guarantees that the close frame will be queued.
861 /// There is no need to call it again. Calling this function is
862 /// the same as calling `send(Message::Close(..))`.
863 pub fn close<Stream>(
864 &mut self,
865 stream: &mut Stream,
866 code: Option<CloseFrame>,
867 ) -> Result<(), ProtocolError>
868 where
869 Stream: Read + Write,
870 {
871 if self.state == WebSocketState::Active {
872 self.state = WebSocketState::ClosedByUs;
873 let frame = Frame::close(code);
874 self._write(stream, Some(frame))?;
875 }
876 self.flush(stream)
877 }
878
879 /// Try to decode one message frame. May return None.
880 fn read_message_frame(
881 &mut self,
882 stream: &mut impl Read,
883 ) -> Result<Option<Message>, ProtocolError> {
884 let Some(frame) = self.frame.read_frame(
885 stream,
886 self.config.max_frame_size,
887 matches!(self.role, Role::Server),
888 self.config.accept_unmasked_frames,
889 )?
890 else {
891 // Connection closed by peer
892 return match std::mem::replace(&mut self.state, WebSocketState::Terminated) {
893 WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
894 Err(ProtocolError::Io(io::Error::new(
895 io::ErrorKind::ConnectionAborted,
896 OpaqueError::from_display("Connection closed normally by peer"),
897 )))
898 }
899 WebSocketState::Active
900 | WebSocketState::ClosedByUs
901 | WebSocketState::Terminated => Err(ProtocolError::ResetWithoutClosingHandshake),
902 };
903 };
904
905 if !self.state.can_read() {
906 return Err(ProtocolError::ReceivedAfterClosing);
907 }
908
909 #[cfg(feature = "compression")]
910 // to ensure that this is valid in later branches,
911 // as this is not always true despite an extension active that supports it
912 let mut rsv1_set = false;
913
914 // MUST be 0 unless an extension is negotiated that defines meanings
915 // for non-zero values. If a nonzero value is received and none of
916 // the negotiated extensions defines the meaning of such a nonzero
917 // value, the receiving endpoint MUST _Fail the WebSocket
918 // Connection_.
919 {
920 let hdr = frame.header();
921 if hdr.rsv1 {
922 #[cfg(feature = "compression")]
923 {
924 rsv1_set = true;
925 if self.per_message_deflate_state.is_none() {
926 tracing::debug!(
927 "rsv1 bit is set but PMD state is none: no use case for it"
928 );
929 return Err(ProtocolError::NonZeroReservedBits);
930 }
931 }
932 #[cfg(not(feature = "compression"))]
933 {
934 tracing::debug!("rsv1 bit is set but compression feature no enabled");
935 return Err(ProtocolError::NonZeroReservedBits);
936 }
937 } else if hdr.rsv2 || hdr.rsv3 {
938 tracing::debug!("rsv2 or rsv3 bit set: not expected ever");
939 return Err(ProtocolError::NonZeroReservedBits);
940 }
941 }
942
943 if self.role == Role::Client && frame.is_masked() {
944 // A client MUST close a connection if it detects a masked frame. (RFC 6455)
945 return Err(ProtocolError::MaskedFrameFromServer);
946 }
947
948 match frame.header().opcode {
949 OpCode::Control(ctl) => {
950 #[cfg(feature = "compression")]
951 if rsv1_set {
952 tracing::debug!("rsv1 bit set in control frame: not expected");
953 return Err(ProtocolError::NonZeroReservedBits);
954 }
955
956 match ctl {
957 // All control frames MUST have a payload length of 125 bytes or less
958 // and MUST NOT be fragmented. (RFC 6455)
959 _ if !frame.header().is_final => Err(ProtocolError::FragmentedControlFrame),
960 _ if frame.payload().len() > 125 => Err(ProtocolError::ControlFrameTooBig),
961 OpCodeControl::Close => {
962 Ok(self.do_close(frame.into_close()?).map(Message::Close))
963 }
964 OpCodeControl::Reserved(i) => Err(ProtocolError::UnknownControlFrameType(i)),
965 OpCodeControl::Ping => {
966 let data = frame.into_payload();
967 // No ping processing after we sent a close frame.
968 if self.state.is_active() {
969 self.set_additional(Frame::pong(data.clone()));
970 }
971 Ok(Some(Message::Ping(data)))
972 }
973 OpCodeControl::Pong => Ok(Some(Message::Pong(frame.into_payload()))),
974 }
975 }
976
977 OpCode::Data(data) => {
978 let fin = frame.header().is_final;
979
980 #[cfg(feature = "compression")]
981 if matches!(data, OpCodeData::Continue) && rsv1_set {
982 tracing::debug!("rsv1 bit set in CONTINUE frame: not expected");
983 return Err(ProtocolError::NonZeroReservedBits);
984 }
985
986 let payload = match (data, self.incomplete.as_mut()) {
987 (OpCodeData::Continue, None) => {
988 #[cfg(feature = "compression")]
989 if let Some(deflate_state) = self.per_message_deflate_state.as_mut() {
990 if fin {
991 let (compressed_data, msg_type) =
992 deflate_state.decompress_incomplete_msg.fin_buffer(
993 frame.into_payload(),
994 self.config.max_message_size,
995 )?;
996 return match deflate_state.decoder.decode(compressed_data.as_ref())
997 {
998 Ok(raw_data) => match msg_type {
999 IncompleteMessageType::Text => {
1000 Ok(Some(Message::Text(Utf8Bytes::try_from(raw_data)?)))
1001 }
1002 IncompleteMessageType::Binary => {
1003 Ok(Some(Message::Binary(raw_data.into())))
1004 }
1005 },
1006 Err(err) => Err(ProtocolError::DeflateError(err)),
1007 };
1008 }
1009
1010 deflate_state
1011 .decompress_incomplete_msg
1012 .extend(frame.into_payload(), self.config.max_message_size)?;
1013 Ok(None)
1014 } else {
1015 return Err(ProtocolError::UnexpectedContinueFrame);
1016 }
1017
1018 #[cfg(not(feature = "compression"))]
1019 return Err(ProtocolError::UnexpectedContinueFrame);
1020 }
1021 (OpCodeData::Continue, Some(incomplete)) => {
1022 incomplete.extend(frame.into_payload(), self.config.max_message_size)?;
1023 Ok(None)
1024 }
1025 (_, Some(_)) => Err(ProtocolError::ExpectedFragment(data)),
1026 (OpCodeData::Text, _) => {
1027 Ok(Some((frame.into_payload(), IncompleteMessageType::Text)))
1028 }
1029 (OpCodeData::Binary, _) => {
1030 Ok(Some((frame.into_payload(), IncompleteMessageType::Binary)))
1031 }
1032 (OpCodeData::Reserved(i), _) => Err(ProtocolError::UnknownDataFrameType(i)),
1033 }?;
1034
1035 match (payload, fin) {
1036 (None, true) =>
1037 {
1038 #[allow(
1039 clippy::expect_used,
1040 reason = "we can only reach here if incomplete is Some"
1041 )]
1042 Ok(Some(
1043 self.incomplete
1044 .take()
1045 .expect("incomplete to be there")
1046 .complete()?,
1047 ))
1048 }
1049 (None, false) => Ok(None),
1050 (Some((payload, t)), true) => {
1051 check_max_size(payload.len(), self.config.max_message_size)?;
1052
1053 #[cfg(feature = "compression")]
1054 if rsv1_set {
1055 if let Some(deflate_state) = self.per_message_deflate_state.as_mut() {
1056 let compressed_data = payload;
1057 let raw_data = deflate_state
1058 .decoder
1059 .decode(&compressed_data)
1060 .map_err(ProtocolError::DeflateError)?;
1061 match t {
1062 IncompleteMessageType::Text => {
1063 Ok(Some(Message::Text(Utf8Bytes::try_from(raw_data)?)))
1064 }
1065 IncompleteMessageType::Binary => {
1066 Ok(Some(Message::Binary(raw_data.into())))
1067 }
1068 }
1069 } else {
1070 tracing::debug!(
1071 "rsv1 bit set in text frame but deflate state is none"
1072 );
1073 Err(ProtocolError::NonZeroReservedBits)
1074 }
1075 } else {
1076 match t {
1077 IncompleteMessageType::Text => {
1078 Ok(Some(Message::Text(payload.try_into()?)))
1079 }
1080 IncompleteMessageType::Binary => Ok(Some(Message::Binary(payload))),
1081 }
1082 }
1083
1084 #[cfg(not(feature = "compression"))]
1085 match t {
1086 IncompleteMessageType::Text => {
1087 Ok(Some(Message::Text(payload.try_into()?)))
1088 }
1089 IncompleteMessageType::Binary => Ok(Some(Message::Binary(payload))),
1090 }
1091 }
1092 (Some((payload, t)), false) => {
1093 #[cfg(feature = "compression")]
1094 if rsv1_set {
1095 if let Some(deflate_state) = self.per_message_deflate_state.as_mut() {
1096 deflate_state.decompress_incomplete_msg.reset(t);
1097 deflate_state
1098 .decompress_incomplete_msg
1099 .extend(payload, self.config.max_message_size)?;
1100 Ok(None)
1101 } else {
1102 tracing::debug!(
1103 "rsv1 bit set in non-fin bin/text frame but deflate state is none"
1104 );
1105 Err(ProtocolError::NonZeroReservedBits)
1106 }
1107 } else {
1108 let mut incomplete = IncompleteMessage::new(t);
1109 incomplete.extend(payload, self.config.max_message_size)?;
1110 self.incomplete = Some(incomplete);
1111 Ok(None)
1112 }
1113 #[cfg(not(feature = "compression"))]
1114 {
1115 let mut incomplete = IncompleteMessage::new(t);
1116 incomplete.extend(payload, self.config.max_message_size)?;
1117 self.incomplete = Some(incomplete);
1118 Ok(None)
1119 }
1120 }
1121 }
1122 }
1123 } // match opcode
1124 }
1125
1126 /// Received a close frame. Tells if we need to return a close frame to the user.
1127 #[allow(clippy::option_option)]
1128 fn do_close(&mut self, close: Option<CloseFrame>) -> Option<Option<CloseFrame>> {
1129 rama_core::telemetry::tracing::trace!("Received close frame: {close:?}");
1130 match self.state {
1131 WebSocketState::Active => {
1132 self.state = WebSocketState::ClosedByPeer;
1133
1134 let close = close.map(|frame| {
1135 if !frame.code.is_allowed() {
1136 CloseFrame {
1137 code: CloseCode::Protocol,
1138 reason: Utf8Bytes::from_static("Protocol violation"),
1139 }
1140 } else {
1141 frame
1142 }
1143 });
1144
1145 let reply = Frame::close(close.clone());
1146 debug!("Replying to close with {reply:?}");
1147 self.set_additional(reply);
1148
1149 Some(close)
1150 }
1151 WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
1152 // It is already closed, just ignore.
1153 None
1154 }
1155 WebSocketState::ClosedByUs => {
1156 // We received a reply.
1157 self.state = WebSocketState::CloseAcknowledged;
1158 Some(close)
1159 }
1160 WebSocketState::Terminated => unreachable!(),
1161 }
1162 }
1163
1164 /// Write a single frame into the write-buffer.
1165 fn buffer_frame<Stream>(
1166 &mut self,
1167 stream: &mut Stream,
1168 mut frame: Frame,
1169 ) -> Result<(), ProtocolError>
1170 where
1171 Stream: Read + Write,
1172 {
1173 match self.role {
1174 Role::Server => {}
1175 Role::Client => {
1176 // 5. If the data is being sent by the client, the frame(s) MUST be
1177 // masked as defined in Section 5.3. (RFC 6455)
1178 frame.set_random_mask();
1179 }
1180 }
1181
1182 trace!("Sending frame: {frame:?}");
1183 self.frame.buffer_frame(stream, frame)
1184 }
1185
1186 /// Replace `additional_send` if it is currently a `Pong` message.
1187 fn set_additional(&mut self, add: Frame) {
1188 let empty_or_pong = self
1189 .additional_send
1190 .as_ref()
1191 .is_none_or(|f| f.header().opcode == OpCode::Control(OpCodeControl::Pong));
1192 if empty_or_pong {
1193 self.additional_send.replace(add);
1194 }
1195 }
1196}
1197
1198fn check_max_size(size: usize, max_size: Option<usize>) -> Result<(), ProtocolError> {
1199 if let Some(max_size) = max_size
1200 && size > max_size
1201 {
1202 return Err(ProtocolError::MessageTooLong { size, max_size });
1203 }
1204 Ok(())
1205}
1206
1207/// The current connection state.
1208#[derive(Debug, PartialEq, Eq, Clone, Copy)]
1209enum WebSocketState {
1210 /// The connection is active.
1211 Active,
1212 /// We initiated a close handshake.
1213 ClosedByUs,
1214 /// The peer initiated a close handshake.
1215 ClosedByPeer,
1216 /// The peer replied to our close handshake.
1217 CloseAcknowledged,
1218 /// The connection does not exist anymore.
1219 Terminated,
1220}
1221
1222impl WebSocketState {
1223 /// Tell if we're allowed to process normal messages.
1224 fn is_active(self) -> bool {
1225 matches!(self, Self::Active)
1226 }
1227
1228 /// Tell if we should process incoming data. Note that if we send a close frame
1229 /// but the remote hasn't confirmed, they might have sent data before they receive our
1230 /// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
1231 fn can_read(self) -> bool {
1232 matches!(self, Self::Active | Self::ClosedByUs)
1233 }
1234
1235 /// Check if the state is active, return error if not.
1236 fn check_not_terminated(self) -> Result<(), ProtocolError> {
1237 match self {
1238 Self::Terminated => Err(ProtocolError::Io(io::Error::new(
1239 io::ErrorKind::NotConnected,
1240 OpaqueError::from_display("Trying to work with closed connection"),
1241 ))),
1242 Self::Active | Self::CloseAcknowledged | Self::ClosedByPeer | Self::ClosedByUs => {
1243 Ok(())
1244 }
1245 }
1246 }
1247}