async_wsocket/native/
mod.rs

1// Copyright (c) 2022-2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Native
5
6#[cfg(feature = "socks")]
7use std::net::SocketAddr;
8#[cfg(feature = "tor")]
9use std::path::PathBuf;
10use std::time::Duration;
11
12#[cfg(feature = "tor")]
13use arti_client::DataStream;
14use tokio::io::{AsyncRead, AsyncWrite};
15#[cfg(feature = "socks")]
16use tokio::net::TcpStream;
17use tokio::time;
18use tokio_tungstenite::tungstenite::protocol::Role;
19pub use tokio_tungstenite::tungstenite::Message;
20pub use tokio_tungstenite::WebSocketStream;
21use url::Url;
22
23mod error;
24#[cfg(feature = "socks")]
25mod socks;
26#[cfg(feature = "tor")]
27pub mod tor;
28
29pub use self::error::Error;
30#[cfg(feature = "socks")]
31use self::socks::TcpSocks5Stream;
32use crate::socket::WebSocket;
33use crate::ConnectionMode;
34
35pub async fn connect(
36    url: &Url,
37    mode: &ConnectionMode,
38    timeout: Duration,
39) -> Result<WebSocket, Error> {
40    match mode {
41        ConnectionMode::Direct => connect_direct(url, timeout).await,
42        #[cfg(feature = "socks")]
43        ConnectionMode::Proxy(proxy) => connect_proxy(url, *proxy, timeout).await,
44        #[cfg(feature = "tor")]
45        ConnectionMode::Tor { custom_path } => {
46            connect_tor(url, timeout, custom_path.as_ref()).await
47        }
48    }
49}
50
51async fn connect_direct(url: &Url, timeout: Duration) -> Result<WebSocket, Error> {
52    // NOT REMOVE `Box::pin`!
53    // Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
54    let (stream, _) = Box::pin(time::timeout(
55        timeout,
56        tokio_tungstenite::connect_async(url.as_str()),
57    ))
58    .await
59    .map_err(|_| Error::Timeout)??;
60    Ok(WebSocket::Tokio(stream))
61}
62
63#[cfg(feature = "socks")]
64async fn connect_proxy(
65    url: &Url,
66    proxy: SocketAddr,
67    timeout: Duration,
68) -> Result<WebSocket, Error> {
69    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
70    let port: u16 = url
71        .port_or_known_default()
72        .ok_or_else(Error::invalid_port)?;
73    let addr: String = format!("{host}:{port}");
74
75    let conn: TcpStream = TcpSocks5Stream::connect(proxy, addr).await?;
76    // NOT REMOVE `Box::pin`!
77    // Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
78    let (stream, _) = Box::pin(time::timeout(
79        timeout,
80        tokio_tungstenite::client_async_tls(url.as_str(), conn),
81    ))
82    .await
83    .map_err(|_| Error::Timeout)??;
84    Ok(WebSocket::Tokio(stream))
85}
86
87#[cfg(feature = "tor")]
88async fn connect_tor(
89    url: &Url,
90    timeout: Duration,
91    custom_path: Option<&PathBuf>,
92) -> Result<WebSocket, Error> {
93    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
94    let port: u16 = url
95        .port_or_known_default()
96        .ok_or_else(Error::invalid_port)?;
97
98    let conn: DataStream = tor::connect(host, port, custom_path).await?;
99    // NOT REMOVE `Box::pin`!
100    // Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
101    let (stream, _) = Box::pin(time::timeout(
102        timeout,
103        tokio_tungstenite::client_async_tls(url.as_str(), conn),
104    ))
105    .await
106    .map_err(|_| Error::Timeout)??;
107    Ok(WebSocket::Tor(stream))
108}
109
110#[inline]
111pub async fn accept<S>(raw_stream: S) -> Result<WebSocketStream<S>, Error>
112where
113    S: AsyncRead + AsyncWrite + Unpin,
114{
115    Ok(tokio_tungstenite::accept_async(raw_stream).await?)
116}
117
118/// Take an already upgraded websocket connection
119///
120/// Useful for when using [hyper] or [warp] or any other HTTP server
121#[inline]
122pub async fn take_upgraded<S>(raw_stream: S) -> WebSocketStream<S>
123where
124    S: AsyncRead + AsyncWrite + Unpin,
125{
126    WebSocketStream::from_raw_socket(raw_stream, Role::Server, None).await
127}