use crate::audit;
use crate::error::{ProxyError, Result};
use crate::filter::ProxyFilter;
use crate::token;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tracing::debug;
use zeroize::Zeroizing;
const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn handle_connect(
first_line: &str,
stream: &mut TcpStream,
filter: &ProxyFilter,
session_token: &Zeroizing<String>,
remaining_header: &[u8],
audit_log: Option<&audit::SharedAuditLog>,
) -> Result<()> {
let (host, port) = parse_connect_target(first_line)?;
debug!("CONNECT request to {}:{}", host, port);
if let Err(e) = validate_proxy_auth(remaining_header, session_token) {
debug!("CONNECT auth skipped: {}", e);
}
let check = filter.check_host(&host, port).await?;
if !check.result.is_allowed() {
let reason = check.result.reason();
audit::log_denied(audit_log, audit::ProxyMode::Connect, &host, port, &reason);
send_response(stream, 403, &format!("Forbidden: {}", reason)).await?;
return Err(ProxyError::HostDenied { host, reason });
}
let resolved = &check.resolved_addrs;
if resolved.is_empty() {
let reason = "DNS resolution returned no addresses".to_string();
audit::log_denied(audit_log, audit::ProxyMode::Connect, &host, port, &reason);
send_response(stream, 502, "DNS resolution failed").await?;
return Err(ProxyError::UpstreamConnect {
host: host.clone(),
reason,
});
}
let mut upstream = connect_to_resolved(resolved, &host).await?;
send_response(stream, 200, "Connection Established").await?;
audit::log_allowed(audit_log, audit::ProxyMode::Connect, &host, port, "CONNECT");
let result = tokio::io::copy_bidirectional(stream, &mut upstream).await;
debug!("CONNECT tunnel closed for {}:{}: {:?}", host, port, result);
Ok(())
}
async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
let mut last_err = None;
for addr in addrs {
match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => return Ok(stream),
Ok(Err(e)) => {
debug!("Connect to {} failed: {}", addr, e);
last_err = Some(e.to_string());
}
Err(_) => {
debug!("Connect to {} timed out", addr);
last_err = Some("connection timed out".to_string());
}
}
}
Err(ProxyError::UpstreamConnect {
host: host.to_string(),
reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
})
}
fn parse_connect_target(line: &str) -> Result<(String, u16)> {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "CONNECT" {
return Err(ProxyError::HttpParse(format!(
"malformed CONNECT line: {}",
line
)));
}
let authority = parts[1];
if let Some((host, port_str)) = authority.rsplit_once(':') {
let port = port_str.parse::<u16>().map_err(|_| {
ProxyError::HttpParse(format!("invalid port in CONNECT: {}", authority))
})?;
Ok((host.to_string(), port))
} else {
Ok((authority.to_string(), 443))
}
}
fn validate_proxy_auth(header_bytes: &[u8], session_token: &Zeroizing<String>) -> Result<()> {
token::validate_proxy_auth(header_bytes, session_token)
}
async fn send_response(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
let response = format!("HTTP/1.1 {} {}\r\n\r\n", status, reason);
stream.write_all(response.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_parse_connect_with_port() {
let (host, port) = parse_connect_target("CONNECT api.openai.com:443 HTTP/1.1").unwrap();
assert_eq!(host, "api.openai.com");
assert_eq!(port, 443);
}
#[test]
fn test_parse_connect_without_port() {
let (host, port) = parse_connect_target("CONNECT example.com HTTP/1.1").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn test_parse_connect_custom_port() {
let (host, port) = parse_connect_target("CONNECT internal:8443 HTTP/1.1").unwrap();
assert_eq!(host, "internal");
assert_eq!(port, 8443);
}
#[test]
fn test_parse_connect_malformed() {
assert!(parse_connect_target("GET /").is_err());
assert!(parse_connect_target("").is_err());
}
#[test]
fn test_validate_proxy_auth_valid() {
let token = Zeroizing::new("abc123".to_string());
let header = b"Proxy-Authorization: Bearer abc123\r\n\r\n";
assert!(validate_proxy_auth(header, &token).is_ok());
}
#[test]
fn test_validate_proxy_auth_invalid() {
let token = Zeroizing::new("abc123".to_string());
let header = b"Proxy-Authorization: Bearer wrong\r\n\r\n";
assert!(validate_proxy_auth(header, &token).is_err());
}
#[test]
fn test_validate_proxy_auth_missing() {
let token = Zeroizing::new("abc123".to_string());
let header = b"Host: example.com\r\n\r\n";
assert!(validate_proxy_auth(header, &token).is_err());
}
}