Skip to main content

deboa_extras/ws/io/
socket.rs

1use std::{
2    future::Future,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use pin_project_lite::pin_project;
9
10use hyper::upgrade::Upgraded;
11#[cfg(feature = "tokio-rt")]
12use hyper_util::rt::TokioIo;
13#[cfg(feature = "smol-rt")]
14use smol_hyper::rt::FuturesIo;
15
16#[cfg(feature = "smol-rt")]
17use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
18#[cfg(feature = "tokio-rt")]
19use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
20
21use ws_framer::{WsFrame, WsRxFramer, WsTxFramer};
22
23use crate::{
24    errors::{DeboaExtrasError, WebSocketError},
25    ws::protocol::Message,
26};
27
28#[cfg(feature = "tokio-rt")]
29pub type UpgradedIo = TokioIo<Upgraded>;
30
31#[cfg(feature = "smol-rt")]
32pub type UpgradedIo = FuturesIo<Upgraded>;
33
34pub trait DeboaWebSocket {
35    type Stream;
36
37    fn new(stream: Self::Stream) -> Self;
38    fn read_message(&mut self) -> impl Future<Output = Result<Option<Message>, DeboaExtrasError>>;
39    fn write_message(
40        &mut self,
41        message: Message,
42    ) -> impl Future<Output = Result<(), DeboaExtrasError>>;
43    fn send_close(
44        &mut self,
45        code: u16,
46        reason: &str,
47    ) -> impl Future<Output = Result<(), DeboaExtrasError>>;
48    fn send_text(&mut self, message: &str) -> impl Future<Output = Result<(), DeboaExtrasError>>;
49    fn send_binary(&mut self, message: &[u8])
50        -> impl Future<Output = Result<(), DeboaExtrasError>>;
51    fn send_ping(&mut self, message: &[u8]) -> impl Future<Output = Result<(), DeboaExtrasError>>;
52    fn send_pong(&mut self, message: &[u8]) -> impl Future<Output = Result<(), DeboaExtrasError>>;
53}
54
55pin_project! {
56    /// WebSocket struct
57    pub struct WebSocket<T>
58    {
59        #[pin]
60        stream: T,
61    }
62}
63
64impl DeboaWebSocket for WebSocket<UpgradedIo> {
65    type Stream = UpgradedIo;
66
67    /// new method
68    ///
69    /// # Arguments
70    ///
71    /// * `stream` - A string slice that holds the stream data.
72    ///
73    /// # Returns
74    ///
75    /// A WebSocket struct.
76    ///
77    fn new(stream: Self::Stream) -> Self {
78        Self { stream }
79    }
80
81    /// Reads a message from the WebSocket.
82    ///
83    /// # Returns
84    ///
85    /// A Result containing an Option<Message> or a DeboaExtrasError.
86    ///
87    /// # Examples
88    ///
89    /// ```rust, compile_fail
90    /// while let Some(message) = websocket.read_message().await {
91    ///     println!("message: {}", message);
92    /// }
93    /// ```
94    ///
95    /// # Panics
96    ///
97    /// This function may panic if the WebSocket frame processing fails.
98    ///
99    async fn read_message(&mut self) -> Result<Option<Message>, DeboaExtrasError> {
100        let mut rx_buf = vec![0; 10240];
101        let mut rx_framer = WsRxFramer::new(&mut rx_buf);
102
103        let bytes_read = self
104            .stream
105            .read(rx_framer.mut_buf())
106            .await;
107        if bytes_read.is_err() {
108            return Err(DeboaExtrasError::WebSocket(WebSocketError::ReceiveMessage {
109                message: "Failed to read message".to_string(),
110            }));
111        }
112
113        let bytes_read = bytes_read.unwrap();
114        rx_framer.revolve_write_offset(bytes_read);
115        let res = rx_framer.process_data();
116        let message = if let Some(frame) = res {
117            #[allow(clippy::collapsible_match)]
118            match frame {
119                WsFrame::Text(data) => Some(Message::Text(data.to_string())),
120                WsFrame::Binary(data) => Some(Message::Binary(data.to_vec())),
121                WsFrame::Close(code, reason) => Some(Message::Close(code, reason.to_string())),
122                WsFrame::Ping(data) => Some(Message::Ping(data.to_vec())),
123                _ => None,
124            }
125        } else {
126            None
127        };
128
129        Ok(message)
130    }
131
132    /// Writes a message to the WebSocket.
133    ///
134    /// # Arguments
135    ///
136    /// * `message` - The message to write.
137    ///
138    /// # Returns
139    ///
140    /// A Result indicating success or a DeboaExtrasError.
141    ///
142    /// # Examples
143    ///
144    /// ```rust, compile_fail
145    /// let result = websocket
146    ///   .write_message(protocol::Message::Text(message.to_string()))
147    ///   .await;
148    /// if result.is_err() {
149    ///     output.send(Event::Disconnected).await;
150    ///     break;
151    /// }
152    /// ```
153    ///
154    /// # Panics
155    ///
156    /// This function may panic if the WebSocket frame processing fails.
157    ///
158    ///
159    async fn write_message(&mut self, message: Message) -> Result<(), DeboaExtrasError> {
160        let mut tx_buf = vec![0; 10240];
161        let mut tx_framer = WsTxFramer::new(true, &mut tx_buf);
162
163        let result = match message {
164            Message::Text(data) => {
165                self.write_all(tx_framer.frame(WsFrame::Text(&data)))
166                    .await
167            }
168            Message::Binary(data) => {
169                self.write_all(tx_framer.frame(WsFrame::Binary(&data)))
170                    .await
171            }
172            Message::Close(code, reason) => {
173                self.write_all(tx_framer.frame(WsFrame::Close(code, &reason)))
174                    .await
175            }
176            Message::Ping(data) => {
177                self.write_all(tx_framer.frame(WsFrame::Ping(&data)))
178                    .await
179            }
180            _ => Ok(()),
181        };
182
183        if result.is_err() {
184            return Err(DeboaExtrasError::WebSocket(WebSocketError::SendMessage {
185                message: "Failed to send frame".to_string(),
186            }));
187        }
188
189        Ok(())
190    }
191
192    /// Sends a close frame to the WebSocket.
193    ///
194    /// # Arguments
195    ///
196    /// * `code` - The close code.
197    /// * `reason` - The close reason.
198    ///
199    /// # Returns
200    ///
201    /// A Result indicating success or a DeboaExtrasError.
202    ///
203    /// # Examples
204    ///
205    /// ```rust, compile_fail
206    /// let result = websocket.send_close(1000, "Goodbye").await;
207    /// if result.is_err() {
208    ///     output.send(Event::Disconnected).await;
209    ///     break;
210    /// }
211    /// ```
212    ///
213    /// # Panics
214    ///
215    /// This function may panic if the WebSocket frame processing fails.
216    ///
217    async fn send_close(&mut self, code: u16, reason: &str) -> Result<(), DeboaExtrasError> {
218        self.write_message(Message::Close(code, reason.to_string()))
219            .await
220    }
221
222    /// Sends a text frame to the WebSocket.
223    ///
224    /// # Arguments
225    ///
226    /// * `message` - The text message to send.
227    ///
228    /// # Returns
229    ///
230    /// A Result indicating success or a DeboaExtrasError.
231    ///
232    /// # Examples
233    ///
234    /// ```rust, compile_fail
235    /// let result = websocket.send_text("Hello").await;
236    /// if result.is_err() {
237    ///     output.send(Event::Disconnected).await;
238    ///     break;
239    /// }
240    /// ```
241    ///
242    /// # Panics
243    ///
244    /// This function may panic if the WebSocket frame processing fails.
245    ///
246    async fn send_text(&mut self, message: &str) -> Result<(), DeboaExtrasError> {
247        self.write_message(Message::Text(message.to_string()))
248            .await
249    }
250
251    /// Sends a binary frame to the WebSocket.
252    ///
253    /// # Arguments
254    ///
255    /// * `message` - The binary message to send.
256    ///
257    /// # Returns
258    ///
259    /// A Result indicating success or a DeboaExtrasError.
260    ///
261    /// # Examples
262    ///
263    /// ```rust, compile_fail
264    /// let result = websocket.send_binary(&[0x00, 0x01, 0x02]).await;
265    /// if result.is_err() {
266    ///     output.send(Event::Disconnected).await;
267    ///     break;
268    /// }
269    /// ```
270    ///
271    /// # Panics
272    ///
273    /// This function may panic if the WebSocket frame processing fails.
274    ///
275    async fn send_binary(&mut self, message: &[u8]) -> Result<(), DeboaExtrasError> {
276        self.write_message(Message::Binary(message.to_vec()))
277            .await
278    }
279
280    /// Sends a ping frame to the WebSocket.
281    ///
282    /// # Arguments
283    ///
284    /// * `message` - The ping message to send.
285    ///
286    /// # Returns
287    ///
288    /// A Result indicating success or a DeboaExtrasError.
289    ///
290    /// # Examples
291    ///
292    /// ```rust, compile_fail
293    /// let result = websocket.send_ping(&[0x00, 0x01, 0x02]).await;
294    /// if result.is_err() {
295    ///     output.send(Event::Disconnected).await;
296    ///     break;
297    /// }
298    /// ```
299    ///
300    /// # Panics
301    ///
302    /// This function may panic if the WebSocket frame processing fails.
303    ///
304    async fn send_ping(&mut self, message: &[u8]) -> Result<(), DeboaExtrasError> {
305        self.write_message(Message::Ping(message.to_vec()))
306            .await
307    }
308
309    /// Sends a pong frame to the WebSocket.
310    ///
311    /// # Arguments
312    ///
313    /// * `message` - The pong message to send.
314    ///
315    /// # Returns
316    ///
317    /// A Result indicating success or a DeboaExtrasError.
318    ///
319    /// # Examples
320    ///
321    /// ```rust, compile_fail
322    /// let result = websocket.send_pong(&[0x00, 0x01, 0x02]).await;
323    /// if result.is_err() {
324    ///     output.send(Event::Disconnected).await;
325    ///     break;
326    /// }
327    /// ```
328    ///
329    /// # Panics
330    ///
331    /// This function may panic if the WebSocket frame processing fails.
332    ///
333    async fn send_pong(&mut self, message: &[u8]) -> Result<(), DeboaExtrasError> {
334        self.write_message(Message::Pong(message.to_vec()))
335            .await
336    }
337}
338
339#[cfg(feature = "tokio-rt")]
340impl AsyncRead for WebSocket<UpgradedIo> {
341    fn poll_read(
342        self: Pin<&mut Self>,
343        cx: &mut Context<'_>,
344        buf: &mut ReadBuf<'_>,
345    ) -> Poll<io::Result<()>> {
346        self.project()
347            .stream
348            .poll_read(cx, buf)
349    }
350}
351
352#[cfg(feature = "tokio-rt")]
353impl AsyncWrite for WebSocket<UpgradedIo> {
354    fn poll_write(
355        self: Pin<&mut Self>,
356        cx: &mut Context<'_>,
357        buf: &[u8],
358    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
359        self.project()
360            .stream
361            .poll_write(cx, buf)
362    }
363
364    fn poll_flush(
365        self: Pin<&mut Self>,
366        cx: &mut Context<'_>,
367    ) -> Poll<std::result::Result<(), std::io::Error>> {
368        self.project()
369            .stream
370            .poll_flush(cx)
371    }
372
373    fn poll_shutdown(
374        self: Pin<&mut Self>,
375        cx: &mut Context<'_>,
376    ) -> Poll<std::result::Result<(), std::io::Error>> {
377        self.project()
378            .stream
379            .poll_shutdown(cx)
380    }
381
382    fn poll_write_vectored(
383        self: Pin<&mut Self>,
384        cx: &mut Context<'_>,
385        bufs: &[std::io::IoSlice<'_>],
386    ) -> Poll<std::result::Result<usize, std::io::Error>> {
387        let buf = bufs
388            .iter()
389            .find(|b| !b.is_empty())
390            .map_or(&[][..], |b| &**b);
391        self.project()
392            .stream
393            .poll_write(cx, buf)
394    }
395
396    fn is_write_vectored(&self) -> bool {
397        self.stream
398            .is_write_vectored()
399    }
400}
401
402#[cfg(feature = "smol-rt")]
403impl<T> AsyncRead for WebSocket<FuturesIo<T>>
404where
405    T: hyper::rt::Read,
406{
407    fn poll_read(
408        self: Pin<&mut Self>,
409        cx: &mut Context<'_>,
410        buf: &mut [u8],
411    ) -> Poll<io::Result<usize>> {
412        Poll::Ready(Ok(0))
413    }
414}
415
416#[cfg(feature = "smol-rt")]
417impl<T> AsyncWrite for WebSocket<FuturesIo<T>>
418where
419    T: hyper::rt::Write,
420{
421    fn poll_write(
422        self: Pin<&mut Self>,
423        cx: &mut Context<'_>,
424        buf: &[u8],
425    ) -> Poll<io::Result<usize>> {
426        hyper::rt::Write::poll_write(
427            self.project()
428                .stream
429                .get_pin_mut(),
430            cx,
431            buf,
432        )
433    }
434
435    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
436        hyper::rt::Write::poll_flush(
437            self.project()
438                .stream
439                .get_pin_mut(),
440            cx,
441        )
442    }
443
444    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
445        hyper::rt::Write::poll_shutdown(
446            self.project()
447                .stream
448                .get_pin_mut(),
449            cx,
450        )
451    }
452
453    fn poll_write_vectored(
454        self: Pin<&mut Self>,
455        cx: &mut Context<'_>,
456        bufs: &[std::io::IoSlice<'_>],
457    ) -> Poll<std::result::Result<usize, std::io::Error>> {
458        hyper::rt::Write::poll_write_vectored(
459            self.project()
460                .stream
461                .get_pin_mut(),
462            cx,
463            bufs,
464        )
465    }
466}