1use core::cmp::min;
2
3use embedded_io_async::{self, Read, ReadExactError, Write};
4
5use super::*;
6
7#[cfg(feature = "embedded-svc")]
8pub use embedded_svc_compat::*;
9
10pub type Error<E> = super::Error<E>;
11
12impl<E> Error<E>
13where
14 E: embedded_io_async::Error,
15{
16 pub fn erase(&self) -> Error<embedded_io_async::ErrorKind> {
17 match self {
18 Self::Incomplete(size) => Error::Incomplete(*size),
19 Self::Invalid => Error::Invalid,
20 Self::BufferOverflow => Error::BufferOverflow,
21 Self::InvalidLen => Error::InvalidLen,
22 Self::Io(e) => Error::Io(e.kind()),
23 }
24 }
25}
26
27impl<E> From<ReadExactError<E>> for Error<E> {
28 fn from(e: ReadExactError<E>) -> Self {
29 match e {
30 ReadExactError::UnexpectedEof => Error::Invalid,
31 ReadExactError::Other(e) => Error::Io(e),
32 }
33 }
34}
35
36impl FrameHeader {
37 pub async fn recv<R>(mut read: R) -> Result<Self, Error<R::Error>>
38 where
39 R: Read,
40 {
41 let mut header_buf = [0; FrameHeader::MAX_LEN];
42 let mut read_offset = 0;
43 let mut read_end = FrameHeader::MIN_LEN;
44
45 loop {
46 read.read_exact(&mut header_buf[read_offset..read_end])
47 .await
48 .map_err(Error::from)?;
49
50 match FrameHeader::deserialize(&header_buf[..read_end]) {
51 Ok((header, _)) => return Ok(header),
52 Err(Error::Incomplete(more)) => {
53 read_offset = read_end;
54 read_end += more;
55 }
56 Err(e) => return Err(e.recast()),
57 }
58 }
59 }
60
61 pub async fn send<W>(&self, mut write: W) -> Result<(), Error<W::Error>>
62 where
63 W: Write,
64 {
65 let mut header_buf = [0; FrameHeader::MAX_LEN];
66 let header_len = unwrap!(self.serialize(&mut header_buf));
67
68 write
69 .write_all(&header_buf[..header_len])
70 .await
71 .map_err(Error::Io)
72 }
73
74 pub async fn recv_payload<'a, R>(
75 &self,
76 mut read: R,
77 payload_buf: &'a mut [u8],
78 ) -> Result<&'a [u8], Error<R::Error>>
79 where
80 R: Read,
81 {
82 if (payload_buf.len() as u64) < self.payload_len {
83 Err(Error::BufferOverflow)
84 } else if self.payload_len == 0 {
85 Ok(&[])
86 } else {
87 let payload = &mut payload_buf[..self.payload_len as _];
88
89 read.read_exact(payload).await.map_err(Error::from)?;
90
91 self.mask(payload, 0);
92
93 Ok(payload)
94 }
95 }
96
97 pub async fn send_payload<'a, W>(
98 &'a self,
99 mut write: W,
100 payload: &'a [u8],
101 ) -> Result<(), Error<W::Error>>
102 where
103 W: Write,
104 {
105 let payload_buf_len = payload.len() as u64;
106
107 if payload_buf_len != self.payload_len {
108 Err(Error::InvalidLen)
109 } else if payload.is_empty() {
110 Ok(())
111 } else if self.mask_key.is_none() {
112 write.write_all(payload).await.map_err(Error::Io)
113 } else {
114 let mut buf = [0_u8; 32];
115
116 let mut offset = 0;
117
118 while offset < payload.len() {
119 let len = min(buf.len(), payload.len() - offset);
120
121 let buf = &mut buf[..len];
122
123 buf.copy_from_slice(&payload[offset..offset + len]);
124
125 self.mask(buf, offset);
126
127 write.write_all(buf).await.map_err(Error::Io)?;
128
129 offset += len;
130 }
131
132 Ok(())
133 }
134 }
135}
136
137pub async fn recv<R>(
138 mut read: R,
139 frame_data_buf: &mut [u8],
140) -> Result<(FrameType, usize), Error<R::Error>>
141where
142 R: Read,
143{
144 let header = FrameHeader::recv(&mut read).await?;
145 header.recv_payload(read, frame_data_buf).await?;
146
147 Ok((header.frame_type, header.payload_len as _))
148}
149
150pub async fn send<W>(
151 mut write: W,
152 frame_type: FrameType,
153 mask_key: Option<u32>,
154 frame_data_buf: &[u8],
155) -> Result<(), Error<W::Error>>
156where
157 W: Write,
158{
159 let header = FrameHeader {
160 frame_type,
161 payload_len: frame_data_buf.len() as _,
162 mask_key,
163 };
164
165 header.send(&mut write).await?;
166 header.send_payload(write, frame_data_buf).await
167}
168
169#[cfg(feature = "embedded-svc")]
170mod embedded_svc_compat {
171 use core::convert::TryInto;
172
173 use embedded_io_async::{Read, Write};
174 use embedded_svc::io::ErrorType as IoErrorType;
175 use embedded_svc::ws::asynch::Sender;
176 use embedded_svc::ws::ErrorType;
177 use embedded_svc::ws::{asynch::Receiver, FrameType};
178
179 use super::Error;
180
181 pub struct WsConnection<T, M>(T, M);
182
183 impl<T, M> WsConnection<T, M> {
184 pub const fn new(connection: T, mask_gen: M) -> Self {
185 Self(connection, mask_gen)
186 }
187 }
188
189 impl<T, M> ErrorType for WsConnection<T, M>
190 where
191 T: IoErrorType,
192 {
193 type Error = Error<T::Error>;
194 }
195
196 impl<T, M> Receiver for WsConnection<T, M>
197 where
198 T: Read,
199 {
200 async fn recv(
201 &mut self,
202 frame_data_buf: &mut [u8],
203 ) -> Result<(FrameType, usize), Self::Error> {
204 super::recv(&mut self.0, frame_data_buf)
205 .await
206 .map(|(frame_type, payload_len)| (frame_type.into(), payload_len))
207 }
208 }
209
210 impl<T, M> Sender for WsConnection<T, M>
211 where
212 T: Write,
213 M: Fn() -> Option<u32>,
214 {
215 async fn send(
216 &mut self,
217 frame_type: FrameType,
218 frame_data: &[u8],
219 ) -> Result<(), Self::Error> {
220 super::send(
221 &mut self.0,
222 unwrap!(frame_type.try_into(), "Invalid frame type"),
223 (self.1)(),
224 frame_data,
225 )
226 .await
227 }
228 }
229}