1use crate::audit;
14use crate::error::{ProxyError, Result};
15use crate::filter::ProxyFilter;
16use crate::token;
17use std::net::SocketAddr;
18use std::time::Duration;
19use tokio::io::AsyncWriteExt;
20use tokio::net::TcpStream;
21use tracing::debug;
22use zeroize::Zeroizing;
23
24const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
26
27pub async fn handle_connect(
32 first_line: &str,
33 stream: &mut TcpStream,
34 filter: &ProxyFilter,
35 session_token: &Zeroizing<String>,
36 remaining_header: &[u8],
37 audit_log: Option<&audit::SharedAuditLog>,
38) -> Result<()> {
39 let (host, port) = parse_connect_target(first_line)?;
41 debug!("CONNECT request to {}:{}", host, port);
42
43 if let Err(e) = validate_proxy_auth(remaining_header, session_token) {
47 debug!("CONNECT auth skipped: {}", e);
48 }
49
50 let check = filter.check_host(&host, port).await?;
52 if !check.result.is_allowed() {
53 let reason = check.result.reason();
54 audit::log_denied(audit_log, audit::ProxyMode::Connect, &host, port, &reason);
55 send_response(stream, 403, &format!("Forbidden: {}", reason)).await?;
56 return Err(ProxyError::HostDenied { host, reason });
57 }
58
59 let resolved = &check.resolved_addrs;
63 if resolved.is_empty() {
64 let reason = "DNS resolution returned no addresses".to_string();
65 audit::log_denied(audit_log, audit::ProxyMode::Connect, &host, port, &reason);
66 send_response(stream, 502, "DNS resolution failed").await?;
67 return Err(ProxyError::UpstreamConnect {
68 host: host.clone(),
69 reason,
70 });
71 }
72
73 let mut upstream = connect_to_resolved(resolved, &host).await?;
74
75 send_response(stream, 200, "Connection Established").await?;
77 audit::log_allowed(audit_log, audit::ProxyMode::Connect, &host, port, "CONNECT");
78
79 let result = tokio::io::copy_bidirectional(stream, &mut upstream).await;
81 debug!("CONNECT tunnel closed for {}:{}: {:?}", host, port, result);
82
83 Ok(())
84}
85
86async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
91 let mut last_err = None;
92 for addr in addrs {
93 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
94 Ok(Ok(stream)) => return Ok(stream),
95 Ok(Err(e)) => {
96 debug!("Connect to {} failed: {}", addr, e);
97 last_err = Some(e.to_string());
98 }
99 Err(_) => {
100 debug!("Connect to {} timed out", addr);
101 last_err = Some("connection timed out".to_string());
102 }
103 }
104 }
105 Err(ProxyError::UpstreamConnect {
106 host: host.to_string(),
107 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
108 })
109}
110
111fn parse_connect_target(line: &str) -> Result<(String, u16)> {
115 let parts: Vec<&str> = line.split_whitespace().collect();
116 if parts.len() < 2 || parts[0] != "CONNECT" {
117 return Err(ProxyError::HttpParse(format!(
118 "malformed CONNECT line: {}",
119 line
120 )));
121 }
122
123 let authority = parts[1];
124 if let Some((host, port_str)) = authority.rsplit_once(':') {
125 let port = port_str.parse::<u16>().map_err(|_| {
126 ProxyError::HttpParse(format!("invalid port in CONNECT: {}", authority))
127 })?;
128 Ok((host.to_string(), port))
129 } else {
130 Ok((authority.to_string(), 443))
132 }
133}
134
135fn validate_proxy_auth(header_bytes: &[u8], session_token: &Zeroizing<String>) -> Result<()> {
140 token::validate_proxy_auth(header_bytes, session_token)
141}
142
143async fn send_response(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
145 let response = format!("HTTP/1.1 {} {}\r\n\r\n", status, reason);
146 stream.write_all(response.as_bytes()).await?;
147 stream.flush().await?;
148 Ok(())
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_parse_connect_with_port() {
158 let (host, port) = parse_connect_target("CONNECT api.openai.com:443 HTTP/1.1").unwrap();
159 assert_eq!(host, "api.openai.com");
160 assert_eq!(port, 443);
161 }
162
163 #[test]
164 fn test_parse_connect_without_port() {
165 let (host, port) = parse_connect_target("CONNECT example.com HTTP/1.1").unwrap();
166 assert_eq!(host, "example.com");
167 assert_eq!(port, 443);
168 }
169
170 #[test]
171 fn test_parse_connect_custom_port() {
172 let (host, port) = parse_connect_target("CONNECT internal:8443 HTTP/1.1").unwrap();
173 assert_eq!(host, "internal");
174 assert_eq!(port, 8443);
175 }
176
177 #[test]
178 fn test_parse_connect_malformed() {
179 assert!(parse_connect_target("GET /").is_err());
180 assert!(parse_connect_target("").is_err());
181 }
182
183 #[test]
184 fn test_validate_proxy_auth_valid() {
185 let token = Zeroizing::new("abc123".to_string());
186 let header = b"Proxy-Authorization: Bearer abc123\r\n\r\n";
187 assert!(validate_proxy_auth(header, &token).is_ok());
188 }
189
190 #[test]
191 fn test_validate_proxy_auth_invalid() {
192 let token = Zeroizing::new("abc123".to_string());
193 let header = b"Proxy-Authorization: Bearer wrong\r\n\r\n";
194 assert!(validate_proxy_auth(header, &token).is_err());
195 }
196
197 #[test]
198 fn test_validate_proxy_auth_missing() {
199 let token = Zeroizing::new("abc123".to_string());
200 let header = b"Host: example.com\r\n\r\n";
201 assert!(validate_proxy_auth(header, &token).is_err());
202 }
203}