1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(async_fn_in_trait)]
3#![warn(clippy::large_futures)]
4#![allow(clippy::uninlined_format_args)]
5#![allow(unknown_lints)]
6
7pub type Fragmented = bool;
8pub type Final = bool;
9
10#[allow(unused)]
11#[cfg(feature = "embedded-svc")]
12pub use embedded_svc_compat::*;
13
14pub(crate) mod fmt;
16
17#[cfg(feature = "io")]
18pub mod io;
19
20#[derive(Copy, Clone, PartialEq, Eq, Debug)]
21pub enum FrameType {
22 Text(Fragmented),
23 Binary(Fragmented),
24 Ping,
25 Pong,
26 Close,
27 Continue(Final),
28}
29
30impl FrameType {
31 pub fn is_fragmented(&self) -> bool {
32 match self {
33 Self::Text(fragmented) | Self::Binary(fragmented) => *fragmented,
34 Self::Continue(_) => true,
35 _ => false,
36 }
37 }
38
39 pub fn is_final(&self) -> bool {
40 match self {
41 Self::Text(fragmented) | Self::Binary(fragmented) => !*fragmented,
42 Self::Continue(final_) => *final_,
43 _ => true,
44 }
45 }
46}
47
48impl core::fmt::Display for FrameType {
49 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50 match self {
51 Self::Text(fragmented) => {
52 write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
53 }
54 Self::Binary(fragmented) => write!(
55 f,
56 "Binary{}",
57 if *fragmented { " (fragmented)" } else { "" }
58 ),
59 Self::Ping => write!(f, "Ping"),
60 Self::Pong => write!(f, "Pong"),
61 Self::Close => write!(f, "Close"),
62 Self::Continue(ffinal) => {
63 write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
64 }
65 }
66 }
67}
68
69#[cfg(feature = "defmt")]
70impl defmt::Format for FrameType {
71 fn format(&self, f: defmt::Formatter<'_>) {
72 match self {
73 Self::Text(fragmented) => {
74 defmt::write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
75 }
76 Self::Binary(fragmented) => defmt::write!(
77 f,
78 "Binary{}",
79 if *fragmented { " (fragmented)" } else { "" }
80 ),
81 Self::Ping => defmt::write!(f, "Ping"),
82 Self::Pong => defmt::write!(f, "Pong"),
83 Self::Close => defmt::write!(f, "Close"),
84 Self::Continue(ffinal) => {
85 defmt::write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
86 }
87 }
88 }
89}
90
91#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
92pub enum Error<E> {
93 Incomplete(usize),
94 Invalid,
95 BufferOverflow,
96 InvalidLen,
97 Io(E),
98}
99
100impl Error<()> {
101 pub fn recast<E>(self) -> Error<E> {
102 match self {
103 Self::Incomplete(v) => Error::Incomplete(v),
104 Self::Invalid => Error::Invalid,
105 Self::BufferOverflow => Error::BufferOverflow,
106 Self::InvalidLen => Error::InvalidLen,
107 Self::Io(_) => panic!(),
108 }
109 }
110}
111
112impl<E> core::fmt::Display for Error<E>
113where
114 E: core::fmt::Display,
115{
116 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117 match self {
118 Self::Incomplete(size) => write!(f, "Incomplete: {} bytes missing", size),
119 Self::Invalid => write!(f, "Invalid"),
120 Self::BufferOverflow => write!(f, "Buffer overflow"),
121 Self::InvalidLen => write!(f, "Invalid length"),
122 Self::Io(err) => write!(f, "IO error: {}", err),
123 }
124 }
125}
126
127#[cfg(feature = "defmt")]
128impl<E> defmt::Format for Error<E>
129where
130 E: defmt::Format,
131{
132 fn format(&self, f: defmt::Formatter<'_>) {
133 match self {
134 Self::Incomplete(size) => defmt::write!(f, "Incomplete: {} bytes missing", size),
135 Self::Invalid => defmt::write!(f, "Invalid"),
136 Self::BufferOverflow => defmt::write!(f, "Buffer overflow"),
137 Self::InvalidLen => defmt::write!(f, "Invalid length"),
138 Self::Io(err) => defmt::write!(f, "IO error: {}", err),
139 }
140 }
141}
142
143#[cfg(feature = "std")]
144impl<E> std::error::Error for Error<E> where E: std::error::Error {}
145
146#[derive(Clone, Debug)]
147pub struct FrameHeader {
148 pub frame_type: FrameType,
149 pub payload_len: u64,
150 pub mask_key: Option<u32>,
151}
152
153impl FrameHeader {
154 pub const MIN_LEN: usize = 2;
155 pub const MAX_LEN: usize = FrameHeader {
156 frame_type: FrameType::Binary(false),
157 payload_len: 65536,
158 mask_key: Some(0),
159 }
160 .serialized_len();
161
162 pub fn deserialize(buf: &[u8]) -> Result<(Self, usize), Error<()>> {
163 let mut expected_len = 2_usize;
164
165 if buf.len() < expected_len {
166 Err(Error::Incomplete(expected_len - buf.len()))
167 } else {
168 let final_frame = buf[0] & 0x80 != 0;
169
170 let rsv = buf[0] & 0x70;
171 if rsv != 0 {
172 return Err(Error::Invalid);
173 }
174
175 let opcode = buf[0] & 0x0f;
176 if (3..=7).contains(&opcode) || opcode >= 11 {
177 return Err(Error::Invalid);
178 }
179
180 let mut payload_len = (buf[1] & 0x7f) as u64;
181 let mut payload_offset = 2;
182
183 if payload_len == 126 {
184 expected_len += 2;
185
186 if buf.len() < expected_len {
187 return Err(Error::Incomplete(expected_len - buf.len()));
188 } else {
189 payload_len = u16::from_be_bytes([buf[2], buf[3]]) as _;
190 payload_offset += 2;
191 }
192 } else if payload_len == 127 {
193 expected_len += 8;
194
195 if buf.len() < expected_len {
196 return Err(Error::Incomplete(expected_len - buf.len()));
197 } else {
198 payload_len = u64::from_be_bytes([
199 buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
200 ]);
201 payload_offset += 8;
202 }
203 }
204
205 let masked = buf[1] & 0x80 != 0;
206 let mask_key = if masked {
207 expected_len += 4;
208 if buf.len() < expected_len {
209 return Err(Error::Incomplete(expected_len - buf.len()));
210 } else {
211 let mask_key = Some(u32::from_be_bytes([
212 buf[payload_offset],
213 buf[payload_offset + 1],
214 buf[payload_offset + 2],
215 buf[payload_offset + 3],
216 ]));
217 payload_offset += 4;
218
219 mask_key
220 }
221 } else {
222 None
223 };
224
225 let frame_type = match opcode {
226 0 => FrameType::Continue(final_frame),
227 1 => FrameType::Text(!final_frame),
228 2 => FrameType::Binary(!final_frame),
229 8 => FrameType::Close,
230 9 => FrameType::Ping,
231 10 => FrameType::Pong,
232 _ => unreachable!(),
233 };
234
235 let frame_header = FrameHeader {
236 frame_type,
237 payload_len,
238 mask_key,
239 };
240
241 Ok((frame_header, payload_offset))
242 }
243 }
244
245 pub const fn serialized_len(&self) -> usize {
246 let payload_len_len = if self.payload_len >= 65536 {
247 8
248 } else if self.payload_len >= 126 {
249 2
250 } else {
251 0
252 };
253
254 2 + if self.mask_key.is_some() { 4 } else { 0 } + payload_len_len
255 }
256
257 pub fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error<()>> {
258 if buf.len() < self.serialized_len() {
259 return Err(Error::InvalidLen);
260 }
261
262 buf[0] = 0;
263 buf[1] = 0;
264
265 if self.frame_type.is_final() {
266 buf[0] |= 0x80;
267 }
268
269 let opcode = match self.frame_type {
270 FrameType::Text(_) => 1,
271 FrameType::Binary(_) => 2,
272 FrameType::Close => 8,
273 FrameType::Ping => 9,
274 FrameType::Pong => 10,
275 _ => 0,
276 };
277
278 buf[0] |= opcode;
279
280 let mut payload_offset = 2;
281
282 if self.payload_len < 126 {
283 buf[1] |= self.payload_len as u8;
284 } else {
285 let payload_len_bytes = self.payload_len.to_be_bytes();
286 if self.payload_len >= 126 && self.payload_len < 65536 {
287 buf[1] |= 126;
288 buf[2] = payload_len_bytes[6];
289 buf[3] = payload_len_bytes[7];
290
291 payload_offset += 2;
292 } else {
293 buf[1] |= 127;
294 buf[2] = payload_len_bytes[0];
295 buf[3] = payload_len_bytes[1];
296 buf[4] = payload_len_bytes[2];
297 buf[5] = payload_len_bytes[3];
298 buf[6] = payload_len_bytes[4];
299 buf[7] = payload_len_bytes[5];
300 buf[8] = payload_len_bytes[6];
301 buf[9] = payload_len_bytes[7];
302
303 payload_offset += 8;
304 }
305 }
306
307 if let Some(mask_key) = self.mask_key {
308 buf[1] |= 0x80;
309
310 let mask_key_bytes = mask_key.to_be_bytes();
311
312 buf[payload_offset] = mask_key_bytes[0];
313 buf[payload_offset + 1] = mask_key_bytes[1];
314 buf[payload_offset + 2] = mask_key_bytes[2];
315 buf[payload_offset + 3] = mask_key_bytes[3];
316
317 payload_offset += 4;
318 }
319
320 Ok(payload_offset)
321 }
322
323 pub fn mask(&self, buf: &mut [u8], payload_offset: usize) {
324 Self::mask_with(buf, self.mask_key, payload_offset)
325 }
326
327 pub fn mask_with(buf: &mut [u8], mask_key: Option<u32>, payload_offset: usize) {
328 if let Some(mask_key) = mask_key {
329 let mask_bytes = mask_key.to_be_bytes();
330
331 for (offset, byte) in buf.iter_mut().enumerate() {
332 *byte ^= mask_bytes[(payload_offset + offset) % 4];
333 }
334 }
335 }
336}
337
338impl core::fmt::Display for FrameHeader {
339 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340 write!(
341 f,
342 "Frame {{ {}, payload len {}, mask {:?} }}",
343 self.frame_type, self.payload_len, self.mask_key
344 )
345 }
346}
347
348#[cfg(feature = "defmt")]
349impl defmt::Format for FrameHeader {
350 fn format(&self, f: defmt::Formatter<'_>) {
351 defmt::write!(
352 f,
353 "Frame {{ {}, payload len {}, mask {:?} }}",
354 self.frame_type,
355 self.payload_len,
356 self.mask_key
357 )
358 }
359}
360
361#[cfg(feature = "embedded-svc")]
362mod embedded_svc_compat {
363 use core::convert::TryFrom;
364
365 use embedded_svc::ws::FrameType;
366
367 impl From<super::FrameType> for FrameType {
368 fn from(frame_type: super::FrameType) -> Self {
369 match frame_type {
370 super::FrameType::Text(v) => Self::Text(v),
371 super::FrameType::Binary(v) => Self::Binary(v),
372 super::FrameType::Ping => Self::Ping,
373 super::FrameType::Pong => Self::Pong,
374 super::FrameType::Close => Self::Close,
375 super::FrameType::Continue(v) => Self::Continue(v),
376 }
377 }
378 }
379
380 impl TryFrom<FrameType> for super::FrameType {
381 type Error = FrameType;
382
383 fn try_from(frame_type: FrameType) -> Result<Self, Self::Error> {
384 let f = match frame_type {
385 FrameType::Text(v) => Self::Text(v),
386 FrameType::Binary(v) => Self::Binary(v),
387 FrameType::Ping => Self::Ping,
388 FrameType::Pong => Self::Pong,
389 FrameType::Close => Self::Close,
390 FrameType::SocketClose => Err(FrameType::SocketClose)?,
391 FrameType::Continue(v) => Self::Continue(v),
392 };
393
394 Ok(f)
395 }
396 }
397}