1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(async_fn_in_trait)]
3#![warn(clippy::large_futures)]
4#![allow(clippy::uninlined_format_args)]
5#![allow(unknown_lints)]
6
7pub type Fragmented = bool;
8pub type Final = bool;
9
10pub(crate) mod fmt;
12
13#[cfg(feature = "io")]
14pub mod io;
15
16#[derive(Copy, Clone, PartialEq, Eq, Debug)]
17pub enum FrameType {
18 Text(Fragmented),
19 Binary(Fragmented),
20 Ping,
21 Pong,
22 Close,
23 Continue(Final),
24}
25
26impl FrameType {
27 pub fn is_fragmented(&self) -> bool {
28 match self {
29 Self::Text(fragmented) | Self::Binary(fragmented) => *fragmented,
30 Self::Continue(_) => true,
31 _ => false,
32 }
33 }
34
35 pub fn is_final(&self) -> bool {
36 match self {
37 Self::Text(fragmented) | Self::Binary(fragmented) => !*fragmented,
38 Self::Continue(final_) => *final_,
39 _ => true,
40 }
41 }
42}
43
44impl core::fmt::Display for FrameType {
45 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46 match self {
47 Self::Text(fragmented) => {
48 write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
49 }
50 Self::Binary(fragmented) => write!(
51 f,
52 "Binary{}",
53 if *fragmented { " (fragmented)" } else { "" }
54 ),
55 Self::Ping => write!(f, "Ping"),
56 Self::Pong => write!(f, "Pong"),
57 Self::Close => write!(f, "Close"),
58 Self::Continue(ffinal) => {
59 write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
60 }
61 }
62 }
63}
64
65#[cfg(feature = "defmt")]
66impl defmt::Format for FrameType {
67 fn format(&self, f: defmt::Formatter<'_>) {
68 match self {
69 Self::Text(fragmented) => {
70 defmt::write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
71 }
72 Self::Binary(fragmented) => defmt::write!(
73 f,
74 "Binary{}",
75 if *fragmented { " (fragmented)" } else { "" }
76 ),
77 Self::Ping => defmt::write!(f, "Ping"),
78 Self::Pong => defmt::write!(f, "Pong"),
79 Self::Close => defmt::write!(f, "Close"),
80 Self::Continue(ffinal) => {
81 defmt::write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
82 }
83 }
84 }
85}
86
87#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
88pub enum Error<E> {
89 Incomplete(usize),
90 Invalid,
91 BufferOverflow,
92 InvalidLen,
93 Io(E),
94}
95
96impl Error<()> {
97 pub fn recast<E>(self) -> Error<E> {
98 match self {
99 Self::Incomplete(v) => Error::Incomplete(v),
100 Self::Invalid => Error::Invalid,
101 Self::BufferOverflow => Error::BufferOverflow,
102 Self::InvalidLen => Error::InvalidLen,
103 Self::Io(_) => panic!(),
104 }
105 }
106}
107
108impl<E> core::fmt::Display for Error<E>
109where
110 E: core::fmt::Display,
111{
112 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
113 match self {
114 Self::Incomplete(size) => write!(f, "Incomplete: {} bytes missing", size),
115 Self::Invalid => write!(f, "Invalid"),
116 Self::BufferOverflow => write!(f, "Buffer overflow"),
117 Self::InvalidLen => write!(f, "Invalid length"),
118 Self::Io(err) => write!(f, "IO error: {}", err),
119 }
120 }
121}
122
123#[cfg(feature = "defmt")]
124impl<E> defmt::Format for Error<E>
125where
126 E: defmt::Format,
127{
128 fn format(&self, f: defmt::Formatter<'_>) {
129 match self {
130 Self::Incomplete(size) => defmt::write!(f, "Incomplete: {} bytes missing", size),
131 Self::Invalid => defmt::write!(f, "Invalid"),
132 Self::BufferOverflow => defmt::write!(f, "Buffer overflow"),
133 Self::InvalidLen => defmt::write!(f, "Invalid length"),
134 Self::Io(err) => defmt::write!(f, "IO error: {}", err),
135 }
136 }
137}
138
139impl<E> core::error::Error for Error<E> where E: core::error::Error {}
140
141#[derive(Clone, Debug)]
142pub struct FrameHeader {
143 pub frame_type: FrameType,
144 pub payload_len: u64,
145 pub mask_key: Option<u32>,
146}
147
148impl FrameHeader {
149 pub const MIN_LEN: usize = 2;
150 pub const MAX_LEN: usize = FrameHeader {
151 frame_type: FrameType::Binary(false),
152 payload_len: 65536,
153 mask_key: Some(0),
154 }
155 .serialized_len();
156
157 pub fn deserialize(buf: &[u8]) -> Result<(Self, usize), Error<()>> {
158 let mut expected_len = 2_usize;
159
160 if buf.len() < expected_len {
161 Err(Error::Incomplete(expected_len - buf.len()))
162 } else {
163 let final_frame = buf[0] & 0x80 != 0;
164
165 let rsv = buf[0] & 0x70;
166 if rsv != 0 {
167 return Err(Error::Invalid);
168 }
169
170 let opcode = buf[0] & 0x0f;
171 if (3..=7).contains(&opcode) || opcode >= 11 {
172 return Err(Error::Invalid);
173 }
174
175 let mut payload_len = (buf[1] & 0x7f) as u64;
176 let mut payload_offset = 2;
177
178 if payload_len == 126 {
179 expected_len += 2;
180
181 if buf.len() < expected_len {
182 return Err(Error::Incomplete(expected_len - buf.len()));
183 } else {
184 payload_len = u16::from_be_bytes([buf[2], buf[3]]) as _;
185 payload_offset += 2;
186 }
187 } else if payload_len == 127 {
188 expected_len += 8;
189
190 if buf.len() < expected_len {
191 return Err(Error::Incomplete(expected_len - buf.len()));
192 } else {
193 payload_len = u64::from_be_bytes([
194 buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
195 ]);
196 payload_offset += 8;
197 }
198 }
199
200 let masked = buf[1] & 0x80 != 0;
201 let mask_key = if masked {
202 expected_len += 4;
203 if buf.len() < expected_len {
204 return Err(Error::Incomplete(expected_len - buf.len()));
205 } else {
206 let mask_key = Some(u32::from_be_bytes([
207 buf[payload_offset],
208 buf[payload_offset + 1],
209 buf[payload_offset + 2],
210 buf[payload_offset + 3],
211 ]));
212 payload_offset += 4;
213
214 mask_key
215 }
216 } else {
217 None
218 };
219
220 let frame_type = match opcode {
221 0 => FrameType::Continue(final_frame),
222 1 => FrameType::Text(!final_frame),
223 2 => FrameType::Binary(!final_frame),
224 8 => FrameType::Close,
225 9 => FrameType::Ping,
226 10 => FrameType::Pong,
227 _ => unreachable!(),
228 };
229
230 let frame_header = FrameHeader {
231 frame_type,
232 payload_len,
233 mask_key,
234 };
235
236 Ok((frame_header, payload_offset))
237 }
238 }
239
240 pub const fn serialized_len(&self) -> usize {
241 let payload_len_len = if self.payload_len >= 65536 {
242 8
243 } else if self.payload_len >= 126 {
244 2
245 } else {
246 0
247 };
248
249 2 + if self.mask_key.is_some() { 4 } else { 0 } + payload_len_len
250 }
251
252 pub fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error<()>> {
253 if buf.len() < self.serialized_len() {
254 return Err(Error::InvalidLen);
255 }
256
257 buf[0] = 0;
258 buf[1] = 0;
259
260 if self.frame_type.is_final() {
261 buf[0] |= 0x80;
262 }
263
264 let opcode = match self.frame_type {
265 FrameType::Text(_) => 1,
266 FrameType::Binary(_) => 2,
267 FrameType::Close => 8,
268 FrameType::Ping => 9,
269 FrameType::Pong => 10,
270 _ => 0,
271 };
272
273 buf[0] |= opcode;
274
275 let mut payload_offset = 2;
276
277 if self.payload_len < 126 {
278 buf[1] |= self.payload_len as u8;
279 } else {
280 let payload_len_bytes = self.payload_len.to_be_bytes();
281 if self.payload_len >= 126 && self.payload_len < 65536 {
282 buf[1] |= 126;
283 buf[2] = payload_len_bytes[6];
284 buf[3] = payload_len_bytes[7];
285
286 payload_offset += 2;
287 } else {
288 buf[1] |= 127;
289 buf[2] = payload_len_bytes[0];
290 buf[3] = payload_len_bytes[1];
291 buf[4] = payload_len_bytes[2];
292 buf[5] = payload_len_bytes[3];
293 buf[6] = payload_len_bytes[4];
294 buf[7] = payload_len_bytes[5];
295 buf[8] = payload_len_bytes[6];
296 buf[9] = payload_len_bytes[7];
297
298 payload_offset += 8;
299 }
300 }
301
302 if let Some(mask_key) = self.mask_key {
303 buf[1] |= 0x80;
304
305 let mask_key_bytes = mask_key.to_be_bytes();
306
307 buf[payload_offset] = mask_key_bytes[0];
308 buf[payload_offset + 1] = mask_key_bytes[1];
309 buf[payload_offset + 2] = mask_key_bytes[2];
310 buf[payload_offset + 3] = mask_key_bytes[3];
311
312 payload_offset += 4;
313 }
314
315 Ok(payload_offset)
316 }
317
318 pub fn mask(&self, buf: &mut [u8], payload_offset: usize) {
319 Self::mask_with(buf, self.mask_key, payload_offset)
320 }
321
322 pub fn mask_with(buf: &mut [u8], mask_key: Option<u32>, payload_offset: usize) {
323 if let Some(mask_key) = mask_key {
324 let mask_bytes = mask_key.to_be_bytes();
325
326 for (offset, byte) in buf.iter_mut().enumerate() {
327 *byte ^= mask_bytes[(payload_offset + offset) % 4];
328 }
329 }
330 }
331}
332
333impl core::fmt::Display for FrameHeader {
334 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
335 write!(
336 f,
337 "Frame {{ {}, payload len {}, mask {:?} }}",
338 self.frame_type, self.payload_len, self.mask_key
339 )
340 }
341}
342
343#[cfg(feature = "defmt")]
344impl defmt::Format for FrameHeader {
345 fn format(&self, f: defmt::Formatter<'_>) {
346 defmt::write!(
347 f,
348 "Frame {{ {}, payload len {}, mask {:?} }}",
349 self.frame_type,
350 self.payload_len,
351 self.mask_key
352 )
353 }
354}