edge_ws/
io.rs

1use core::cmp::min;
2
3use embedded_io_async::{self, Read, ReadExactError, Write};
4
5use super::*;
6
7#[cfg(feature = "embedded-svc")]
8pub use embedded_svc_compat::*;
9
10pub type Error<E> = super::Error<E>;
11
12impl<E> Error<E>
13where
14    E: embedded_io_async::Error,
15{
16    pub fn erase(&self) -> Error<embedded_io_async::ErrorKind> {
17        match self {
18            Self::Incomplete(size) => Error::Incomplete(*size),
19            Self::Invalid => Error::Invalid,
20            Self::BufferOverflow => Error::BufferOverflow,
21            Self::InvalidLen => Error::InvalidLen,
22            Self::Io(e) => Error::Io(e.kind()),
23        }
24    }
25}
26
27impl<E> From<ReadExactError<E>> for Error<E> {
28    fn from(e: ReadExactError<E>) -> Self {
29        match e {
30            ReadExactError::UnexpectedEof => Error::Invalid,
31            ReadExactError::Other(e) => Error::Io(e),
32        }
33    }
34}
35
36impl FrameHeader {
37    pub async fn recv<R>(mut read: R) -> Result<Self, Error<R::Error>>
38    where
39        R: Read,
40    {
41        let mut header_buf = [0; FrameHeader::MAX_LEN];
42        let mut read_offset = 0;
43        let mut read_end = FrameHeader::MIN_LEN;
44
45        loop {
46            read.read_exact(&mut header_buf[read_offset..read_end])
47                .await
48                .map_err(Error::from)?;
49
50            match FrameHeader::deserialize(&header_buf[..read_end]) {
51                Ok((header, _)) => return Ok(header),
52                Err(Error::Incomplete(more)) => {
53                    read_offset = read_end;
54                    read_end += more;
55                }
56                Err(e) => return Err(e.recast()),
57            }
58        }
59    }
60
61    pub async fn send<W>(&self, mut write: W) -> Result<(), Error<W::Error>>
62    where
63        W: Write,
64    {
65        let mut header_buf = [0; FrameHeader::MAX_LEN];
66        let header_len = unwrap!(self.serialize(&mut header_buf));
67
68        write
69            .write_all(&header_buf[..header_len])
70            .await
71            .map_err(Error::Io)
72    }
73
74    pub async fn recv_payload<'a, R>(
75        &self,
76        mut read: R,
77        payload_buf: &'a mut [u8],
78    ) -> Result<&'a [u8], Error<R::Error>>
79    where
80        R: Read,
81    {
82        if (payload_buf.len() as u64) < self.payload_len {
83            Err(Error::BufferOverflow)
84        } else if self.payload_len == 0 {
85            Ok(&[])
86        } else {
87            let payload = &mut payload_buf[..self.payload_len as _];
88
89            read.read_exact(payload).await.map_err(Error::from)?;
90
91            self.mask(payload, 0);
92
93            Ok(payload)
94        }
95    }
96
97    pub async fn send_payload<'a, W>(
98        &'a self,
99        mut write: W,
100        payload: &'a [u8],
101    ) -> Result<(), Error<W::Error>>
102    where
103        W: Write,
104    {
105        let payload_buf_len = payload.len() as u64;
106
107        if payload_buf_len != self.payload_len {
108            Err(Error::InvalidLen)
109        } else if payload.is_empty() {
110            Ok(())
111        } else if self.mask_key.is_none() {
112            write.write_all(payload).await.map_err(Error::Io)
113        } else {
114            let mut buf = [0_u8; 32];
115
116            let mut offset = 0;
117
118            while offset < payload.len() {
119                let len = min(buf.len(), payload.len() - offset);
120
121                let buf = &mut buf[..len];
122
123                buf.copy_from_slice(&payload[offset..offset + len]);
124
125                self.mask(buf, offset);
126
127                write.write_all(buf).await.map_err(Error::Io)?;
128
129                offset += len;
130            }
131
132            Ok(())
133        }
134    }
135}
136
137pub async fn recv<R>(
138    mut read: R,
139    frame_data_buf: &mut [u8],
140) -> Result<(FrameType, usize), Error<R::Error>>
141where
142    R: Read,
143{
144    let header = FrameHeader::recv(&mut read).await?;
145    header.recv_payload(read, frame_data_buf).await?;
146
147    Ok((header.frame_type, header.payload_len as _))
148}
149
150pub async fn send<W>(
151    mut write: W,
152    frame_type: FrameType,
153    mask_key: Option<u32>,
154    frame_data_buf: &[u8],
155) -> Result<(), Error<W::Error>>
156where
157    W: Write,
158{
159    let header = FrameHeader {
160        frame_type,
161        payload_len: frame_data_buf.len() as _,
162        mask_key,
163    };
164
165    header.send(&mut write).await?;
166    header.send_payload(write, frame_data_buf).await
167}
168
169#[cfg(feature = "embedded-svc")]
170mod embedded_svc_compat {
171    use core::convert::TryInto;
172
173    use embedded_io_async::{Read, Write};
174    use embedded_svc::io::ErrorType as IoErrorType;
175    use embedded_svc::ws::asynch::Sender;
176    use embedded_svc::ws::ErrorType;
177    use embedded_svc::ws::{asynch::Receiver, FrameType};
178
179    use super::Error;
180
181    pub struct WsConnection<T, M>(T, M);
182
183    impl<T, M> WsConnection<T, M> {
184        pub const fn new(connection: T, mask_gen: M) -> Self {
185            Self(connection, mask_gen)
186        }
187    }
188
189    impl<T, M> ErrorType for WsConnection<T, M>
190    where
191        T: IoErrorType,
192    {
193        type Error = Error<T::Error>;
194    }
195
196    impl<T, M> Receiver for WsConnection<T, M>
197    where
198        T: Read,
199    {
200        async fn recv(
201            &mut self,
202            frame_data_buf: &mut [u8],
203        ) -> Result<(FrameType, usize), Self::Error> {
204            super::recv(&mut self.0, frame_data_buf)
205                .await
206                .map(|(frame_type, payload_len)| (frame_type.into(), payload_len))
207        }
208    }
209
210    impl<T, M> Sender for WsConnection<T, M>
211    where
212        T: Write,
213        M: Fn() -> Option<u32>,
214    {
215        async fn send(
216            &mut self,
217            frame_type: FrameType,
218            frame_data: &[u8],
219        ) -> Result<(), Self::Error> {
220            super::send(
221                &mut self.0,
222                unwrap!(frame_type.try_into(), "Invalid frame type"),
223                (self.1)(),
224                frame_data,
225            )
226            .await
227        }
228    }
229}