Skip to main content

async_wsocket/
socket.rs

1// Copyright (c) 2022-2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures_util::{Sink, Stream};
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::net::TcpStream;
10#[cfg(not(target_arch = "wasm32"))]
11use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
12use url::Url;
13
14#[cfg(target_arch = "wasm32")]
15use crate::wasm::WsStream;
16use crate::{ConnectionMode, Error, Message};
17
18#[cfg(not(target_arch = "wasm32"))]
19type WsStream<T> = WebSocketStream<MaybeTlsStream<T>>;
20
21enum InnerWebSocket {
22    #[cfg(not(target_arch = "wasm32"))]
23    Tokio(Box<WsStream<TcpStream>>),
24    #[cfg(target_arch = "wasm32")]
25    Wasm(WsStream),
26}
27
28pub struct WebSocket {
29    inner: InnerWebSocket,
30}
31
32impl WebSocket {
33    #[inline]
34    fn new(inner: InnerWebSocket) -> Self {
35        Self { inner }
36    }
37
38    #[inline]
39    #[cfg(not(target_arch = "wasm32"))]
40    pub(crate) fn tokio(inner: Box<WsStream<TcpStream>>) -> Self {
41        Self::new(InnerWebSocket::Tokio(inner))
42    }
43    #[inline]
44    #[cfg(target_arch = "wasm32")]
45    pub(crate) fn wasm(inner: WsStream) -> Self {
46        Self::new(InnerWebSocket::Wasm(inner))
47    }
48
49    pub async fn connect(url: &Url, _mode: &ConnectionMode) -> Result<Self, Error> {
50        #[cfg(not(target_arch = "wasm32"))]
51        let socket: WebSocket = crate::native::connect(url, _mode).await?;
52
53        #[cfg(target_arch = "wasm32")]
54        let socket: WebSocket = crate::wasm::connect(url).await?;
55
56        Ok(socket)
57    }
58}
59
60impl Sink<Message> for WebSocket {
61    type Error = Error;
62
63    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
64        match &mut self.inner {
65            #[cfg(not(target_arch = "wasm32"))]
66            InnerWebSocket::Tokio(s) => Pin::new(s.as_mut()).poll_ready(cx).map_err(Into::into),
67            #[cfg(target_arch = "wasm32")]
68            InnerWebSocket::Wasm(s) => Pin::new(s).poll_ready(cx),
69        }
70    }
71
72    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
73        match &mut self.inner {
74            #[cfg(not(target_arch = "wasm32"))]
75            InnerWebSocket::Tokio(s) => Pin::new(s.as_mut())
76                .start_send(item.into())
77                .map_err(Into::into),
78            #[cfg(target_arch = "wasm32")]
79            InnerWebSocket::Wasm(s) => Pin::new(s).start_send(item),
80        }
81    }
82
83    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        match &mut self.inner {
85            #[cfg(not(target_arch = "wasm32"))]
86            InnerWebSocket::Tokio(s) => Pin::new(s.as_mut()).poll_flush(cx).map_err(Into::into),
87            #[cfg(target_arch = "wasm32")]
88            InnerWebSocket::Wasm(s) => Pin::new(s).poll_flush(cx),
89        }
90    }
91
92    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93        match &mut self.inner {
94            #[cfg(not(target_arch = "wasm32"))]
95            InnerWebSocket::Tokio(s) => Pin::new(s.as_mut()).poll_close(cx).map_err(Into::into),
96            #[cfg(target_arch = "wasm32")]
97            InnerWebSocket::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
98        }
99    }
100}
101
102impl Stream for WebSocket {
103    type Item = Result<Message, Error>;
104
105    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        match &mut self.inner {
107            #[cfg(not(target_arch = "wasm32"))]
108            InnerWebSocket::Tokio(s) => Pin::new(s)
109                .poll_next(cx)
110                .map(|i| i.map(|res| res.map(Message::from_native)))
111                .map_err(Into::into),
112            #[cfg(target_arch = "wasm32")]
113            InnerWebSocket::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
114        }
115    }
116
117    fn size_hint(&self) -> (usize, Option<usize>) {
118        match &self.inner {
119            #[cfg(not(target_arch = "wasm32"))]
120            InnerWebSocket::Tokio(s) => s.size_hint(),
121            #[cfg(target_arch = "wasm32")]
122            InnerWebSocket::Wasm(s) => s.size_hint(),
123        }
124    }
125}