async_wsocket/
socket.rs

1// Copyright (c) 2022-2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4use std::ops::DerefMut;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
10use arti_client::DataStream;
11use futures_util::{Sink, Stream};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::net::TcpStream;
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16use url::Url;
17
18#[cfg(target_arch = "wasm32")]
19use crate::wasm::WsStream;
20use crate::{ConnectionMode, Error, Message};
21
22#[cfg(not(target_arch = "wasm32"))]
23type WsStream<T> = WebSocketStream<MaybeTlsStream<T>>;
24
25pub enum WebSocket {
26    #[cfg(not(target_arch = "wasm32"))]
27    Tokio(WsStream<TcpStream>),
28    #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
29    Tor(WsStream<DataStream>),
30    #[cfg(target_arch = "wasm32")]
31    Wasm(WsStream),
32}
33
34impl WebSocket {
35    pub async fn connect(
36        url: &Url,
37        _mode: &ConnectionMode,
38        timeout: Duration,
39    ) -> Result<Self, Error> {
40        #[cfg(not(target_arch = "wasm32"))]
41        let socket: WebSocket = crate::native::connect(url, _mode, timeout).await?;
42
43        #[cfg(target_arch = "wasm32")]
44        let socket: WebSocket = crate::wasm::connect(url, timeout).await?;
45
46        Ok(socket)
47    }
48}
49
50impl Sink<Message> for WebSocket {
51    type Error = Error;
52
53    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54        match self.deref_mut() {
55            #[cfg(not(target_arch = "wasm32"))]
56            Self::Tokio(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
57            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
58            Self::Tor(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
59            #[cfg(target_arch = "wasm32")]
60            Self::Wasm(s) => Pin::new(s).poll_ready(cx),
61        }
62    }
63
64    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
65        match self.deref_mut() {
66            #[cfg(not(target_arch = "wasm32"))]
67            Self::Tokio(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
68            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
69            Self::Tor(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
70            #[cfg(target_arch = "wasm32")]
71            Self::Wasm(s) => Pin::new(s).start_send(item),
72        }
73    }
74
75    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        match self.deref_mut() {
77            #[cfg(not(target_arch = "wasm32"))]
78            Self::Tokio(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
79            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
80            Self::Tor(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
81            #[cfg(target_arch = "wasm32")]
82            Self::Wasm(s) => Pin::new(s).poll_flush(cx),
83        }
84    }
85
86    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        match self.deref_mut() {
88            #[cfg(not(target_arch = "wasm32"))]
89            Self::Tokio(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
90            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
91            Self::Tor(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
92            #[cfg(target_arch = "wasm32")]
93            Self::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
94        }
95    }
96}
97
98impl Stream for WebSocket {
99    type Item = Result<Message, Error>;
100
101    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
102        match self.deref_mut() {
103            #[cfg(not(target_arch = "wasm32"))]
104            Self::Tokio(s) => Pin::new(s)
105                .poll_next(cx)
106                .map(|i| i.map(|res| res.map(Message::from_native)))
107                .map_err(Into::into),
108            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
109            Self::Tor(s) => Pin::new(s)
110                .poll_next(cx)
111                .map(|i| i.map(|res| res.map(Message::from_native)))
112                .map_err(Into::into),
113            #[cfg(target_arch = "wasm32")]
114            Self::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
115        }
116    }
117
118    fn size_hint(&self) -> (usize, Option<usize>) {
119        match self {
120            #[cfg(not(target_arch = "wasm32"))]
121            Self::Tokio(s) => s.size_hint(),
122            #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
123            Self::Tor(s) => s.size_hint(),
124            #[cfg(target_arch = "wasm32")]
125            Self::Wasm(s) => s.size_hint(),
126        }
127    }
128}