1use crate::audit;
8use crate::config::ExternalProxyConfig;
9use crate::error::{ProxyError, Result};
10use crate::filter::ProxyFilter;
11use crate::token;
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
13use tokio::net::TcpStream;
14use tracing::debug;
15use zeroize::Zeroizing;
16
17#[derive(Debug, Clone)]
23pub struct BypassMatcher {
24 exact: Vec<String>,
26 suffixes: Vec<String>,
28}
29
30impl BypassMatcher {
31 #[must_use]
40 pub fn new(hosts: &[String]) -> Self {
41 let mut exact = Vec::new();
42 let mut suffixes = Vec::new();
43
44 for host in hosts {
45 let lower = host.to_lowercase();
46 if let Some(suffix) = lower.strip_prefix("*.") {
47 if !suffix.is_empty() {
49 suffixes.push(format!(".{suffix}"));
50 }
51 } else {
53 exact.push(lower);
54 }
55 }
56
57 Self { exact, suffixes }
58 }
59
60 #[must_use]
62 pub fn matches(&self, host: &str) -> bool {
63 let lower = host.to_lowercase();
64
65 if self.exact.contains(&lower) {
67 return true;
68 }
69
70 for suffix in &self.suffixes {
72 if lower.ends_with(suffix.as_str()) && lower.len() > suffix.len() {
73 return true;
74 }
75 }
76
77 false
78 }
79
80 #[must_use]
82 pub fn is_empty(&self) -> bool {
83 self.exact.is_empty() && self.suffixes.is_empty()
84 }
85}
86
87pub async fn handle_external_proxy(
96 first_line: &str,
97 stream: &mut TcpStream,
98 remaining_header: &[u8],
99 filter: &ProxyFilter,
100 session_token: &Zeroizing<String>,
101 external_config: &ExternalProxyConfig,
102 audit_log: Option<&audit::SharedAuditLog>,
103) -> Result<()> {
104 let (host, port) = parse_connect_target(first_line)?;
106 debug!("External proxy CONNECT to {}:{}", host, port);
107
108 validate_proxy_auth(remaining_header, session_token)?;
110
111 let check = filter.check_host(&host, port).await?;
114 if !check.result.is_allowed() {
115 let reason = check.result.reason();
116 audit::log_denied(audit_log, audit::ProxyMode::External, &host, port, &reason);
117 send_response(stream, 403, &format!("Forbidden: {}", reason)).await?;
118 return Err(ProxyError::HostDenied { host, reason });
119 }
120
121 let mut proxy_stream = TcpStream::connect(&external_config.address)
123 .await
124 .map_err(|e| {
125 ProxyError::ExternalProxy(format!(
126 "cannot connect to external proxy {}: {}",
127 external_config.address, e
128 ))
129 })?;
130
131 let mut connect_req = format!(
133 "CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n",
134 host, port, host, port
135 );
136
137 if external_config.auth.is_some() {
141 return Err(ProxyError::ExternalProxy(
142 "external proxy authentication is configured but not yet implemented; \
143 remove the auth section from the external proxy config or wait for \
144 a future release"
145 .to_string(),
146 ));
147 }
148
149 connect_req.push_str("\r\n");
150 proxy_stream
151 .write_all(connect_req.as_bytes())
152 .await
153 .map_err(|e| {
154 ProxyError::ExternalProxy(format!("failed to send CONNECT to external proxy: {}", e))
155 })?;
156
157 let mut buf_reader = BufReader::new(&mut proxy_stream);
159 let mut response_line = String::new();
160 buf_reader
161 .read_line(&mut response_line)
162 .await
163 .map_err(|e| {
164 ProxyError::ExternalProxy(format!(
165 "failed to read response from external proxy: {}",
166 e
167 ))
168 })?;
169
170 let status = parse_status_code(&response_line)?;
172 if status != 200 {
173 audit::log_denied(
174 audit_log,
175 audit::ProxyMode::External,
176 &host,
177 port,
178 &format!("external proxy rejected with status {}", status),
179 );
180 send_response(
181 stream,
182 status,
183 &format!("Blocked by upstream proxy (status {})", status),
184 )
185 .await?;
186 return Err(ProxyError::ExternalProxy(format!(
187 "enterprise proxy rejected CONNECT to {}:{} with status {}",
188 host, port, status
189 )));
190 }
191
192 loop {
194 let mut line = String::new();
195 buf_reader.read_line(&mut line).await.map_err(|e| {
196 ProxyError::ExternalProxy(format!("failed to drain proxy response headers: {}", e))
197 })?;
198 if line.trim().is_empty() {
199 break;
200 }
201 }
202
203 let proxy_stream = buf_reader.into_inner();
205
206 send_response(stream, 200, "Connection Established").await?;
208 audit::log_allowed(
209 audit_log,
210 audit::ProxyMode::External,
211 &host,
212 port,
213 "CONNECT",
214 );
215
216 let result = tokio::io::copy_bidirectional(stream, proxy_stream).await;
218 debug!(
219 "External proxy tunnel closed for {}:{}: {:?}",
220 host, port, result
221 );
222
223 Ok(())
224}
225
226fn parse_connect_target(line: &str) -> Result<(String, u16)> {
228 let parts: Vec<&str> = line.split_whitespace().collect();
229 if parts.len() < 2 {
230 return Err(ProxyError::HttpParse(format!(
231 "malformed CONNECT line: {}",
232 line
233 )));
234 }
235
236 let authority = parts[1];
237 if let Some((host, port_str)) = authority.rsplit_once(':') {
238 let port = port_str.parse::<u16>().map_err(|_| {
239 ProxyError::HttpParse(format!("invalid port in CONNECT: {}", authority))
240 })?;
241 Ok((host.to_string(), port))
242 } else {
243 Ok((authority.to_string(), 443))
244 }
245}
246
247fn validate_proxy_auth(header_bytes: &[u8], session_token: &Zeroizing<String>) -> Result<()> {
252 token::validate_proxy_auth(header_bytes, session_token)
253}
254
255fn parse_status_code(line: &str) -> Result<u16> {
257 let parts: Vec<&str> = line.split_whitespace().collect();
258 if parts.len() < 2 {
259 return Err(ProxyError::HttpParse(format!(
260 "malformed HTTP response: {}",
261 line
262 )));
263 }
264 parts[1]
265 .parse::<u16>()
266 .map_err(|_| ProxyError::HttpParse(format!("invalid status code in response: {}", line)))
267}
268
269async fn send_response(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
271 let response = format!("HTTP/1.1 {} {}\r\n\r\n", status, reason);
272 stream.write_all(response.as_bytes()).await?;
273 stream.flush().await?;
274 Ok(())
275}
276
277#[cfg(test)]
278#[allow(clippy::unwrap_used)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_parse_connect_target() {
284 let (host, port) = parse_connect_target("CONNECT api.openai.com:443 HTTP/1.1").unwrap();
285 assert_eq!(host, "api.openai.com");
286 assert_eq!(port, 443);
287 }
288
289 #[test]
290 fn test_parse_status_code_200() {
291 assert_eq!(
292 parse_status_code("HTTP/1.1 200 Connection Established\r\n").unwrap(),
293 200
294 );
295 }
296
297 #[test]
298 fn test_parse_status_code_403() {
299 assert_eq!(
300 parse_status_code("HTTP/1.1 403 Forbidden\r\n").unwrap(),
301 403
302 );
303 }
304
305 #[test]
306 fn test_parse_status_code_malformed() {
307 assert!(parse_status_code("garbage").is_err());
308 }
309
310 #[test]
311 fn test_bypass_matcher_exact() {
312 let matcher = BypassMatcher::new(&["internal.corp".to_string()]);
313 assert!(matcher.matches("internal.corp"));
314 assert!(!matcher.matches("other.corp"));
315 }
316
317 #[test]
318 fn test_bypass_matcher_case_insensitive() {
319 let matcher = BypassMatcher::new(&["Internal.Corp".to_string()]);
320 assert!(matcher.matches("internal.corp"));
321 assert!(matcher.matches("INTERNAL.CORP"));
322 }
323
324 #[test]
325 fn test_bypass_matcher_wildcard() {
326 let matcher = BypassMatcher::new(&["*.internal.corp".to_string()]);
327 assert!(matcher.matches("app.internal.corp"));
328 assert!(matcher.matches("deep.sub.internal.corp"));
329 assert!(!matcher.matches("internal.corp"));
331 }
332
333 #[test]
334 fn test_bypass_matcher_wildcard_case_insensitive() {
335 let matcher = BypassMatcher::new(&["*.Internal.Corp".to_string()]);
336 assert!(matcher.matches("APP.INTERNAL.CORP"));
337 }
338
339 #[test]
340 fn test_bypass_matcher_no_match() {
341 let matcher =
342 BypassMatcher::new(&["internal.corp".to_string(), "*.private.net".to_string()]);
343 assert!(!matcher.matches("api.openai.com"));
344 assert!(!matcher.matches("evil.com"));
345 }
346
347 #[test]
348 fn test_bypass_matcher_empty() {
349 let matcher = BypassMatcher::new(&[]);
350 assert!(matcher.is_empty());
351 assert!(!matcher.matches("anything.com"));
352 }
353
354 #[test]
355 fn test_bypass_matcher_mixed() {
356 let matcher =
357 BypassMatcher::new(&["exact.host.com".to_string(), "*.wildcard.com".to_string()]);
358 assert!(matcher.matches("exact.host.com"));
359 assert!(matcher.matches("sub.wildcard.com"));
360 assert!(!matcher.matches("wildcard.com"));
361 assert!(!matcher.matches("other.com"));
362 }
363
364 #[test]
365 fn test_bypass_matcher_bare_star_is_not_wildcard() {
366 let matcher = BypassMatcher::new(&["*".to_string()]);
369 assert!(!matcher.matches("anything.com"));
370 assert!(!matcher.matches("internal.corp"));
371 }
372
373 #[test]
374 fn test_bypass_matcher_star_without_dot_is_literal() {
375 let matcher = BypassMatcher::new(&["*corp".to_string()]);
378 assert!(!matcher.matches("internal.corp"));
379 assert!(!matcher.matches("subcorp"));
380 assert!(matcher.matches("*corp"));
382 }
383
384 #[test]
385 fn test_bypass_matcher_star_dot_only_is_ignored() {
386 let matcher = BypassMatcher::new(&["*.".to_string()]);
388 assert!(matcher.is_empty());
389 assert!(!matcher.matches("anything.com"));
390 }
391}