edge_ws/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(async_fn_in_trait)]
3#![warn(clippy::large_futures)]
4#![allow(clippy::uninlined_format_args)]
5#![allow(unknown_lints)]
6
7pub type Fragmented = bool;
8pub type Final = bool;
9
10// This mod MUST go first, so that the others see its macros.
11pub(crate) mod fmt;
12
13#[cfg(feature = "io")]
14pub mod io;
15
16#[derive(Copy, Clone, PartialEq, Eq, Debug)]
17pub enum FrameType {
18    Text(Fragmented),
19    Binary(Fragmented),
20    Ping,
21    Pong,
22    Close,
23    Continue(Final),
24}
25
26impl FrameType {
27    pub fn is_fragmented(&self) -> bool {
28        match self {
29            Self::Text(fragmented) | Self::Binary(fragmented) => *fragmented,
30            Self::Continue(_) => true,
31            _ => false,
32        }
33    }
34
35    pub fn is_final(&self) -> bool {
36        match self {
37            Self::Text(fragmented) | Self::Binary(fragmented) => !*fragmented,
38            Self::Continue(final_) => *final_,
39            _ => true,
40        }
41    }
42}
43
44impl core::fmt::Display for FrameType {
45    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46        match self {
47            Self::Text(fragmented) => {
48                write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
49            }
50            Self::Binary(fragmented) => write!(
51                f,
52                "Binary{}",
53                if *fragmented { " (fragmented)" } else { "" }
54            ),
55            Self::Ping => write!(f, "Ping"),
56            Self::Pong => write!(f, "Pong"),
57            Self::Close => write!(f, "Close"),
58            Self::Continue(ffinal) => {
59                write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
60            }
61        }
62    }
63}
64
65#[cfg(feature = "defmt")]
66impl defmt::Format for FrameType {
67    fn format(&self, f: defmt::Formatter<'_>) {
68        match self {
69            Self::Text(fragmented) => {
70                defmt::write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
71            }
72            Self::Binary(fragmented) => defmt::write!(
73                f,
74                "Binary{}",
75                if *fragmented { " (fragmented)" } else { "" }
76            ),
77            Self::Ping => defmt::write!(f, "Ping"),
78            Self::Pong => defmt::write!(f, "Pong"),
79            Self::Close => defmt::write!(f, "Close"),
80            Self::Continue(ffinal) => {
81                defmt::write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
82            }
83        }
84    }
85}
86
87#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
88pub enum Error<E> {
89    Incomplete(usize),
90    Invalid,
91    BufferOverflow,
92    InvalidLen,
93    Io(E),
94}
95
96impl Error<()> {
97    pub fn recast<E>(self) -> Error<E> {
98        match self {
99            Self::Incomplete(v) => Error::Incomplete(v),
100            Self::Invalid => Error::Invalid,
101            Self::BufferOverflow => Error::BufferOverflow,
102            Self::InvalidLen => Error::InvalidLen,
103            Self::Io(_) => panic!(),
104        }
105    }
106}
107
108impl<E> core::fmt::Display for Error<E>
109where
110    E: core::fmt::Display,
111{
112    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
113        match self {
114            Self::Incomplete(size) => write!(f, "Incomplete: {} bytes missing", size),
115            Self::Invalid => write!(f, "Invalid"),
116            Self::BufferOverflow => write!(f, "Buffer overflow"),
117            Self::InvalidLen => write!(f, "Invalid length"),
118            Self::Io(err) => write!(f, "IO error: {}", err),
119        }
120    }
121}
122
123#[cfg(feature = "defmt")]
124impl<E> defmt::Format for Error<E>
125where
126    E: defmt::Format,
127{
128    fn format(&self, f: defmt::Formatter<'_>) {
129        match self {
130            Self::Incomplete(size) => defmt::write!(f, "Incomplete: {} bytes missing", size),
131            Self::Invalid => defmt::write!(f, "Invalid"),
132            Self::BufferOverflow => defmt::write!(f, "Buffer overflow"),
133            Self::InvalidLen => defmt::write!(f, "Invalid length"),
134            Self::Io(err) => defmt::write!(f, "IO error: {}", err),
135        }
136    }
137}
138
139impl<E> core::error::Error for Error<E> where E: core::error::Error {}
140
141#[derive(Clone, Debug)]
142pub struct FrameHeader {
143    pub frame_type: FrameType,
144    pub payload_len: u64,
145    pub mask_key: Option<u32>,
146}
147
148impl FrameHeader {
149    pub const MIN_LEN: usize = 2;
150    pub const MAX_LEN: usize = FrameHeader {
151        frame_type: FrameType::Binary(false),
152        payload_len: 65536,
153        mask_key: Some(0),
154    }
155    .serialized_len();
156
157    pub fn deserialize(buf: &[u8]) -> Result<(Self, usize), Error<()>> {
158        let mut expected_len = 2_usize;
159
160        if buf.len() < expected_len {
161            Err(Error::Incomplete(expected_len - buf.len()))
162        } else {
163            let final_frame = buf[0] & 0x80 != 0;
164
165            let rsv = buf[0] & 0x70;
166            if rsv != 0 {
167                return Err(Error::Invalid);
168            }
169
170            let opcode = buf[0] & 0x0f;
171            if (3..=7).contains(&opcode) || opcode >= 11 {
172                return Err(Error::Invalid);
173            }
174
175            let mut payload_len = (buf[1] & 0x7f) as u64;
176            let mut payload_offset = 2;
177
178            if payload_len == 126 {
179                expected_len += 2;
180
181                if buf.len() < expected_len {
182                    return Err(Error::Incomplete(expected_len - buf.len()));
183                } else {
184                    payload_len = u16::from_be_bytes([buf[2], buf[3]]) as _;
185                    payload_offset += 2;
186                }
187            } else if payload_len == 127 {
188                expected_len += 8;
189
190                if buf.len() < expected_len {
191                    return Err(Error::Incomplete(expected_len - buf.len()));
192                } else {
193                    payload_len = u64::from_be_bytes([
194                        buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
195                    ]);
196                    payload_offset += 8;
197                }
198            }
199
200            let masked = buf[1] & 0x80 != 0;
201            let mask_key = if masked {
202                expected_len += 4;
203                if buf.len() < expected_len {
204                    return Err(Error::Incomplete(expected_len - buf.len()));
205                } else {
206                    let mask_key = Some(u32::from_be_bytes([
207                        buf[payload_offset],
208                        buf[payload_offset + 1],
209                        buf[payload_offset + 2],
210                        buf[payload_offset + 3],
211                    ]));
212                    payload_offset += 4;
213
214                    mask_key
215                }
216            } else {
217                None
218            };
219
220            let frame_type = match opcode {
221                0 => FrameType::Continue(final_frame),
222                1 => FrameType::Text(!final_frame),
223                2 => FrameType::Binary(!final_frame),
224                8 => FrameType::Close,
225                9 => FrameType::Ping,
226                10 => FrameType::Pong,
227                _ => unreachable!(),
228            };
229
230            let frame_header = FrameHeader {
231                frame_type,
232                payload_len,
233                mask_key,
234            };
235
236            Ok((frame_header, payload_offset))
237        }
238    }
239
240    pub const fn serialized_len(&self) -> usize {
241        let payload_len_len = if self.payload_len >= 65536 {
242            8
243        } else if self.payload_len >= 126 {
244            2
245        } else {
246            0
247        };
248
249        2 + if self.mask_key.is_some() { 4 } else { 0 } + payload_len_len
250    }
251
252    pub fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error<()>> {
253        if buf.len() < self.serialized_len() {
254            return Err(Error::InvalidLen);
255        }
256
257        buf[0] = 0;
258        buf[1] = 0;
259
260        if self.frame_type.is_final() {
261            buf[0] |= 0x80;
262        }
263
264        let opcode = match self.frame_type {
265            FrameType::Text(_) => 1,
266            FrameType::Binary(_) => 2,
267            FrameType::Close => 8,
268            FrameType::Ping => 9,
269            FrameType::Pong => 10,
270            _ => 0,
271        };
272
273        buf[0] |= opcode;
274
275        let mut payload_offset = 2;
276
277        if self.payload_len < 126 {
278            buf[1] |= self.payload_len as u8;
279        } else {
280            let payload_len_bytes = self.payload_len.to_be_bytes();
281            if self.payload_len >= 126 && self.payload_len < 65536 {
282                buf[1] |= 126;
283                buf[2] = payload_len_bytes[6];
284                buf[3] = payload_len_bytes[7];
285
286                payload_offset += 2;
287            } else {
288                buf[1] |= 127;
289                buf[2] = payload_len_bytes[0];
290                buf[3] = payload_len_bytes[1];
291                buf[4] = payload_len_bytes[2];
292                buf[5] = payload_len_bytes[3];
293                buf[6] = payload_len_bytes[4];
294                buf[7] = payload_len_bytes[5];
295                buf[8] = payload_len_bytes[6];
296                buf[9] = payload_len_bytes[7];
297
298                payload_offset += 8;
299            }
300        }
301
302        if let Some(mask_key) = self.mask_key {
303            buf[1] |= 0x80;
304
305            let mask_key_bytes = mask_key.to_be_bytes();
306
307            buf[payload_offset] = mask_key_bytes[0];
308            buf[payload_offset + 1] = mask_key_bytes[1];
309            buf[payload_offset + 2] = mask_key_bytes[2];
310            buf[payload_offset + 3] = mask_key_bytes[3];
311
312            payload_offset += 4;
313        }
314
315        Ok(payload_offset)
316    }
317
318    pub fn mask(&self, buf: &mut [u8], payload_offset: usize) {
319        Self::mask_with(buf, self.mask_key, payload_offset)
320    }
321
322    pub fn mask_with(buf: &mut [u8], mask_key: Option<u32>, payload_offset: usize) {
323        if let Some(mask_key) = mask_key {
324            let mask_bytes = mask_key.to_be_bytes();
325
326            for (offset, byte) in buf.iter_mut().enumerate() {
327                *byte ^= mask_bytes[(payload_offset + offset) % 4];
328            }
329        }
330    }
331}
332
333impl core::fmt::Display for FrameHeader {
334    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
335        write!(
336            f,
337            "Frame {{ {}, payload len {}, mask {:?} }}",
338            self.frame_type, self.payload_len, self.mask_key
339        )
340    }
341}
342
343#[cfg(feature = "defmt")]
344impl defmt::Format for FrameHeader {
345    fn format(&self, f: defmt::Formatter<'_>) {
346        defmt::write!(
347            f,
348            "Frame {{ {}, payload len {}, mask {:?} }}",
349            self.frame_type,
350            self.payload_len,
351            self.mask_key
352        )
353    }
354}