Skip to main content

nono_proxy/
connect.rs

1//! HTTP CONNECT tunnel handler (Mode 1 — Host Filtering).
2//!
3//! Handles `CONNECT host:port HTTP/1.1` requests by:
4//! 1. Validating the session token
5//! 2. Checking the host against the filter (cloud metadata deny list, then allowlist)
6//! 3. Establishing a TCP connection to the upstream
7//! 4. Returning `200 Connection Established`
8//! 5. Relaying bytes bidirectionally (transparent TLS tunnel)
9//!
10//! The proxy never terminates TLS — it just passes encrypted bytes through.
11//! Streaming (SSE, MCP Streamable HTTP, A2A) works transparently.
12
13use crate::audit;
14use crate::error::{ProxyError, Result};
15use crate::filter::ProxyFilter;
16use crate::token;
17use std::net::SocketAddr;
18use std::time::Duration;
19use tokio::io::AsyncWriteExt;
20use tokio::net::TcpStream;
21use tracing::debug;
22use zeroize::Zeroizing;
23
24/// Timeout for upstream TCP connect.
25const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
26
27/// Handle an HTTP CONNECT request.
28///
29/// `first_line` is the already-read CONNECT line (e.g., "CONNECT api.openai.com:443 HTTP/1.1").
30/// `stream` is the raw TCP stream from the client.
31pub async fn handle_connect(
32    first_line: &str,
33    stream: &mut TcpStream,
34    filter: &ProxyFilter,
35    session_token: &Zeroizing<String>,
36    remaining_header: &[u8],
37    audit_log: Option<&audit::SharedAuditLog>,
38) -> Result<()> {
39    // Parse host:port from CONNECT line
40    let (host, port) = parse_connect_target(first_line)?;
41    debug!("CONNECT request to {}:{}", host, port);
42
43    // Validate session token from Proxy-Authorization header.
44    // Non-fatal for CONNECT: Node.js undici doesn't send Proxy-Authorization
45    // from URL userinfo for CONNECT requests.
46    if let Err(e) = validate_proxy_auth(remaining_header, session_token) {
47        debug!("CONNECT auth skipped: {}", e);
48    }
49
50    // Check host against filter (DNS resolution happens here)
51    let check = filter.check_host(&host, port).await?;
52    if !check.result.is_allowed() {
53        let reason = check.result.reason();
54        audit::log_denied(audit_log, audit::ProxyMode::Connect, &host, port, &reason);
55        send_response(stream, 403, &format!("Forbidden: {}", reason)).await?;
56        return Err(ProxyError::HostDenied { host, reason });
57    }
58
59    // Connect to the resolved IP directly — NOT re-resolving the hostname.
60    // This eliminates the DNS rebinding TOCTOU: the IPs were already checked
61    // against the link-local range in check_host() above.
62    let resolved = &check.resolved_addrs;
63    if resolved.is_empty() {
64        let reason = "DNS resolution returned no addresses".to_string();
65        audit::log_denied(audit_log, audit::ProxyMode::Connect, &host, port, &reason);
66        send_response(stream, 502, "DNS resolution failed").await?;
67        return Err(ProxyError::UpstreamConnect {
68            host: host.clone(),
69            reason,
70        });
71    }
72
73    let mut upstream = connect_to_resolved(resolved, &host).await?;
74
75    // Send 200 Connection Established
76    send_response(stream, 200, "Connection Established").await?;
77    audit::log_allowed(audit_log, audit::ProxyMode::Connect, &host, port, "CONNECT");
78
79    // Bidirectional relay
80    let result = tokio::io::copy_bidirectional(stream, &mut upstream).await;
81    debug!("CONNECT tunnel closed for {}:{}: {:?}", host, port, result);
82
83    Ok(())
84}
85
86/// Connect to one of the pre-resolved socket addresses with timeout.
87///
88/// Tries each address in order until one succeeds. This connects to the
89/// IP directly (not re-resolving the hostname), preventing DNS rebinding.
90async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
91    let mut last_err = None;
92    for addr in addrs {
93        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
94            Ok(Ok(stream)) => return Ok(stream),
95            Ok(Err(e)) => {
96                debug!("Connect to {} failed: {}", addr, e);
97                last_err = Some(e.to_string());
98            }
99            Err(_) => {
100                debug!("Connect to {} timed out", addr);
101                last_err = Some("connection timed out".to_string());
102            }
103        }
104    }
105    Err(ProxyError::UpstreamConnect {
106        host: host.to_string(),
107        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
108    })
109}
110
111/// Parse the target host and port from a CONNECT request line.
112///
113/// Expected format: "CONNECT host:port HTTP/1.1"
114fn parse_connect_target(line: &str) -> Result<(String, u16)> {
115    let parts: Vec<&str> = line.split_whitespace().collect();
116    if parts.len() < 2 || parts[0] != "CONNECT" {
117        return Err(ProxyError::HttpParse(format!(
118            "malformed CONNECT line: {}",
119            line
120        )));
121    }
122
123    let authority = parts[1];
124    if let Some((host, port_str)) = authority.rsplit_once(':') {
125        let port = port_str.parse::<u16>().map_err(|_| {
126            ProxyError::HttpParse(format!("invalid port in CONNECT: {}", authority))
127        })?;
128        Ok((host.to_string(), port))
129    } else {
130        // No port specified, default to 443 for CONNECT
131        Ok((authority.to_string(), 443))
132    }
133}
134
135/// Validate the Proxy-Authorization header against the session token.
136///
137/// Delegates to `token::validate_proxy_auth` which accepts both Bearer
138/// and Basic auth formats.
139fn validate_proxy_auth(header_bytes: &[u8], session_token: &Zeroizing<String>) -> Result<()> {
140    token::validate_proxy_auth(header_bytes, session_token)
141}
142
143/// Send an HTTP response line to the client.
144async fn send_response(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
145    let response = format!("HTTP/1.1 {} {}\r\n\r\n", status, reason);
146    stream.write_all(response.as_bytes()).await?;
147    stream.flush().await?;
148    Ok(())
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_parse_connect_with_port() {
158        let (host, port) = parse_connect_target("CONNECT api.openai.com:443 HTTP/1.1").unwrap();
159        assert_eq!(host, "api.openai.com");
160        assert_eq!(port, 443);
161    }
162
163    #[test]
164    fn test_parse_connect_without_port() {
165        let (host, port) = parse_connect_target("CONNECT example.com HTTP/1.1").unwrap();
166        assert_eq!(host, "example.com");
167        assert_eq!(port, 443);
168    }
169
170    #[test]
171    fn test_parse_connect_custom_port() {
172        let (host, port) = parse_connect_target("CONNECT internal:8443 HTTP/1.1").unwrap();
173        assert_eq!(host, "internal");
174        assert_eq!(port, 8443);
175    }
176
177    #[test]
178    fn test_parse_connect_malformed() {
179        assert!(parse_connect_target("GET /").is_err());
180        assert!(parse_connect_target("").is_err());
181    }
182
183    #[test]
184    fn test_validate_proxy_auth_valid() {
185        let token = Zeroizing::new("abc123".to_string());
186        let header = b"Proxy-Authorization: Bearer abc123\r\n\r\n";
187        assert!(validate_proxy_auth(header, &token).is_ok());
188    }
189
190    #[test]
191    fn test_validate_proxy_auth_invalid() {
192        let token = Zeroizing::new("abc123".to_string());
193        let header = b"Proxy-Authorization: Bearer wrong\r\n\r\n";
194        assert!(validate_proxy_auth(header, &token).is_err());
195    }
196
197    #[test]
198    fn test_validate_proxy_auth_missing() {
199        let token = Zeroizing::new("abc123".to_string());
200        let header = b"Host: example.com\r\n\r\n";
201        assert!(validate_proxy_auth(header, &token).is_err());
202    }
203}