1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct ReverseProxyCtx<'a> {
43 pub credential_store: &'a CredentialStore,
45 pub session_token: &'a Zeroizing<String>,
47 pub filter: &'a ProxyFilter,
49 pub tls_connector: &'a TlsConnector,
51 pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55pub async fn handle_reverse_proxy(
69 first_line: &str,
70 stream: &mut TcpStream,
71 remaining_header: &[u8],
72 ctx: &ReverseProxyCtx<'_>,
73 buffered_body: &[u8],
74) -> Result<()> {
75 let (method, path, version) = parse_request_line(first_line)?;
77 debug!("Reverse proxy: {} {}", method, path);
78
79 let (service, upstream_path) = parse_service_prefix(&path)?;
81
82 let cred = ctx
84 .credential_store
85 .get(&service)
86 .ok_or_else(|| ProxyError::UnknownService {
87 prefix: service.clone(),
88 })?;
89
90 if let Err(e) = validate_phantom_token_for_mode(
95 &cred.inject_mode,
96 remaining_header,
97 &upstream_path,
98 &cred.header_name,
99 cred.path_pattern.as_deref(),
100 cred.query_param_name.as_deref(),
101 ctx.session_token,
102 ) {
103 audit::log_denied(
104 ctx.audit_log,
105 audit::ProxyMode::Reverse,
106 &service,
107 0,
108 &e.to_string(),
109 );
110 send_error(stream, 401, "Unauthorized").await?;
111 return Ok(());
112 }
113
114 let transformed_path = transform_path_for_mode(
116 &cred.inject_mode,
117 &upstream_path,
118 cred.path_pattern.as_deref(),
119 cred.path_replacement.as_deref(),
120 cred.query_param_name.as_deref(),
121 &cred.raw_credential,
122 )?;
123
124 let upstream_url = format!(
126 "{}{}",
127 cred.upstream.trim_end_matches('/'),
128 transformed_path
129 );
130 debug!("Forwarding to upstream: {} {}", method, upstream_url);
131
132 let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
133
134 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
136 if !check.result.is_allowed() {
137 let reason = check.result.reason();
138 warn!("Upstream host denied by filter: {}", reason);
139 send_error(stream, 403, "Forbidden").await?;
140 audit::log_denied(
141 ctx.audit_log,
142 audit::ProxyMode::Reverse,
143 &service,
144 0,
145 &reason,
146 );
147 return Ok(());
148 }
149
150 let filtered_headers = filter_headers(remaining_header);
152 let content_length = extract_content_length(remaining_header);
153
154 let body = if let Some(len) = content_length {
158 if len > MAX_REQUEST_BODY {
159 send_error(stream, 413, "Payload Too Large").await?;
160 return Ok(());
161 }
162 let mut buf = Vec::with_capacity(len);
163 let pre = buffered_body.len().min(len);
164 buf.extend_from_slice(&buffered_body[..pre]);
165 let remaining = len - pre;
166 if remaining > 0 {
167 let mut rest = vec![0u8; remaining];
168 stream.read_exact(&mut rest).await?;
169 buf.extend_from_slice(&rest);
170 }
171 buf
172 } else {
173 Vec::new()
174 };
175
176 let upstream_result = connect_upstream_tls(
178 &upstream_host,
179 upstream_port,
180 &check.resolved_addrs,
181 ctx.tls_connector,
182 )
183 .await;
184 let mut tls_stream = match upstream_result {
185 Ok(s) => s,
186 Err(e) => {
187 warn!("Upstream connection failed: {}", e);
188 send_error(stream, 502, "Bad Gateway").await?;
189 audit::log_denied(
190 ctx.audit_log,
191 audit::ProxyMode::Reverse,
192 &service,
193 0,
194 &e.to_string(),
195 );
196 return Ok(());
197 }
198 };
199
200 let mut request = Zeroizing::new(format!(
204 "{} {} {}\r\nHost: {}\r\n",
205 method, upstream_path_full, version, upstream_host
206 ));
207
208 inject_credential_for_mode(cred, &mut request);
210
211 let auth_header_lower = cred.header_name.to_lowercase();
213 for (name, value) in &filtered_headers {
214 if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
217 && name.to_lowercase() == auth_header_lower
218 {
219 continue;
220 }
221 request.push_str(&format!("{}: {}\r\n", name, value));
222 }
223
224 if !body.is_empty() {
226 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
227 }
228 request.push_str("\r\n");
229
230 tls_stream.write_all(request.as_bytes()).await?;
231 if !body.is_empty() {
232 tls_stream.write_all(&body).await?;
233 }
234 tls_stream.flush().await?;
235
236 let mut response_buf = [0u8; 8192];
239 let mut status_code: u16 = 502;
240 let mut first_chunk = true;
241
242 loop {
243 let n = match tls_stream.read(&mut response_buf).await {
244 Ok(0) => break,
245 Ok(n) => n,
246 Err(e) => {
247 debug!("Upstream read error: {}", e);
248 break;
249 }
250 };
251
252 if first_chunk {
256 status_code = parse_response_status(&response_buf[..n]);
257 first_chunk = false;
258 }
259
260 stream.write_all(&response_buf[..n]).await?;
261 stream.flush().await?;
262 }
263
264 audit::log_reverse_proxy(
265 ctx.audit_log,
266 &service,
267 &method,
268 &upstream_path,
269 status_code,
270 );
271 Ok(())
272}
273
274fn parse_request_line(line: &str) -> Result<(String, String, String)> {
276 let parts: Vec<&str> = line.split_whitespace().collect();
277 if parts.len() < 3 {
278 return Err(ProxyError::HttpParse(format!(
279 "malformed request line: {}",
280 line
281 )));
282 }
283 Ok((
284 parts[0].to_string(),
285 parts[1].to_string(),
286 parts[2].to_string(),
287 ))
288}
289
290fn parse_service_prefix(path: &str) -> Result<(String, String)> {
295 let trimmed = path.strip_prefix('/').unwrap_or(path);
296 if let Some((prefix, rest)) = trimmed.split_once('/') {
297 Ok((prefix.to_string(), format!("/{}", rest)))
298 } else {
299 Ok((trimmed.to_string(), "/".to_string()))
301 }
302}
303
304fn validate_phantom_token(
311 header_bytes: &[u8],
312 header_name: &str,
313 session_token: &Zeroizing<String>,
314) -> Result<()> {
315 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
316 let header_name_lower = header_name.to_lowercase();
317
318 for line in header_str.lines() {
319 let lower = line.to_lowercase();
320 if lower.starts_with(&format!("{}:", header_name_lower)) {
321 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
322
323 let value_lower = value.to_lowercase();
326 let token_value = if value_lower.starts_with("bearer ") {
327 value[7..].trim()
329 } else {
330 value
331 };
332
333 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
334 return Ok(());
335 }
336 warn!("Invalid phantom token in {} header", header_name);
337 return Err(ProxyError::InvalidToken);
338 }
339 }
340
341 warn!(
342 "Missing {} header for phantom token validation",
343 header_name
344 );
345 Err(ProxyError::InvalidToken)
346}
347
348fn filter_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
354 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
355 let mut headers = Vec::new();
356
357 for line in header_str.lines() {
358 let lower = line.to_lowercase();
359 if lower.starts_with("host:")
360 || lower.starts_with("content-length:")
361 || lower.starts_with("authorization:")
362 || lower.starts_with("x-api-key:")
363 || lower.starts_with("x-goog-api-key:")
364 || line.trim().is_empty()
365 {
366 continue;
367 }
368 if let Some((name, value)) = line.split_once(':') {
369 headers.push((name.trim().to_string(), value.trim().to_string()));
370 }
371 }
372
373 headers
374}
375
376fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
378 let header_str = std::str::from_utf8(header_bytes).ok()?;
379 for line in header_str.lines() {
380 if line.to_lowercase().starts_with("content-length:") {
381 let value = line.split_once(':')?.1.trim();
382 return value.parse().ok();
383 }
384 }
385 None
386}
387
388fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
390 let parsed = url::Url::parse(url_str)
391 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
392
393 let scheme = parsed.scheme();
394 if scheme != "https" && scheme != "http" {
395 return Err(ProxyError::HttpParse(format!(
396 "unsupported URL scheme: {}",
397 url_str
398 )));
399 }
400
401 let host = parsed
402 .host_str()
403 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
404 .to_string();
405
406 let default_port = if scheme == "https" { 443 } else { 80 };
407 let port = parsed.port().unwrap_or(default_port);
408
409 let path = parsed.path().to_string();
410 let path = if path.is_empty() {
411 "/".to_string()
412 } else {
413 path
414 };
415
416 let path_with_query = if let Some(query) = parsed.query() {
418 format!("{}?{}", path, query)
419 } else {
420 path
421 };
422
423 Ok((host, port, path_with_query))
424}
425
426async fn connect_upstream_tls(
435 host: &str,
436 port: u16,
437 resolved_addrs: &[SocketAddr],
438 connector: &TlsConnector,
439) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
440 let tcp = if resolved_addrs.is_empty() {
441 let addr = format!("{}:{}", host, port);
443 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
444 Ok(Ok(s)) => s,
445 Ok(Err(e)) => {
446 return Err(ProxyError::UpstreamConnect {
447 host: host.to_string(),
448 reason: e.to_string(),
449 });
450 }
451 Err(_) => {
452 return Err(ProxyError::UpstreamConnect {
453 host: host.to_string(),
454 reason: "connection timed out".to_string(),
455 });
456 }
457 }
458 } else {
459 connect_to_resolved(resolved_addrs, host).await?
460 };
461
462 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
463 ProxyError::UpstreamConnect {
464 host: host.to_string(),
465 reason: "invalid server name for TLS".to_string(),
466 }
467 })?;
468
469 let tls_stream =
470 connector
471 .connect(server_name, tcp)
472 .await
473 .map_err(|e| ProxyError::UpstreamConnect {
474 host: host.to_string(),
475 reason: format!("TLS handshake failed: {}", e),
476 })?;
477
478 Ok(tls_stream)
479}
480
481async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
483 let mut last_err = None;
484 for addr in addrs {
485 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
486 Ok(Ok(stream)) => return Ok(stream),
487 Ok(Err(e)) => {
488 debug!("Connect to {} failed: {}", addr, e);
489 last_err = Some(e.to_string());
490 }
491 Err(_) => {
492 debug!("Connect to {} timed out", addr);
493 last_err = Some("connection timed out".to_string());
494 }
495 }
496 }
497 Err(ProxyError::UpstreamConnect {
498 host: host.to_string(),
499 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
500 })
501}
502
503fn parse_response_status(data: &[u8]) -> u16 {
509 let line_end = data
511 .iter()
512 .position(|&b| b == b'\r' || b == b'\n')
513 .unwrap_or(data.len());
514 let first_line = &data[..line_end.min(64)];
515
516 if let Ok(line) = std::str::from_utf8(first_line) {
517 let mut parts = line.split_whitespace();
519 if let Some(version) = parts.next() {
520 if version.starts_with("HTTP/") {
521 if let Some(code_str) = parts.next() {
522 if code_str.len() == 3 {
523 return code_str.parse().unwrap_or(502);
524 }
525 }
526 }
527 }
528 }
529 502
530}
531
532async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
534 let body = format!("{{\"error\":\"{}\"}}", reason);
535 let response = format!(
536 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
537 status,
538 reason,
539 body.len(),
540 body
541 );
542 stream.write_all(response.as_bytes()).await?;
543 stream.flush().await?;
544 Ok(())
545}
546
547fn validate_phantom_token_for_mode(
558 mode: &InjectMode,
559 header_bytes: &[u8],
560 path: &str,
561 header_name: &str,
562 path_pattern: Option<&str>,
563 query_param_name: Option<&str>,
564 session_token: &Zeroizing<String>,
565) -> Result<()> {
566 match mode {
567 InjectMode::Header | InjectMode::BasicAuth => {
568 validate_phantom_token(header_bytes, header_name, session_token)
570 }
571 InjectMode::UrlPath => {
572 let pattern = path_pattern.ok_or_else(|| {
574 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
575 })?;
576 validate_phantom_token_in_path(path, pattern, session_token)
577 }
578 InjectMode::QueryParam => {
579 let param_name = query_param_name.ok_or_else(|| {
581 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
582 })?;
583 validate_phantom_token_in_query(path, param_name, session_token)
584 }
585 }
586}
587
588fn validate_phantom_token_in_path(
593 path: &str,
594 pattern: &str,
595 session_token: &Zeroizing<String>,
596) -> Result<()> {
597 let parts: Vec<&str> = pattern.split("{}").collect();
599 if parts.len() != 2 {
600 return Err(ProxyError::HttpParse(format!(
601 "invalid path_pattern '{}': must contain exactly one {{}}",
602 pattern
603 )));
604 }
605 let (prefix, suffix) = (parts[0], parts[1]);
606
607 if let Some(start) = path.find(prefix) {
609 let after_prefix = start + prefix.len();
610
611 let end_offset = if suffix.is_empty() {
613 path[after_prefix..]
614 .find(['/', '?'])
615 .unwrap_or(path[after_prefix..].len())
616 } else {
617 match path[after_prefix..].find(suffix) {
618 Some(offset) => offset,
619 None => {
620 warn!("Missing phantom token in URL path (pattern: {})", pattern);
621 return Err(ProxyError::InvalidToken);
622 }
623 }
624 };
625
626 let token = &path[after_prefix..after_prefix + end_offset];
627 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
628 return Ok(());
629 }
630 warn!("Invalid phantom token in URL path");
631 return Err(ProxyError::InvalidToken);
632 }
633
634 warn!("Missing phantom token in URL path (pattern: {})", pattern);
635 Err(ProxyError::InvalidToken)
636}
637
638fn validate_phantom_token_in_query(
640 path: &str,
641 param_name: &str,
642 session_token: &Zeroizing<String>,
643) -> Result<()> {
644 if let Some(query_start) = path.find('?') {
646 let query = &path[query_start + 1..];
647 for pair in query.split('&') {
648 if let Some((name, value)) = pair.split_once('=') {
649 if name == param_name {
650 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
652 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
653 return Ok(());
654 }
655 warn!("Invalid phantom token in query parameter '{}'", param_name);
656 return Err(ProxyError::InvalidToken);
657 }
658 }
659 }
660 }
661
662 warn!("Missing phantom token in query parameter '{}'", param_name);
663 Err(ProxyError::InvalidToken)
664}
665
666fn transform_path_for_mode(
672 mode: &InjectMode,
673 path: &str,
674 path_pattern: Option<&str>,
675 path_replacement: Option<&str>,
676 query_param_name: Option<&str>,
677 credential: &Zeroizing<String>,
678) -> Result<String> {
679 match mode {
680 InjectMode::Header | InjectMode::BasicAuth => {
681 Ok(path.to_string())
683 }
684 InjectMode::UrlPath => {
685 let pattern = path_pattern.ok_or_else(|| {
686 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
687 })?;
688 let replacement = path_replacement.unwrap_or(pattern);
689 transform_url_path(path, pattern, replacement, credential)
690 }
691 InjectMode::QueryParam => {
692 let param_name = query_param_name.ok_or_else(|| {
693 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
694 })?;
695 transform_query_param(path, param_name, credential)
696 }
697 }
698}
699
700fn transform_url_path(
704 path: &str,
705 pattern: &str,
706 replacement: &str,
707 credential: &Zeroizing<String>,
708) -> Result<String> {
709 let parts: Vec<&str> = pattern.split("{}").collect();
711 if parts.len() != 2 {
712 return Err(ProxyError::HttpParse(format!(
713 "invalid path_pattern '{}': must contain exactly one {{}}",
714 pattern
715 )));
716 }
717 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
718
719 let repl_parts: Vec<&str> = replacement.split("{}").collect();
721 if repl_parts.len() != 2 {
722 return Err(ProxyError::HttpParse(format!(
723 "invalid path_replacement '{}': must contain exactly one {{}}",
724 replacement
725 )));
726 }
727 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
728
729 if let Some(start) = path.find(pattern_prefix) {
731 let after_prefix = start + pattern_prefix.len();
732
733 let end_offset = if pattern_suffix.is_empty() {
735 path[after_prefix..]
737 .find(['/', '?'])
738 .unwrap_or(path[after_prefix..].len())
739 } else {
740 match path[after_prefix..].find(pattern_suffix) {
742 Some(offset) => offset,
743 None => {
744 return Err(ProxyError::HttpParse(format!(
745 "path '{}' does not match pattern '{}'",
746 path, pattern
747 )));
748 }
749 }
750 };
751
752 let before = &path[..start];
753 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
754 return Ok(format!(
755 "{}{}{}{}{}",
756 before,
757 repl_prefix,
758 credential.as_str(),
759 repl_suffix,
760 after
761 ));
762 }
763
764 Err(ProxyError::HttpParse(format!(
765 "path '{}' does not match pattern '{}'",
766 path, pattern
767 )))
768}
769
770fn transform_query_param(
772 path: &str,
773 param_name: &str,
774 credential: &Zeroizing<String>,
775) -> Result<String> {
776 let encoded_value = urlencoding::encode(credential.as_str());
777
778 if let Some(query_start) = path.find('?') {
779 let base_path = &path[..query_start];
780 let query = &path[query_start + 1..];
781
782 let mut found = false;
784 let new_query: Vec<String> = query
785 .split('&')
786 .map(|pair| {
787 if let Some((name, _)) = pair.split_once('=') {
788 if name == param_name {
789 found = true;
790 return format!("{}={}", param_name, encoded_value);
791 }
792 }
793 pair.to_string()
794 })
795 .collect();
796
797 if found {
798 Ok(format!("{}?{}", base_path, new_query.join("&")))
799 } else {
800 Ok(format!(
802 "{}?{}&{}={}",
803 base_path, query, param_name, encoded_value
804 ))
805 }
806 } else {
807 Ok(format!("{}?{}={}", path, param_name, encoded_value))
809 }
810}
811
812fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
817 match cred.inject_mode {
818 InjectMode::Header | InjectMode::BasicAuth => {
819 request.push_str(&format!(
821 "{}: {}\r\n",
822 cred.header_name,
823 cred.header_value.as_str()
824 ));
825 }
826 InjectMode::UrlPath | InjectMode::QueryParam => {
827 }
830 }
831}
832
833#[cfg(test)]
834#[allow(clippy::unwrap_used)]
835mod tests {
836 use super::*;
837
838 #[test]
839 fn test_parse_request_line() {
840 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
841 assert_eq!(method, "POST");
842 assert_eq!(path, "/openai/v1/chat");
843 assert_eq!(version, "HTTP/1.1");
844 }
845
846 #[test]
847 fn test_parse_request_line_malformed() {
848 assert!(parse_request_line("GET").is_err());
849 }
850
851 #[test]
852 fn test_parse_service_prefix() {
853 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
854 assert_eq!(service, "openai");
855 assert_eq!(path, "/v1/chat/completions");
856 }
857
858 #[test]
859 fn test_parse_service_prefix_no_subpath() {
860 let (service, path) = parse_service_prefix("/anthropic").unwrap();
861 assert_eq!(service, "anthropic");
862 assert_eq!(path, "/");
863 }
864
865 #[test]
866 fn test_validate_phantom_token_bearer_valid() {
867 let token = Zeroizing::new("secret123".to_string());
868 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
869 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
870 }
871
872 #[test]
873 fn test_validate_phantom_token_bearer_invalid() {
874 let token = Zeroizing::new("secret123".to_string());
875 let header = b"Authorization: Bearer wrong\r\n\r\n";
876 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
877 }
878
879 #[test]
880 fn test_validate_phantom_token_x_api_key_valid() {
881 let token = Zeroizing::new("secret123".to_string());
882 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
883 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
884 }
885
886 #[test]
887 fn test_validate_phantom_token_x_goog_api_key_valid() {
888 let token = Zeroizing::new("secret123".to_string());
889 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
890 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
891 }
892
893 #[test]
894 fn test_validate_phantom_token_missing() {
895 let token = Zeroizing::new("secret123".to_string());
896 let header = b"Content-Type: application/json\r\n\r\n";
897 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
898 }
899
900 #[test]
901 fn test_validate_phantom_token_case_insensitive_header() {
902 let token = Zeroizing::new("secret123".to_string());
903 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
904 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
905 }
906
907 #[test]
908 fn test_filter_headers_removes_host_auth() {
909 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
910 let filtered = filter_headers(header);
911 assert_eq!(filtered.len(), 2);
912 assert_eq!(filtered[0].0, "Content-Type");
913 assert_eq!(filtered[1].0, "Accept");
914 }
915
916 #[test]
917 fn test_filter_headers_removes_x_api_key() {
918 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
919 let filtered = filter_headers(header);
920 assert_eq!(filtered.len(), 1);
921 assert_eq!(filtered[0].0, "Content-Type");
922 }
923
924 #[test]
925 fn test_filter_headers_removes_x_goog_api_key() {
926 let header = b"x-goog-api-key: gemini-key\r\nContent-Type: application/json\r\n\r\n";
927 let filtered = filter_headers(header);
928 assert_eq!(filtered.len(), 1);
929 assert_eq!(filtered[0].0, "Content-Type");
930 }
931
932 #[test]
933 fn test_extract_content_length() {
934 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
935 assert_eq!(extract_content_length(header), Some(42));
936 }
937
938 #[test]
939 fn test_extract_content_length_missing() {
940 let header = b"Content-Type: application/json\r\n\r\n";
941 assert_eq!(extract_content_length(header), None);
942 }
943
944 #[test]
945 fn test_parse_upstream_url_https() {
946 let (host, port, path) =
947 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
948 assert_eq!(host, "api.openai.com");
949 assert_eq!(port, 443);
950 assert_eq!(path, "/v1/chat/completions");
951 }
952
953 #[test]
954 fn test_parse_upstream_url_http_with_port() {
955 let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
956 assert_eq!(host, "localhost");
957 assert_eq!(port, 8080);
958 assert_eq!(path, "/api");
959 }
960
961 #[test]
962 fn test_parse_upstream_url_no_path() {
963 let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
964 assert_eq!(host, "api.anthropic.com");
965 assert_eq!(port, 443);
966 assert_eq!(path, "/");
967 }
968
969 #[test]
970 fn test_parse_upstream_url_invalid_scheme() {
971 assert!(parse_upstream_url("ftp://example.com").is_err());
972 }
973
974 #[test]
975 fn test_parse_response_status_200() {
976 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
977 assert_eq!(parse_response_status(data), 200);
978 }
979
980 #[test]
981 fn test_parse_response_status_404() {
982 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
983 assert_eq!(parse_response_status(data), 404);
984 }
985
986 #[test]
987 fn test_parse_response_status_garbage() {
988 let data = b"not an http response";
989 assert_eq!(parse_response_status(data), 502);
990 }
991
992 #[test]
993 fn test_parse_response_status_empty() {
994 assert_eq!(parse_response_status(b""), 502);
995 }
996
997 #[test]
998 fn test_parse_response_status_partial() {
999 let data = b"HTTP/1.1 ";
1000 assert_eq!(parse_response_status(data), 502);
1001 }
1002
1003 #[test]
1008 fn test_validate_phantom_token_in_path_valid() {
1009 let token = Zeroizing::new("session123".to_string());
1010 let path = "/bot/session123/getMe";
1011 let pattern = "/bot/{}/";
1012 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1013 }
1014
1015 #[test]
1016 fn test_validate_phantom_token_in_path_invalid() {
1017 let token = Zeroizing::new("session123".to_string());
1018 let path = "/bot/wrong_token/getMe";
1019 let pattern = "/bot/{}/";
1020 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1021 }
1022
1023 #[test]
1024 fn test_validate_phantom_token_in_path_missing() {
1025 let token = Zeroizing::new("session123".to_string());
1026 let path = "/api/getMe";
1027 let pattern = "/bot/{}/";
1028 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1029 }
1030
1031 #[test]
1032 fn test_transform_url_path_basic() {
1033 let credential = Zeroizing::new("real_token".to_string());
1034 let path = "/bot/phantom_token/getMe";
1035 let pattern = "/bot/{}/";
1036 let replacement = "/bot/{}/";
1037 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1038 assert_eq!(result, "/bot/real_token/getMe");
1039 }
1040
1041 #[test]
1042 fn test_transform_url_path_different_replacement() {
1043 let credential = Zeroizing::new("real_token".to_string());
1044 let path = "/api/v1/phantom_token/chat";
1045 let pattern = "/api/v1/{}/";
1046 let replacement = "/v2/bot/{}/";
1047 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1048 assert_eq!(result, "/v2/bot/real_token/chat");
1049 }
1050
1051 #[test]
1052 fn test_transform_url_path_no_trailing_slash() {
1053 let credential = Zeroizing::new("real_token".to_string());
1054 let path = "/bot/phantom_token";
1055 let pattern = "/bot/{}";
1056 let replacement = "/bot/{}";
1057 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1058 assert_eq!(result, "/bot/real_token");
1059 }
1060
1061 #[test]
1066 fn test_validate_phantom_token_in_query_valid() {
1067 let token = Zeroizing::new("session123".to_string());
1068 let path = "/api/data?api_key=session123&other=value";
1069 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1070 }
1071
1072 #[test]
1073 fn test_validate_phantom_token_in_query_invalid() {
1074 let token = Zeroizing::new("session123".to_string());
1075 let path = "/api/data?api_key=wrong_token";
1076 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1077 }
1078
1079 #[test]
1080 fn test_validate_phantom_token_in_query_missing_param() {
1081 let token = Zeroizing::new("session123".to_string());
1082 let path = "/api/data?other=value";
1083 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1084 }
1085
1086 #[test]
1087 fn test_validate_phantom_token_in_query_no_query_string() {
1088 let token = Zeroizing::new("session123".to_string());
1089 let path = "/api/data";
1090 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1091 }
1092
1093 #[test]
1094 fn test_validate_phantom_token_in_query_url_encoded() {
1095 let token = Zeroizing::new("token with spaces".to_string());
1096 let path = "/api/data?api_key=token%20with%20spaces";
1097 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1098 }
1099
1100 #[test]
1101 fn test_transform_query_param_add_to_no_query() {
1102 let credential = Zeroizing::new("real_key".to_string());
1103 let path = "/api/data";
1104 let result = transform_query_param(path, "api_key", &credential).unwrap();
1105 assert_eq!(result, "/api/data?api_key=real_key");
1106 }
1107
1108 #[test]
1109 fn test_transform_query_param_add_to_existing_query() {
1110 let credential = Zeroizing::new("real_key".to_string());
1111 let path = "/api/data?other=value";
1112 let result = transform_query_param(path, "api_key", &credential).unwrap();
1113 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1114 }
1115
1116 #[test]
1117 fn test_transform_query_param_replace_existing() {
1118 let credential = Zeroizing::new("real_key".to_string());
1119 let path = "/api/data?api_key=phantom&other=value";
1120 let result = transform_query_param(path, "api_key", &credential).unwrap();
1121 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1122 }
1123
1124 #[test]
1125 fn test_transform_query_param_url_encodes_special_chars() {
1126 let credential = Zeroizing::new("key with spaces".to_string());
1127 let path = "/api/data";
1128 let result = transform_query_param(path, "api_key", &credential).unwrap();
1129 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1130 }
1131}