edge_ws/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(async_fn_in_trait)]
3#![warn(clippy::large_futures)]
4#![allow(clippy::uninlined_format_args)]
5#![allow(unknown_lints)]
6
7pub type Fragmented = bool;
8pub type Final = bool;
9
10#[allow(unused)]
11#[cfg(feature = "embedded-svc")]
12pub use embedded_svc_compat::*;
13
14// This mod MUST go first, so that the others see its macros.
15pub(crate) mod fmt;
16
17#[cfg(feature = "io")]
18pub mod io;
19
20#[derive(Copy, Clone, PartialEq, Eq, Debug)]
21pub enum FrameType {
22    Text(Fragmented),
23    Binary(Fragmented),
24    Ping,
25    Pong,
26    Close,
27    Continue(Final),
28}
29
30impl FrameType {
31    pub fn is_fragmented(&self) -> bool {
32        match self {
33            Self::Text(fragmented) | Self::Binary(fragmented) => *fragmented,
34            Self::Continue(_) => true,
35            _ => false,
36        }
37    }
38
39    pub fn is_final(&self) -> bool {
40        match self {
41            Self::Text(fragmented) | Self::Binary(fragmented) => !*fragmented,
42            Self::Continue(final_) => *final_,
43            _ => true,
44        }
45    }
46}
47
48impl core::fmt::Display for FrameType {
49    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50        match self {
51            Self::Text(fragmented) => {
52                write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
53            }
54            Self::Binary(fragmented) => write!(
55                f,
56                "Binary{}",
57                if *fragmented { " (fragmented)" } else { "" }
58            ),
59            Self::Ping => write!(f, "Ping"),
60            Self::Pong => write!(f, "Pong"),
61            Self::Close => write!(f, "Close"),
62            Self::Continue(ffinal) => {
63                write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
64            }
65        }
66    }
67}
68
69#[cfg(feature = "defmt")]
70impl defmt::Format for FrameType {
71    fn format(&self, f: defmt::Formatter<'_>) {
72        match self {
73            Self::Text(fragmented) => {
74                defmt::write!(f, "Text{}", if *fragmented { " (fragmented)" } else { "" })
75            }
76            Self::Binary(fragmented) => defmt::write!(
77                f,
78                "Binary{}",
79                if *fragmented { " (fragmented)" } else { "" }
80            ),
81            Self::Ping => defmt::write!(f, "Ping"),
82            Self::Pong => defmt::write!(f, "Pong"),
83            Self::Close => defmt::write!(f, "Close"),
84            Self::Continue(ffinal) => {
85                defmt::write!(f, "Continue{}", if *ffinal { " (final)" } else { "" })
86            }
87        }
88    }
89}
90
91#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
92pub enum Error<E> {
93    Incomplete(usize),
94    Invalid,
95    BufferOverflow,
96    InvalidLen,
97    Io(E),
98}
99
100impl Error<()> {
101    pub fn recast<E>(self) -> Error<E> {
102        match self {
103            Self::Incomplete(v) => Error::Incomplete(v),
104            Self::Invalid => Error::Invalid,
105            Self::BufferOverflow => Error::BufferOverflow,
106            Self::InvalidLen => Error::InvalidLen,
107            Self::Io(_) => panic!(),
108        }
109    }
110}
111
112impl<E> core::fmt::Display for Error<E>
113where
114    E: core::fmt::Display,
115{
116    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117        match self {
118            Self::Incomplete(size) => write!(f, "Incomplete: {} bytes missing", size),
119            Self::Invalid => write!(f, "Invalid"),
120            Self::BufferOverflow => write!(f, "Buffer overflow"),
121            Self::InvalidLen => write!(f, "Invalid length"),
122            Self::Io(err) => write!(f, "IO error: {}", err),
123        }
124    }
125}
126
127#[cfg(feature = "defmt")]
128impl<E> defmt::Format for Error<E>
129where
130    E: defmt::Format,
131{
132    fn format(&self, f: defmt::Formatter<'_>) {
133        match self {
134            Self::Incomplete(size) => defmt::write!(f, "Incomplete: {} bytes missing", size),
135            Self::Invalid => defmt::write!(f, "Invalid"),
136            Self::BufferOverflow => defmt::write!(f, "Buffer overflow"),
137            Self::InvalidLen => defmt::write!(f, "Invalid length"),
138            Self::Io(err) => defmt::write!(f, "IO error: {}", err),
139        }
140    }
141}
142
143#[cfg(feature = "std")]
144impl<E> std::error::Error for Error<E> where E: std::error::Error {}
145
146#[derive(Clone, Debug)]
147pub struct FrameHeader {
148    pub frame_type: FrameType,
149    pub payload_len: u64,
150    pub mask_key: Option<u32>,
151}
152
153impl FrameHeader {
154    pub const MIN_LEN: usize = 2;
155    pub const MAX_LEN: usize = FrameHeader {
156        frame_type: FrameType::Binary(false),
157        payload_len: 65536,
158        mask_key: Some(0),
159    }
160    .serialized_len();
161
162    pub fn deserialize(buf: &[u8]) -> Result<(Self, usize), Error<()>> {
163        let mut expected_len = 2_usize;
164
165        if buf.len() < expected_len {
166            Err(Error::Incomplete(expected_len - buf.len()))
167        } else {
168            let final_frame = buf[0] & 0x80 != 0;
169
170            let rsv = buf[0] & 0x70;
171            if rsv != 0 {
172                return Err(Error::Invalid);
173            }
174
175            let opcode = buf[0] & 0x0f;
176            if (3..=7).contains(&opcode) || opcode >= 11 {
177                return Err(Error::Invalid);
178            }
179
180            let mut payload_len = (buf[1] & 0x7f) as u64;
181            let mut payload_offset = 2;
182
183            if payload_len == 126 {
184                expected_len += 2;
185
186                if buf.len() < expected_len {
187                    return Err(Error::Incomplete(expected_len - buf.len()));
188                } else {
189                    payload_len = u16::from_be_bytes([buf[2], buf[3]]) as _;
190                    payload_offset += 2;
191                }
192            } else if payload_len == 127 {
193                expected_len += 8;
194
195                if buf.len() < expected_len {
196                    return Err(Error::Incomplete(expected_len - buf.len()));
197                } else {
198                    payload_len = u64::from_be_bytes([
199                        buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
200                    ]);
201                    payload_offset += 8;
202                }
203            }
204
205            let masked = buf[1] & 0x80 != 0;
206            let mask_key = if masked {
207                expected_len += 4;
208                if buf.len() < expected_len {
209                    return Err(Error::Incomplete(expected_len - buf.len()));
210                } else {
211                    let mask_key = Some(u32::from_be_bytes([
212                        buf[payload_offset],
213                        buf[payload_offset + 1],
214                        buf[payload_offset + 2],
215                        buf[payload_offset + 3],
216                    ]));
217                    payload_offset += 4;
218
219                    mask_key
220                }
221            } else {
222                None
223            };
224
225            let frame_type = match opcode {
226                0 => FrameType::Continue(final_frame),
227                1 => FrameType::Text(!final_frame),
228                2 => FrameType::Binary(!final_frame),
229                8 => FrameType::Close,
230                9 => FrameType::Ping,
231                10 => FrameType::Pong,
232                _ => unreachable!(),
233            };
234
235            let frame_header = FrameHeader {
236                frame_type,
237                payload_len,
238                mask_key,
239            };
240
241            Ok((frame_header, payload_offset))
242        }
243    }
244
245    pub const fn serialized_len(&self) -> usize {
246        let payload_len_len = if self.payload_len >= 65536 {
247            8
248        } else if self.payload_len >= 126 {
249            2
250        } else {
251            0
252        };
253
254        2 + if self.mask_key.is_some() { 4 } else { 0 } + payload_len_len
255    }
256
257    pub fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error<()>> {
258        if buf.len() < self.serialized_len() {
259            return Err(Error::InvalidLen);
260        }
261
262        buf[0] = 0;
263        buf[1] = 0;
264
265        if self.frame_type.is_final() {
266            buf[0] |= 0x80;
267        }
268
269        let opcode = match self.frame_type {
270            FrameType::Text(_) => 1,
271            FrameType::Binary(_) => 2,
272            FrameType::Close => 8,
273            FrameType::Ping => 9,
274            FrameType::Pong => 10,
275            _ => 0,
276        };
277
278        buf[0] |= opcode;
279
280        let mut payload_offset = 2;
281
282        if self.payload_len < 126 {
283            buf[1] |= self.payload_len as u8;
284        } else {
285            let payload_len_bytes = self.payload_len.to_be_bytes();
286            if self.payload_len >= 126 && self.payload_len < 65536 {
287                buf[1] |= 126;
288                buf[2] = payload_len_bytes[6];
289                buf[3] = payload_len_bytes[7];
290
291                payload_offset += 2;
292            } else {
293                buf[1] |= 127;
294                buf[2] = payload_len_bytes[0];
295                buf[3] = payload_len_bytes[1];
296                buf[4] = payload_len_bytes[2];
297                buf[5] = payload_len_bytes[3];
298                buf[6] = payload_len_bytes[4];
299                buf[7] = payload_len_bytes[5];
300                buf[8] = payload_len_bytes[6];
301                buf[9] = payload_len_bytes[7];
302
303                payload_offset += 8;
304            }
305        }
306
307        if let Some(mask_key) = self.mask_key {
308            buf[1] |= 0x80;
309
310            let mask_key_bytes = mask_key.to_be_bytes();
311
312            buf[payload_offset] = mask_key_bytes[0];
313            buf[payload_offset + 1] = mask_key_bytes[1];
314            buf[payload_offset + 2] = mask_key_bytes[2];
315            buf[payload_offset + 3] = mask_key_bytes[3];
316
317            payload_offset += 4;
318        }
319
320        Ok(payload_offset)
321    }
322
323    pub fn mask(&self, buf: &mut [u8], payload_offset: usize) {
324        Self::mask_with(buf, self.mask_key, payload_offset)
325    }
326
327    pub fn mask_with(buf: &mut [u8], mask_key: Option<u32>, payload_offset: usize) {
328        if let Some(mask_key) = mask_key {
329            let mask_bytes = mask_key.to_be_bytes();
330
331            for (offset, byte) in buf.iter_mut().enumerate() {
332                *byte ^= mask_bytes[(payload_offset + offset) % 4];
333            }
334        }
335    }
336}
337
338impl core::fmt::Display for FrameHeader {
339    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340        write!(
341            f,
342            "Frame {{ {}, payload len {}, mask {:?} }}",
343            self.frame_type, self.payload_len, self.mask_key
344        )
345    }
346}
347
348#[cfg(feature = "defmt")]
349impl defmt::Format for FrameHeader {
350    fn format(&self, f: defmt::Formatter<'_>) {
351        defmt::write!(
352            f,
353            "Frame {{ {}, payload len {}, mask {:?} }}",
354            self.frame_type,
355            self.payload_len,
356            self.mask_key
357        )
358    }
359}
360
361#[cfg(feature = "embedded-svc")]
362mod embedded_svc_compat {
363    use core::convert::TryFrom;
364
365    use embedded_svc::ws::FrameType;
366
367    impl From<super::FrameType> for FrameType {
368        fn from(frame_type: super::FrameType) -> Self {
369            match frame_type {
370                super::FrameType::Text(v) => Self::Text(v),
371                super::FrameType::Binary(v) => Self::Binary(v),
372                super::FrameType::Ping => Self::Ping,
373                super::FrameType::Pong => Self::Pong,
374                super::FrameType::Close => Self::Close,
375                super::FrameType::Continue(v) => Self::Continue(v),
376            }
377        }
378    }
379
380    impl TryFrom<FrameType> for super::FrameType {
381        type Error = FrameType;
382
383        fn try_from(frame_type: FrameType) -> Result<Self, Self::Error> {
384            let f = match frame_type {
385                FrameType::Text(v) => Self::Text(v),
386                FrameType::Binary(v) => Self::Binary(v),
387                FrameType::Ping => Self::Ping,
388                FrameType::Pong => Self::Pong,
389                FrameType::Close => Self::Close,
390                FrameType::SocketClose => Err(FrameType::SocketClose)?,
391                FrameType::Continue(v) => Self::Continue(v),
392            };
393
394            Ok(f)
395        }
396    }
397}