use crate::proxy::{ProxyScheme, UpstreamProxy};
use anyhow::{Context, Result, bail};
use std::{net::IpAddr, str};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::oneshot,
task::JoinHandle,
};
const MAX_HEADER_BYTES: usize = 64 * 1024;
pub struct ProxyBridge {
local_url: String,
shutdown: Option<oneshot::Sender<()>>,
task: JoinHandle<Result<()>>,
}
impl ProxyBridge {
pub async fn start(upstream: UpstreamProxy, listen_port: Option<u16>) -> Result<Self> {
let listener = TcpListener::bind(("127.0.0.1", listen_port.unwrap_or(0)))
.await
.context("failed to bind local bridge listener")?;
let local_addr = listener.local_addr()?;
let local_url = format!("http://{local_addr}");
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let task = tokio::spawn(run_server(listener, upstream, shutdown_rx));
Ok(Self {
local_url,
shutdown: Some(shutdown_tx),
task,
})
}
pub fn local_proxy_url(&self) -> String {
self.local_url.clone()
}
pub async fn shutdown(mut self) -> Result<()> {
if let Some(shutdown) = self.shutdown.take() {
let _ = shutdown.send(());
}
self.task
.await
.context("local proxy bridge task failed to join")?
}
}
async fn run_server(
listener: TcpListener,
upstream: UpstreamProxy,
mut shutdown: oneshot::Receiver<()>,
) -> Result<()> {
loop {
tokio::select! {
result = listener.accept() => {
let (client, peer) = result.context("failed to accept local proxy connection")?;
let upstream = upstream.clone();
tokio::spawn(async move {
if let Err(error) = handle_client(client, upstream).await {
tracing::debug!("local proxy connection from {peer} failed: {error:#}");
}
});
}
_ = &mut shutdown => {
return Ok(());
}
}
}
}
async fn handle_client(mut client: TcpStream, upstream: UpstreamProxy) -> Result<()> {
let request_bytes = read_http_request_head(&mut client).await?;
let header_end = find_header_end(&request_bytes).context("HTTP header terminator not found")?;
let (head, leftover) = request_bytes.split_at(header_end);
let request = parse_http_request(head)?;
match upstream.scheme() {
ProxyScheme::Http => {
let mut upstream_stream = TcpStream::connect((upstream.host(), upstream.port()))
.await
.with_context(|| {
format!(
"failed to connect upstream HTTP proxy {}",
upstream.authority()
)
})?;
let outgoing =
add_proxy_authorization(head, upstream.basic_proxy_authorization().as_deref());
upstream_stream.write_all(&outgoing).await?;
if !leftover.is_empty() {
upstream_stream.write_all(leftover).await?;
}
tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
}
ProxyScheme::Socks5 => {
if !request.method.eq_ignore_ascii_case("CONNECT") {
write_proxy_error(
&mut client,
501,
"Only CONNECT is supported for SOCKS upstreams",
)
.await?;
bail!("non-CONNECT request is not supported for SOCKS upstreams");
}
let (target_host, target_port) = parse_host_port(&request.target)?;
let mut upstream_stream =
connect_via_socks5(&upstream, &target_host, target_port).await?;
client
.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
.await?;
if !leftover.is_empty() {
upstream_stream.write_all(leftover).await?;
}
tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
}
}
Ok(())
}
async fn read_http_request_head(stream: &mut TcpStream) -> Result<Vec<u8>> {
let mut buffer = Vec::with_capacity(4096);
let mut chunk = [0_u8; 2048];
loop {
let read = stream.read(&mut chunk).await?;
if read == 0 {
bail!("connection closed before HTTP header was complete");
}
buffer.extend_from_slice(&chunk[..read]);
if find_header_end(&buffer).is_some() {
return Ok(buffer);
}
if buffer.len() > MAX_HEADER_BYTES {
bail!("HTTP proxy request header is too large");
}
}
}
#[derive(Debug, Eq, PartialEq)]
struct HttpRequest {
method: String,
target: String,
}
fn parse_http_request(head: &[u8]) -> Result<HttpRequest> {
let text = str::from_utf8(head).context("HTTP request header is not valid UTF-8")?;
let first_line = text.lines().next().context("HTTP request is empty")?;
let mut parts = first_line.split_whitespace();
let method = parts.next().context("HTTP request is missing method")?;
let target = parts.next().context("HTTP request is missing target")?;
let version = parts.next().context("HTTP request is missing version")?;
if !version.starts_with("HTTP/") {
bail!("invalid HTTP proxy request version: {version}");
}
Ok(HttpRequest {
method: method.to_string(),
target: target.to_string(),
})
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|index| index + 4)
}
fn add_proxy_authorization(head: &[u8], authorization: Option<&str>) -> Vec<u8> {
let Some(authorization) = authorization else {
return head.to_vec();
};
let text = String::from_utf8_lossy(head);
if text
.to_ascii_lowercase()
.contains("\r\nproxy-authorization:")
{
return head.to_vec();
}
let Some(insert_at) = text.rfind("\r\n\r\n") else {
return head.to_vec();
};
let mut outgoing = Vec::with_capacity(head.len() + authorization.len() + 24);
outgoing.extend_from_slice(&head[..insert_at]);
outgoing.extend_from_slice(format!("\r\nProxy-Authorization: {authorization}").as_bytes());
outgoing.extend_from_slice(&head[insert_at..]);
outgoing
}
fn parse_host_port(value: &str) -> Result<(String, u16)> {
if let Some(rest) = value.strip_prefix('[') {
let (host, tail) = rest
.split_once(']')
.context("invalid bracketed IPv6 CONNECT target")?;
let port = tail
.strip_prefix(':')
.context("IPv6 CONNECT target is missing port")?
.parse()
.context("invalid CONNECT target port")?;
return Ok((host.to_string(), port));
}
let (host, port) = value
.rsplit_once(':')
.context("CONNECT target must be host:port")?;
if host.is_empty() {
bail!("CONNECT target host cannot be empty");
}
Ok((
host.to_string(),
port.parse().context("invalid CONNECT target port")?,
))
}
async fn connect_via_socks5(
proxy: &UpstreamProxy,
target_host: &str,
target_port: u16,
) -> Result<TcpStream> {
let mut stream = TcpStream::connect((proxy.host(), proxy.port()))
.await
.with_context(|| {
format!(
"failed to connect upstream SOCKS5 proxy {}",
proxy.authority()
)
})?;
if proxy.has_auth() {
stream.write_all(&[0x05, 0x02, 0x00, 0x02]).await?;
} else {
stream.write_all(&[0x05, 0x01, 0x00]).await?;
}
let mut method_response = [0_u8; 2];
stream.read_exact(&mut method_response).await?;
if method_response[0] != 0x05 {
bail!("invalid SOCKS5 method response");
}
match method_response[1] {
0x00 => {}
0x02 => authenticate_socks5(proxy, &mut stream).await?,
0xff => bail!("SOCKS5 proxy rejected all authentication methods"),
method => bail!("SOCKS5 proxy selected unsupported authentication method {method:#x}"),
}
let request = build_socks5_connect_request(target_host, target_port)?;
stream.write_all(&request).await?;
let mut response = [0_u8; 4];
stream.read_exact(&mut response).await?;
if response[0] != 0x05 {
bail!("invalid SOCKS5 connect response");
}
if response[1] != 0x00 {
bail!("SOCKS5 connect failed with code {:#x}", response[1]);
}
read_socks5_bound_address(&mut stream, response[3]).await?;
Ok(stream)
}
async fn authenticate_socks5(proxy: &UpstreamProxy, stream: &mut TcpStream) -> Result<()> {
let username = proxy.username().unwrap_or_default().as_bytes();
let password = proxy.password().unwrap_or_default().as_bytes();
if username.len() > u8::MAX as usize || password.len() > u8::MAX as usize {
bail!("SOCKS5 username and password must be at most 255 bytes");
}
let mut request = Vec::with_capacity(username.len() + password.len() + 3);
request.push(0x01);
request.push(username.len() as u8);
request.extend_from_slice(username);
request.push(password.len() as u8);
request.extend_from_slice(password);
stream.write_all(&request).await?;
let mut response = [0_u8; 2];
stream.read_exact(&mut response).await?;
if response != [0x01, 0x00] {
bail!("SOCKS5 username/password authentication failed");
}
Ok(())
}
fn build_socks5_connect_request(target_host: &str, target_port: u16) -> Result<Vec<u8>> {
let mut request = vec![0x05, 0x01, 0x00];
match target_host.parse::<IpAddr>() {
Ok(IpAddr::V4(address)) => {
request.push(0x01);
request.extend_from_slice(&address.octets());
}
Ok(IpAddr::V6(address)) => {
request.push(0x04);
request.extend_from_slice(&address.octets());
}
Err(_) => {
let host = target_host.as_bytes();
if host.len() > u8::MAX as usize {
bail!("SOCKS5 target host is too long");
}
request.push(0x03);
request.push(host.len() as u8);
request.extend_from_slice(host);
}
}
request.extend_from_slice(&target_port.to_be_bytes());
Ok(request)
}
async fn read_socks5_bound_address(stream: &mut TcpStream, address_type: u8) -> Result<()> {
match address_type {
0x01 => {
let mut buffer = [0_u8; 4 + 2];
stream.read_exact(&mut buffer).await?;
}
0x03 => {
let mut length = [0_u8; 1];
stream.read_exact(&mut length).await?;
let mut buffer = vec![0_u8; length[0] as usize + 2];
stream.read_exact(&mut buffer).await?;
}
0x04 => {
let mut buffer = [0_u8; 16 + 2];
stream.read_exact(&mut buffer).await?;
}
other => bail!("invalid SOCKS5 address type {other:#x}"),
}
Ok(())
}
async fn write_proxy_error(stream: &mut TcpStream, code: u16, message: &str) -> Result<()> {
let response = format!(
"HTTP/1.1 {code} {message}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{message}",
message.len()
);
stream.write_all(response.as_bytes()).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_connect_targets() {
assert_eq!(
parse_host_port("discord.com:443").unwrap(),
("discord.com".to_string(), 443)
);
assert_eq!(
parse_host_port("[::1]:443").unwrap(),
("::1".to_string(), 443)
);
}
#[test]
fn injects_proxy_authorization_header() {
let head = b"CONNECT discord.com:443 HTTP/1.1\r\nHost: discord.com:443\r\n\r\n";
let outgoing = add_proxy_authorization(head, Some("Basic abc"));
let text = String::from_utf8(outgoing).unwrap();
assert!(text.contains("\r\nProxy-Authorization: Basic abc\r\n"));
assert!(text.ends_with("\r\n\r\n"));
}
#[test]
fn does_not_duplicate_proxy_authorization_header() {
let head = b"CONNECT discord.com:443 HTTP/1.1\r\nProxy-Authorization: Basic old\r\n\r\n";
let outgoing = add_proxy_authorization(head, Some("Basic new"));
assert_eq!(outgoing, head);
}
#[test]
fn builds_domain_socks_connect_request() {
let request = build_socks5_connect_request("discord.com", 443).unwrap();
assert_eq!(&request[..5], &[0x05, 0x01, 0x00, 0x03, 11]);
assert_eq!(&request[5..16], b"discord.com");
assert_eq!(&request[16..], &443_u16.to_be_bytes());
}
}