use super::Filter;
#[cfg(test)]
use super::pick_ephemeral_port;
use super::relay;
use anyhow::{Context, Result, bail};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{debug, warn};
const MAX_REQUEST_BYTES: usize = 8 * 1024;
const REQUEST_READ_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug)]
pub struct Server {
listener: TcpListener,
filter: Filter,
port: u16,
upstream: crate::proxy::upstream::UpstreamConfig,
}
impl Server {
pub async fn bind(port: Option<u16>, filter: Filter) -> Result<Self> {
let bind_port = port.unwrap_or(0);
let listener = TcpListener::bind(("127.0.0.1", bind_port))
.await
.with_context(|| format!("bind built-in proxy on 127.0.0.1:{bind_port}"))?;
let actual = listener
.local_addr()
.context("read local_addr from listener")?
.port();
debug!(
"built-in proxy bound: port={} filter_size={}",
actual,
filter.len()
);
let upstream = crate::proxy::upstream::UpstreamConfig::from_env();
if !matches!(upstream, crate::proxy::upstream::UpstreamConfig::Direct) {
debug!("built-in proxy chaining upstream via {upstream:?}");
}
Ok(Self {
listener,
filter,
port: actual,
upstream,
})
}
pub fn port(&self) -> u16 {
self.port
}
pub fn with_upstream(mut self, upstream: crate::proxy::upstream::UpstreamConfig) -> Self {
self.upstream = upstream;
self
}
pub async fn serve(self) {
let filter = self.filter;
let upstream = std::sync::Arc::new(self.upstream);
loop {
let (sock, peer) = match self.listener.accept().await {
Ok(t) => t,
Err(e) => {
warn!("built-in proxy accept failed: {e}");
return;
}
};
let f = filter.clone();
let up = std::sync::Arc::clone(&upstream);
tokio::spawn(async move {
if let Err(e) = handle_one(sock, &f, &up).await {
debug!("proxy connection from {peer} ended: {e:#}");
}
});
}
}
}
async fn handle_one(
mut client: TcpStream,
filter: &Filter,
upstream: &crate::proxy::upstream::UpstreamConfig,
) -> Result<()> {
let req = read_request(&mut client).await?;
let (method, target) = parse_request_line(&req)?;
if method != "CONNECT" {
write_status(&mut client, 405, "Method Not Allowed").await?;
return Ok(());
}
if !filter.allows(&target) {
write_status(&mut client, 403, "Forbidden").await?;
debug!("proxy: blocked CONNECT {target} (not in allowlist)");
return Ok(());
}
let upstream_sock = match crate::proxy::upstream::connect_upstream(&target, upstream).await {
Ok(s) => s,
Err(e) => {
warn!("proxy: upstream connect to {target} failed: {e:#}");
write_status(&mut client, 502, "Bad Gateway").await?;
return Ok(());
}
};
write_status(&mut client, 200, "Connection Established").await?;
let _ = relay::relay_with_timeouts(
client,
upstream_sock,
relay::DEFAULT_IDLE_TIMEOUT,
relay::DEFAULT_TOTAL_TIMEOUT,
)
.await;
Ok(())
}
async fn read_request(client: &mut TcpStream) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(512);
let read = tokio::time::timeout(REQUEST_READ_TIMEOUT, async {
let mut chunk = [0u8; 1024];
loop {
let n = client.read(&mut chunk).await?;
if n == 0 {
bail!("client closed before sending request");
}
buf.extend_from_slice(&chunk[..n]);
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
return Ok::<_, anyhow::Error>(());
}
if buf.len() > MAX_REQUEST_BYTES {
bail!("request headers exceed {} bytes", MAX_REQUEST_BYTES);
}
}
})
.await;
match read {
Ok(Ok(())) => Ok(buf),
Ok(Err(e)) => Err(e),
Err(_) => bail!("client request timed out after {REQUEST_READ_TIMEOUT:?}"),
}
}
fn parse_request_line(req: &[u8]) -> Result<(String, String)> {
let line_end = req
.windows(2)
.position(|w| w == b"\r\n")
.context("malformed request: no CRLF after request line")?;
let line = std::str::from_utf8(&req[..line_end])
.context("malformed request: request line not UTF-8")?;
let mut parts = line.splitn(3, ' ');
let method = parts
.next()
.context("malformed request: missing method")?
.to_string();
let target = parts
.next()
.context("malformed request: missing target")?
.to_string();
let version = parts.next().unwrap_or("");
if !version.starts_with("HTTP/") {
bail!("malformed request: unexpected version {version:?}");
}
Ok((method, target))
}
async fn write_status(client: &mut TcpStream, code: u16, reason: &str) -> Result<()> {
let resp = format!("HTTP/1.1 {code} {reason}\r\n\r\n");
client
.write_all(resp.as_bytes())
.await
.with_context(|| format!("write {code} response"))?;
client.flush().await.ok();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener as StdTcpListener;
async fn fake_upstream(body: &'static str) -> (u16, SocketAddr) {
let l = StdTcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = l.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((mut sock, _)) = l.accept().await {
let _ = sock.write_all(body.as_bytes()).await;
let _ = sock.shutdown().await;
}
});
(addr.port(), addr)
}
async fn do_connect(proxy_port: u16, target: &str) -> (String, Vec<u8>) {
let mut sock = TcpStream::connect(("127.0.0.1", proxy_port)).await.unwrap();
let req = format!("CONNECT {target} HTTP/1.1\r\nHost: {target}\r\n\r\n");
sock.write_all(req.as_bytes()).await.unwrap();
let (r, _w) = sock.split();
let mut reader = BufReader::new(r);
let mut status_line = String::new();
reader.read_line(&mut status_line).await.unwrap();
let mut rest = Vec::new();
let _ = reader.read_to_end(&mut rest).await;
(status_line, rest)
}
#[tokio::test]
async fn server_bind_uses_ephemeral_port_when_none() {
let s = Server::bind(None, Filter::default()).await.unwrap();
let p = s.port();
assert!(p > 0, "ephemeral port must be non-zero");
}
#[tokio::test]
async fn rejects_non_connect_with_405() {
let server = Server::bind(None, Filter::new(["github.com"]).unwrap())
.await
.unwrap();
let port = server.port();
tokio::spawn(server.serve());
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
sock.write_all(b"GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
.await
.unwrap();
let mut buf = String::new();
BufReader::new(sock).read_line(&mut buf).await.unwrap();
assert!(buf.starts_with("HTTP/1.1 405"), "got: {buf:?}");
}
#[tokio::test]
async fn rejects_disallowed_host_with_403() {
let server = Server::bind(None, Filter::new(["github.com"]).unwrap())
.await
.unwrap();
let port = server.port();
tokio::spawn(server.serve());
let (status, _) = do_connect(port, "evil.example.com:443").await;
assert!(status.starts_with("HTTP/1.1 403"), "got: {status:?}");
}
#[tokio::test]
async fn allows_listed_host_and_tunnels_payload() {
let payload = "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello";
let (up_port, _up_addr) = fake_upstream(payload).await;
let server = Server::bind(None, Filter::new(["127.0.0.1"]).unwrap())
.await
.unwrap()
.with_upstream(crate::proxy::upstream::UpstreamConfig::Direct);
let proxy_port = server.port();
tokio::spawn(server.serve());
let (status, body) = do_connect(proxy_port, &format!("127.0.0.1:{up_port}")).await;
assert!(status.starts_with("HTTP/1.1 200"), "got: {status:?}");
let body_str = String::from_utf8_lossy(&body);
let body_str = body_str.trim_start_matches("\r\n");
assert_eq!(body_str, payload, "tunneled body mismatch");
}
#[tokio::test]
async fn returns_502_when_upstream_unreachable() {
let server = Server::bind(None, Filter::new(["127.0.0.1"]).unwrap())
.await
.unwrap();
let proxy_port = server.port();
tokio::spawn(server.serve());
let dead = pick_ephemeral_port().unwrap();
let (status, _) = do_connect(proxy_port, &format!("127.0.0.1:{dead}")).await;
assert!(status.starts_with("HTTP/1.1 502"), "got: {status:?}");
}
#[tokio::test]
async fn chains_through_upstream_http_proxy() {
let payload = "PAYLOAD-FROM-REAL-UPSTREAM";
let (real_port, _) = fake_upstream(payload).await;
let connect_log = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
let corp_listener = StdTcpListener::bind("127.0.0.1:0").await.unwrap();
let corp_port = corp_listener.local_addr().unwrap().port();
let log_for_corp = std::sync::Arc::clone(&connect_log);
tokio::spawn(async move {
while let Ok((mut sock, _)) = corp_listener.accept().await {
let log = std::sync::Arc::clone(&log_for_corp);
tokio::spawn(async move {
let mut reader = BufReader::new(&mut sock);
let mut req = String::new();
reader.read_line(&mut req).await.unwrap();
loop {
let mut line = String::new();
let n = reader.read_line(&mut line).await.unwrap();
if n == 0 || line == "\r\n" || line == "\n" {
break;
}
}
let target = req.split_whitespace().nth(1).unwrap().to_string();
log.lock().unwrap().push(target.clone());
let mut up = TcpStream::connect(&target).await.unwrap();
sock.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
.await
.unwrap();
let _ = tokio::io::copy_bidirectional(&mut sock, &mut up).await;
});
}
});
let server = Server::bind(None, Filter::new(["127.0.0.1"]).unwrap())
.await
.unwrap()
.with_upstream(crate::proxy::upstream::UpstreamConfig::HttpProxy {
host: "127.0.0.1".into(),
port: corp_port,
no_proxy: vec![],
});
let proxy_port = server.port();
tokio::spawn(server.serve());
let target = format!("127.0.0.1:{real_port}");
let (status, body) = do_connect(proxy_port, &target).await;
assert!(status.starts_with("HTTP/1.1 200"), "got: {status:?}");
let body_str = String::from_utf8_lossy(&body);
let body_str = body_str.trim_start_matches("\r\n");
assert_eq!(body_str, payload, "tunneled body mismatch");
let log = connect_log.lock().unwrap();
assert_eq!(*log, vec![target], "corp proxy must see original target");
}
#[tokio::test]
async fn chains_skips_upstream_for_no_proxy_hosts() {
let payload = "DIRECT-NOT-CHAINED";
let (real_port, _) = fake_upstream(payload).await;
let corp_listener = StdTcpListener::bind("127.0.0.1:0").await.unwrap();
let corp_port = corp_listener.local_addr().unwrap().port();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<()>();
tokio::spawn(async move {
while corp_listener.accept().await.is_ok() {
let _ = tx.send(());
}
});
let server = Server::bind(None, Filter::new(["127.0.0.1"]).unwrap())
.await
.unwrap()
.with_upstream(crate::proxy::upstream::UpstreamConfig::HttpProxy {
host: "127.0.0.1".into(),
port: corp_port,
no_proxy: vec!["127.0.0.1".into()],
});
let proxy_port = server.port();
tokio::spawn(server.serve());
let (status, body) = do_connect(proxy_port, &format!("127.0.0.1:{real_port}")).await;
assert!(status.starts_with("HTTP/1.1 200"), "got: {status:?}");
let body_str = String::from_utf8_lossy(&body);
let body_str = body_str.trim_start_matches("\r\n");
assert_eq!(body_str, payload);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(
rx.try_recv().is_err(),
"NO_PROXY-listed host must bypass upstream proxy"
);
}
#[tokio::test]
async fn surfaces_upstream_407_as_502() {
let corp_listener = StdTcpListener::bind("127.0.0.1:0").await.unwrap();
let corp_port = corp_listener.local_addr().unwrap().port();
tokio::spawn(async move {
while let Ok((mut sock, _)) = corp_listener.accept().await {
tokio::spawn(async move {
let mut buf = [0u8; 1024];
let _ = sock.read(&mut buf).await;
let _ = sock
.write_all(
b"HTTP/1.1 407 Proxy Authentication Required\r\n\
Proxy-Authenticate: Basic\r\n\r\n",
)
.await;
let _ = sock.shutdown().await;
});
}
});
let server = Server::bind(None, Filter::new(["127.0.0.1"]).unwrap())
.await
.unwrap()
.with_upstream(crate::proxy::upstream::UpstreamConfig::HttpProxy {
host: "127.0.0.1".into(),
port: corp_port,
no_proxy: vec![],
});
let proxy_port = server.port();
tokio::spawn(server.serve());
let (status, _) = do_connect(proxy_port, "127.0.0.1:1").await;
assert!(
status.starts_with("HTTP/1.1 502"),
"upstream auth failure must surface as 502, got: {status:?}"
);
}
#[test]
fn parse_connect_request() {
let req = b"CONNECT github.com:443 HTTP/1.1\r\nHost: github.com:443\r\n\r\n";
let (m, t) = parse_request_line(req).unwrap();
assert_eq!(m, "CONNECT");
assert_eq!(t, "github.com:443");
}
#[test]
fn parse_get_request() {
let req = b"GET /index HTTP/1.1\r\n\r\n";
let (m, t) = parse_request_line(req).unwrap();
assert_eq!(m, "GET");
assert_eq!(t, "/index");
}
#[test]
fn parse_rejects_no_crlf() {
let req = b"CONNECT github.com:443 HTTP/1.1";
let err = parse_request_line(req).expect_err("must reject");
assert!(err.to_string().contains("no CRLF"));
}
#[test]
fn parse_rejects_missing_version() {
let req = b"CONNECT github.com:443\r\n\r\n";
let err = parse_request_line(req).expect_err("must reject");
assert!(err.to_string().contains("unexpected version"));
}
#[test]
fn parse_rejects_bad_version() {
let req = b"CONNECT github.com:443 SPDY/1\r\n\r\n";
let err = parse_request_line(req).expect_err("must reject");
assert!(err.to_string().contains("unexpected version"));
}
#[test]
fn parse_rejects_non_utf8() {
let req = &[
0xff, 0xff, b' ', b'/', b' ', b'H', b'T', b'T', b'P', b'\r', b'\n',
];
let err = parse_request_line(req).expect_err("must reject");
assert!(err.to_string().contains("not UTF-8"));
}
}