1#[cfg(not(feature = "embedded-io-async"))]
2use core::ops::Deref;
3use core::{fmt::Debug, str::Utf8Error};
4#[cfg(feature = "embedded-io-async")]
5use embedded_io_async::{ErrorType, Read, Write};
6#[cfg(not(feature = "embedded-io-async"))]
7use futures::{Sink, SinkExt, Stream, StreamExt};
8use rand_core::RngCore;
9
10use crate::{
11 WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType,
12 WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
13};
14
15pub struct CloseMessage<'a> {
16 pub status_code: WebSocketCloseStatusCode,
17 pub reason: &'a [u8],
18}
19
20pub enum ReadResult<'a> {
21 Binary(&'a [u8]),
22 Text(&'a str),
23 Pong(&'a [u8]),
26 Ping(&'a [u8]),
30 Close(CloseMessage<'a>),
34}
35
36#[derive(Debug)]
37pub enum FramerError<E> {
38 Io(E),
39 FrameTooLarge(usize),
40 Utf8(Utf8Error),
41 HttpHeader(httparse::Error),
42 WebSocket(crate::Error),
43 Disconnected,
44 RxBufferTooSmall(usize),
45}
46
47pub struct Framer<TRng, TWebSocketType>
48where
49 TRng: RngCore,
50 TWebSocketType: WebSocketType,
51{
52 websocket: WebSocket<TRng, TWebSocketType>,
53 frame_cursor: usize,
54 rx_remainder_len: usize,
55}
56
57#[cfg(not(feature = "embedded-io-async"))]
58impl<TRng> Framer<TRng, crate::Client>
59where
60 TRng: RngCore,
61{
62 pub async fn connect<'a, B, E>(
63 &mut self,
64 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
65 buffer: &'a mut [u8],
66 websocket_options: &WebSocketOptions<'_>,
67 ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
68 where
69 B: AsRef<[u8]>,
70 {
71 let (tx_len, web_socket_key) = self
72 .websocket
73 .client_connect(websocket_options, buffer)
74 .map_err(FramerError::WebSocket)?;
75
76 let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
77 stream.send(tx_buf).await.map_err(FramerError::Io)?;
78 stream.flush().await.map_err(FramerError::Io)?;
79
80 loop {
81 match stream.next().await {
82 Some(buf) => {
83 let buf = buf.map_err(FramerError::Io)?;
84 let buf = buf.as_ref();
85
86 match self.websocket.client_accept(&web_socket_key, buf) {
87 Ok((len, sub_protocol)) => {
88 let from = len;
93 let to = buf.len();
94 let remaining_len = to - from;
95
96 if remaining_len > 0 {
97 let rx_start = rx_buf.len() - remaining_len;
98 rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
99 self.rx_remainder_len = remaining_len;
100 }
101
102 return Ok(sub_protocol);
103 }
104 Err(crate::Error::HttpHeaderIncomplete) => {
105 panic!("oh no");
107 }
108 Err(e) => {
109 return Err(FramerError::WebSocket(e));
110 }
111 }
112 }
113 None => return Err(FramerError::Disconnected),
114 }
115 }
116 }
117}
118
119#[cfg(not(feature = "embedded-io-async"))]
120impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
121where
122 TRng: RngCore,
123 TWebSocketType: WebSocketType,
124{
125 pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
126 Self {
127 websocket,
128 frame_cursor: 0,
129 rx_remainder_len: 0,
130 }
131 }
132
133 pub fn encode<E>(
134 &mut self,
135 message_type: WebSocketSendMessageType,
136 end_of_message: bool,
137 from: &[u8],
138 to: &mut [u8],
139 ) -> Result<usize, FramerError<E>> {
140 let len = self
141 .websocket
142 .write(message_type, end_of_message, from, to)
143 .map_err(FramerError::WebSocket)?;
144
145 Ok(len)
146 }
147
148 pub async fn write<'b, E>(
149 &mut self,
150 tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
151 tx_buf: &'b mut [u8],
152 message_type: WebSocketSendMessageType,
153 end_of_message: bool,
154 frame_buf: &[u8],
155 ) -> Result<(), FramerError<E>>
156 where
157 E: Debug,
158 {
159 let len = self
160 .websocket
161 .write(message_type, end_of_message, frame_buf, tx_buf)
162 .map_err(FramerError::WebSocket)?;
163
164 tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
165 tx.flush().await.map_err(FramerError::Io)?;
166 Ok(())
167 }
168
169 pub async fn close<'b, E>(
170 &mut self,
171 tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
172 tx_buf: &'b mut [u8],
173 close_status: WebSocketCloseStatusCode,
174 status_description: Option<&str>,
175 ) -> Result<(), FramerError<E>>
176 where
177 E: Debug,
178 {
179 let len = self
180 .websocket
181 .close(close_status, status_description, tx_buf)
182 .map_err(FramerError::WebSocket)?;
183
184 tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
185 tx.flush().await.map_err(FramerError::Io)?;
186 Ok(())
187 }
188
189 pub async fn read<'a, B: Deref<Target = [u8]>, E>(
194 &mut self,
195 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
196 buffer: &'a mut [u8],
197 ) -> Option<Result<ReadResult<'a>, FramerError<E>>>
198 where
199 E: Debug,
200 {
201 if self.rx_remainder_len == 0 {
202 match stream.next().await {
203 Some(Ok(input)) => {
204 if buffer.len() < input.len() {
205 return Some(Err(FramerError::RxBufferTooSmall(input.len())));
206 }
207
208 let rx_start = buffer.len() - input.len();
209
210 buffer[rx_start..].copy_from_slice(&input);
212 self.rx_remainder_len = input.len()
213 }
214 Some(Err(e)) => {
215 return Some(Err(FramerError::Io(e)));
216 }
217 None => return None,
218 }
219 }
220
221 let rx_start = buffer.len() - self.rx_remainder_len;
222 let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
223
224 let ws_result = match self.websocket.read(rx_buf, frame_buf) {
225 Ok(ws_result) => ws_result,
226 Err(e) => return Some(Err(FramerError::WebSocket(e))),
227 };
228
229 self.rx_remainder_len -= ws_result.len_from;
230
231 match ws_result.message_type {
232 WebSocketReceiveMessageType::Binary => {
233 self.frame_cursor += ws_result.len_to;
234 if ws_result.end_of_message {
235 let range = 0..self.frame_cursor;
236 self.frame_cursor = 0;
237 return Some(Ok(ReadResult::Binary(&frame_buf[range])));
238 }
239 }
240 WebSocketReceiveMessageType::Text => {
241 self.frame_cursor += ws_result.len_to;
242 if ws_result.end_of_message {
243 let range = 0..self.frame_cursor;
244 self.frame_cursor = 0;
245 match core::str::from_utf8(&frame_buf[range]) {
246 Ok(text) => return Some(Ok(ReadResult::Text(text))),
247 Err(e) => return Some(Err(FramerError::Utf8(e))),
248 }
249 }
250 }
251 WebSocketReceiveMessageType::CloseMustReply => {
252 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
253
254 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
257 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
258
259 match self.websocket.write(
260 WebSocketSendMessageType::CloseReply,
261 true,
262 &frame_buf[range.start..range.end],
263 tx_buf,
264 ) {
265 Ok(len) => match stream.send(&tx_buf[..len]).await {
266 Ok(()) => {
267 self.frame_cursor = 0;
268 let status_code = ws_result
269 .close_status
270 .expect("close message must have code");
271 let reason = &frame_buf[range];
272 return Some(Ok(ReadResult::Close(CloseMessage {
273 status_code,
274 reason,
275 })));
276 }
277 Err(e) => return Some(Err(FramerError::Io(e))),
278 },
279 Err(e) => return Some(Err(FramerError::WebSocket(e))),
280 }
281 }
282 WebSocketReceiveMessageType::CloseCompleted => return None,
283 WebSocketReceiveMessageType::Pong => {
284 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
285 return Some(Ok(ReadResult::Pong(&frame_buf[range])));
286 }
287 WebSocketReceiveMessageType::Ping => {
288 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
289
290 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
293 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
294
295 match self.websocket.write(
296 WebSocketSendMessageType::Pong,
297 true,
298 &frame_buf[range.start..range.end],
299 tx_buf,
300 ) {
301 Ok(len) => match stream.send(&tx_buf[..len]).await {
302 Ok(()) => {
303 return Some(Ok(ReadResult::Ping(&frame_buf[range])));
304 }
305 Err(e) => return Some(Err(FramerError::Io(e))),
306 },
307 Err(e) => return Some(Err(FramerError::WebSocket(e))),
308 }
309 }
310 }
311
312 None
313 }
314}
315
316#[cfg(feature = "embedded-io-async")]
317impl<TRng> Framer<TRng, crate::Client>
318where
319 TRng: RngCore,
320{
321 pub async fn connect<'a, S>(
322 &mut self,
323 stream: &mut S,
324 buffer: &'a mut [u8],
325 websocket_options: &WebSocketOptions<'_>,
326 ) -> Result<Option<WebSocketSubProtocol>, FramerError<<S as ErrorType>::Error>>
327 where
328 S: Read + Write + Unpin,
329 {
330 let (tx_len, web_socket_key) = self
331 .websocket
332 .client_connect(websocket_options, buffer)
333 .map_err(FramerError::WebSocket)?;
334
335 let (tx_buf, _rx_buf) = buffer.split_at_mut(tx_len);
336 stream.write(tx_buf).await.map_err(FramerError::Io)?;
337 stream.flush().await.map_err(FramerError::Io)?;
338
339 loop {
340 let read_len = stream.read(buffer).await.map_err(FramerError::Io)?;
341
342 match self.websocket.client_accept(&web_socket_key, buffer) {
343 Ok((len, sub_protocol)) => {
344 let from = len;
349 let to = read_len;
350 let remaining_len = to - from;
351
352 if remaining_len > 0 {
353 self.rx_remainder_len = remaining_len;
356 }
357
358 return Ok(sub_protocol);
359 }
360 Err(crate::Error::HttpHeaderIncomplete) => {
361 panic!("oh no");
363 }
364 Err(e) => {
365 return Err(FramerError::WebSocket(e));
366 }
367 }
368 }
369 }
370}
371
372#[cfg(feature = "embedded-io-async")]
373impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
374where
375 TRng: RngCore,
376 TWebSocketType: WebSocketType,
377{
378 pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
379 Self {
380 websocket,
381 frame_cursor: 0,
382 rx_remainder_len: 0,
383 }
384 }
385
386 pub fn encode<E>(
387 &mut self,
388 message_type: WebSocketSendMessageType,
389 end_of_message: bool,
390 from: &[u8],
391 to: &mut [u8],
392 ) -> Result<usize, FramerError<E>> {
393 let len = self
394 .websocket
395 .write(message_type, end_of_message, from, to)
396 .map_err(FramerError::WebSocket)?;
397
398 Ok(len)
399 }
400
401 pub async fn write<'b, T>(
402 &mut self,
403 tx: &mut T,
404 tx_buf: &'b mut [u8],
405 message_type: WebSocketSendMessageType,
406 end_of_message: bool,
407 frame_buf: &[u8],
408 ) -> Result<(), FramerError<<T as ErrorType>::Error>>
409 where
410 T: Write + Unpin,
411 {
412 let len = self
413 .websocket
414 .write(message_type, end_of_message, frame_buf, tx_buf)
415 .map_err(FramerError::WebSocket)?;
416
417 tx.write(&tx_buf[..len])
418 .await
419 .map_err(FramerError::Io)
420 .unwrap();
421 tx.flush().await.map_err(FramerError::Io).unwrap();
422 Ok(())
423 }
424
425 pub async fn close<'b, T>(
426 &mut self,
427 tx: &mut T,
428 tx_buf: &'b mut [u8],
429 close_status: WebSocketCloseStatusCode,
430 status_description: Option<&str>,
431 ) -> Result<(), FramerError<<T as ErrorType>::Error>>
432 where
433 T: Write + Unpin,
434 {
435 let len = self
436 .websocket
437 .close(close_status, status_description, tx_buf)
438 .map_err(FramerError::WebSocket)?;
439
440 tx.write(&tx_buf[..len])
441 .await
442 .map_err(FramerError::Io)
443 .unwrap();
444 tx.flush().await.map_err(FramerError::Io).unwrap();
445 Ok(())
446 }
447
448 pub async fn read<'a, S>(
453 &mut self,
454 stream: &mut S,
455 buffer: &'a mut [u8],
456 ) -> Option<Result<ReadResult<'a>, FramerError<<S as ErrorType>::Error>>>
457 where
458 S: Read + Write + Unpin,
459 {
460 if self.rx_remainder_len == 0 {
461 match stream.read(buffer).await {
462 Ok(read_len) => {
463 if buffer.len() < read_len {
464 return Some(Err(FramerError::RxBufferTooSmall(read_len)));
465 }
466
467 self.rx_remainder_len = read_len
468 }
469 Err(error) => {
470 return Some(Err(FramerError::Io(error)));
471 }
472 }
473 }
474
475 let (rx_buf, frame_buf) = buffer.split_at_mut(self.rx_remainder_len);
476 let ws_result = match self.websocket.read(rx_buf, frame_buf) {
477 Ok(ws_result) => ws_result,
478 Err(e) => return Some(Err(FramerError::WebSocket(e))),
479 };
480
481 self.rx_remainder_len -= ws_result.len_from;
482
483 match ws_result.message_type {
484 WebSocketReceiveMessageType::Binary => {
485 self.frame_cursor += ws_result.len_to;
486 if ws_result.end_of_message {
487 let range = 0..self.frame_cursor;
488 self.frame_cursor = 0;
489 return Some(Ok(ReadResult::Binary(&frame_buf[range])));
490 }
491 }
492 WebSocketReceiveMessageType::Text => {
493 self.frame_cursor += ws_result.len_to;
494 if ws_result.end_of_message {
495 let range = 0..self.frame_cursor;
496 self.frame_cursor = 0;
497 match core::str::from_utf8(&frame_buf[range]) {
498 Ok(text) => return Some(Ok(ReadResult::Text(text))),
499 Err(e) => return Some(Err(FramerError::Utf8(e))),
500 }
501 }
502 }
503 WebSocketReceiveMessageType::CloseMustReply => {
504 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
505
506 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
509 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
510
511 match self.websocket.write(
512 WebSocketSendMessageType::CloseReply,
513 true,
514 &frame_buf[range.start..range.end],
515 tx_buf,
516 ) {
517 Ok(len) => match stream.write(&tx_buf[..len]).await {
518 Ok(_write_len) => {
519 self.frame_cursor = 0;
520 let status_code = ws_result
521 .close_status
522 .expect("close message must have code");
523 let reason = &frame_buf[range];
524 return Some(Ok(ReadResult::Close(CloseMessage {
525 status_code,
526 reason,
527 })));
528 }
529 Err(e) => return Some(Err(FramerError::Io(e))),
530 },
531 Err(e) => return Some(Err(FramerError::WebSocket(e))),
532 }
533 }
534 WebSocketReceiveMessageType::CloseCompleted => return None,
535 WebSocketReceiveMessageType::Pong => {
536 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
537 return Some(Ok(ReadResult::Pong(&frame_buf[range])));
538 }
539 WebSocketReceiveMessageType::Ping => {
540 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
541
542 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
545 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
546
547 match self.websocket.write(
548 WebSocketSendMessageType::Pong,
549 true,
550 &frame_buf[range.start..range.end],
551 tx_buf,
552 ) {
553 Ok(len) => match stream.write(&tx_buf[..len]).await {
554 Ok(_write_len) => {
555 return Some(Ok(ReadResult::Ping(&frame_buf[range])));
556 }
557 Err(e) => return Some(Err(FramerError::Io(e))),
558 },
559 Err(e) => return Some(Err(FramerError::WebSocket(e))),
560 }
561 }
562 }
563
564 None
565 }
566}