use anyhow::{Context, Result};
use std::io::{BufRead, BufReader, Write};
use std::net::{IpAddr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
use tracing::{debug, warn};
use crate::sandbox::rpc::generate_token;
const MAX_CONNECTIONS: usize = 16;
const MAX_REQUEST_SIZE: usize = 8 * 1024;
const AUTH_TIMEOUT: Duration = Duration::from_secs(5);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const ALLOWED_PORT: u16 = 443;
pub struct NetworkProxy {
listener: TcpListener,
port: u16,
token: String,
allowed_domains: Vec<String>,
}
pub struct ProxyHandle {
_handle: thread::JoinHandle<()>,
}
impl NetworkProxy {
pub fn bind(allowed_domains: &[String]) -> Result<Self> {
let listener =
TcpListener::bind("0.0.0.0:0").context("Failed to bind network proxy listener")?;
let port = listener.local_addr()?.port();
let token = generate_token();
debug!(port, "network proxy bound");
Ok(Self {
listener,
port,
token,
allowed_domains: allowed_domains.to_vec(),
})
}
pub fn port(&self) -> u16 {
self.port
}
pub fn token(&self) -> &str {
&self.token
}
pub fn spawn(self) -> ProxyHandle {
let ctx = Arc::new(ProxyContext {
token: self.token,
allowed_domains: self.allowed_domains,
});
let active = Arc::new(AtomicUsize::new(0));
let handle = thread::spawn(move || {
for stream in self.listener.incoming() {
match stream {
Ok(stream) => {
let current = active.load(Ordering::Relaxed);
if current >= MAX_CONNECTIONS {
warn!(current, "proxy connection limit reached, dropping");
drop(stream);
continue;
}
active.fetch_add(1, Ordering::Relaxed);
let ctx = Arc::clone(&ctx);
let active = Arc::clone(&active);
thread::spawn(move || {
if let Err(e) = handle_proxy_connection(stream, &ctx) {
debug!(error = %e, "proxy connection ended");
}
active.fetch_sub(1, Ordering::Relaxed);
});
}
Err(e) => {
debug!(error = %e, "proxy accept error, shutting down");
break;
}
}
}
});
ProxyHandle { _handle: handle }
}
}
struct ProxyContext {
token: String,
allowed_domains: Vec<String>,
}
fn domain_matches(domain: &str, pattern: &str) -> bool {
let domain = domain.to_ascii_lowercase();
let pattern = pattern.to_ascii_lowercase();
if let Some(suffix) = pattern.strip_prefix('*') {
domain.ends_with(&suffix)
} else {
domain == pattern
}
}
fn is_private_ip(addr: &IpAddr) -> bool {
match addr {
IpAddr::V4(ip) => {
ip.is_private()
|| ip.is_loopback()
|| ip.is_link_local()
|| ip.is_broadcast()
|| ip.is_unspecified()
|| ip.is_multicast()
|| (ip.octets()[0] == 100 && (ip.octets()[1] & 0xC0) == 64)
}
IpAddr::V6(ip) => {
if let Some(mapped) = ip.to_ipv4_mapped() {
return is_private_ip(&IpAddr::V4(mapped));
}
ip.is_loopback()
|| ip.is_unspecified()
|| ip.is_multicast()
|| (ip.segments()[0] & 0xfe00) == 0xfc00
|| (ip.segments()[0] & 0xffc0) == 0xfe80
}
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn handle_proxy_connection(stream: TcpStream, ctx: &ProxyContext) -> Result<()> {
let peer = stream.peer_addr().ok();
debug!(?peer, "proxy connection accepted");
stream.set_read_timeout(Some(AUTH_TIMEOUT))?;
let mut reader = BufReader::new(&stream);
let mut writer = stream.try_clone().context("Failed to clone proxy stream")?;
let mut total_read = 0usize;
let mut request_line = String::new();
let mut proxy_auth: Option<String> = None;
let n = reader.read_line(&mut request_line)?;
debug!(
?peer,
request_line = request_line.trim(),
bytes = n,
"proxy request line"
);
total_read += n;
if total_read > MAX_REQUEST_SIZE {
write_error(&mut writer, 400, "Request too large")?;
return Ok(());
}
loop {
let mut header_line = String::new();
let n = reader.read_line(&mut header_line)?;
total_read += n;
if total_read > MAX_REQUEST_SIZE {
write_error(&mut writer, 400, "Request too large")?;
return Ok(());
}
let trimmed = header_line.trim();
if trimmed.is_empty() {
break;
}
if let Some((name, value)) = trimmed.split_once(':')
&& name.trim().eq_ignore_ascii_case("Proxy-Authorization")
{
proxy_auth = Some(value.trim().to_string());
}
}
let request_line = request_line.trim();
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "CONNECT" {
write_error(&mut writer, 400, "Expected CONNECT method")?;
return Ok(());
}
let target = parts[1];
let (hostname, port) = parse_host_port(target)?;
debug!(hostname, port, "CONNECT request");
let expected = format!("Basic {}", base64_encode(&format!("workmux:{}", ctx.token)));
match proxy_auth {
None => {
debug!(hostname, "proxy auth missing");
write_error(&mut writer, 407, "Proxy authentication required")?;
return Ok(());
}
Some(ref auth) if !constant_time_eq(auth.as_bytes(), expected.as_bytes()) => {
debug!(hostname, "proxy auth failed");
write_error(&mut writer, 407, "Invalid proxy credentials")?;
return Ok(());
}
_ => {}
}
stream.set_read_timeout(None)?;
if port != ALLOWED_PORT {
warn!(hostname, port, "rejected: non-443 port");
write_error(&mut writer, 403, "Only port 443 is allowed")?;
return Ok(());
}
let hostname = hostname.to_ascii_lowercase();
let hostname = hostname.strip_suffix('.').unwrap_or(&hostname);
if hostname.parse::<IpAddr>().is_ok() {
warn!(hostname, "rejected: IP literal hostname");
write_error(&mut writer, 403, "IP literal hostnames not allowed")?;
return Ok(());
}
let allowed = ctx
.allowed_domains
.iter()
.any(|pattern| domain_matches(hostname, pattern));
if !allowed {
warn!(hostname, "rejected: domain not in allowlist");
write_error(&mut writer, 403, "Domain not allowed")?;
return Ok(());
}
let addrs: Vec<SocketAddr> = match format!("{}:{}", hostname, port).to_socket_addrs() {
Ok(addrs) => addrs.collect(),
Err(e) => {
warn!(hostname, error = %e, "DNS resolution failed");
write_error(&mut writer, 502, "DNS resolution failed")?;
return Ok(());
}
};
let public_addrs: Vec<&SocketAddr> = addrs.iter().filter(|a| !is_private_ip(&a.ip())).collect();
if public_addrs.is_empty() {
warn!(hostname, "rejected: all resolved IPs are private");
write_error(&mut writer, 403, "All resolved IPs are private")?;
return Ok(());
}
let target_addr = *public_addrs[0];
let target_stream = match TcpStream::connect_timeout(&target_addr, CONNECT_TIMEOUT) {
Ok(s) => s,
Err(e) => {
debug!(hostname, addr = %target_addr, error = %e, "connect failed");
write_error(&mut writer, 502, "Connection to target failed")?;
return Ok(());
}
};
debug!(hostname, addr = %target_addr, "tunnel established");
writer.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")?;
writer.flush()?;
let buffered = reader.buffer();
if !buffered.is_empty() {
let mut target_ref = &target_stream;
target_ref
.write_all(buffered)
.context("Failed to forward buffered data to target")?;
}
tunnel(reader.into_inner(), &target_stream)?;
Ok(())
}
fn parse_host_port(target: &str) -> Result<(&str, u16)> {
if let Some(bracket_end) = target.find(']') {
let host = &target[..=bracket_end];
let port_str = target[bracket_end + 1..].strip_prefix(':').unwrap_or("443");
let port: u16 = port_str.parse().context("Invalid port")?;
return Ok((host, port));
}
match target.rsplit_once(':') {
Some((host, port_str)) => {
let port: u16 = port_str.parse().context("Invalid port")?;
Ok((host, port))
}
None => Ok((target, 443)),
}
}
fn write_error(writer: &mut impl Write, code: u16, reason: &str) -> Result<()> {
let response = format!(
"HTTP/1.1 {} {}\r\nContent-Length: {}\r\n\r\n{}",
code,
reason,
reason.len(),
reason,
);
writer.write_all(response.as_bytes())?;
writer.flush()?;
Ok(())
}
fn base64_encode(input: &str) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let bytes = input.as_bytes();
let mut result = String::with_capacity(bytes.len().div_ceil(3) * 4);
for chunk in bytes.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let triple = (b0 << 16) | (b1 << 8) | b2;
result.push(ALPHABET[((triple >> 18) & 0x3F) as usize] as char);
result.push(ALPHABET[((triple >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(ALPHABET[((triple >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(ALPHABET[(triple & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}
fn tunnel(client: &TcpStream, target: &TcpStream) -> Result<()> {
let mut client_read = client.try_clone()?;
let mut target_write = target.try_clone()?;
let mut target_read = target.try_clone()?;
let mut client_write = client.try_clone()?;
let t1 = thread::spawn(move || {
let _ = std::io::copy(&mut client_read, &mut target_write);
let _ = target_write.shutdown(std::net::Shutdown::Write);
});
let t2 = thread::spawn(move || {
let _ = std::io::copy(&mut target_read, &mut client_write);
let _ = client_write.shutdown(std::net::Shutdown::Write);
});
t1.join().ok();
t2.join().ok();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn domain_exact_match() {
assert!(domain_matches("example.com", "example.com"));
assert!(domain_matches("api.anthropic.com", "api.anthropic.com"));
}
#[test]
fn domain_case_insensitive() {
assert!(domain_matches("Example.COM", "example.com"));
assert!(domain_matches("example.com", "Example.COM"));
}
#[test]
fn domain_wildcard_match() {
assert!(domain_matches("foo.googleapis.com", "*.googleapis.com"));
assert!(domain_matches("bar.baz.googleapis.com", "*.googleapis.com"));
}
#[test]
fn domain_wildcard_does_not_match_base() {
assert!(!domain_matches("example.com", "*.example.com"));
}
#[test]
fn domain_no_match() {
assert!(!domain_matches("evil.com", "example.com"));
assert!(!domain_matches("notexample.com", "example.com"));
assert!(!domain_matches("evil.com", "*.example.com"));
}
#[test]
fn private_ip_rfc1918() {
assert!(is_private_ip(&"10.0.0.1".parse().unwrap()));
assert!(is_private_ip(&"172.16.0.1".parse().unwrap()));
assert!(is_private_ip(&"192.168.1.1".parse().unwrap()));
}
#[test]
fn private_ip_loopback() {
assert!(is_private_ip(&"127.0.0.1".parse().unwrap()));
assert!(is_private_ip(&"::1".parse().unwrap()));
}
#[test]
fn private_ip_link_local() {
assert!(is_private_ip(&"169.254.1.1".parse().unwrap()));
}
#[test]
fn private_ip_cgnat() {
assert!(is_private_ip(&"100.64.0.1".parse().unwrap()));
assert!(is_private_ip(&"100.127.255.255".parse().unwrap()));
}
#[test]
fn private_ip_multicast() {
assert!(is_private_ip(&"224.0.0.1".parse().unwrap()));
}
#[test]
fn private_ip_ipv6_ula() {
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fd12::1".parse().unwrap()));
}
#[test]
fn private_ip_ipv6_link_local() {
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
}
#[test]
fn private_ip_v4_mapped_v6() {
assert!(is_private_ip(&"::ffff:127.0.0.1".parse().unwrap()));
assert!(is_private_ip(&"::ffff:10.0.0.1".parse().unwrap()));
assert!(!is_private_ip(&"::ffff:8.8.8.8".parse().unwrap()));
}
#[test]
fn public_ip_allowed() {
assert!(!is_private_ip(&"8.8.8.8".parse().unwrap()));
assert!(!is_private_ip(&"1.1.1.1".parse().unwrap()));
assert!(!is_private_ip(&"2607:f8b0:4004:800::200e".parse().unwrap()));
}
#[test]
fn not_private_ip_100_non_cgnat() {
assert!(!is_private_ip(&"100.0.0.1".parse().unwrap()));
}
#[test]
fn parse_host_port_standard() {
let (host, port) = parse_host_port("example.com:443").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn parse_host_port_non_standard() {
let (host, port) = parse_host_port("example.com:8443").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 8443);
}
#[test]
fn parse_host_port_no_port() {
let (host, port) = parse_host_port("example.com").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn base64_encode_basic_auth() {
assert_eq!(base64_encode("workmux:mytoken"), "d29ya211eDpteXRva2Vu");
assert_eq!(base64_encode(""), "");
assert_eq!(base64_encode("a"), "YQ==");
assert_eq!(base64_encode("ab"), "YWI=");
assert_eq!(base64_encode("abc"), "YWJj");
}
#[test]
fn proxy_binds_to_random_port() {
let proxy = NetworkProxy::bind(&["example.com".to_string()]).unwrap();
assert!(proxy.port() > 0);
}
#[test]
fn proxy_token_is_nonempty() {
let proxy = NetworkProxy::bind(&[]).unwrap();
assert!(!proxy.token().is_empty());
}
#[test]
fn proxy_rejects_missing_auth() {
let proxy = NetworkProxy::bind(&["example.com".to_string()]).unwrap();
let port = proxy.port();
let _handle = proxy.spawn();
std::thread::sleep(Duration::from_millis(50));
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
stream
.write_all(b"CONNECT example.com:443 HTTP/1.1\r\n\r\n")
.unwrap();
stream.flush().unwrap();
let mut response = String::new();
let mut reader = BufReader::new(&stream);
reader.read_line(&mut response).unwrap();
assert!(response.contains("407"));
}
#[test]
fn proxy_rejects_wrong_auth() {
let proxy = NetworkProxy::bind(&["example.com".to_string()]).unwrap();
let port = proxy.port();
let _handle = proxy.spawn();
std::thread::sleep(Duration::from_millis(50));
let auth = format!("Basic {}", base64_encode("workmux:wrong-token"));
let request = format!(
"CONNECT example.com:443 HTTP/1.1\r\nProxy-Authorization: {}\r\n\r\n",
auth
);
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
stream.write_all(request.as_bytes()).unwrap();
stream.flush().unwrap();
let mut response = String::new();
let mut reader = BufReader::new(&stream);
reader.read_line(&mut response).unwrap();
assert!(response.contains("407"));
}
#[test]
fn proxy_accepts_lowercase_auth_header() {
let proxy = NetworkProxy::bind(&["example.com".to_string()]).unwrap();
let port = proxy.port();
let token = proxy.token().to_string();
let _handle = proxy.spawn();
std::thread::sleep(Duration::from_millis(50));
let auth = format!("Basic {}", base64_encode(&format!("workmux:{}", token)));
let request = format!(
"CONNECT example.com:443 HTTP/1.1\r\nproxy-authorization: {}\r\n\r\n",
auth
);
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
stream.write_all(request.as_bytes()).unwrap();
stream.flush().unwrap();
let mut response = String::new();
let mut reader = BufReader::new(&stream);
reader.read_line(&mut response).unwrap();
assert!(
!response.contains("407"),
"lowercase proxy-authorization should be accepted, got: {}",
response.trim()
);
}
#[test]
fn proxy_rejects_non_443_port() {
let proxy = NetworkProxy::bind(&["example.com".to_string()]).unwrap();
let port = proxy.port();
let token = proxy.token().to_string();
let _handle = proxy.spawn();
std::thread::sleep(Duration::from_millis(50));
let auth = format!("Basic {}", base64_encode(&format!("workmux:{}", token)));
let request = format!(
"CONNECT example.com:80 HTTP/1.1\r\nProxy-Authorization: {}\r\n\r\n",
auth
);
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
stream.write_all(request.as_bytes()).unwrap();
stream.flush().unwrap();
let mut response = String::new();
let mut reader = BufReader::new(&stream);
reader.read_line(&mut response).unwrap();
assert!(response.contains("403"));
}
#[test]
fn proxy_rejects_unlisted_domain() {
let proxy = NetworkProxy::bind(&["allowed.com".to_string()]).unwrap();
let port = proxy.port();
let token = proxy.token().to_string();
let _handle = proxy.spawn();
std::thread::sleep(Duration::from_millis(50));
let auth = format!("Basic {}", base64_encode(&format!("workmux:{}", token)));
let request = format!(
"CONNECT denied.com:443 HTTP/1.1\r\nProxy-Authorization: {}\r\n\r\n",
auth
);
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
stream.write_all(request.as_bytes()).unwrap();
stream.flush().unwrap();
let mut response = String::new();
let mut reader = BufReader::new(&stream);
reader.read_line(&mut response).unwrap();
assert!(response.contains("403"));
}
#[test]
fn pipelined_data_forwarded_through_tunnel() {
use std::io::Read;
let target_listener = TcpListener::bind("127.0.0.1:0").unwrap();
let target_addr = target_listener.local_addr().unwrap();
let target_handle = thread::spawn(move || {
let (mut conn, _) = target_listener.accept().unwrap();
conn.set_read_timeout(Some(Duration::from_secs(2))).unwrap();
let mut buf = vec![0u8; 26];
conn.read_exact(&mut buf).unwrap();
buf
});
let proxy_listener = TcpListener::bind("127.0.0.1:0").unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let extra_data = b"SIMULATED_TLS_CLIENT_HELLO";
let mut pipelined = Vec::new();
pipelined
.extend_from_slice(b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n");
pipelined.extend_from_slice(extra_data);
let mut client = TcpStream::connect(proxy_addr).unwrap();
client.write_all(&pipelined).unwrap();
client.flush().unwrap();
let (proxy_stream, _) = proxy_listener.accept().unwrap();
thread::sleep(Duration::from_millis(50));
let mut reader = BufReader::new(&proxy_stream);
loop {
let mut line = String::new();
reader.read_line(&mut line).unwrap();
if line.trim().is_empty() {
break;
}
}
let mut target_stream = TcpStream::connect(target_addr).unwrap();
let buffer = reader.buffer();
assert!(
!buffer.is_empty(),
"BufReader should have buffered the pipelined data"
);
target_stream.write_all(buffer).unwrap();
target_stream.flush().unwrap();
drop(target_stream);
let received = target_handle.join().unwrap();
assert_eq!(received, extra_data);
}
#[test]
fn proxy_rejects_ip_literal_hostname() {
let proxy = NetworkProxy::bind(&["8.8.8.8".to_string()]).unwrap();
let port = proxy.port();
let token = proxy.token().to_string();
let _handle = proxy.spawn();
std::thread::sleep(Duration::from_millis(50));
let auth = format!("Basic {}", base64_encode(&format!("workmux:{}", token)));
let request = format!(
"CONNECT 8.8.8.8:443 HTTP/1.1\r\nProxy-Authorization: {}\r\n\r\n",
auth
);
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
stream.write_all(request.as_bytes()).unwrap();
stream.flush().unwrap();
let mut response = String::new();
let mut reader = BufReader::new(&stream);
reader.read_line(&mut response).unwrap();
assert!(response.contains("403"));
}
}