blitz_ws/handshake/
machine.rs

1//! WebSocket handshake machine
2
3use std::io::{Cursor, Read, Write};
4
5use bytes::Buf;
6
7use crate::{
8    error::{Error, ProtocolError, Result},
9    util::NonBlockingResult,
10    ReadBuffer,
11};
12
13/// A generic handshake state machine
14#[derive(Debug)]
15pub struct HandshakeMachine<Stream> {
16    stream: Stream,
17    state: HandshakeState,
18}
19
20impl<Stream> HandshakeMachine<Stream> {
21    /// Start reading data from the peer
22    pub fn start_read(stream: Stream) -> Self {
23        Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
24    }
25
26    /// Start writing data to the peer
27    pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
28        HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
29    }
30
31    /// Returns a shared reference to the internal stream
32    pub fn get_ref(&self) -> &Stream {
33        &self.stream
34    }
35
36    /// Returns a mutable reference to the internal stream
37    pub fn get_mut(&mut self) -> &mut Stream {
38        &mut self.stream
39    }
40}
41
42impl<Stream: Read + Write> HandshakeMachine<Stream> {
43    /// Performs a single-round handshake
44    pub fn single_round<Object: TryParse>(mut self) -> Result<RoundResult<Object, Stream>> {
45        match self.state {
46            HandshakeState::Reading(mut buf, mut attack_check) => {
47                let read = buf.read_from(&mut self.stream).no_block()?;
48                match read {
49                    Some(0) => Err(Error::Protocol(ProtocolError::IncompleteHandshake)),
50                    Some(count) => {
51                        attack_check.check_incoming_packet(count)?;
52                        if let Some((size, obj)) = Object::try_parse(Buf::chunk(&buf))? {
53                            buf.advance(size);
54
55                            Ok(RoundResult::StageFinished(StageResult::DoneReading {
56                                result: obj,
57                                stream: self.stream,
58                                tail: buf.into_vec(),
59                            }))
60                        } else {
61                            Ok(RoundResult::Incomplete(HandshakeMachine {
62                                state: HandshakeState::Reading(buf, attack_check),
63                                ..self
64                            }))
65                        }
66                    }
67                    None => Ok(RoundResult::WouldBlock(HandshakeMachine {
68                        state: HandshakeState::Reading(buf, attack_check),
69                        ..self
70                    })),
71                }
72            }
73            HandshakeState::Writing(mut buf) => {
74                assert!(buf.has_remaining());
75
76                if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? {
77                    assert!(size > 0);
78
79                    buf.advance(size);
80
81                    Ok(if buf.has_remaining() {
82                        RoundResult::Incomplete(HandshakeMachine {
83                            state: HandshakeState::Writing(buf),
84                            ..self
85                        })
86                    } else {
87                        RoundResult::Incomplete(HandshakeMachine {
88                            state: HandshakeState::Flushing,
89                            ..self
90                        })
91                    })
92                } else {
93                    Ok(RoundResult::WouldBlock(HandshakeMachine {
94                        state: HandshakeState::Writing(buf),
95                        ..self
96                    }))
97                }
98            }
99            HandshakeState::Flushing => match self.stream.flush().no_block()? {
100                Some(()) => Ok(RoundResult::StageFinished(StageResult::DoneWriting(self.stream))),
101                None => Ok(RoundResult::WouldBlock(HandshakeMachine {
102                    state: HandshakeState::Flushing,
103                    ..self
104                })),
105            },
106        }
107    }
108}
109
110/// The result of the Round
111#[derive(Debug)]
112pub enum RoundResult<Object, Stream> {
113    /// Round not done, I/O would block
114    WouldBlock(HandshakeMachine<Stream>),
115    /// Round done, stage unchanged
116    Incomplete(HandshakeMachine<Stream>),
117    /// Stage complete
118    StageFinished(StageResult<Object, Stream>),
119}
120
121/// The result of the stage
122#[derive(Debug)]
123pub enum StageResult<Object, Stream> {
124    /// Reading finished round
125    #[allow(missing_docs)]
126    DoneReading { result: Object, stream: Stream, tail: Vec<u8> },
127    /// Writing finished round
128    DoneWriting(Stream),
129}
130
131/// A parse-able object
132pub trait TryParse: Sized {
133    /// Returns Ok(None) if incomplete, Err on syntax errors
134    fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
135}
136
137/// The handshake state
138#[derive(Debug)]
139enum HandshakeState {
140    /// Reading data from peer
141    Reading(ReadBuffer, AttackCheck),
142    /// Sending data to peer
143    Writing(Cursor<Vec<u8>>),
144    /// Flushing data to ensure that all intermediaries reach their destinations
145    Flushing,
146}
147
148/// Attack mitigation against DoS attacks
149#[derive(Debug)]
150pub(crate) struct AttackCheck {
151    /// Number of HTTP header successful reads (TCP packets)
152    packets: usize,
153    /// Total number of bytes in HTTP header
154    bytes: usize,
155}
156
157impl AttackCheck {
158    /// Initialize attack checking for incoming buffer
159    fn new() -> Self {
160        Self { packets: 0, bytes: 0 }
161    }
162
163    /// Check the size of an incoming packet. To be called immediately after `read()`
164    /// passing its returned bytes count as `size`
165    fn check_incoming_packet(&mut self, size: usize) -> Result<()> {
166        self.packets += 1;
167        self.bytes += size;
168
169        const MAX_BYTES: usize = 65536;
170        const MAX_PACKETS: usize = 512;
171        const MIN_PACKET_SIZE: usize = 128;
172        const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
173
174        if self.bytes > MAX_BYTES
175            || self.packets > MAX_PACKETS
176            || (self.packets > MIN_PACKET_CHECK_THRESHOLD
177                && self.packets * MIN_PACKET_SIZE > self.bytes)
178        {
179            return Err(Error::AttackAttempt);
180        }
181
182        Ok(())
183    }
184}