1use std::ops::DerefMut;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
10use arti_client::DataStream;
11use futures_util::{Sink, Stream};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::net::TcpStream;
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16use url::Url;
17
18#[cfg(target_arch = "wasm32")]
19use crate::wasm::WsStream;
20use crate::{ConnectionMode, Error, Message};
21
22#[cfg(not(target_arch = "wasm32"))]
23type WsStream<T> = WebSocketStream<MaybeTlsStream<T>>;
24
25pub enum WebSocket {
26 #[cfg(not(target_arch = "wasm32"))]
27 Tokio(WsStream<TcpStream>),
28 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
29 Tor(WsStream<DataStream>),
30 #[cfg(target_arch = "wasm32")]
31 Wasm(WsStream),
32}
33
34impl WebSocket {
35 pub async fn connect(
36 url: &Url,
37 _mode: &ConnectionMode,
38 timeout: Duration,
39 ) -> Result<Self, Error> {
40 #[cfg(not(target_arch = "wasm32"))]
41 let socket: WebSocket = crate::native::connect(url, _mode, timeout).await?;
42
43 #[cfg(target_arch = "wasm32")]
44 let socket: WebSocket = crate::wasm::connect(url, timeout).await?;
45
46 Ok(socket)
47 }
48}
49
50impl Sink<Message> for WebSocket {
51 type Error = Error;
52
53 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54 match self.deref_mut() {
55 #[cfg(not(target_arch = "wasm32"))]
56 Self::Tokio(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
57 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
58 Self::Tor(s) => Pin::new(s).poll_ready(cx).map_err(Into::into),
59 #[cfg(target_arch = "wasm32")]
60 Self::Wasm(s) => Pin::new(s).poll_ready(cx),
61 }
62 }
63
64 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
65 match self.deref_mut() {
66 #[cfg(not(target_arch = "wasm32"))]
67 Self::Tokio(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
68 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
69 Self::Tor(s) => Pin::new(s).start_send(item.into()).map_err(Into::into),
70 #[cfg(target_arch = "wasm32")]
71 Self::Wasm(s) => Pin::new(s).start_send(item),
72 }
73 }
74
75 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 match self.deref_mut() {
77 #[cfg(not(target_arch = "wasm32"))]
78 Self::Tokio(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
79 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
80 Self::Tor(s) => Pin::new(s).poll_flush(cx).map_err(Into::into),
81 #[cfg(target_arch = "wasm32")]
82 Self::Wasm(s) => Pin::new(s).poll_flush(cx),
83 }
84 }
85
86 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 match self.deref_mut() {
88 #[cfg(not(target_arch = "wasm32"))]
89 Self::Tokio(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
90 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
91 Self::Tor(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
92 #[cfg(target_arch = "wasm32")]
93 Self::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
94 }
95 }
96}
97
98impl Stream for WebSocket {
99 type Item = Result<Message, Error>;
100
101 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
102 match self.deref_mut() {
103 #[cfg(not(target_arch = "wasm32"))]
104 Self::Tokio(s) => Pin::new(s)
105 .poll_next(cx)
106 .map(|i| i.map(|res| res.map(Message::from_native)))
107 .map_err(Into::into),
108 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
109 Self::Tor(s) => Pin::new(s)
110 .poll_next(cx)
111 .map(|i| i.map(|res| res.map(Message::from_native)))
112 .map_err(Into::into),
113 #[cfg(target_arch = "wasm32")]
114 Self::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
115 }
116 }
117
118 fn size_hint(&self) -> (usize, Option<usize>) {
119 match self {
120 #[cfg(not(target_arch = "wasm32"))]
121 Self::Tokio(s) => s.size_hint(),
122 #[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
123 Self::Tor(s) => s.size_hint(),
124 #[cfg(target_arch = "wasm32")]
125 Self::Wasm(s) => s.size_hint(),
126 }
127 }
128}