1use std::ops::DerefMut;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
10use arti_client::DataStream;
11use futures_util::{Sink, Stream};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::net::TcpStream;
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16use url::Url;
17
18#[cfg(target_arch = "wasm32")]
19use crate::wasm::WsStream;
20use crate::{ConnectionMode, Error, Message};
21
22#[cfg(not(target_arch = "wasm32"))]
23type WsStream<T> = WebSocketStream<MaybeTlsStream<T>>;
24
25#[allow(clippy::large_enum_variant)]
26pub enum WebSocket {
27 #[cfg(not(target_arch = "wasm32"))]
28 Tokio(WsStream<TcpStream>),
29 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
30 Tor(WsStream<DataStream>),
31 #[cfg(target_arch = "wasm32")]
32 Wasm(WsStream),
33}
34
35impl WebSocket {
36 pub async fn connect(
37 url: &Url,
38 _mode: &ConnectionMode,
39 timeout: Duration,
40 ) -> Result<Self, Error> {
41 #[cfg(not(target_arch = "wasm32"))]
42 let socket: WebSocket = crate::native::connect(url, _mode, timeout).await?;
43
44 #[cfg(target_arch = "wasm32")]
45 let socket: WebSocket = crate::wasm::connect(url, timeout).await?;
46
47 Ok(socket)
48 }
49}
50
51impl Sink<Message> for WebSocket {
52 type Error = Error;
53
54 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
55 match self.deref_mut() {
56 #[cfg(not(target_arch = "wasm32"))]
57 Self::Tokio(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
58 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
59 Self::Tor(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
60 #[cfg(target_arch = "wasm32")]
61 Self::Wasm(s) => Pin::new(s).poll_ready(cx),
62 }
63 }
64
65 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
66 match self.deref_mut() {
67 #[cfg(not(target_arch = "wasm32"))]
68 Self::Tokio(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
69 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
70 Self::Tor(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
71 #[cfg(target_arch = "wasm32")]
72 Self::Wasm(s) => Pin::new(s).start_send(item),
73 }
74 }
75
76 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77 match self.deref_mut() {
78 #[cfg(not(target_arch = "wasm32"))]
79 Self::Tokio(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
80 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
81 Self::Tor(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
82 #[cfg(target_arch = "wasm32")]
83 Self::Wasm(s) => Pin::new(s).poll_flush(cx),
84 }
85 }
86
87 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88 match self.deref_mut() {
89 #[cfg(not(target_arch = "wasm32"))]
90 Self::Tokio(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
91 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
92 Self::Tor(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
93 #[cfg(target_arch = "wasm32")]
94 Self::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
95 }
96 }
97}
98
99impl Stream for WebSocket {
100 type Item = Result<Message, Error>;
101
102 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
103 match self.deref_mut() {
104 #[cfg(not(target_arch = "wasm32"))]
105 Self::Tokio(s) => Pin::new(s)
106 .poll_next(cx)
107 .map(|i| i.map(|res| res.map(Message::from_native)))
108 .map_err(Into::into),
109 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
110 Self::Tor(s) => Pin::new(s)
111 .poll_next(cx)
112 .map(|i| i.map(|res| res.map(Message::from_native)))
113 .map_err(Into::into),
114 #[cfg(target_arch = "wasm32")]
115 Self::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
116 }
117 }
118
119 fn size_hint(&self) -> (usize, Option<usize>) {
120 match self {
121 #[cfg(not(target_arch = "wasm32"))]
122 Self::Tokio(s) => s.size_hint(),
123 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
124 Self::Tor(s) => s.size_hint(),
125 #[cfg(target_arch = "wasm32")]
126 Self::Wasm(s) => s.size_hint(),
127 }
128 }
129}