1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use pin_project_lite::pin_project;
9
10use hyper::upgrade::Upgraded;
11#[cfg(feature = "tokio-rt")]
12use hyper_util::rt::TokioIo;
13#[cfg(feature = "smol-rt")]
14use smol_hyper::rt::FuturesIo;
15
16#[cfg(feature = "smol-rt")]
17use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
18#[cfg(feature = "tokio-rt")]
19use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
20
21use ws_framer::{WsFrame, WsRxFramer, WsTxFramer};
22
23use crate::{
24 errors::{DeboaExtrasError, WebSocketError},
25 ws::protocol::Message,
26};
27
28#[cfg(feature = "tokio-rt")]
29pub type UpgradedIo = TokioIo<Upgraded>;
30
31#[cfg(feature = "smol-rt")]
32pub type UpgradedIo = FuturesIo<Upgraded>;
33
34pub trait DeboaWebSocket {
35 type Stream;
36
37 fn new(stream: Self::Stream) -> Self;
38 fn read_message(&mut self) -> impl Future<Output = Result<Option<Message>, DeboaExtrasError>>;
39 fn write_message(
40 &mut self,
41 message: Message,
42 ) -> impl Future<Output = Result<(), DeboaExtrasError>>;
43 fn send_close(
44 &mut self,
45 code: u16,
46 reason: &str,
47 ) -> impl Future<Output = Result<(), DeboaExtrasError>>;
48 fn send_text(&mut self, message: &str) -> impl Future<Output = Result<(), DeboaExtrasError>>;
49 fn send_binary(&mut self, message: &[u8])
50 -> impl Future<Output = Result<(), DeboaExtrasError>>;
51 fn send_ping(&mut self, message: &[u8]) -> impl Future<Output = Result<(), DeboaExtrasError>>;
52 fn send_pong(&mut self, message: &[u8]) -> impl Future<Output = Result<(), DeboaExtrasError>>;
53}
54
55pin_project! {
56 pub struct WebSocket<T>
58 {
59 #[pin]
60 stream: T,
61 }
62}
63
64impl DeboaWebSocket for WebSocket<UpgradedIo> {
65 type Stream = UpgradedIo;
66
67 fn new(stream: Self::Stream) -> Self {
78 Self { stream }
79 }
80
81 async fn read_message(&mut self) -> Result<Option<Message>, DeboaExtrasError> {
100 let mut rx_buf = vec![0; 10240];
101 let mut rx_framer = WsRxFramer::new(&mut rx_buf);
102
103 let bytes_read = self
104 .stream
105 .read(rx_framer.mut_buf())
106 .await;
107 if bytes_read.is_err() {
108 return Err(DeboaExtrasError::WebSocket(WebSocketError::ReceiveMessage {
109 message: "Failed to read message".to_string(),
110 }));
111 }
112
113 let bytes_read = bytes_read.unwrap();
114 rx_framer.revolve_write_offset(bytes_read);
115 let res = rx_framer.process_data();
116 let message = if let Some(frame) = res {
117 #[allow(clippy::collapsible_match)]
118 match frame {
119 WsFrame::Text(data) => Some(Message::Text(data.to_string())),
120 WsFrame::Binary(data) => Some(Message::Binary(data.to_vec())),
121 WsFrame::Close(code, reason) => Some(Message::Close(code, reason.to_string())),
122 WsFrame::Ping(data) => Some(Message::Ping(data.to_vec())),
123 _ => None,
124 }
125 } else {
126 None
127 };
128
129 Ok(message)
130 }
131
132 async fn write_message(&mut self, message: Message) -> Result<(), DeboaExtrasError> {
160 let mut tx_buf = vec![0; 10240];
161 let mut tx_framer = WsTxFramer::new(true, &mut tx_buf);
162
163 let result = match message {
164 Message::Text(data) => {
165 self.write_all(tx_framer.frame(WsFrame::Text(&data)))
166 .await
167 }
168 Message::Binary(data) => {
169 self.write_all(tx_framer.frame(WsFrame::Binary(&data)))
170 .await
171 }
172 Message::Close(code, reason) => {
173 self.write_all(tx_framer.frame(WsFrame::Close(code, &reason)))
174 .await
175 }
176 Message::Ping(data) => {
177 self.write_all(tx_framer.frame(WsFrame::Ping(&data)))
178 .await
179 }
180 _ => Ok(()),
181 };
182
183 if result.is_err() {
184 return Err(DeboaExtrasError::WebSocket(WebSocketError::SendMessage {
185 message: "Failed to send frame".to_string(),
186 }));
187 }
188
189 Ok(())
190 }
191
192 async fn send_close(&mut self, code: u16, reason: &str) -> Result<(), DeboaExtrasError> {
218 self.write_message(Message::Close(code, reason.to_string()))
219 .await
220 }
221
222 async fn send_text(&mut self, message: &str) -> Result<(), DeboaExtrasError> {
247 self.write_message(Message::Text(message.to_string()))
248 .await
249 }
250
251 async fn send_binary(&mut self, message: &[u8]) -> Result<(), DeboaExtrasError> {
276 self.write_message(Message::Binary(message.to_vec()))
277 .await
278 }
279
280 async fn send_ping(&mut self, message: &[u8]) -> Result<(), DeboaExtrasError> {
305 self.write_message(Message::Ping(message.to_vec()))
306 .await
307 }
308
309 async fn send_pong(&mut self, message: &[u8]) -> Result<(), DeboaExtrasError> {
334 self.write_message(Message::Pong(message.to_vec()))
335 .await
336 }
337}
338
339#[cfg(feature = "tokio-rt")]
340impl AsyncRead for WebSocket<UpgradedIo> {
341 fn poll_read(
342 self: Pin<&mut Self>,
343 cx: &mut Context<'_>,
344 buf: &mut ReadBuf<'_>,
345 ) -> Poll<io::Result<()>> {
346 self.project()
347 .stream
348 .poll_read(cx, buf)
349 }
350}
351
352#[cfg(feature = "tokio-rt")]
353impl AsyncWrite for WebSocket<UpgradedIo> {
354 fn poll_write(
355 self: Pin<&mut Self>,
356 cx: &mut Context<'_>,
357 buf: &[u8],
358 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
359 self.project()
360 .stream
361 .poll_write(cx, buf)
362 }
363
364 fn poll_flush(
365 self: Pin<&mut Self>,
366 cx: &mut Context<'_>,
367 ) -> Poll<std::result::Result<(), std::io::Error>> {
368 self.project()
369 .stream
370 .poll_flush(cx)
371 }
372
373 fn poll_shutdown(
374 self: Pin<&mut Self>,
375 cx: &mut Context<'_>,
376 ) -> Poll<std::result::Result<(), std::io::Error>> {
377 self.project()
378 .stream
379 .poll_shutdown(cx)
380 }
381
382 fn poll_write_vectored(
383 self: Pin<&mut Self>,
384 cx: &mut Context<'_>,
385 bufs: &[std::io::IoSlice<'_>],
386 ) -> Poll<std::result::Result<usize, std::io::Error>> {
387 let buf = bufs
388 .iter()
389 .find(|b| !b.is_empty())
390 .map_or(&[][..], |b| &**b);
391 self.project()
392 .stream
393 .poll_write(cx, buf)
394 }
395
396 fn is_write_vectored(&self) -> bool {
397 self.stream
398 .is_write_vectored()
399 }
400}
401
402#[cfg(feature = "smol-rt")]
403impl<T> AsyncRead for WebSocket<FuturesIo<T>>
404where
405 T: hyper::rt::Read,
406{
407 fn poll_read(
408 self: Pin<&mut Self>,
409 cx: &mut Context<'_>,
410 buf: &mut [u8],
411 ) -> Poll<io::Result<usize>> {
412 Poll::Ready(Ok(0))
413 }
414}
415
416#[cfg(feature = "smol-rt")]
417impl<T> AsyncWrite for WebSocket<FuturesIo<T>>
418where
419 T: hyper::rt::Write,
420{
421 fn poll_write(
422 self: Pin<&mut Self>,
423 cx: &mut Context<'_>,
424 buf: &[u8],
425 ) -> Poll<io::Result<usize>> {
426 hyper::rt::Write::poll_write(
427 self.project()
428 .stream
429 .get_pin_mut(),
430 cx,
431 buf,
432 )
433 }
434
435 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
436 hyper::rt::Write::poll_flush(
437 self.project()
438 .stream
439 .get_pin_mut(),
440 cx,
441 )
442 }
443
444 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
445 hyper::rt::Write::poll_shutdown(
446 self.project()
447 .stream
448 .get_pin_mut(),
449 cx,
450 )
451 }
452
453 fn poll_write_vectored(
454 self: Pin<&mut Self>,
455 cx: &mut Context<'_>,
456 bufs: &[std::io::IoSlice<'_>],
457 ) -> Poll<std::result::Result<usize, std::io::Error>> {
458 hyper::rt::Write::poll_write_vectored(
459 self.project()
460 .stream
461 .get_pin_mut(),
462 cx,
463 bufs,
464 )
465 }
466}