use anyhow::{Context, Result, anyhow};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::warn;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UpstreamConfig {
Direct,
HttpProxy {
host: String,
port: u16,
no_proxy: Vec<String>,
},
}
impl UpstreamConfig {
pub fn from_env() -> Self {
let raw = std::env::var("HTTPS_PROXY")
.ok()
.or_else(|| std::env::var("https_proxy").ok());
let Some(raw) = raw else {
return Self::Direct;
};
let trimmed = raw.trim();
if trimmed.is_empty() {
return Self::Direct;
}
match parse_proxy_url(trimmed) {
Some((host, port)) => Self::HttpProxy {
host,
port,
no_proxy: parse_no_proxy(),
},
None => {
warn!(
"HTTPS_PROXY={raw:?} is not parseable (expected http://host:port, no auth) — \
falling back to direct upstream"
);
Self::Direct
}
}
}
}
pub fn parse_proxy_url(s: &str) -> Option<(String, u16)> {
let s = s.strip_suffix('/').unwrap_or(s);
let s = s
.strip_prefix("https://")
.or_else(|| s.strip_prefix("http://"))
.unwrap_or(s);
if s.contains('@') || s.contains(char::is_whitespace) || s.contains('/') {
return None;
}
let (host, port_str) = s.rsplit_once(':')?;
if host.is_empty() {
return None;
}
let port: u16 = port_str.parse().ok()?;
if port == 0 {
return None;
}
Some((host.to_string(), port))
}
pub fn parse_no_proxy() -> Vec<String> {
let raw = std::env::var("NO_PROXY")
.ok()
.or_else(|| std::env::var("no_proxy").ok())
.unwrap_or_default();
raw.split(',')
.map(|e| e.trim().trim_start_matches('.').to_string())
.filter(|e| !e.is_empty())
.collect()
}
pub fn bypasses_proxy(target_host: &str, no_proxy: &[String]) -> bool {
no_proxy
.iter()
.any(|e| e == "*" || target_host == e || target_host.ends_with(&format!(".{e}")))
}
const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
pub async fn connect_upstream(target: &str, cfg: &UpstreamConfig) -> Result<TcpStream> {
match cfg {
UpstreamConfig::Direct => connect_direct(target).await,
UpstreamConfig::HttpProxy {
host,
port,
no_proxy,
} => {
let target_host = target.rsplit_once(':').map(|(h, _)| h).unwrap_or(target);
if bypasses_proxy(target_host, no_proxy) {
return connect_direct(target).await;
}
connect_via_http_proxy(target, host, *port).await
}
}
}
async fn connect_direct(target: &str) -> Result<TcpStream> {
let stream = tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(target))
.await
.with_context(|| format!("upstream connect to {target} timed out"))?
.with_context(|| format!("upstream connect to {target} failed"))?;
Ok(stream)
}
async fn connect_via_http_proxy(
target: &str,
proxy_host: &str,
proxy_port: u16,
) -> Result<TcpStream> {
let proxy_addr = format!("{proxy_host}:{proxy_port}");
let mut stream =
tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&proxy_addr))
.await
.with_context(|| format!("dial upstream proxy {proxy_addr} timed out"))?
.with_context(|| format!("dial upstream proxy {proxy_addr} failed"))?;
let req = format!(
"CONNECT {target} HTTP/1.1\r\nHost: {target}\r\nProxy-Connection: keep-alive\r\n\r\n"
);
tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, stream.write_all(req.as_bytes()))
.await
.with_context(|| format!("send CONNECT to upstream proxy {proxy_addr} timed out"))?
.with_context(|| format!("send CONNECT to upstream proxy {proxy_addr} failed"))?;
let (status, _headers) = read_proxy_response(&mut stream)
.await
.with_context(|| format!("read CONNECT response from upstream proxy {proxy_addr}"))?;
if !(200..300).contains(&status) {
return Err(anyhow!(
"upstream proxy {proxy_addr} refused CONNECT to {target} with status {status}"
));
}
Ok(stream)
}
async fn read_proxy_response(stream: &mut TcpStream) -> Result<(u16, Vec<u8>)> {
const MAX_RESPONSE: usize = 8192;
let mut buf = Vec::with_capacity(256);
let deadline = tokio::time::Instant::now() + UPSTREAM_CONNECT_TIMEOUT;
loop {
if buf.len() >= MAX_RESPONSE {
return Err(anyhow!(
"proxy response exceeded {MAX_RESPONSE} bytes before CRLF CRLF"
));
}
let mut byte = [0u8; 1];
let n = tokio::time::timeout_at(deadline, stream.read(&mut byte))
.await
.context("read proxy response timed out")?
.context("read proxy response failed")?;
if n == 0 {
return Err(anyhow!("proxy closed connection mid-response"));
}
buf.push(byte[0]);
if buf.ends_with(b"\r\n\r\n") || buf.ends_with(b"\n\n") {
break;
}
}
let line_end = buf
.iter()
.position(|&b| b == b'\n')
.ok_or_else(|| anyhow!("proxy response missing newline after status line"))?;
let status_line = std::str::from_utf8(&buf[..line_end])
.context("proxy status line is not UTF-8")?
.trim_end_matches('\r');
let mut parts = status_line.split_whitespace();
let _version = parts
.next()
.context("missing HTTP version in status line")?;
let code_str = parts.next().context("missing status code in status line")?;
let code: u16 = code_str
.parse()
.with_context(|| format!("non-numeric status code {code_str:?}"))?;
Ok((code, buf))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_bare_host_port() {
assert_eq!(
parse_proxy_url("proxy.corp:8080"),
Some(("proxy.corp".into(), 8080))
);
}
#[test]
fn parse_strips_http_scheme() {
assert_eq!(
parse_proxy_url("http://proxy.corp:8080"),
Some(("proxy.corp".into(), 8080))
);
}
#[test]
fn parse_strips_https_scheme() {
assert_eq!(
parse_proxy_url("https://proxy.corp:8080"),
Some(("proxy.corp".into(), 8080))
);
}
#[test]
fn parse_strips_trailing_slash() {
assert_eq!(
parse_proxy_url("http://proxy.corp:8080/"),
Some(("proxy.corp".into(), 8080))
);
}
#[test]
fn parse_rejects_auth() {
assert_eq!(parse_proxy_url("http://user:pass@proxy.corp:8080"), None);
}
#[test]
fn parse_rejects_missing_port() {
assert_eq!(parse_proxy_url("proxy.corp"), None);
assert_eq!(parse_proxy_url("http://proxy.corp"), None);
}
#[test]
fn parse_rejects_zero_port() {
assert_eq!(parse_proxy_url("proxy.corp:0"), None);
}
#[test]
fn parse_rejects_non_numeric_port() {
assert_eq!(parse_proxy_url("proxy.corp:eight"), None);
}
#[test]
fn parse_rejects_path_component() {
assert_eq!(parse_proxy_url("http://proxy.corp:8080/some/path"), None);
}
#[test]
fn no_proxy_exact_match() {
let no_proxy = vec!["localhost".into(), "internal.corp".into()];
assert!(bypasses_proxy("localhost", &no_proxy));
assert!(bypasses_proxy("internal.corp", &no_proxy));
assert!(!bypasses_proxy("example.com", &no_proxy));
}
#[test]
fn no_proxy_subdomain_match() {
let no_proxy = vec!["walmart.com".into()];
assert!(bypasses_proxy("pypi.ci.artifacts.walmart.com", &no_proxy));
assert!(bypasses_proxy("walmart.com", &no_proxy));
assert!(!bypasses_proxy("evilwalmart.com", &no_proxy));
}
#[test]
fn no_proxy_star_matches_everything() {
let no_proxy = vec!["*".into()];
assert!(bypasses_proxy("api.openai.com", &no_proxy));
assert!(bypasses_proxy("anything.at.all", &no_proxy));
}
#[test]
fn no_proxy_empty_list_bypasses_nothing() {
assert!(!bypasses_proxy("localhost", &[]));
}
#[test]
fn parse_no_proxy_strips_dots_and_whitespace() {
let raw = " .example.com, internal.corp ,, *";
let parsed: Vec<String> = raw
.split(',')
.map(|e| e.trim().trim_start_matches('.').to_string())
.filter(|e| !e.is_empty())
.collect();
assert_eq!(parsed, vec!["example.com", "internal.corp", "*"]);
}
}