discord-proxy 0.1.0

Windows-first Discord process-local proxy launcher
Documentation
use crate::proxy::{ProxyScheme, UpstreamProxy};
use anyhow::{Context, Result, bail};
use std::{net::IpAddr, str};
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    net::{TcpListener, TcpStream},
    sync::oneshot,
    task::JoinHandle,
};

const MAX_HEADER_BYTES: usize = 64 * 1024;

pub struct ProxyBridge {
    local_url: String,
    shutdown: Option<oneshot::Sender<()>>,
    task: JoinHandle<Result<()>>,
}

impl ProxyBridge {
    pub async fn start(upstream: UpstreamProxy, listen_port: Option<u16>) -> Result<Self> {
        let listener = TcpListener::bind(("127.0.0.1", listen_port.unwrap_or(0)))
            .await
            .context("failed to bind local bridge listener")?;
        let local_addr = listener.local_addr()?;
        let local_url = format!("http://{local_addr}");
        let (shutdown_tx, shutdown_rx) = oneshot::channel();
        let task = tokio::spawn(run_server(listener, upstream, shutdown_rx));

        Ok(Self {
            local_url,
            shutdown: Some(shutdown_tx),
            task,
        })
    }

    pub fn local_proxy_url(&self) -> String {
        self.local_url.clone()
    }

    pub async fn shutdown(mut self) -> Result<()> {
        if let Some(shutdown) = self.shutdown.take() {
            let _ = shutdown.send(());
        }

        self.task
            .await
            .context("local proxy bridge task failed to join")?
    }
}

async fn run_server(
    listener: TcpListener,
    upstream: UpstreamProxy,
    mut shutdown: oneshot::Receiver<()>,
) -> Result<()> {
    loop {
        tokio::select! {
            result = listener.accept() => {
                let (client, peer) = result.context("failed to accept local proxy connection")?;
                let upstream = upstream.clone();
                tokio::spawn(async move {
                    if let Err(error) = handle_client(client, upstream).await {
                        tracing::debug!("local proxy connection from {peer} failed: {error:#}");
                    }
                });
            }
            _ = &mut shutdown => {
                return Ok(());
            }
        }
    }
}

async fn handle_client(mut client: TcpStream, upstream: UpstreamProxy) -> Result<()> {
    let request_bytes = read_http_request_head(&mut client).await?;
    let header_end = find_header_end(&request_bytes).context("HTTP header terminator not found")?;
    let (head, leftover) = request_bytes.split_at(header_end);
    let request = parse_http_request(head)?;

    match upstream.scheme() {
        ProxyScheme::Http => {
            let mut upstream_stream = TcpStream::connect((upstream.host(), upstream.port()))
                .await
                .with_context(|| {
                    format!(
                        "failed to connect upstream HTTP proxy {}",
                        upstream.authority()
                    )
                })?;
            let outgoing =
                add_proxy_authorization(head, upstream.basic_proxy_authorization().as_deref());
            upstream_stream.write_all(&outgoing).await?;
            if !leftover.is_empty() {
                upstream_stream.write_all(leftover).await?;
            }
            tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
        }
        ProxyScheme::Socks5 => {
            if !request.method.eq_ignore_ascii_case("CONNECT") {
                write_proxy_error(
                    &mut client,
                    501,
                    "Only CONNECT is supported for SOCKS upstreams",
                )
                .await?;
                bail!("non-CONNECT request is not supported for SOCKS upstreams");
            }

            let (target_host, target_port) = parse_host_port(&request.target)?;
            let mut upstream_stream =
                connect_via_socks5(&upstream, &target_host, target_port).await?;
            client
                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
                .await?;
            if !leftover.is_empty() {
                upstream_stream.write_all(leftover).await?;
            }
            tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
        }
    }

    Ok(())
}

async fn read_http_request_head(stream: &mut TcpStream) -> Result<Vec<u8>> {
    let mut buffer = Vec::with_capacity(4096);
    let mut chunk = [0_u8; 2048];

    loop {
        let read = stream.read(&mut chunk).await?;
        if read == 0 {
            bail!("connection closed before HTTP header was complete");
        }

        buffer.extend_from_slice(&chunk[..read]);
        if find_header_end(&buffer).is_some() {
            return Ok(buffer);
        }
        if buffer.len() > MAX_HEADER_BYTES {
            bail!("HTTP proxy request header is too large");
        }
    }
}

#[derive(Debug, Eq, PartialEq)]
struct HttpRequest {
    method: String,
    target: String,
}

fn parse_http_request(head: &[u8]) -> Result<HttpRequest> {
    let text = str::from_utf8(head).context("HTTP request header is not valid UTF-8")?;
    let first_line = text.lines().next().context("HTTP request is empty")?;
    let mut parts = first_line.split_whitespace();
    let method = parts.next().context("HTTP request is missing method")?;
    let target = parts.next().context("HTTP request is missing target")?;
    let version = parts.next().context("HTTP request is missing version")?;

    if !version.starts_with("HTTP/") {
        bail!("invalid HTTP proxy request version: {version}");
    }

    Ok(HttpRequest {
        method: method.to_string(),
        target: target.to_string(),
    })
}

fn find_header_end(buffer: &[u8]) -> Option<usize> {
    buffer
        .windows(4)
        .position(|window| window == b"\r\n\r\n")
        .map(|index| index + 4)
}

fn add_proxy_authorization(head: &[u8], authorization: Option<&str>) -> Vec<u8> {
    let Some(authorization) = authorization else {
        return head.to_vec();
    };

    let text = String::from_utf8_lossy(head);
    if text
        .to_ascii_lowercase()
        .contains("\r\nproxy-authorization:")
    {
        return head.to_vec();
    }

    let Some(insert_at) = text.rfind("\r\n\r\n") else {
        return head.to_vec();
    };

    let mut outgoing = Vec::with_capacity(head.len() + authorization.len() + 24);
    outgoing.extend_from_slice(&head[..insert_at]);
    outgoing.extend_from_slice(format!("\r\nProxy-Authorization: {authorization}").as_bytes());
    outgoing.extend_from_slice(&head[insert_at..]);
    outgoing
}

fn parse_host_port(value: &str) -> Result<(String, u16)> {
    if let Some(rest) = value.strip_prefix('[') {
        let (host, tail) = rest
            .split_once(']')
            .context("invalid bracketed IPv6 CONNECT target")?;
        let port = tail
            .strip_prefix(':')
            .context("IPv6 CONNECT target is missing port")?
            .parse()
            .context("invalid CONNECT target port")?;
        return Ok((host.to_string(), port));
    }

    let (host, port) = value
        .rsplit_once(':')
        .context("CONNECT target must be host:port")?;
    if host.is_empty() {
        bail!("CONNECT target host cannot be empty");
    }

    Ok((
        host.to_string(),
        port.parse().context("invalid CONNECT target port")?,
    ))
}

async fn connect_via_socks5(
    proxy: &UpstreamProxy,
    target_host: &str,
    target_port: u16,
) -> Result<TcpStream> {
    let mut stream = TcpStream::connect((proxy.host(), proxy.port()))
        .await
        .with_context(|| {
            format!(
                "failed to connect upstream SOCKS5 proxy {}",
                proxy.authority()
            )
        })?;

    if proxy.has_auth() {
        stream.write_all(&[0x05, 0x02, 0x00, 0x02]).await?;
    } else {
        stream.write_all(&[0x05, 0x01, 0x00]).await?;
    }

    let mut method_response = [0_u8; 2];
    stream.read_exact(&mut method_response).await?;
    if method_response[0] != 0x05 {
        bail!("invalid SOCKS5 method response");
    }

    match method_response[1] {
        0x00 => {}
        0x02 => authenticate_socks5(proxy, &mut stream).await?,
        0xff => bail!("SOCKS5 proxy rejected all authentication methods"),
        method => bail!("SOCKS5 proxy selected unsupported authentication method {method:#x}"),
    }

    let request = build_socks5_connect_request(target_host, target_port)?;
    stream.write_all(&request).await?;

    let mut response = [0_u8; 4];
    stream.read_exact(&mut response).await?;
    if response[0] != 0x05 {
        bail!("invalid SOCKS5 connect response");
    }
    if response[1] != 0x00 {
        bail!("SOCKS5 connect failed with code {:#x}", response[1]);
    }

    read_socks5_bound_address(&mut stream, response[3]).await?;
    Ok(stream)
}

async fn authenticate_socks5(proxy: &UpstreamProxy, stream: &mut TcpStream) -> Result<()> {
    let username = proxy.username().unwrap_or_default().as_bytes();
    let password = proxy.password().unwrap_or_default().as_bytes();
    if username.len() > u8::MAX as usize || password.len() > u8::MAX as usize {
        bail!("SOCKS5 username and password must be at most 255 bytes");
    }

    let mut request = Vec::with_capacity(username.len() + password.len() + 3);
    request.push(0x01);
    request.push(username.len() as u8);
    request.extend_from_slice(username);
    request.push(password.len() as u8);
    request.extend_from_slice(password);
    stream.write_all(&request).await?;

    let mut response = [0_u8; 2];
    stream.read_exact(&mut response).await?;
    if response != [0x01, 0x00] {
        bail!("SOCKS5 username/password authentication failed");
    }
    Ok(())
}

fn build_socks5_connect_request(target_host: &str, target_port: u16) -> Result<Vec<u8>> {
    let mut request = vec![0x05, 0x01, 0x00];

    match target_host.parse::<IpAddr>() {
        Ok(IpAddr::V4(address)) => {
            request.push(0x01);
            request.extend_from_slice(&address.octets());
        }
        Ok(IpAddr::V6(address)) => {
            request.push(0x04);
            request.extend_from_slice(&address.octets());
        }
        Err(_) => {
            let host = target_host.as_bytes();
            if host.len() > u8::MAX as usize {
                bail!("SOCKS5 target host is too long");
            }
            request.push(0x03);
            request.push(host.len() as u8);
            request.extend_from_slice(host);
        }
    }

    request.extend_from_slice(&target_port.to_be_bytes());
    Ok(request)
}

async fn read_socks5_bound_address(stream: &mut TcpStream, address_type: u8) -> Result<()> {
    match address_type {
        0x01 => {
            let mut buffer = [0_u8; 4 + 2];
            stream.read_exact(&mut buffer).await?;
        }
        0x03 => {
            let mut length = [0_u8; 1];
            stream.read_exact(&mut length).await?;
            let mut buffer = vec![0_u8; length[0] as usize + 2];
            stream.read_exact(&mut buffer).await?;
        }
        0x04 => {
            let mut buffer = [0_u8; 16 + 2];
            stream.read_exact(&mut buffer).await?;
        }
        other => bail!("invalid SOCKS5 address type {other:#x}"),
    }
    Ok(())
}

async fn write_proxy_error(stream: &mut TcpStream, code: u16, message: &str) -> Result<()> {
    let response = format!(
        "HTTP/1.1 {code} {message}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{message}",
        message.len()
    );
    stream.write_all(response.as_bytes()).await?;
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parses_connect_targets() {
        assert_eq!(
            parse_host_port("discord.com:443").unwrap(),
            ("discord.com".to_string(), 443)
        );
        assert_eq!(
            parse_host_port("[::1]:443").unwrap(),
            ("::1".to_string(), 443)
        );
    }

    #[test]
    fn injects_proxy_authorization_header() {
        let head = b"CONNECT discord.com:443 HTTP/1.1\r\nHost: discord.com:443\r\n\r\n";

        let outgoing = add_proxy_authorization(head, Some("Basic abc"));
        let text = String::from_utf8(outgoing).unwrap();

        assert!(text.contains("\r\nProxy-Authorization: Basic abc\r\n"));
        assert!(text.ends_with("\r\n\r\n"));
    }

    #[test]
    fn does_not_duplicate_proxy_authorization_header() {
        let head = b"CONNECT discord.com:443 HTTP/1.1\r\nProxy-Authorization: Basic old\r\n\r\n";

        let outgoing = add_proxy_authorization(head, Some("Basic new"));

        assert_eq!(outgoing, head);
    }

    #[test]
    fn builds_domain_socks_connect_request() {
        let request = build_socks5_connect_request("discord.com", 443).unwrap();

        assert_eq!(&request[..5], &[0x05, 0x01, 0x00, 0x03, 11]);
        assert_eq!(&request[5..16], b"discord.com");
        assert_eq!(&request[16..], &443_u16.to_be_bytes());
    }
}