Skip to main content

async_wsocket/native/
mod.rs

1// Copyright (c) 2022-2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Native
5
6use std::net::SocketAddr;
7#[cfg(feature = "tor")]
8use std::path::PathBuf;
9use std::time::Duration;
10
11#[cfg(feature = "tor")]
12use arti_client::DataStream;
13use tokio::io::{AsyncRead, AsyncWrite};
14use tokio::net::TcpStream;
15use tokio::time;
16use tokio_tungstenite::tungstenite::protocol::Role;
17pub use tokio_tungstenite::tungstenite::Message;
18pub use tokio_tungstenite::WebSocketStream;
19use url::Url;
20
21mod error;
22#[cfg(feature = "socks")]
23mod socks;
24#[cfg(feature = "tor")]
25pub mod tor;
26
27pub use self::error::Error;
28#[cfg(feature = "socks")]
29use self::socks::TcpSocks5Stream;
30use crate::socket::WebSocket;
31use crate::ConnectionMode;
32
33pub async fn connect(
34    url: &Url,
35    mode: &ConnectionMode,
36    timeout: Duration,
37) -> Result<WebSocket, Error> {
38    match mode {
39        ConnectionMode::Direct => connect_direct(url, timeout).await,
40        #[cfg(feature = "socks")]
41        ConnectionMode::Proxy(proxy) => connect_proxy(url, *proxy, timeout).await,
42        #[cfg(feature = "tor")]
43        ConnectionMode::Tor { custom_path } => {
44            connect_tor(url, timeout, custom_path.as_ref()).await
45        }
46    }
47}
48
49/// Happy Eyeballs connection delay (RFC 8305).
50const HAPPY_EYEBALLS_DELAY: Duration = Duration::from_millis(250);
51
52async fn connect_direct(url: &Url, timeout: Duration) -> Result<WebSocket, Error> {
53    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
54    let port: u16 = url
55        .port_or_known_default()
56        .ok_or_else(Error::invalid_port)?;
57
58    let conn_fut = async {
59        let tcp_stream = happy_eyeballs_connect(host, port).await?;
60
61        // NOT REMOVE `Box::pin`!
62        // Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
63        let (stream, _) = Box::pin(tokio_tungstenite::client_async_tls(
64            url.as_str(),
65            tcp_stream,
66        ))
67        .await?;
68
69        Ok::<_, Error>(stream)
70    };
71
72    let stream = time::timeout(timeout, conn_fut)
73        .await
74        .map_err(|_| Error::Timeout)??;
75
76    Ok(WebSocket::Tokio(stream))
77}
78
79/// Connect to a host using the Happy Eyeballs algorithm (RFC 8305).
80///
81/// When DNS returns both IPv6 and IPv4 addresses, tries the preferred family
82/// first and starts the other family after a 250ms delay if the first hasn't
83/// connected yet. Uses whichever connection succeeds first.
84async fn happy_eyeballs_connect(host: &str, port: u16) -> Result<TcpStream, Error> {
85    let addrs: Vec<SocketAddr> = tokio::net::lookup_host((host, port)).await?.collect();
86
87    if addrs.is_empty() {
88        return Err(std::io::Error::new(
89            std::io::ErrorKind::AddrNotAvailable,
90            "DNS resolution returned no addresses",
91        )
92        .into());
93    }
94
95    // Separate into IPv6 and IPv4, preserving order within each group
96    let mut ipv6: Vec<SocketAddr> = Vec::new();
97    let mut ipv4: Vec<SocketAddr> = Vec::new();
98    for addr in addrs {
99        if addr.is_ipv6() {
100            ipv6.push(addr);
101        } else {
102            ipv4.push(addr);
103        }
104    }
105
106    // If only one family, try addresses sequentially
107    if ipv4.is_empty() {
108        return try_addrs_sequential(&ipv6).await;
109    }
110    if ipv6.is_empty() {
111        return try_addrs_sequential(&ipv4).await;
112    }
113
114    // Both families available: Happy Eyeballs
115    // Try first IPv6 address, after delay start first IPv4 in parallel
116    let ipv6_first = ipv6[0];
117    let ipv4_first = ipv4[0];
118
119    // Pin the IPv6 future so it survives across select boundaries
120    let ipv6_fut = TcpStream::connect(ipv6_first);
121    tokio::pin!(ipv6_fut);
122
123    // Phase 1: Give IPv6 a 250ms head start
124    tokio::select! {
125        result = &mut ipv6_fut => {
126            return match result {
127                Ok(stream) => Ok(stream),
128                // IPv6 failed fast, try IPv4 directly
129                Err(_) => try_addrs_sequential(&ipv4).await,
130            }
131        }
132        _ = tokio::time::sleep(HAPPY_EYEBALLS_DELAY) => {
133            // Timer fired, IPv6 still pending. Start IPv4 and race both.
134        }
135    }
136
137    // Phase 2: Race the still-pending IPv6 against a new IPv4 attempt.
138    // Use a loop so that if one fails, we keep waiting for the other.
139    let ipv4_fut = TcpStream::connect(ipv4_first);
140    tokio::pin!(ipv4_fut);
141
142    let mut ipv6_done = false;
143    let mut ipv4_done = false;
144
145    loop {
146        tokio::select! {
147            result = &mut ipv6_fut, if !ipv6_done => {
148                match result {
149                    Ok(stream) => return Ok(stream),
150                    Err(_) => { ipv6_done = true; }
151                }
152            }
153            result = &mut ipv4_fut, if !ipv4_done => {
154                match result {
155                    Ok(stream) => return Ok(stream),
156                    Err(_) => { ipv4_done = true; }
157                }
158            }
159        }
160        if ipv6_done && ipv4_done {
161            break;
162        }
163    }
164
165    // Both initial attempts failed, try remaining addresses sequentially
166    // Interleave remaining IPv4 and IPv6 per RFC 8305
167    let mut remaining = Vec::new();
168    let ipv4_remaining = ipv4.iter().skip(1);
169    let ipv6_remaining = ipv6.iter().skip(1);
170
171    let mut ipv4_iter = ipv4_remaining.peekable();
172    let mut ipv6_iter = ipv6_remaining.peekable();
173
174    // Interleave: take one from ipv4, then one from ipv6, alternating
175    while ipv4_iter.peek().is_some() || ipv6_iter.peek().is_some() {
176        if let Some(addr) = ipv4_iter.next() {
177            remaining.push(*addr);
178        }
179        if let Some(addr) = ipv6_iter.next() {
180            remaining.push(*addr);
181        }
182    }
183
184    try_addrs_sequential(&remaining).await
185}
186
187/// Try connecting to addresses sequentially, returning the first success.
188async fn try_addrs_sequential(addrs: &[SocketAddr]) -> Result<TcpStream, Error> {
189    let mut last_err = None;
190    for addr in addrs {
191        match TcpStream::connect(addr).await {
192            Ok(stream) => return Ok(stream),
193            Err(e) => last_err = Some(e),
194        }
195    }
196    Err(last_err
197        .unwrap_or_else(|| {
198            std::io::Error::new(
199                std::io::ErrorKind::AddrNotAvailable,
200                "no addresses to connect to",
201            )
202        })
203        .into())
204}
205
206#[cfg(feature = "socks")]
207async fn connect_proxy(
208    url: &Url,
209    proxy: SocketAddr,
210    timeout: Duration,
211) -> Result<WebSocket, Error> {
212    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
213    let port: u16 = url
214        .port_or_known_default()
215        .ok_or_else(Error::invalid_port)?;
216    let addr: String = format!("{host}:{port}");
217
218    let conn: TcpStream = TcpSocks5Stream::connect(proxy, addr).await?;
219    // NOT REMOVE `Box::pin`!
220    // Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
221    let (stream, _) = Box::pin(time::timeout(
222        timeout,
223        tokio_tungstenite::client_async_tls(url.as_str(), conn),
224    ))
225    .await
226    .map_err(|_| Error::Timeout)??;
227    Ok(WebSocket::Tokio(stream))
228}
229
230#[cfg(feature = "tor")]
231async fn connect_tor(
232    url: &Url,
233    timeout: Duration,
234    custom_path: Option<&PathBuf>,
235) -> Result<WebSocket, Error> {
236    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
237    let port: u16 = url
238        .port_or_known_default()
239        .ok_or_else(Error::invalid_port)?;
240
241    let conn: DataStream = tor::connect(host, port, custom_path).await?;
242    // NOT REMOVE `Box::pin`!
243    // Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
244    let (stream, _) = Box::pin(time::timeout(
245        timeout,
246        tokio_tungstenite::client_async_tls(url.as_str(), conn),
247    ))
248    .await
249    .map_err(|_| Error::Timeout)??;
250    Ok(WebSocket::Tor(stream))
251}
252
253#[inline]
254pub async fn accept<S>(raw_stream: S) -> Result<WebSocketStream<S>, Error>
255where
256    S: AsyncRead + AsyncWrite + Unpin,
257{
258    Ok(tokio_tungstenite::accept_async(raw_stream).await?)
259}
260
261/// Take an already upgraded websocket connection
262///
263/// Useful for when using [hyper] or [warp] or any other HTTP server
264#[inline]
265pub async fn take_upgraded<S>(raw_stream: S) -> WebSocketStream<S>
266where
267    S: AsyncRead + AsyncWrite + Unpin,
268{
269    WebSocketStream::from_raw_socket(raw_stream, Role::Server, None).await
270}