Skip to main content

async_wsocket/native/
mod.rs

1// Copyright (c) 2022-2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Native
5
6#[cfg(feature = "socks")]
7use std::net::SocketAddr;
8
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::net::TcpStream;
11use tokio_tungstenite::tungstenite::protocol::Role;
12pub use tokio_tungstenite::tungstenite::Message;
13use tokio_tungstenite::MaybeTlsStream;
14pub use tokio_tungstenite::WebSocketStream;
15use url::Url;
16
17mod error;
18#[cfg(feature = "socks")]
19mod socks;
20
21pub use self::error::Error;
22#[cfg(feature = "socks")]
23use self::socks::TcpSocks5Stream;
24use crate::socket::WebSocket;
25use crate::ConnectionMode;
26
27pub async fn connect(url: &Url, mode: &ConnectionMode) -> Result<WebSocket, Error> {
28    match mode {
29        ConnectionMode::Direct => connect_direct(url).await,
30        #[cfg(feature = "socks")]
31        ConnectionMode::Proxy(proxy) => connect_proxy(url, *proxy).await,
32    }
33}
34
35async fn connect_direct(url: &Url) -> Result<WebSocket, Error> {
36    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
37    let port: u16 = url
38        .port_or_known_default()
39        .ok_or_else(Error::invalid_port)?;
40
41    let host: String = format!("{}:{}", host, port);
42
43    let tcp_stream: TcpStream = tokio_happy_eyeballs::connect(host).await?;
44
45    connect_stream(url, tcp_stream).await
46}
47
48#[cfg(feature = "socks")]
49async fn connect_proxy(url: &Url, proxy: SocketAddr) -> Result<WebSocket, Error> {
50    let host: &str = url.host_str().ok_or_else(Error::empty_host)?;
51    let port: u16 = url
52        .port_or_known_default()
53        .ok_or_else(Error::invalid_port)?;
54    let addr: String = format!("{host}:{port}");
55
56    let conn: TcpStream = TcpSocks5Stream::connect(proxy, addr).await?;
57    connect_stream(url, conn).await
58}
59
60async fn connect_stream(url: &Url, stream: TcpStream) -> Result<WebSocket, Error> {
61    let stream = client_async(url, stream).await?;
62    Ok(WebSocket::tokio(Box::new(stream)))
63}
64
65// NOT REMOVE `Box::pin`!
66// Use `Box::pin` to fix stack overflow on windows targets due to large `Future`
67#[cfg(any(
68    feature = "native-tls",
69    feature = "native-tls-vendored",
70    feature = "rustls-tls-native-roots",
71    feature = "rustls-tls-webpki-roots"
72))]
73async fn client_async(
74    url: &Url,
75    stream: TcpStream,
76) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Error> {
77    let (stream, _) = Box::pin(tokio_tungstenite::client_async_tls(url.as_str(), stream)).await?;
78    Ok(stream)
79}
80
81#[cfg(not(any(
82    feature = "native-tls",
83    feature = "native-tls-vendored",
84    feature = "rustls-tls-native-roots",
85    feature = "rustls-tls-webpki-roots"
86)))]
87async fn client_async(
88    url: &Url,
89    stream: TcpStream,
90) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Error> {
91    if url.scheme() == "wss" {
92        return Err(tokio_tungstenite::tungstenite::Error::Url(
93            tokio_tungstenite::tungstenite::error::UrlError::TlsFeatureNotEnabled,
94        )
95        .into());
96    }
97
98    let (stream, _) = Box::pin(tokio_tungstenite::client_async(
99        url.as_str(),
100        MaybeTlsStream::Plain(stream),
101    ))
102    .await?;
103    Ok(stream)
104}
105
106#[inline]
107pub async fn accept<S>(raw_stream: S) -> Result<WebSocketStream<S>, Error>
108where
109    S: AsyncRead + AsyncWrite + Unpin,
110{
111    Ok(tokio_tungstenite::accept_async(raw_stream).await?)
112}
113
114/// Take an already upgraded websocket connection
115///
116/// Useful for when using [hyper] or [warp] or any other HTTP server
117#[inline]
118pub async fn take_upgraded<S>(raw_stream: S) -> WebSocketStream<S>
119where
120    S: AsyncRead + AsyncWrite + Unpin,
121{
122    WebSocketStream::from_raw_socket(raw_stream, Role::Server, None).await
123}