use std::fmt::Write as _;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use rustls::{ClientConfig, RootCertStore, pki_types::ServerName};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_rustls::{TlsConnector, client::TlsStream};
use url::Url;
use crate::{net::TcpStream, transport::TransportError};
const MAX_PROXY_RESPONSE_BYTES: usize = 16 * 1024;
#[derive(Debug)]
pub enum ProxiedStream {
Plain(TcpStream),
PlainOverTlsProxy(Box<TlsStream<TcpStream>>),
Tls(Box<TlsStream<TcpStream>>),
TlsOverTlsProxy(Box<TlsStream<TlsStream<TcpStream>>>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WsTarget {
pub host: String,
pub port: u16,
pub is_tls: bool,
}
impl WsTarget {
pub fn parse(url: &str) -> Result<Self, TransportError> {
let parsed =
Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
let is_tls = match parsed.scheme() {
"ws" => false,
"wss" => true,
other => {
return Err(TransportError::InvalidUrl(format!(
"expected ws:// or wss:// scheme, was {other}"
)));
}
};
let raw_host = parsed
.host_str()
.ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
let host = if raw_host.starts_with('[') && raw_host.ends_with(']') {
raw_host[1..raw_host.len() - 1].to_string()
} else {
raw_host.to_string()
};
let port = parsed.port().unwrap_or(if is_tls { 443 } else { 80 });
Ok(Self { host, port, is_tls })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProxyKind {
Http(ProxyTarget),
Unsupported {
scheme: String,
},
}
impl ProxyKind {
pub fn parse(url: &str) -> Result<Self, TransportError> {
let parsed =
Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
match parsed.scheme() {
"http" | "https" => ProxyTarget::parse(url).map(ProxyKind::Http),
scheme @ ("socks5" | "socks5h" | "socks4" | "socks4a") => {
if parsed.host_str().is_none_or(str::is_empty) {
return Err(TransportError::InvalidUrl(format!(
"proxy URL '{url}' is missing a host (did you mean {scheme}://...)?"
)));
}
Ok(Self::Unsupported {
scheme: scheme.to_string(),
})
}
other => Err(TransportError::InvalidUrl(format!(
"unsupported proxy scheme '{other}'; expected http:// or https://"
))),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProxyTarget {
pub host: String,
pub port: u16,
pub is_tls: bool,
pub auth_header: Option<String>,
}
impl ProxyTarget {
pub fn parse(url: &str) -> Result<Self, TransportError> {
let parsed =
Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
let is_tls = match parsed.scheme() {
"http" => false,
"https" => true,
"socks5" | "socks5h" | "socks4" | "socks4a" => {
return Err(TransportError::InvalidUrl(format!(
"SOCKS proxy scheme '{}' is not yet supported for WebSocket connections; \
use an http:// or https:// proxy",
parsed.scheme()
)));
}
other => {
return Err(TransportError::InvalidUrl(format!(
"unsupported proxy scheme '{other}'; expected http:// or https://"
)));
}
};
let raw_host = parsed
.host_str()
.ok_or_else(|| TransportError::InvalidUrl("proxy URL missing hostname".to_string()))?;
let host = if raw_host.starts_with('[') && raw_host.ends_with(']') {
raw_host[1..raw_host.len() - 1].to_string()
} else {
raw_host.to_string()
};
let port = parsed.port().unwrap_or(if is_tls { 443 } else { 80 });
let auth_header = if parsed.username().is_empty() {
None
} else {
let username = decode_userinfo(parsed.username());
let password = decode_userinfo(parsed.password().unwrap_or(""));
let credentials = format!("{username}:{password}");
Some(format!("Basic {}", BASE64.encode(credentials)))
};
Ok(Self {
host,
port,
is_tls,
auth_header,
})
}
}
fn decode_userinfo(value: &str) -> String {
let bytes = nautilus_core::string::urlencoding::decode_bytes(value.as_bytes());
String::from_utf8_lossy(&bytes).into_owned()
}
pub async fn tunnel_via_proxy(
target: &WsTarget,
proxy: &ProxyTarget,
) -> Result<ProxiedStream, TransportError> {
let tcp = TcpStream::connect((proxy.host.as_str(), proxy.port))
.await
.map_err(TransportError::Io)?;
if let Err(e) = tcp.set_nodelay(true) {
log::warn!("Failed to enable TCP_NODELAY on proxy connection: {e:?}");
}
if proxy.is_tls {
let proxy_tls = wrap_tls(tcp, &proxy.host).await?;
let tunneled = send_connect(proxy_tls, target, proxy).await?;
if target.is_tls {
let upstream = wrap_tls(tunneled, &target.host).await?;
Ok(ProxiedStream::TlsOverTlsProxy(Box::new(upstream)))
} else {
Ok(ProxiedStream::PlainOverTlsProxy(Box::new(tunneled)))
}
} else {
let tunneled = send_connect(tcp, target, proxy).await?;
if target.is_tls {
let upstream = wrap_tls(tunneled, &target.host).await?;
Ok(ProxiedStream::Tls(Box::new(upstream)))
} else {
Ok(ProxiedStream::Plain(tunneled))
}
}
}
async fn send_connect<S>(
mut stream: S,
target: &WsTarget,
proxy: &ProxyTarget,
) -> Result<S, TransportError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let host_header = format_host_header(&target.host, target.port);
let mut request = format!(
"CONNECT {host_header} HTTP/1.1\r\n\
Host: {host_header}\r\n\
Proxy-Connection: Keep-Alive\r\n"
);
if let Some(auth) = &proxy.auth_header {
write!(request, "Proxy-Authorization: {auth}\r\n").expect("writing to String never fails");
}
request.push_str("\r\n");
stream
.write_all(request.as_bytes())
.await
.map_err(TransportError::Io)?;
stream.flush().await.map_err(TransportError::Io)?;
read_connect_response(&mut stream).await?;
Ok(stream)
}
fn format_host_header(host: &str, port: u16) -> String {
if host.contains(':') && !(host.starts_with('[') && host.ends_with(']')) {
format!("[{host}]:{port}")
} else {
format!("{host}:{port}")
}
}
async fn read_connect_response<S>(stream: &mut S) -> Result<(), TransportError>
where
S: AsyncRead + Unpin,
{
let mut buf = Vec::with_capacity(512);
let mut byte = [0u8; 1];
loop {
let n = stream.read(&mut byte).await.map_err(TransportError::Io)?;
if n == 0 {
return Err(TransportError::Handshake(
"proxy closed connection before sending CONNECT response".to_string(),
));
}
buf.push(byte[0]);
if buf.ends_with(b"\r\n\r\n") {
break;
}
if buf.len() > MAX_PROXY_RESPONSE_BYTES {
return Err(TransportError::Handshake(format!(
"proxy CONNECT response exceeded {MAX_PROXY_RESPONSE_BYTES} bytes without terminator"
)));
}
}
let text = std::str::from_utf8(&buf).map_err(|_| {
TransportError::Handshake("proxy CONNECT response was not valid UTF-8".to_string())
})?;
let status_line = text.lines().next().ok_or_else(|| {
TransportError::Handshake("proxy CONNECT response missing status line".to_string())
})?;
let mut parts = status_line.splitn(3, ' ');
let _version = parts.next().ok_or_else(|| {
TransportError::Handshake(format!("malformed status line: {status_line}"))
})?;
let status_code = parts
.next()
.ok_or_else(|| TransportError::Handshake(format!("malformed status line: {status_line}")))?
.parse::<u16>()
.map_err(|_| TransportError::Handshake(format!("non-numeric status: {status_line}")))?;
if !(200..300).contains(&status_code) {
return Err(TransportError::Handshake(format!(
"proxy refused CONNECT: {status_line}"
)));
}
Ok(())
}
async fn wrap_tls<S>(stream: S, server_name: &str) -> Result<TlsStream<S>, TransportError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = TlsConnector::from(std::sync::Arc::new(config));
let domain = ServerName::try_from(server_name.to_string())
.map_err(|e| TransportError::Tls(format!("invalid DNS name '{server_name}': {e}")))?;
connector
.connect(domain, stream)
.await
.map_err(TransportError::Io)
}
#[cfg(test)]
#[cfg(not(feature = "turmoil"))] mod tests {
use std::net::SocketAddr;
use rstest::rstest;
use tokio::net::TcpListener;
use super::*;
#[rstest]
fn ws_target_parses_wss() {
let target = WsTarget::parse("wss://stream.binance.com:9443/ws/btcusdt@trade").unwrap();
assert_eq!(target.host, "stream.binance.com");
assert_eq!(target.port, 9443);
assert!(target.is_tls);
}
#[rstest]
fn ws_target_default_ports() {
let plain = WsTarget::parse("ws://example.com/path").unwrap();
assert_eq!(plain.port, 80);
assert!(!plain.is_tls);
let tls = WsTarget::parse("wss://example.com/path").unwrap();
assert_eq!(tls.port, 443);
assert!(tls.is_tls);
}
#[rstest]
fn ws_target_strips_ipv6_brackets() {
let target = WsTarget::parse("wss://[::1]:9443/ws").unwrap();
assert_eq!(target.host, "::1");
assert_eq!(target.port, 9443);
}
#[rstest]
fn ws_target_rejects_non_ws_scheme() {
let err = WsTarget::parse("https://example.com").unwrap_err();
assert!(matches!(err, TransportError::InvalidUrl(_)));
}
#[rstest]
fn proxy_target_parses_http() {
let proxy = ProxyTarget::parse("http://127.0.0.1:9999").unwrap();
assert_eq!(proxy.host, "127.0.0.1");
assert_eq!(proxy.port, 9999);
assert!(!proxy.is_tls);
assert!(proxy.auth_header.is_none());
}
#[rstest]
fn proxy_target_default_ports() {
let plain = ProxyTarget::parse("http://proxy.example.com").unwrap();
assert_eq!(plain.port, 80);
let tls = ProxyTarget::parse("https://proxy.example.com").unwrap();
assert_eq!(tls.port, 443);
assert!(tls.is_tls);
}
#[rstest]
fn proxy_target_basic_auth() {
let proxy =
ProxyTarget::parse("http://proxytest:fixture42@proxy.example.com:8080").unwrap();
assert_eq!(
proxy.auth_header.unwrap(),
"Basic cHJveHl0ZXN0OmZpeHR1cmU0Mg=="
);
}
#[rstest]
fn proxy_target_basic_auth_decodes_percent_encoded() {
let proxy = ProxyTarget::parse("http://us%2Fer:p%40ss@proxy.example.com:8080").unwrap();
let header = proxy.auth_header.unwrap();
assert_eq!(header, "Basic dXMvZXI6cEBzcw==");
}
#[rstest]
fn proxy_target_strips_ipv6_brackets() {
let proxy = ProxyTarget::parse("http://[::1]:8080").unwrap();
assert_eq!(proxy.host, "::1");
assert_eq!(proxy.port, 8080);
}
#[rstest]
fn proxy_target_rejects_socks() {
let err = ProxyTarget::parse("socks5://127.0.0.1:1080").unwrap_err();
let TransportError::InvalidUrl(msg) = err else {
panic!("expected InvalidUrl");
};
assert!(msg.contains("SOCKS"));
}
#[rstest]
fn proxy_kind_classifies_http() {
let kind = ProxyKind::parse("http://127.0.0.1:9999").unwrap();
assert!(matches!(kind, ProxyKind::Http(_)));
}
#[rstest]
fn proxy_kind_classifies_socks_as_unsupported() {
let kind = ProxyKind::parse("socks5://127.0.0.1:1080").unwrap();
let ProxyKind::Unsupported { scheme } = kind else {
panic!("expected Unsupported");
};
assert_eq!(scheme, "socks5");
}
#[rstest]
fn proxy_kind_rejects_garbage() {
assert!(ProxyKind::parse("ftp://x").is_err());
assert!(ProxyKind::parse("").is_err());
}
#[rstest]
fn proxy_kind_rejects_socks_without_authority() {
let err = ProxyKind::parse("socks5:127.0.0.1:1080").unwrap_err();
assert!(matches!(err, TransportError::InvalidUrl(_)));
}
#[rstest]
fn proxy_target_rejects_unknown_scheme() {
let err = ProxyTarget::parse("ftp://proxy.example.com").unwrap_err();
assert!(matches!(err, TransportError::InvalidUrl(_)));
}
#[rstest]
fn proxy_target_rejects_empty() {
let err = ProxyTarget::parse("").unwrap_err();
assert!(matches!(err, TransportError::InvalidUrl(_)));
}
#[rstest]
fn host_header_brackets_ipv6() {
assert_eq!(format_host_header("example.com", 443), "example.com:443");
assert_eq!(format_host_header("::1", 443), "[::1]:443");
assert_eq!(format_host_header("[::1]", 443), "[::1]:443");
}
async fn spawn_fake_proxy(response: &'static [u8]) -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
loop {
let n = AsyncReadExt::read(&mut stream, &mut buf).await.unwrap();
if n == 0 {
break;
}
if buf[..n].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
stream.write_all(response).await.unwrap();
stream.flush().await.unwrap();
});
addr
}
#[tokio::test]
async fn read_connect_response_accepts_2xx() {
let addr = spawn_fake_proxy(b"HTTP/1.1 200 Connection established\r\n\r\n").await;
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
.await
.unwrap();
stream.flush().await.unwrap();
read_connect_response(&mut stream).await.unwrap();
}
#[tokio::test]
async fn read_connect_response_rejects_403() {
let addr = spawn_fake_proxy(b"HTTP/1.1 403 Forbidden\r\n\r\n").await;
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
.await
.unwrap();
stream.flush().await.unwrap();
let err = read_connect_response(&mut stream).await.unwrap_err();
let TransportError::Handshake(msg) = err else {
panic!("expected Handshake error");
};
assert!(msg.contains("403"));
}
#[rstest]
#[case::status_300(&b"HTTP/1.1 300 Multiple Choices\r\n\r\n"[..], "300")]
#[case::status_407(
&b"HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic\r\n\r\n"[..],
"407",
)]
#[case::malformed_status(&b"HTTP/1.1 abc Boom\r\n\r\n"[..], "non-numeric")]
#[tokio::test]
async fn read_connect_response_rejects_non_2xx(
#[case] response: &'static [u8],
#[case] expected_msg_substring: &'static str,
) {
let addr = spawn_fake_proxy(response).await;
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
.await
.unwrap();
stream.flush().await.unwrap();
let err = read_connect_response(&mut stream).await.unwrap_err();
let TransportError::Handshake(msg) = err else {
panic!("expected Handshake error, was {err:?}");
};
assert!(
msg.contains(expected_msg_substring),
"expected error message to contain {expected_msg_substring:?}, was {msg:?}"
);
}
#[tokio::test]
async fn read_connect_response_rejects_eof_before_terminator() {
let addr = spawn_fake_proxy(b"HTTP/1.1 200 OK\r\n").await;
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
.await
.unwrap();
stream.flush().await.unwrap();
let err = read_connect_response(&mut stream).await.unwrap_err();
let TransportError::Handshake(msg) = err else {
panic!("expected Handshake error, was {err:?}");
};
assert!(
msg.contains("closed connection"),
"unexpected handshake error: {msg}"
);
}
#[tokio::test]
async fn read_connect_response_rejects_oversize_headers() {
let mut response = b"HTTP/1.1 200 OK\r\n".to_vec();
while response.len() <= MAX_PROXY_RESPONSE_BYTES {
response.extend_from_slice(b"X-Pad: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\r\n");
}
let leaked: &'static [u8] = response.leak();
let addr = spawn_fake_proxy(leaked).await;
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
.await
.unwrap();
stream.flush().await.unwrap();
let err = read_connect_response(&mut stream).await.unwrap_err();
let TransportError::Handshake(msg) = err else {
panic!("expected Handshake error, was {err:?}");
};
assert!(
msg.contains("exceeded"),
"unexpected handshake error: {msg}"
);
}
#[tokio::test]
async fn read_connect_response_preserves_trailing_bytes() {
let addr = spawn_fake_proxy(b"HTTP/1.1 200 Connection established\r\n\r\nLEFTOVER").await;
let mut stream = TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
.await
.unwrap();
stream.flush().await.unwrap();
read_connect_response(&mut stream).await.unwrap();
let mut tail = [0u8; b"LEFTOVER".len()];
AsyncReadExt::read_exact(&mut stream, &mut tail)
.await
.unwrap();
assert_eq!(&tail, b"LEFTOVER");
}
}