async_wsocket/native/
mod.rs1use std::net::SocketAddr;
7#[cfg(feature = "tor")]
8use std::path::PathBuf;
9use std::time::Duration;
10
11#[cfg(feature = "tor")]
12use arti_client::DataStream;
13use tokio::io::{AsyncRead, AsyncWrite};
14use tokio::net::TcpStream;
15use tokio::time;
16use tokio_tungstenite::tungstenite::protocol::Role;
17pub use tokio_tungstenite::tungstenite::Message;
18pub use tokio_tungstenite::WebSocketStream;
19use url::Url;
20
21mod error;
22#[cfg(feature = "socks")]
23mod socks;
24#[cfg(feature = "tor")]
25pub mod tor;
26
27pub use self::error::Error;
28#[cfg(feature = "socks")]
29use self::socks::TcpSocks5Stream;
30use crate::socket::WebSocket;
31use crate::ConnectionMode;
32
33pub async fn connect(
34 url: &Url,
35 mode: &ConnectionMode,
36 timeout: Duration,
37) -> Result<WebSocket, Error> {
38 match mode {
39 ConnectionMode::Direct => connect_direct(url, timeout).await,
40 #[cfg(feature = "socks")]
41 ConnectionMode::Proxy(proxy) => connect_proxy(url, *proxy, timeout).await,
42 #[cfg(feature = "tor")]
43 ConnectionMode::Tor { custom_path } => {
44 connect_tor(url, timeout, custom_path.as_ref()).await
45 }
46 }
47}
48
49const HAPPY_EYEBALLS_DELAY: Duration = Duration::from_millis(250);
51
52async fn connect_direct(url: &Url, timeout: Duration) -> Result<WebSocket, Error> {
53 let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
54 let port: u16 = url
55 .port_or_known_default()
56 .ok_or_else(Error::invalid_port)?;
57
58 let conn_fut = async {
59 let tcp_stream = happy_eyeballs_connect(host, port).await?;
60
61 let (stream, _) = Box::pin(tokio_tungstenite::client_async_tls(
64 url.as_str(),
65 tcp_stream,
66 ))
67 .await?;
68
69 Ok::<_, Error>(stream)
70 };
71
72 let stream = time::timeout(timeout, conn_fut)
73 .await
74 .map_err(|_| Error::Timeout)??;
75
76 Ok(WebSocket::Tokio(stream))
77}
78
79async fn happy_eyeballs_connect(host: &str, port: u16) -> Result<TcpStream, Error> {
85 let addrs: Vec<SocketAddr> = tokio::net::lookup_host((host, port)).await?.collect();
86
87 if addrs.is_empty() {
88 return Err(std::io::Error::new(
89 std::io::ErrorKind::AddrNotAvailable,
90 "DNS resolution returned no addresses",
91 )
92 .into());
93 }
94
95 let mut ipv6: Vec<SocketAddr> = Vec::new();
97 let mut ipv4: Vec<SocketAddr> = Vec::new();
98 for addr in addrs {
99 if addr.is_ipv6() {
100 ipv6.push(addr);
101 } else {
102 ipv4.push(addr);
103 }
104 }
105
106 if ipv4.is_empty() {
108 return try_addrs_sequential(&ipv6).await;
109 }
110 if ipv6.is_empty() {
111 return try_addrs_sequential(&ipv4).await;
112 }
113
114 let ipv6_first = ipv6[0];
117 let ipv4_first = ipv4[0];
118
119 let ipv6_fut = TcpStream::connect(ipv6_first);
121 tokio::pin!(ipv6_fut);
122
123 tokio::select! {
125 result = &mut ipv6_fut => {
126 return match result {
127 Ok(stream) => Ok(stream),
128 Err(_) => try_addrs_sequential(&ipv4).await,
130 }
131 }
132 _ = tokio::time::sleep(HAPPY_EYEBALLS_DELAY) => {
133 }
135 }
136
137 let ipv4_fut = TcpStream::connect(ipv4_first);
140 tokio::pin!(ipv4_fut);
141
142 let mut ipv6_done = false;
143 let mut ipv4_done = false;
144
145 loop {
146 tokio::select! {
147 result = &mut ipv6_fut, if !ipv6_done => {
148 match result {
149 Ok(stream) => return Ok(stream),
150 Err(_) => { ipv6_done = true; }
151 }
152 }
153 result = &mut ipv4_fut, if !ipv4_done => {
154 match result {
155 Ok(stream) => return Ok(stream),
156 Err(_) => { ipv4_done = true; }
157 }
158 }
159 }
160 if ipv6_done && ipv4_done {
161 break;
162 }
163 }
164
165 let mut remaining = Vec::new();
168 let ipv4_remaining = ipv4.iter().skip(1);
169 let ipv6_remaining = ipv6.iter().skip(1);
170
171 let mut ipv4_iter = ipv4_remaining.peekable();
172 let mut ipv6_iter = ipv6_remaining.peekable();
173
174 while ipv4_iter.peek().is_some() || ipv6_iter.peek().is_some() {
176 if let Some(addr) = ipv4_iter.next() {
177 remaining.push(*addr);
178 }
179 if let Some(addr) = ipv6_iter.next() {
180 remaining.push(*addr);
181 }
182 }
183
184 try_addrs_sequential(&remaining).await
185}
186
187async fn try_addrs_sequential(addrs: &[SocketAddr]) -> Result<TcpStream, Error> {
189 let mut last_err = None;
190 for addr in addrs {
191 match TcpStream::connect(addr).await {
192 Ok(stream) => return Ok(stream),
193 Err(e) => last_err = Some(e),
194 }
195 }
196 Err(last_err
197 .unwrap_or_else(|| {
198 std::io::Error::new(
199 std::io::ErrorKind::AddrNotAvailable,
200 "no addresses to connect to",
201 )
202 })
203 .into())
204}
205
206#[cfg(feature = "socks")]
207async fn connect_proxy(
208 url: &Url,
209 proxy: SocketAddr,
210 timeout: Duration,
211) -> Result<WebSocket, Error> {
212 let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
213 let port: u16 = url
214 .port_or_known_default()
215 .ok_or_else(Error::invalid_port)?;
216 let addr: String = format!("{host}:{port}");
217
218 let conn: TcpStream = TcpSocks5Stream::connect(proxy, addr).await?;
219 let (stream, _) = Box::pin(time::timeout(
222 timeout,
223 tokio_tungstenite::client_async_tls(url.as_str(), conn),
224 ))
225 .await
226 .map_err(|_| Error::Timeout)??;
227 Ok(WebSocket::Tokio(stream))
228}
229
230#[cfg(feature = "tor")]
231async fn connect_tor(
232 url: &Url,
233 timeout: Duration,
234 custom_path: Option<&PathBuf>,
235) -> Result<WebSocket, Error> {
236 let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
237 let port: u16 = url
238 .port_or_known_default()
239 .ok_or_else(Error::invalid_port)?;
240
241 let conn: DataStream = tor::connect(host, port, custom_path).await?;
242 let (stream, _) = Box::pin(time::timeout(
245 timeout,
246 tokio_tungstenite::client_async_tls(url.as_str(), conn),
247 ))
248 .await
249 .map_err(|_| Error::Timeout)??;
250 Ok(WebSocket::Tor(stream))
251}
252
253#[inline]
254pub async fn accept<S>(raw_stream: S) -> Result<WebSocketStream<S>, Error>
255where
256 S: AsyncRead + AsyncWrite + Unpin,
257{
258 Ok(tokio_tungstenite::accept_async(raw_stream).await?)
259}
260
261#[inline]
265pub async fn take_upgraded<S>(raw_stream: S) -> WebSocketStream<S>
266where
267 S: AsyncRead + AsyncWrite + Unpin,
268{
269 WebSocketStream::from_raw_socket(raw_stream, Role::Server, None).await
270}