use crate::audit;
use crate::error::{ProxyError, Result};
use crate::filter::ProxyFilter;
use crate::token;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::debug;
use zeroize::Zeroizing;
const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn handle_connect(
first_line: &str,
stream: &mut TcpStream,
filter: &ProxyFilter,
session_token: &Zeroizing<String>,
remaining_header: &[u8],
audit_log: Option<&audit::SharedAuditLog>,
) -> Result<()> {
let (host, port) = parse_connect_target(first_line)?;
debug!("CONNECT request to {}:{}", host, port);
if let Err(e) = validate_proxy_auth(remaining_header, session_token) {
debug!("CONNECT auth skipped: {}", e);
}
let check = filter.check_host(&host, port).await?;
if !check.result.is_allowed() {
let reason = check.result.reason();
audit::log_denied(
audit_log,
audit::ProxyMode::Connect,
&audit::EventContext {
denial_category: Some(nono::undo::NetworkAuditDenialCategory::HostDenied),
..audit::EventContext::default()
},
&host,
port,
&reason,
);
send_response(stream, 403, &format!("Forbidden: {}", reason)).await?;
return Err(ProxyError::HostDenied { host, reason });
}
let resolved = &check.resolved_addrs;
if resolved.is_empty() {
let reason = "DNS resolution returned no addresses".to_string();
let _ = write_upstream_failure(stream, audit_log, &host, port, &reason).await;
return Err(ProxyError::UpstreamConnect {
host: host.clone(),
reason,
});
}
let mut upstream = match connect_to_resolved(resolved, &host).await {
Ok(stream) => stream,
Err(err) => {
let reason = match &err {
ProxyError::UpstreamConnect { reason, .. } => reason.clone(),
other => other.to_string(),
};
let _ = write_upstream_failure(stream, audit_log, &host, port, &reason).await;
return Err(err);
}
};
send_response(stream, 200, "Connection Established").await?;
audit::log_allowed(
audit_log,
audit::ProxyMode::Connect,
&audit::EventContext::default(),
&host,
port,
"CONNECT",
);
let result = tokio::io::copy_bidirectional(stream, &mut upstream).await;
debug!("CONNECT tunnel closed for {}:{}: {:?}", host, port, result);
Ok(())
}
async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
let mut last_err = None;
for addr in addrs {
match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => return Ok(stream),
Ok(Err(e)) => {
debug!("Connect to {} failed: {}", addr, e);
last_err = Some(e.to_string());
}
Err(_) => {
debug!("Connect to {} timed out", addr);
last_err = Some("connection timed out".to_string());
}
}
}
Err(ProxyError::UpstreamConnect {
host: host.to_string(),
reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
})
}
fn parse_connect_target(line: &str) -> Result<(String, u16)> {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "CONNECT" {
return Err(ProxyError::HttpParse(format!(
"malformed CONNECT line: {}",
line
)));
}
let authority = parts[1];
if let Some((host, port_str)) = authority.rsplit_once(':') {
let port = port_str.parse::<u16>().map_err(|_| {
ProxyError::HttpParse(format!("invalid port in CONNECT: {}", authority))
})?;
Ok((host.to_string(), port))
} else {
Ok((authority.to_string(), 443))
}
}
fn validate_proxy_auth(header_bytes: &[u8], session_token: &Zeroizing<String>) -> Result<()> {
token::validate_proxy_auth(header_bytes, session_token)
}
async fn send_response<S: AsyncWrite + Unpin>(
stream: &mut S,
status: u16,
reason: &str,
) -> Result<()> {
let sanitised_reason = reason.replace(['\r', '\n'], " ");
let response = format!("HTTP/1.1 {} {}\r\n\r\n", status, sanitised_reason);
stream.write_all(response.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
async fn write_upstream_failure<S: AsyncWrite + Unpin>(
stream: &mut S,
audit_log: Option<&audit::SharedAuditLog>,
host: &str,
port: u16,
reason: &str,
) -> Result<()> {
audit::log_denied(
audit_log,
audit::ProxyMode::Connect,
&audit::EventContext {
denial_category: Some(nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed),
..audit::EventContext::default()
},
host,
port,
reason,
);
send_response(stream, 502, &format!("Upstream connect failed: {}", reason)).await
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_parse_connect_with_port() {
let (host, port) = parse_connect_target("CONNECT api.openai.com:443 HTTP/1.1").unwrap();
assert_eq!(host, "api.openai.com");
assert_eq!(port, 443);
}
#[test]
fn test_parse_connect_without_port() {
let (host, port) = parse_connect_target("CONNECT example.com HTTP/1.1").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn test_parse_connect_custom_port() {
let (host, port) = parse_connect_target("CONNECT internal:8443 HTTP/1.1").unwrap();
assert_eq!(host, "internal");
assert_eq!(port, 8443);
}
#[test]
fn test_parse_connect_malformed() {
assert!(parse_connect_target("GET /").is_err());
assert!(parse_connect_target("").is_err());
}
#[test]
fn test_validate_proxy_auth_valid() {
let token = Zeroizing::new("abc123".to_string());
let header = b"Proxy-Authorization: Bearer abc123\r\n\r\n";
assert!(validate_proxy_auth(header, &token).is_ok());
}
#[test]
fn test_validate_proxy_auth_invalid() {
let token = Zeroizing::new("abc123".to_string());
let header = b"Proxy-Authorization: Bearer wrong\r\n\r\n";
assert!(validate_proxy_auth(header, &token).is_err());
}
#[test]
fn test_validate_proxy_auth_missing() {
let token = Zeroizing::new("abc123".to_string());
let header = b"Host: example.com\r\n\r\n";
assert!(validate_proxy_auth(header, &token).is_err());
}
use nono::undo::{NetworkAuditDecision, NetworkAuditDenialCategory, NetworkAuditMode};
use tokio::io::{AsyncReadExt, duplex};
async fn read_to_string<R: tokio::io::AsyncRead + Unpin>(mut reader: R) -> String {
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await.unwrap();
String::from_utf8(buf).unwrap()
}
#[tokio::test]
async fn write_upstream_failure_sends_502_status_line() {
let (server, client) = duplex(1024);
let mut server = server;
write_upstream_failure(&mut server, None, "example.com", 443, "connection refused")
.await
.unwrap();
drop(server);
let response = read_to_string(client).await;
assert!(
response.starts_with("HTTP/1.1 502 "),
"expected 502 status line, got: {:?}",
response
);
assert!(
response.contains("Upstream connect failed: connection refused"),
"expected reason on status line, got: {:?}",
response
);
}
#[tokio::test]
async fn write_upstream_failure_records_audit_entry() {
let (mut server, _client) = duplex(1024);
let log = audit::new_audit_log();
write_upstream_failure(
&mut server,
Some(&log),
"example.com",
443,
"connection refused",
)
.await
.unwrap();
let events = audit::drain_audit_events(&log);
assert_eq!(events.len(), 1);
let event = &events[0];
assert_eq!(event.mode, NetworkAuditMode::Connect);
assert_eq!(event.decision, NetworkAuditDecision::Deny);
assert_eq!(
event.denial_category,
Some(NetworkAuditDenialCategory::UpstreamConnectFailed)
);
assert_eq!(event.target, "example.com");
assert_eq!(event.port, Some(443));
assert_eq!(event.reason.as_deref(), Some("connection refused"));
}
#[tokio::test]
async fn write_upstream_failure_without_audit_log_still_writes_response() {
let (mut server, client) = duplex(1024);
write_upstream_failure(&mut server, None, "example.com", 443, "connection refused")
.await
.unwrap();
drop(server);
let response = read_to_string(client).await;
assert!(response.starts_with("HTTP/1.1 502 "));
}
#[tokio::test]
async fn write_upstream_failure_sanitises_crlf_in_reason() {
let (mut server, client) = duplex(1024);
write_upstream_failure(
&mut server,
None,
"example.com",
443,
"connection refused\r\nX-Injected: yes",
)
.await
.unwrap();
drop(server);
let response = read_to_string(client).await;
let terminator = "\r\n\r\n";
let body_end = response.find(terminator).expect("response must terminate");
let status_line = &response[..body_end];
assert!(
!status_line.contains('\r'),
"status line must not contain CR, got: {:?}",
status_line
);
assert!(
!status_line.contains('\n'),
"status line must not contain LF, got: {:?}",
status_line
);
assert!(
!response.contains("\r\nX-Injected:"),
"injected header must not be split into a real header: {:?}",
response
);
}
#[tokio::test]
async fn write_upstream_failure_round_trips_timeout_reason() {
let (mut server, client) = duplex(1024);
let log = audit::new_audit_log();
write_upstream_failure(
&mut server,
Some(&log),
"slow.example.com",
443,
"connection timed out",
)
.await
.unwrap();
drop(server);
let response = read_to_string(client).await;
assert!(response.contains("connection timed out"));
let events = audit::drain_audit_events(&log);
assert_eq!(events.len(), 1);
assert_eq!(events[0].reason.as_deref(), Some("connection timed out"));
}
}