childflow 0.1.0

Forces DNS/proxy/interface for a child process tree and captures only its packets
use std::io::{Read, Write};
use std::net::{Shutdown, SocketAddr, SocketAddrV4, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;

use anyhow::{anyhow, bail, Context, Result};
use socket2::{Domain, Protocol, Socket, Type};

use crate::cli::ProxyType;

#[derive(Clone)]
pub struct ProxyServer {
    pub host: String,
    pub port: u16,
}

#[derive(Clone)]
pub struct ProxyUpstreamConfig {
    pub server: ProxyServer,
    pub kind: ProxyType,
    pub bind_interface: Option<String>,
}

pub struct TproxyHandle {
    stop: Arc<AtomicBool>,
    join: Option<JoinHandle<Result<()>>>,
    listen_port: u16,
}

impl TproxyHandle {
    pub fn start(upstream: ProxyUpstreamConfig) -> Result<Self> {
        let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
            .context("failed to create transparent listener socket")?;
        socket.set_reuse_address(true).ok();
        socket.set_reuse_port(true).ok();
        socket
            .set_ip_transparent_v4(true)
            .context("failed to enable IP_TRANSPARENT on listener socket")?;
        let bind_addr = SocketAddr::from(([0, 0, 0, 0], 0));
        socket
            .bind(&bind_addr.into())
            .context("failed to bind transparent listener")?;
        socket
            .listen(1024)
            .context("failed to listen on transparent socket")?;

        let listener: TcpListener = socket.into();
        listener
            .set_nonblocking(true)
            .context("failed to set transparent listener nonblocking")?;
        let listen_port = listener
            .local_addr()
            .context("failed to query transparent listener local address")?
            .port();

        let stop = Arc::new(AtomicBool::new(false));
        let stop_for_thread = Arc::clone(&stop);
        let join = thread::spawn(move || accept_loop(listener, upstream, stop_for_thread));

        Ok(Self {
            stop,
            join: Some(join),
            listen_port,
        })
    }

    pub fn listen_port(&self) -> u16 {
        self.listen_port
    }

    fn stop_and_join(&mut self) {
        self.stop.store(true, Ordering::Relaxed);
        if let Some(join) = self.join.take() {
            let _ = join.join();
        }
    }
}

impl Drop for TproxyHandle {
    fn drop(&mut self) {
        self.stop_and_join();
    }
}

fn accept_loop(
    listener: TcpListener,
    upstream: ProxyUpstreamConfig,
    stop: Arc<AtomicBool>,
) -> Result<()> {
    while !stop.load(Ordering::Relaxed) {
        match listener.accept() {
            Ok((stream, _peer)) => {
                let upstream = upstream.clone();
                thread::spawn(move || {
                    let _ = handle_connection(stream, &upstream);
                });
            }
            Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
                thread::sleep(Duration::from_millis(100));
            }
            Err(err) => {
                return Err(err).context("transparent listener accept failed");
            }
        }
    }

    Ok(())
}

fn handle_connection(mut inbound: TcpStream, upstream: &ProxyUpstreamConfig) -> Result<()> {
    let original_dst = inbound
        .local_addr()
        .context("failed to query original destination from transparent socket")?;

    let target = match original_dst {
        SocketAddr::V4(v4) => v4,
        SocketAddr::V6(_) => bail!("IPv6 interception is not implemented in this PoC"),
    };

    let mut outbound = connect_upstream_proxy(upstream, target)?;
    relay_bidirectional(&mut inbound, &mut outbound)?;
    Ok(())
}

fn connect_upstream_proxy(
    upstream: &ProxyUpstreamConfig,
    target: SocketAddrV4,
) -> Result<TcpStream> {
    let upstream_addr = resolve_proxy_server(&upstream.server)?;
    let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
        .context("failed to create outbound proxy socket")?;

    if let Some(iface) = &upstream.bind_interface {
        socket
            .bind_device(Some(iface.as_bytes()))
            .with_context(|| {
                format!("failed to bind outbound proxy socket to interface {iface}")
            })?;
    }

    socket
        .connect(&SocketAddr::V4(upstream_addr).into())
        .with_context(|| {
            format!(
                "failed to connect to upstream proxy {}",
                render_proxy_server(&upstream.server)
            )
        })?;

    let mut stream: TcpStream = socket.into();
    stream
        .set_nodelay(true)
        .context("failed to enable TCP_NODELAY on outbound proxy socket")?;

    match upstream.kind {
        ProxyType::Http => negotiate_http_connect(&mut stream, target)?,
        ProxyType::Socks5 => negotiate_socks5_connect(&mut stream, target)?,
    }

    Ok(stream)
}

fn resolve_proxy_server(server: &ProxyServer) -> Result<SocketAddrV4> {
    let addrs = (server.host.as_str(), server.port)
        .to_socket_addrs()
        .with_context(|| {
            format!(
                "failed to resolve upstream proxy {}",
                render_proxy_server(server)
            )
        })?;

    for addr in addrs {
        if let SocketAddr::V4(v4) = addr {
            return Ok(v4);
        }
    }

    bail!(
        "upstream proxy {} did not resolve to an IPv4 address",
        render_proxy_server(server)
    );
}

fn render_proxy_server(server: &ProxyServer) -> String {
    format!("{}:{}", server.host, server.port)
}

fn negotiate_http_connect(stream: &mut TcpStream, target: SocketAddrV4) -> Result<()> {
    let request = format!(
        "CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\nProxy-Connection: Keep-Alive\r\n\r\n",
        target.ip(),
        target.port(),
        target.ip(),
        target.port()
    );
    stream
        .write_all(request.as_bytes())
        .context("failed to write HTTP CONNECT request")?;
    stream
        .flush()
        .context("failed to flush HTTP CONNECT request")?;

    let response = read_headers(stream)?;
    let first_line = response.lines().next().unwrap_or_default();
    if !(first_line.starts_with("HTTP/1.1 200") || first_line.starts_with("HTTP/1.0 200")) {
        bail!("upstream HTTP proxy rejected CONNECT: {first_line}");
    }

    Ok(())
}

fn negotiate_socks5_connect(stream: &mut TcpStream, target: SocketAddrV4) -> Result<()> {
    stream
        .write_all(&[0x05, 0x01, 0x00])
        .context("failed to write SOCKS5 greeting")?;
    let mut method_reply = [0_u8; 2];
    stream
        .read_exact(&mut method_reply)
        .context("failed to read SOCKS5 method reply")?;
    if method_reply != [0x05, 0x00] {
        bail!("SOCKS5 proxy does not allow unauthenticated CONNECT");
    }

    let ip = target.ip().octets();
    let port = target.port().to_be_bytes();
    let request = [vec![0x05, 0x01, 0x00, 0x01], ip.to_vec(), port.to_vec()].concat();
    stream
        .write_all(&request)
        .context("failed to write SOCKS5 CONNECT request")?;

    let mut header = [0_u8; 4];
    stream
        .read_exact(&mut header)
        .context("failed to read SOCKS5 CONNECT reply header")?;
    if header[1] != 0x00 {
        bail!("SOCKS5 CONNECT failed with reply code 0x{:02x}", header[1]);
    }

    match header[3] {
        0x01 => {
            let mut rest = [0_u8; 6];
            stream
                .read_exact(&mut rest)
                .context("failed to read SOCKS5 IPv4 reply tail")?;
        }
        0x03 => {
            let mut len = [0_u8; 1];
            stream
                .read_exact(&mut len)
                .context("failed to read SOCKS5 domain length")?;
            let mut rest = vec![0_u8; len[0] as usize + 2];
            stream
                .read_exact(&mut rest)
                .context("failed to read SOCKS5 domain reply tail")?;
        }
        0x04 => {
            let mut rest = [0_u8; 18];
            stream
                .read_exact(&mut rest)
                .context("failed to read SOCKS5 IPv6 reply tail")?;
        }
        other => bail!("unsupported SOCKS5 ATYP in reply: 0x{other:02x}"),
    }

    Ok(())
}

fn read_headers(stream: &mut TcpStream) -> Result<String> {
    let mut buf = Vec::with_capacity(1024);
    let mut chunk = [0_u8; 256];
    loop {
        let n = stream
            .read(&mut chunk)
            .context("failed to read proxy response")?;
        if n == 0 {
            bail!("proxy closed connection while reading response headers");
        }
        buf.extend_from_slice(&chunk[..n]);
        if buf.windows(4).any(|window| window == b"\r\n\r\n") {
            break;
        }
        if buf.len() > 16 * 1024 {
            bail!("proxy response headers are too large");
        }
    }

    String::from_utf8(buf).map_err(|err| anyhow!("proxy response is not valid UTF-8: {err}"))
}

fn relay_bidirectional(left: &mut TcpStream, right: &mut TcpStream) -> Result<()> {
    let mut left_reader = left.try_clone().context("failed to clone inbound stream")?;
    let mut left_writer = left
        .try_clone()
        .context("failed to clone inbound stream writer")?;
    let mut right_reader = right
        .try_clone()
        .context("failed to clone outbound stream")?;
    let mut right_writer = right
        .try_clone()
        .context("failed to clone outbound stream writer")?;

    let client_to_proxy = thread::spawn(move || -> std::io::Result<u64> {
        let copied = std::io::copy(&mut left_reader, &mut right_writer)?;
        let _ = right_writer.shutdown(Shutdown::Write);
        Ok(copied)
    });
    let proxy_to_client = thread::spawn(move || -> std::io::Result<u64> {
        let copied = std::io::copy(&mut right_reader, &mut left_writer)?;
        let _ = left_writer.shutdown(Shutdown::Write);
        Ok(copied)
    });

    client_to_proxy
        .join()
        .map_err(|_| anyhow!("client->proxy relay thread panicked"))?
        .context("client->proxy relay failed")?;
    proxy_to_client
        .join()
        .map_err(|_| anyhow!("proxy->client relay thread panicked"))?
        .context("proxy->client relay failed")?;

    Ok(())
}