blitz_ws/protocol/websocket.rs
1//! WebSocket handler
2
3use std::{
4 io::{self, Read, Write},
5 mem::replace,
6};
7
8use crate::{
9 error::{CapacityError, Error, ProtocolError, Result},
10 protocol::{
11 config::WebSocketConfig,
12 frame::{
13 codec::{CloseCode, Control, Data, OpCode},
14 core::FrameCodec,
15 CloseFrame, Frame, Utf8Bytes,
16 },
17 message::{IncompleteMessage, IncompleteMessageType, Message},
18 },
19 MAX_CONTROL_FRAME_PAYLOAD,
20};
21
22/// WebSocket operation mode
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum OperationMode {
25 /// Client mode
26 Client,
27 /// Server mode
28 Server,
29}
30
31/// WebSocket input-output stream.
32///
33/// This is THE structure you want to create to be able to speak the WebSocket protocol.
34/// It may be created by calling `connect`, `accept` or `client` functions.
35///
36/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
37#[derive(Debug)]
38pub struct WebSocket<T> {
39 stream: T,
40 context: WebSocketContext,
41}
42
43impl<T: Read + Write> WebSocket<T> {
44 /// Convert a raw socket into a WebSocket without performing a handshake.
45 ///
46 /// Call this function if you're using Tungstenite as a part of a web framework
47 /// or together with an existing one. If you need an initial handshake, use
48 /// `connect()` or `accept()` functions of the crate to construct a websocket.
49 ///
50 /// # Panics
51 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
52 pub fn new(stream: T, mode: OperationMode, config: Option<WebSocketConfig>) -> Self {
53 WebSocket { stream, context: WebSocketContext::new(mode, config) }
54 }
55
56 /// Convert a raw socket into a WebSocket without performing a handshake.
57 ///
58 /// Call this function if you're using Tungstenite as a part of a web framework
59 /// or together with an existing one. If you need an initial handshake, use
60 /// `connect()` or `accept()` functions of the crate to construct a websocket.
61 ///
62 /// # Panics
63 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
64 pub fn from_partially_read(
65 stream: T,
66 part: Vec<u8>,
67 mode: OperationMode,
68 config: Option<WebSocketConfig>,
69 ) -> Self {
70 WebSocket { stream, context: WebSocketContext::from_partially_read(part, mode, config) }
71 }
72
73 /// Returns a shared reference to the stream
74 pub fn get_ref(&self) -> &T {
75 &self.stream
76 }
77
78 /// Returns a mutable reference to the stream
79 pub fn get_mut(&mut self) -> &mut T {
80 &mut self.stream
81 }
82
83 /// Returns the inner instance of the stream
84 pub fn into_inner(self) -> T {
85 self.stream
86 }
87
88 /// Change the configuration.
89 ///
90 /// # Panics
91 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
92 pub fn set_config(&mut self, func: impl FnOnce(&mut WebSocketConfig)) {
93 self.context.set_config(func);
94 }
95
96 /// Read the configuration.
97 pub fn get_config(&self) -> &WebSocketConfig {
98 self.context.get_config()
99 }
100
101 /// Check if it is possible to read messages.
102 ///
103 /// Reading is impossible after receiving `Message::Close`. It is still possible after
104 /// sending close frame since the peer still may send some data before confirming close.
105 pub fn can_read(&self) -> bool {
106 self.context.can_read()
107 }
108
109 /// Check if it is possible to write messages.
110 ///
111 /// Writing gets impossible immediately after sending or receiving `Message::Close`.
112 pub fn can_write(&self) -> bool {
113 self.context.can_write()
114 }
115
116 /// Check if it is possible to read messages.
117 ///
118 /// Reading is impossible after receiving `Message::Close`. It is still possible after
119 /// sending close frame since the peer still may send some data before confirming close.
120 pub fn read(&mut self) -> Result<Message> {
121 self.context.read(&mut self.stream)
122 }
123
124 /// Writes and immediately flushes a message.
125 /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
126 pub fn send(&mut self, msg: Message) -> Result<()> {
127 self.write(msg)?;
128 self.flush()
129 }
130
131 /// Write a message to the provided stream, if possible.
132 ///
133 /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
134 ///
135 /// In the event of stream write failure the message frame will be stored
136 /// in the write buffer and will try again on the next call to [`write`](Self::write)
137 /// or [`flush`](Self::flush).
138 ///
139 /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
140 /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
141 ///
142 /// This call will generally not flush. However, if there are queued automatic messages
143 /// they will be written and eagerly flushed.
144 ///
145 /// For example, upon receiving ping messages tungstenite queues pong replies automatically.
146 /// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
147 /// will write & flush the pong reply. This means you should not respond to ping frames manually.
148 ///
149 /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
150 /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
151 /// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing
152 /// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the
153 /// ping will not be sent as it will be replaced by your custom pong message.
154 ///
155 /// # Errors
156 /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned
157 /// along with the equivalent passed message frame.
158 /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`].
159 /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from
160 /// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program
161 /// error on your part.
162 /// - [`Error::Io`] is returned if the underlying connection returns an error
163 /// (consider these fatal except for WouldBlock).
164 /// - [`Error::Capacity`] if your message size is bigger than the configured max message size.
165 pub fn write(&mut self, msg: Message) -> Result<()> {
166 self.context.write(&mut self.stream, msg)
167 }
168
169 /// Flush writes.
170 ///
171 /// Ensures all messages previously passed to [`write`](Self::write) and automatic
172 /// queued pong responses are written & flushed into the underlying stream.
173 pub fn flush(&mut self) -> Result<()> {
174 self.context.flush(&mut self.stream)
175 }
176
177 /// Close the connection.
178 ///
179 /// This function guarantees that the close frame will be queued.
180 /// There is no need to call it again. Calling this function is
181 /// the same as calling `write(Message::Close(..))`.
182 ///
183 /// After queuing the close frame you should continue calling [`read`](Self::read) or
184 /// [`flush`](Self::flush) to drive the close handshake to completion.
185 ///
186 /// The websocket RFC defines that the underlying connection should be closed
187 /// by the server. Tungstenite takes care of this asymmetry for you.
188 ///
189 /// When the close handshake is finished (we have both sent and received
190 /// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return
191 /// [Error::ConnectionClosed] if this endpoint is the server.
192 ///
193 /// If this endpoint is a client, [Error::ConnectionClosed] will only be
194 /// returned after the server has closed the underlying connection.
195 ///
196 /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
197 /// is returned from [`read`](Self::read) or [`flush`](Self::flush).
198 pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
199 self.context.close(&mut self.stream, code)
200 }
201}
202
203/// A context for managing WebSocket stream.
204#[derive(Debug)]
205pub struct WebSocketContext {
206 /// Server or client?
207 mode: OperationMode,
208 /// encoder / decoder of frame.
209 frame: FrameCodec,
210 /// The state of processing, either "active" or "closing".
211 state: WebSocketState,
212 /// Receive: an incomplete message being processed.
213 incomplete: Option<IncompleteMessage>,
214 /// Send in addition to regular messages E.g. "pong" or "close".
215 additional_send: Option<Frame>,
216 /// True indicates there is an additional message (like a pong)
217 /// that failed to flush previously and we should try again.
218 unflushed_additional: bool,
219 /// The configuration for the websocket session.
220 config: WebSocketConfig,
221}
222
223impl WebSocketContext {
224 /// Create a WebSocket context that manages a post-handshake stream.
225 ///
226 /// # Panics
227 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
228 pub fn new(mode: OperationMode, config: Option<WebSocketConfig>) -> Self {
229 let configuration = config.unwrap_or_default();
230 Self::_new(mode, FrameCodec::new(configuration.read_buffer_size), configuration)
231 }
232
233 /// Create a WebSocket context that manages an post-handshake stream.
234 ///
235 /// # Panics
236 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
237 pub fn from_partially_read(
238 part: Vec<u8>,
239 mode: OperationMode,
240 config: Option<WebSocketConfig>,
241 ) -> Self {
242 let configuration = config.unwrap_or_default();
243 Self::_new(
244 mode,
245 FrameCodec::from_partially_read(part, configuration.read_buffer_size),
246 configuration,
247 )
248 }
249
250 fn _new(mode: OperationMode, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
251 config.asset_valid();
252
253 frame.max_out_buffer_len(config.max_write_buffer_size);
254 frame.out_buffer_write_len(config.write_buffer_size);
255
256 Self {
257 mode,
258 frame,
259 state: WebSocketState::Active,
260 incomplete: None,
261 additional_send: None,
262 unflushed_additional: false,
263 config,
264 }
265 }
266
267 /// Change the configuration.
268 ///
269 /// # Panics
270 /// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
271 pub fn set_config(&mut self, func: impl FnOnce(&mut WebSocketConfig)) {
272 func(&mut self.config);
273
274 self.config.asset_valid();
275 self.frame.max_out_buffer_len(self.config.max_write_buffer_size);
276 self.frame.out_buffer_write_len(self.config.write_buffer_size);
277 }
278
279 /// Read the configuration.
280 pub fn get_config(&self) -> &WebSocketConfig {
281 &self.config
282 }
283
284 /// Check if it is possible to read messages.
285 ///
286 /// Reading is impossible after receiving `Message::Close`. It is still possible after
287 /// sending close frame since the peer still may send some data before confirming close.
288 pub fn can_read(&self) -> bool {
289 self.state.can_read()
290 }
291
292 /// Check if it is possible to write messages.
293 ///
294 /// Writing gets impossible immediately after sending or receiving `Message::Close`.
295 pub fn can_write(&self) -> bool {
296 self.state.is_active()
297 }
298
299 /// Read a message from the provided stream, if possible.
300 ///
301 /// This function sends pong and close responses automatically.
302 /// However, it never blocks on write.
303 pub fn read<T: Read + Write>(&mut self, stream: &mut T) -> Result<Message> {
304 self.state.check_if_terminated()?;
305
306 loop {
307 if self.additional_send.is_some() || self.unflushed_additional {
308 match self.flush(stream) {
309 Ok(_) => {}
310 Err(Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => {
311 self.unflushed_additional = true
312 }
313 Err(e) => return Err(e),
314 }
315 } else if self.mode == OperationMode::Server && !self.state.can_read() {
316 self.state = WebSocketState::Terminated;
317 return Err(Error::ConnectionClosed);
318 }
319
320 if let Some(msg) = self._read(stream)? {
321 return Ok(msg);
322 }
323 }
324 }
325
326 /// Write a message to the provided stream.
327 ///
328 /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
329 ///
330 /// In the event of stream write failure the message frame will be stored
331 /// in the write buffer and will try again on the next call to [`write`](Self::write)
332 /// or [`flush`](Self::flush).
333 ///
334 /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
335 /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
336 pub fn write<T: Read + Write>(&mut self, stream: &mut T, msg: Message) -> Result<()> {
337 self.state.check_if_terminated()?;
338
339 if !self.state.is_active() {
340 return Err(Error::Protocol(ProtocolError::SendAfterClose));
341 }
342
343 let frame = match msg {
344 Message::Text(data) => Frame::new_data(data, OpCode::Data(Data::Text), true),
345 Message::Binary(data) => Frame::new_data(data, OpCode::Data(Data::Binary), true),
346 Message::Ping(data) => Frame::new_ping(data),
347 Message::Pong(data) => {
348 self.set_additional(Frame::new_pong(data));
349 return self._write(stream, None).map(|_| ());
350 }
351 Message::Close(code) => return self.close(stream, code),
352 Message::Frame(f) => f,
353 };
354
355 let should_flush = self._write(stream, Some(frame))?;
356 if should_flush {
357 self.flush(stream)?;
358 }
359
360 Ok(())
361 }
362
363 /// Flush writes.
364 ///
365 /// Ensures all messages previously passed to [`write`](Self::write) and automatically
366 /// queued pong responses are written & flushed into the `stream`.
367 #[inline]
368 pub fn flush<T: Read + Write>(&mut self, stream: &mut T) -> Result<()> {
369 self._write(stream, None)?;
370 self.frame.write_out(stream)?;
371
372 stream.flush()?;
373
374 self.unflushed_additional = false;
375
376 Ok(())
377 }
378
379 /// Close the connection.
380 ///
381 /// This function guarantees that the close frame will be queued.
382 /// There is no need to call it again. Calling this function is
383 /// the same as calling `send(Message::Close(..))`.
384 pub fn close<T: Read + Write>(
385 &mut self,
386 stream: &mut T,
387 code: Option<CloseFrame>,
388 ) -> Result<()> {
389 if let WebSocketState::Active = self.state {
390 self.state = WebSocketState::ClosedByServer;
391
392 let frame = Frame::new_close(code);
393
394 self._write(stream, Some(frame))?;
395 }
396
397 self.flush(stream)
398 }
399
400 fn _read<T: Read>(&mut self, stream: &mut T) -> Result<Option<Message>> {
401 if let Some(frame) = self
402 .frame
403 .read(
404 stream,
405 self.config.max_frame_size,
406 matches!(self.mode, OperationMode::Server),
407 self.config.accept_unmasked_frames,
408 )
409 .check_connection_reset(self.state)?
410 {
411 if !self.state.can_read() {
412 return Err(Error::Protocol(ProtocolError::ReceiveAfterClose));
413 }
414
415 let header = frame.header();
416 if header.rsv1 || header.rsv2 || header.rsv3 {
417 return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
418 }
419
420 if self.mode == OperationMode::Client && frame.is_masked() {
421 return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer));
422 }
423
424 match frame.header().opcode {
425 OpCode::Control(ctrl) => match ctrl {
426 _ if !frame.header().fin => {
427 Err(Error::Protocol(ProtocolError::FragmentedControlFrame))
428 }
429 _ if frame.payload().len() > MAX_CONTROL_FRAME_PAYLOAD => {
430 Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
431 }
432 Control::Close => Ok(self.try_close(frame.into_close()?).map(Message::Close)),
433 Control::Reserved(code) => {
434 Err(Error::Protocol(ProtocolError::UnknownControlOpCode(code)))
435 }
436 Control::Ping => {
437 let data = frame.into_payload();
438 if self.state.is_active() {
439 self.set_additional(Frame::new_pong(data.clone()));
440 }
441
442 Ok(Some(Message::Ping(data)))
443 }
444 Control::Pong => Ok(Some(Message::Pong(frame.into_payload()))),
445 },
446 OpCode::Data(data) => {
447 let fin = frame.header().fin;
448
449 match data {
450 Data::Continuation => {
451 if let Some(ref mut msg) = self.incomplete {
452 msg.extend(frame.into_payload(), self.config.max_message_size)?;
453 } else {
454 return Err(Error::Protocol(ProtocolError::UnexpectedContinue));
455 }
456
457 if fin {
458 Ok(Some(self.incomplete.take().unwrap().complete()?))
459 } else {
460 Ok(None)
461 }
462 }
463 data_frag if self.incomplete.is_some() => {
464 Err(Error::Protocol(ProtocolError::ExpectedFragment(data_frag)))
465 }
466 Data::Text if fin => {
467 check_max_size(frame.payload().len(), self.config.max_message_size)?;
468 Ok(Some(Message::Text(frame.into_text()?)))
469 }
470 Data::Binary if fin => {
471 check_max_size(frame.payload().len(), self.config.max_message_size)?;
472 Ok(Some(Message::Binary(frame.into_payload())))
473 }
474 Data::Text | Data::Binary => {
475 let msg_type = match data {
476 Data::Text => IncompleteMessageType::Text,
477 Data::Binary => IncompleteMessageType::Binary,
478 _ => panic!("Bug: message is neither text not binary"),
479 };
480
481 let mut incomplete = IncompleteMessage::new(msg_type);
482 incomplete
483 .extend(frame.into_payload(), self.config.max_message_size)?;
484
485 self.incomplete = Some(incomplete);
486
487 Ok(None)
488 }
489 Data::Reserved(code) => {
490 Err(Error::Protocol(ProtocolError::UnknownDataOpCode(code)))
491 }
492 }
493 }
494 }
495 } else {
496 match replace(&mut self.state, WebSocketState::Terminated) {
497 WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
498 Err(Error::ConnectionClosed)
499 }
500 _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosing)),
501 }
502 }
503 }
504
505 fn _write<T: Read + Write>(&mut self, stream: &mut T, data: Option<Frame>) -> Result<bool> {
506 if let Some(data) = data {
507 self.buffer_frame(stream, data)?;
508 }
509
510 let should_flush = if let Some(msg) = self.additional_send.take() {
511 match self.buffer_frame(stream, msg.clone()) {
512 Err(Error::WriteBufferFull) => {
513 self.set_additional(msg);
514 false
515 }
516 Err(e) => return Err(e),
517 Ok(_) => true,
518 }
519 } else {
520 self.unflushed_additional
521 };
522
523 if self.mode == OperationMode::Server && !self.state.can_read() {
524 self.frame.write_out(stream)?;
525 self.state = WebSocketState::Terminated;
526
527 Err(Error::ConnectionClosed)
528 } else {
529 Ok(should_flush)
530 }
531 }
532
533 /// Received a close frame. Tells if we need to return a close frame to the user.
534 #[allow(clippy::option_option)]
535 fn try_close(&mut self, close: Option<CloseFrame>) -> Option<Option<CloseFrame>> {
536 match self.state {
537 WebSocketState::Active => {
538 self.state = WebSocketState::ClosedByPeer;
539
540 let close = close.map(|frame| {
541 if !frame.code.allowed() {
542 CloseFrame {
543 code: CloseCode::Protocol,
544 reason: Utf8Bytes::from_static("Protocol violatoin"),
545 }
546 } else {
547 frame
548 }
549 });
550
551 let reply = Frame::new_close(close.clone());
552 self.set_additional(reply);
553
554 Some(close)
555 }
556 WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => None,
557 WebSocketState::ClosedByServer => {
558 self.state = WebSocketState::CloseAcknowledged;
559 Some(close)
560 }
561 WebSocketState::Terminated => unreachable!(),
562 }
563 }
564
565 /// Write a single frame into the write-buffer.
566 fn buffer_frame<T>(&mut self, stream: &mut T, mut frame: Frame) -> Result<()>
567 where
568 T: Read + Write,
569 {
570 match self.mode {
571 OperationMode::Server => {}
572 OperationMode::Client => frame.set_random_mask(),
573 }
574
575 self.frame.write(stream, frame).check_connection_reset(self.state)
576 }
577
578 /// Replace `additional_send` if it is currently a `Pong` message.
579 fn set_additional(&mut self, additional: Frame) {
580 let empty_or_pong = self
581 .additional_send
582 .as_ref()
583 .map_or(true, |f| f.header().opcode == OpCode::Control(Control::Pong));
584
585 if empty_or_pong {
586 self.additional_send.replace(additional);
587 }
588 }
589}
590
591fn check_max_size(size: usize, max: Option<usize>) -> Result<()> {
592 if let Some(max) = max {
593 if size > max {
594 return Err(Error::Capacity(CapacityError::MessageTooLarge { size, max }));
595 }
596 }
597
598 Ok(())
599}
600
601/// The current connection state.
602#[derive(Debug, PartialEq, Eq, Clone, Copy)]
603enum WebSocketState {
604 /// The connection is active.
605 Active,
606 /// We initiated a close handshake.
607 ClosedByServer,
608 /// The peer initiated a close handshake.
609 ClosedByPeer,
610 /// The peer replied to our close handshake.
611 CloseAcknowledged,
612 /// The connection does not exist anymore.
613 Terminated,
614}
615
616impl WebSocketState {
617 /// Tell if we're allowed to process normal messages.
618 fn is_active(self) -> bool {
619 matches!(self, Self::Active)
620 }
621
622 /// Tell if we should process incoming data. Note that if we send a close frame
623 /// but the remote hasn't confirmed, they might have sent data before they receive our
624 /// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
625 fn can_read(self) -> bool {
626 matches!(self, Self::Active | Self::ClosedByServer)
627 }
628
629 /// Check if the state is active, return error if not.
630 fn check_if_terminated(self) -> Result<()> {
631 match self {
632 WebSocketState::Terminated => Err(Error::AlreadyClosed),
633 _ => Ok(()),
634 }
635 }
636}
637
638/// Translate "Connection reset by peer" into `ConnectionClosed` if appropriate.
639trait CheckConnectionReset {
640 fn check_connection_reset(self, state: WebSocketState) -> Self;
641}
642
643impl<T> CheckConnectionReset for Result<T> {
644 fn check_connection_reset(self, state: WebSocketState) -> Self {
645 match self {
646 Err(Error::Io(e)) => Err({
647 if !state.can_read() && e.kind() == io::ErrorKind::ConnectionReset {
648 Error::ConnectionClosed
649 } else {
650 Error::Io(e)
651 }
652 }),
653 other => other,
654 }
655 }
656}