blitz_ws/protocol/
websocket.rs

1//! WebSocket handler
2
3use std::{
4    io::{self, Read, Write},
5    mem::replace,
6};
7
8use crate::{
9    error::{CapacityError, Error, ProtocolError, Result},
10    protocol::{
11        config::WebSocketConfig,
12        frame::{
13            codec::{CloseCode, Control, Data, OpCode},
14            core::FrameCodec,
15            CloseFrame, Frame, Utf8Bytes,
16        },
17        message::{IncompleteMessage, IncompleteMessageType, Message},
18    },
19    MAX_CONTROL_FRAME_PAYLOAD,
20};
21
22/// WebSocket operation mode
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum OperationMode {
25    /// Client mode
26    Client,
27    /// Server mode
28    Server,
29}
30
31/// WebSocket input-output stream.
32///
33/// This is THE structure you want to create to be able to speak the WebSocket protocol.
34/// It may be created by calling `connect`, `accept` or `client` functions.
35///
36/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
37#[derive(Debug)]
38pub struct WebSocket<T> {
39    stream: T,
40    context: WebSocketContext,
41}
42
43impl<T: Read + Write> WebSocket<T> {
44    /// Convert a raw socket into a WebSocket without performing a handshake.
45    ///
46    /// Call this function if you're using Tungstenite as a part of a web framework
47    /// or together with an existing one. If you need an initial handshake, use
48    /// `connect()` or `accept()` functions of the crate to construct a websocket.
49    ///
50    /// # Panics
51    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
52    pub fn new(stream: T, mode: OperationMode, config: Option<WebSocketConfig>) -> Self {
53        WebSocket { stream, context: WebSocketContext::new(mode, config) }
54    }
55
56    /// Convert a raw socket into a WebSocket without performing a handshake.
57    ///
58    /// Call this function if you're using Tungstenite as a part of a web framework
59    /// or together with an existing one. If you need an initial handshake, use
60    /// `connect()` or `accept()` functions of the crate to construct a websocket.
61    ///
62    /// # Panics
63    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
64    pub fn from_partially_read(
65        stream: T,
66        part: Vec<u8>,
67        mode: OperationMode,
68        config: Option<WebSocketConfig>,
69    ) -> Self {
70        WebSocket { stream, context: WebSocketContext::from_partially_read(part, mode, config) }
71    }
72
73    /// Returns a shared reference to the stream
74    pub fn get_ref(&self) -> &T {
75        &self.stream
76    }
77
78    /// Returns a mutable reference to the stream
79    pub fn get_mut(&mut self) -> &mut T {
80        &mut self.stream
81    }
82
83    /// Returns the inner instance of the stream
84    pub fn into_inner(self) -> T {
85        self.stream
86    }
87
88    /// Change the configuration.
89    ///
90    /// # Panics
91    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
92    pub fn set_config(&mut self, func: impl FnOnce(&mut WebSocketConfig)) {
93        self.context.set_config(func);
94    }
95
96    /// Read the configuration.
97    pub fn get_config(&self) -> &WebSocketConfig {
98        self.context.get_config()
99    }
100
101    /// Check if it is possible to read messages.
102    ///
103    /// Reading is impossible after receiving `Message::Close`. It is still possible after
104    /// sending close frame since the peer still may send some data before confirming close.
105    pub fn can_read(&self) -> bool {
106        self.context.can_read()
107    }
108
109    /// Check if it is possible to write messages.
110    ///
111    /// Writing gets impossible immediately after sending or receiving `Message::Close`.
112    pub fn can_write(&self) -> bool {
113        self.context.can_write()
114    }
115
116    /// Check if it is possible to read messages.
117    ///
118    /// Reading is impossible after receiving `Message::Close`. It is still possible after
119    /// sending close frame since the peer still may send some data before confirming close.
120    pub fn read(&mut self) -> Result<Message> {
121        self.context.read(&mut self.stream)
122    }
123
124    /// Writes and immediately flushes a message.
125    /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
126    pub fn send(&mut self, msg: Message) -> Result<()> {
127        self.write(msg)?;
128        self.flush()
129    }
130
131    /// Write a message to the provided stream, if possible.
132    ///
133    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
134    ///
135    /// In the event of stream write failure the message frame will be stored
136    /// in the write buffer and will try again on the next call to [`write`](Self::write)
137    /// or [`flush`](Self::flush).
138    ///
139    /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
140    /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
141    ///
142    /// This call will generally not flush. However, if there are queued automatic messages
143    /// they will be written and eagerly flushed.
144    ///
145    /// For example, upon receiving ping messages tungstenite queues pong replies automatically.
146    /// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
147    /// will write & flush the pong reply. This means you should not respond to ping frames manually.
148    ///
149    /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
150    /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
151    /// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing
152    /// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the
153    /// ping will not be sent as it will be replaced by your custom pong message.
154    ///
155    /// # Errors
156    /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned
157    ///   along with the equivalent passed message frame.
158    /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`].
159    /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from
160    ///   [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program
161    ///   error on your part.
162    /// - [`Error::Io`] is returned if the underlying connection returns an error
163    ///   (consider these fatal except for WouldBlock).
164    /// - [`Error::Capacity`] if your message size is bigger than the configured max message size.
165    pub fn write(&mut self, msg: Message) -> Result<()> {
166        self.context.write(&mut self.stream, msg)
167    }
168
169    /// Flush writes.
170    ///
171    /// Ensures all messages previously passed to [`write`](Self::write) and automatic
172    /// queued pong responses are written & flushed into the underlying stream.
173    pub fn flush(&mut self) -> Result<()> {
174        self.context.flush(&mut self.stream)
175    }
176
177    /// Close the connection.
178    ///
179    /// This function guarantees that the close frame will be queued.
180    /// There is no need to call it again. Calling this function is
181    /// the same as calling `write(Message::Close(..))`.
182    ///
183    /// After queuing the close frame you should continue calling [`read`](Self::read) or
184    /// [`flush`](Self::flush) to drive the close handshake to completion.
185    ///
186    /// The websocket RFC defines that the underlying connection should be closed
187    /// by the server. Tungstenite takes care of this asymmetry for you.
188    ///
189    /// When the close handshake is finished (we have both sent and received
190    /// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return
191    /// [Error::ConnectionClosed] if this endpoint is the server.
192    ///
193    /// If this endpoint is a client, [Error::ConnectionClosed] will only be
194    /// returned after the server has closed the underlying connection.
195    ///
196    /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
197    /// is returned from [`read`](Self::read) or [`flush`](Self::flush).
198    pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
199        self.context.close(&mut self.stream, code)
200    }
201}
202
203/// A context for managing WebSocket stream.
204#[derive(Debug)]
205pub struct WebSocketContext {
206    /// Server or client?
207    mode: OperationMode,
208    /// encoder / decoder of frame.
209    frame: FrameCodec,
210    /// The state of processing, either "active" or "closing".
211    state: WebSocketState,
212    /// Receive: an incomplete message being processed.
213    incomplete: Option<IncompleteMessage>,
214    /// Send in addition to regular messages E.g. "pong" or "close".
215    additional_send: Option<Frame>,
216    /// True indicates there is an additional message (like a pong)
217    /// that failed to flush previously and we should try again.
218    unflushed_additional: bool,
219    /// The configuration for the websocket session.
220    config: WebSocketConfig,
221}
222
223impl WebSocketContext {
224    /// Create a WebSocket context that manages a post-handshake stream.
225    ///
226    /// # Panics
227    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
228    pub fn new(mode: OperationMode, config: Option<WebSocketConfig>) -> Self {
229        let configuration = config.unwrap_or_default();
230        Self::_new(mode, FrameCodec::new(configuration.read_buffer_size), configuration)
231    }
232
233    /// Create a WebSocket context that manages an post-handshake stream.
234    ///
235    /// # Panics
236    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
237    pub fn from_partially_read(
238        part: Vec<u8>,
239        mode: OperationMode,
240        config: Option<WebSocketConfig>,
241    ) -> Self {
242        let configuration = config.unwrap_or_default();
243        Self::_new(
244            mode,
245            FrameCodec::from_partially_read(part, configuration.read_buffer_size),
246            configuration,
247        )
248    }
249
250    fn _new(mode: OperationMode, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
251        config.asset_valid();
252
253        frame.max_out_buffer_len(config.max_write_buffer_size);
254        frame.out_buffer_write_len(config.write_buffer_size);
255
256        Self {
257            mode,
258            frame,
259            state: WebSocketState::Active,
260            incomplete: None,
261            additional_send: None,
262            unflushed_additional: false,
263            config,
264        }
265    }
266
267    /// Change the configuration.
268    ///
269    /// # Panics
270    /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
271    pub fn set_config(&mut self, func: impl FnOnce(&mut WebSocketConfig)) {
272        func(&mut self.config);
273
274        self.config.asset_valid();
275        self.frame.max_out_buffer_len(self.config.max_write_buffer_size);
276        self.frame.out_buffer_write_len(self.config.write_buffer_size);
277    }
278
279    /// Read the configuration.
280    pub fn get_config(&self) -> &WebSocketConfig {
281        &self.config
282    }
283
284    /// Check if it is possible to read messages.
285    ///
286    /// Reading is impossible after receiving `Message::Close`. It is still possible after
287    /// sending close frame since the peer still may send some data before confirming close.
288    pub fn can_read(&self) -> bool {
289        self.state.can_read()
290    }
291
292    /// Check if it is possible to write messages.
293    ///
294    /// Writing gets impossible immediately after sending or receiving `Message::Close`.
295    pub fn can_write(&self) -> bool {
296        self.state.is_active()
297    }
298
299    /// Read a message from the provided stream, if possible.
300    ///
301    /// This function sends pong and close responses automatically.
302    /// However, it never blocks on write.
303    pub fn read<T: Read + Write>(&mut self, stream: &mut T) -> Result<Message> {
304        self.state.check_if_terminated()?;
305
306        loop {
307            if self.additional_send.is_some() || self.unflushed_additional {
308                match self.flush(stream) {
309                    Ok(_) => {}
310                    Err(Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => {
311                        self.unflushed_additional = true
312                    }
313                    Err(e) => return Err(e),
314                }
315            } else if self.mode == OperationMode::Server && !self.state.can_read() {
316                self.state = WebSocketState::Terminated;
317                return Err(Error::ConnectionClosed);
318            }
319
320            if let Some(msg) = self._read(stream)? {
321                return Ok(msg);
322            }
323        }
324    }
325
326    /// Write a message to the provided stream.
327    ///
328    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
329    ///
330    /// In the event of stream write failure the message frame will be stored
331    /// in the write buffer and will try again on the next call to [`write`](Self::write)
332    /// or [`flush`](Self::flush).
333    ///
334    /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
335    /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
336    pub fn write<T: Read + Write>(&mut self, stream: &mut T, msg: Message) -> Result<()> {
337        self.state.check_if_terminated()?;
338
339        if !self.state.is_active() {
340            return Err(Error::Protocol(ProtocolError::SendAfterClose));
341        }
342
343        let frame = match msg {
344            Message::Text(data) => Frame::new_data(data, OpCode::Data(Data::Text), true),
345            Message::Binary(data) => Frame::new_data(data, OpCode::Data(Data::Binary), true),
346            Message::Ping(data) => Frame::new_ping(data),
347            Message::Pong(data) => {
348                self.set_additional(Frame::new_pong(data));
349                return self._write(stream, None).map(|_| ());
350            }
351            Message::Close(code) => return self.close(stream, code),
352            Message::Frame(f) => f,
353        };
354
355        let should_flush = self._write(stream, Some(frame))?;
356        if should_flush {
357            self.flush(stream)?;
358        }
359
360        Ok(())
361    }
362
363    /// Flush writes.
364    ///
365    /// Ensures all messages previously passed to [`write`](Self::write) and automatically
366    /// queued pong responses are written & flushed into the `stream`.
367    #[inline]
368    pub fn flush<T: Read + Write>(&mut self, stream: &mut T) -> Result<()> {
369        self._write(stream, None)?;
370        self.frame.write_out(stream)?;
371
372        stream.flush()?;
373
374        self.unflushed_additional = false;
375
376        Ok(())
377    }
378
379    /// Close the connection.
380    ///
381    /// This function guarantees that the close frame will be queued.
382    /// There is no need to call it again. Calling this function is
383    /// the same as calling `send(Message::Close(..))`.
384    pub fn close<T: Read + Write>(
385        &mut self,
386        stream: &mut T,
387        code: Option<CloseFrame>,
388    ) -> Result<()> {
389        if let WebSocketState::Active = self.state {
390            self.state = WebSocketState::ClosedByServer;
391
392            let frame = Frame::new_close(code);
393
394            self._write(stream, Some(frame))?;
395        }
396
397        self.flush(stream)
398    }
399
400    fn _read<T: Read>(&mut self, stream: &mut T) -> Result<Option<Message>> {
401        if let Some(frame) = self
402            .frame
403            .read(
404                stream,
405                self.config.max_frame_size,
406                matches!(self.mode, OperationMode::Server),
407                self.config.accept_unmasked_frames,
408            )
409            .check_connection_reset(self.state)?
410        {
411            if !self.state.can_read() {
412                return Err(Error::Protocol(ProtocolError::ReceiveAfterClose));
413            }
414
415            let header = frame.header();
416            if header.rsv1 || header.rsv2 || header.rsv3 {
417                return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
418            }
419
420            if self.mode == OperationMode::Client && frame.is_masked() {
421                return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer));
422            }
423
424            match frame.header().opcode {
425                OpCode::Control(ctrl) => match ctrl {
426                    _ if !frame.header().fin => {
427                        Err(Error::Protocol(ProtocolError::FragmentedControlFrame))
428                    }
429                    _ if frame.payload().len() > MAX_CONTROL_FRAME_PAYLOAD => {
430                        Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
431                    }
432                    Control::Close => Ok(self.try_close(frame.into_close()?).map(Message::Close)),
433                    Control::Reserved(code) => {
434                        Err(Error::Protocol(ProtocolError::UnknownControlOpCode(code)))
435                    }
436                    Control::Ping => {
437                        let data = frame.into_payload();
438                        if self.state.is_active() {
439                            self.set_additional(Frame::new_pong(data.clone()));
440                        }
441
442                        Ok(Some(Message::Ping(data)))
443                    }
444                    Control::Pong => Ok(Some(Message::Pong(frame.into_payload()))),
445                },
446                OpCode::Data(data) => {
447                    let fin = frame.header().fin;
448
449                    match data {
450                        Data::Continuation => {
451                            if let Some(ref mut msg) = self.incomplete {
452                                msg.extend(frame.into_payload(), self.config.max_message_size)?;
453                            } else {
454                                return Err(Error::Protocol(ProtocolError::UnexpectedContinue));
455                            }
456
457                            if fin {
458                                Ok(Some(self.incomplete.take().unwrap().complete()?))
459                            } else {
460                                Ok(None)
461                            }
462                        }
463                        data_frag if self.incomplete.is_some() => {
464                            Err(Error::Protocol(ProtocolError::ExpectedFragment(data_frag)))
465                        }
466                        Data::Text if fin => {
467                            check_max_size(frame.payload().len(), self.config.max_message_size)?;
468                            Ok(Some(Message::Text(frame.into_text()?)))
469                        }
470                        Data::Binary if fin => {
471                            check_max_size(frame.payload().len(), self.config.max_message_size)?;
472                            Ok(Some(Message::Binary(frame.into_payload())))
473                        }
474                        Data::Text | Data::Binary => {
475                            let msg_type = match data {
476                                Data::Text => IncompleteMessageType::Text,
477                                Data::Binary => IncompleteMessageType::Binary,
478                                _ => panic!("Bug: message is neither text not binary"),
479                            };
480
481                            let mut incomplete = IncompleteMessage::new(msg_type);
482                            incomplete
483                                .extend(frame.into_payload(), self.config.max_message_size)?;
484
485                            self.incomplete = Some(incomplete);
486
487                            Ok(None)
488                        }
489                        Data::Reserved(code) => {
490                            Err(Error::Protocol(ProtocolError::UnknownDataOpCode(code)))
491                        }
492                    }
493                }
494            }
495        } else {
496            match replace(&mut self.state, WebSocketState::Terminated) {
497                WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
498                    Err(Error::ConnectionClosed)
499                }
500                _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosing)),
501            }
502        }
503    }
504
505    fn _write<T: Read + Write>(&mut self, stream: &mut T, data: Option<Frame>) -> Result<bool> {
506        if let Some(data) = data {
507            self.buffer_frame(stream, data)?;
508        }
509
510        let should_flush = if let Some(msg) = self.additional_send.take() {
511            match self.buffer_frame(stream, msg.clone()) {
512                Err(Error::WriteBufferFull) => {
513                    self.set_additional(msg);
514                    false
515                }
516                Err(e) => return Err(e),
517                Ok(_) => true,
518            }
519        } else {
520            self.unflushed_additional
521        };
522
523        if self.mode == OperationMode::Server && !self.state.can_read() {
524            self.frame.write_out(stream)?;
525            self.state = WebSocketState::Terminated;
526
527            Err(Error::ConnectionClosed)
528        } else {
529            Ok(should_flush)
530        }
531    }
532
533    /// Received a close frame. Tells if we need to return a close frame to the user.
534    #[allow(clippy::option_option)]
535    fn try_close(&mut self, close: Option<CloseFrame>) -> Option<Option<CloseFrame>> {
536        match self.state {
537            WebSocketState::Active => {
538                self.state = WebSocketState::ClosedByPeer;
539
540                let close = close.map(|frame| {
541                    if !frame.code.allowed() {
542                        CloseFrame {
543                            code: CloseCode::Protocol,
544                            reason: Utf8Bytes::from_static("Protocol violatoin"),
545                        }
546                    } else {
547                        frame
548                    }
549                });
550
551                let reply = Frame::new_close(close.clone());
552                self.set_additional(reply);
553
554                Some(close)
555            }
556            WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => None,
557            WebSocketState::ClosedByServer => {
558                self.state = WebSocketState::CloseAcknowledged;
559                Some(close)
560            }
561            WebSocketState::Terminated => unreachable!(),
562        }
563    }
564
565    /// Write a single frame into the write-buffer.
566    fn buffer_frame<T>(&mut self, stream: &mut T, mut frame: Frame) -> Result<()>
567    where
568        T: Read + Write,
569    {
570        match self.mode {
571            OperationMode::Server => {}
572            OperationMode::Client => frame.set_random_mask(),
573        }
574
575        self.frame.write(stream, frame).check_connection_reset(self.state)
576    }
577
578    /// Replace `additional_send` if it is currently a `Pong` message.
579    fn set_additional(&mut self, additional: Frame) {
580        let empty_or_pong = self
581            .additional_send
582            .as_ref()
583            .map_or(true, |f| f.header().opcode == OpCode::Control(Control::Pong));
584
585        if empty_or_pong {
586            self.additional_send.replace(additional);
587        }
588    }
589}
590
591fn check_max_size(size: usize, max: Option<usize>) -> Result<()> {
592    if let Some(max) = max {
593        if size > max {
594            return Err(Error::Capacity(CapacityError::MessageTooLarge { size, max }));
595        }
596    }
597
598    Ok(())
599}
600
601/// The current connection state.
602#[derive(Debug, PartialEq, Eq, Clone, Copy)]
603enum WebSocketState {
604    /// The connection is active.
605    Active,
606    /// We initiated a close handshake.
607    ClosedByServer,
608    /// The peer initiated a close handshake.
609    ClosedByPeer,
610    /// The peer replied to our close handshake.
611    CloseAcknowledged,
612    /// The connection does not exist anymore.
613    Terminated,
614}
615
616impl WebSocketState {
617    /// Tell if we're allowed to process normal messages.
618    fn is_active(self) -> bool {
619        matches!(self, Self::Active)
620    }
621
622    /// Tell if we should process incoming data. Note that if we send a close frame
623    /// but the remote hasn't confirmed, they might have sent data before they receive our
624    /// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
625    fn can_read(self) -> bool {
626        matches!(self, Self::Active | Self::ClosedByServer)
627    }
628
629    /// Check if the state is active, return error if not.
630    fn check_if_terminated(self) -> Result<()> {
631        match self {
632            WebSocketState::Terminated => Err(Error::AlreadyClosed),
633            _ => Ok(()),
634        }
635    }
636}
637
638/// Translate "Connection reset by peer" into `ConnectionClosed` if appropriate.
639trait CheckConnectionReset {
640    fn check_connection_reset(self, state: WebSocketState) -> Self;
641}
642
643impl<T> CheckConnectionReset for Result<T> {
644    fn check_connection_reset(self, state: WebSocketState) -> Self {
645        match self {
646            Err(Error::Io(e)) => Err({
647                if !state.can_read() && e.kind() == io::ErrorKind::ConnectionReset {
648                    Error::ConnectionClosed
649                } else {
650                    Error::Io(e)
651                }
652            }),
653            other => other,
654        }
655    }
656}