blitz_ws/handshake/
machine.rs1use std::io::{Cursor, Read, Write};
4
5use bytes::Buf;
6
7use crate::{
8 error::{Error, ProtocolError, Result},
9 util::NonBlockingResult,
10 ReadBuffer,
11};
12
13#[derive(Debug)]
15pub struct HandshakeMachine<Stream> {
16 stream: Stream,
17 state: HandshakeState,
18}
19
20impl<Stream> HandshakeMachine<Stream> {
21 pub fn start_read(stream: Stream) -> Self {
23 Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
24 }
25
26 pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
28 HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
29 }
30
31 pub fn get_ref(&self) -> &Stream {
33 &self.stream
34 }
35
36 pub fn get_mut(&mut self) -> &mut Stream {
38 &mut self.stream
39 }
40}
41
42impl<Stream: Read + Write> HandshakeMachine<Stream> {
43 pub fn single_round<Object: TryParse>(mut self) -> Result<RoundResult<Object, Stream>> {
45 match self.state {
46 HandshakeState::Reading(mut buf, mut attack_check) => {
47 let read = buf.read_from(&mut self.stream).no_block()?;
48 match read {
49 Some(0) => Err(Error::Protocol(ProtocolError::IncompleteHandshake)),
50 Some(count) => {
51 attack_check.check_incoming_packet(count)?;
52 if let Some((size, obj)) = Object::try_parse(Buf::chunk(&buf))? {
53 buf.advance(size);
54
55 Ok(RoundResult::StageFinished(StageResult::DoneReading {
56 result: obj,
57 stream: self.stream,
58 tail: buf.into_vec(),
59 }))
60 } else {
61 Ok(RoundResult::Incomplete(HandshakeMachine {
62 state: HandshakeState::Reading(buf, attack_check),
63 ..self
64 }))
65 }
66 }
67 None => Ok(RoundResult::WouldBlock(HandshakeMachine {
68 state: HandshakeState::Reading(buf, attack_check),
69 ..self
70 })),
71 }
72 }
73 HandshakeState::Writing(mut buf) => {
74 assert!(buf.has_remaining());
75
76 if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? {
77 assert!(size > 0);
78
79 buf.advance(size);
80
81 Ok(if buf.has_remaining() {
82 RoundResult::Incomplete(HandshakeMachine {
83 state: HandshakeState::Writing(buf),
84 ..self
85 })
86 } else {
87 RoundResult::Incomplete(HandshakeMachine {
88 state: HandshakeState::Flushing,
89 ..self
90 })
91 })
92 } else {
93 Ok(RoundResult::WouldBlock(HandshakeMachine {
94 state: HandshakeState::Writing(buf),
95 ..self
96 }))
97 }
98 }
99 HandshakeState::Flushing => match self.stream.flush().no_block()? {
100 Some(()) => Ok(RoundResult::StageFinished(StageResult::DoneWriting(self.stream))),
101 None => Ok(RoundResult::WouldBlock(HandshakeMachine {
102 state: HandshakeState::Flushing,
103 ..self
104 })),
105 },
106 }
107 }
108}
109
110#[derive(Debug)]
112pub enum RoundResult<Object, Stream> {
113 WouldBlock(HandshakeMachine<Stream>),
115 Incomplete(HandshakeMachine<Stream>),
117 StageFinished(StageResult<Object, Stream>),
119}
120
121#[derive(Debug)]
123pub enum StageResult<Object, Stream> {
124 #[allow(missing_docs)]
126 DoneReading { result: Object, stream: Stream, tail: Vec<u8> },
127 DoneWriting(Stream),
129}
130
131pub trait TryParse: Sized {
133 fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
135}
136
137#[derive(Debug)]
139enum HandshakeState {
140 Reading(ReadBuffer, AttackCheck),
142 Writing(Cursor<Vec<u8>>),
144 Flushing,
146}
147
148#[derive(Debug)]
150pub(crate) struct AttackCheck {
151 packets: usize,
153 bytes: usize,
155}
156
157impl AttackCheck {
158 fn new() -> Self {
160 Self { packets: 0, bytes: 0 }
161 }
162
163 fn check_incoming_packet(&mut self, size: usize) -> Result<()> {
166 self.packets += 1;
167 self.bytes += size;
168
169 const MAX_BYTES: usize = 65536;
170 const MAX_PACKETS: usize = 512;
171 const MIN_PACKET_SIZE: usize = 128;
172 const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
173
174 if self.bytes > MAX_BYTES
175 || self.packets > MAX_PACKETS
176 || (self.packets > MIN_PACKET_CHECK_THRESHOLD
177 && self.packets * MIN_PACKET_SIZE > self.bytes)
178 {
179 return Err(Error::AttackAttempt);
180 }
181
182 Ok(())
183 }
184}