Skip to main content

async_wsocket/
socket.rs

1// Copyright (c) 2022-2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4use std::ops::DerefMut;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
10use arti_client::DataStream;
11use futures_util::{Sink, Stream};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::net::TcpStream;
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16use url::Url;
17
18#[cfg(target_arch = "wasm32")]
19use crate::wasm::WsStream;
20use crate::{ConnectionMode, Error, Message};
21
22#[cfg(not(target_arch = "wasm32"))]
23type WsStream<T> = WebSocketStream<MaybeTlsStream<T>>;
24
25#[allow(clippy::large_enum_variant)]
26pub enum WebSocket {
27    #[cfg(not(target_arch = "wasm32"))]
28    Tokio(WsStream<TcpStream>),
29    #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
30    Tor(WsStream<DataStream>),
31    #[cfg(target_arch = "wasm32")]
32    Wasm(WsStream),
33}
34
35impl WebSocket {
36    pub async fn connect(
37        url: &Url,
38        _mode: &ConnectionMode,
39        timeout: Duration,
40    ) -> Result<Self, Error> {
41        #[cfg(not(target_arch = "wasm32"))]
42        let socket: WebSocket = crate::native::connect(url, _mode, timeout).await?;
43
44        #[cfg(target_arch = "wasm32")]
45        let socket: WebSocket = crate::wasm::connect(url, timeout).await?;
46
47        Ok(socket)
48    }
49}
50
51impl Sink<Message> for WebSocket {
52    type Error = Error;
53
54    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
55        match self.deref_mut() {
56            #[cfg(not(target_arch = "wasm32"))]
57            Self::Tokio(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
58            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
59            Self::Tor(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
60            #[cfg(target_arch = "wasm32")]
61            Self::Wasm(s) => Pin::new(s).poll_ready(cx),
62        }
63    }
64
65    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
66        match self.deref_mut() {
67            #[cfg(not(target_arch = "wasm32"))]
68            Self::Tokio(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
69            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
70            Self::Tor(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
71            #[cfg(target_arch = "wasm32")]
72            Self::Wasm(s) => Pin::new(s).start_send(item),
73        }
74    }
75
76    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77        match self.deref_mut() {
78            #[cfg(not(target_arch = "wasm32"))]
79            Self::Tokio(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
80            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
81            Self::Tor(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
82            #[cfg(target_arch = "wasm32")]
83            Self::Wasm(s) => Pin::new(s).poll_flush(cx),
84        }
85    }
86
87    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88        match self.deref_mut() {
89            #[cfg(not(target_arch = "wasm32"))]
90            Self::Tokio(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
91            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
92            Self::Tor(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
93            #[cfg(target_arch = "wasm32")]
94            Self::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
95        }
96    }
97}
98
99impl Stream for WebSocket {
100    type Item = Result<Message, Error>;
101
102    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
103        match self.deref_mut() {
104            #[cfg(not(target_arch = "wasm32"))]
105            Self::Tokio(s) => Pin::new(s)
106                .poll_next(cx)
107                .map(|i| i.map(|res| res.map(Message::from_native)))
108                .map_err(Into::into),
109            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
110            Self::Tor(s) => Pin::new(s)
111                .poll_next(cx)
112                .map(|i| i.map(|res| res.map(Message::from_native)))
113                .map_err(Into::into),
114            #[cfg(target_arch = "wasm32")]
115            Self::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
116        }
117    }
118
119    fn size_hint(&self) -> (usize, Option<usize>) {
120        match self {
121            #[cfg(not(target_arch = "wasm32"))]
122            Self::Tokio(s) => s.size_hint(),
123            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
124            Self::Tor(s) => s.size_hint(),
125            #[cfg(target_arch = "wasm32")]
126            Self::Wasm(s) => s.size_hint(),
127        }
128    }
129}