1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::route::RouteStore;
23use crate::token;
24use std::net::SocketAddr;
25use std::time::Duration;
26use tokio::io::{AsyncReadExt, AsyncWriteExt};
27use tokio::net::TcpStream;
28use tokio_rustls::TlsConnector;
29use tracing::{debug, warn};
30use zeroize::Zeroizing;
31
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
34
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
37
38pub struct ReverseProxyCtx<'a> {
44 pub route_store: &'a RouteStore,
46 pub credential_store: &'a CredentialStore,
48 pub session_token: &'a Zeroizing<String>,
50 pub filter: &'a ProxyFilter,
52 pub tls_connector: &'a TlsConnector,
54 pub audit_log: Option<&'a audit::SharedAuditLog>,
56}
57
58pub async fn handle_reverse_proxy(
72 first_line: &str,
73 stream: &mut TcpStream,
74 remaining_header: &[u8],
75 ctx: &ReverseProxyCtx<'_>,
76 buffered_body: &[u8],
77) -> Result<()> {
78 let (method, path, version) = parse_request_line(first_line)?;
80 debug!("Reverse proxy: {} {}", method, path);
81
82 let (service, upstream_path) = parse_service_prefix(&path)?;
84
85 let route = ctx
87 .route_store
88 .get(&service)
89 .ok_or_else(|| ProxyError::UnknownService {
90 prefix: service.clone(),
91 })?;
92
93 if !route.endpoint_rules.is_allowed(&method, &upstream_path) {
97 let reason = format!(
98 "endpoint denied: {} {} on service '{}'",
99 method, upstream_path, service
100 );
101 warn!("{}", reason);
102 audit::log_denied(
103 ctx.audit_log,
104 audit::ProxyMode::Reverse,
105 &service,
106 0,
107 &reason,
108 );
109 send_error(stream, 403, "Forbidden").await?;
110 return Ok(());
111 }
112
113 let cred = ctx.credential_store.get(&service);
115
116 if let Some(cred) = cred {
120 if let Err(e) = validate_phantom_token_for_mode(
123 &cred.inject_mode,
124 remaining_header,
125 &upstream_path,
126 &cred.header_name,
127 cred.path_pattern.as_deref(),
128 cred.query_param_name.as_deref(),
129 ctx.session_token,
130 ) {
131 audit::log_denied(
132 ctx.audit_log,
133 audit::ProxyMode::Reverse,
134 &service,
135 0,
136 &e.to_string(),
137 );
138 send_error(stream, 401, "Unauthorized").await?;
139 return Ok(());
140 }
141 } else {
142 if let Err(e) = token::validate_proxy_auth(remaining_header, ctx.session_token) {
146 audit::log_denied(
147 ctx.audit_log,
148 audit::ProxyMode::Reverse,
149 &service,
150 0,
151 &e.to_string(),
152 );
153 send_error(stream, 407, "Proxy Authentication Required").await?;
154 return Ok(());
155 }
156 }
157
158 let transformed_path = if let Some(cred) = cred {
161 transform_path_for_mode(
162 &cred.inject_mode,
163 &upstream_path,
164 cred.path_pattern.as_deref(),
165 cred.path_replacement.as_deref(),
166 cred.query_param_name.as_deref(),
167 &cred.raw_credential,
168 )?
169 } else {
170 upstream_path.clone()
171 };
172
173 let upstream_url = format!(
176 "{}{}",
177 route.upstream.trim_end_matches('/'),
178 transformed_path
179 );
180 debug!("Forwarding to upstream: {} {}", method, upstream_url);
181
182 let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
183
184 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
186 if !check.result.is_allowed() {
187 let reason = check.result.reason();
188 warn!("Upstream host denied by filter: {}", reason);
189 send_error(stream, 403, "Forbidden").await?;
190 audit::log_denied(
191 ctx.audit_log,
192 audit::ProxyMode::Reverse,
193 &service,
194 0,
195 &reason,
196 );
197 return Ok(());
198 }
199
200 let strip_header = cred.map(|c| c.header_name.as_str()).unwrap_or("");
207 let filtered_headers = filter_headers(remaining_header, strip_header);
208 let content_length = extract_content_length(remaining_header);
209
210 let body = if let Some(len) = content_length {
214 if len > MAX_REQUEST_BODY {
215 send_error(stream, 413, "Payload Too Large").await?;
216 return Ok(());
217 }
218 let mut buf = Vec::with_capacity(len);
219 let pre = buffered_body.len().min(len);
220 buf.extend_from_slice(&buffered_body[..pre]);
221 let remaining = len - pre;
222 if remaining > 0 {
223 let mut rest = vec![0u8; remaining];
224 stream.read_exact(&mut rest).await?;
225 buf.extend_from_slice(&rest);
226 }
227 buf
228 } else {
229 Vec::new()
230 };
231
232 let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
236 let upstream_result = connect_upstream_tls(
237 &upstream_host,
238 upstream_port,
239 &check.resolved_addrs,
240 connector,
241 )
242 .await;
243 let mut tls_stream = match upstream_result {
244 Ok(s) => s,
245 Err(e) => {
246 warn!("Upstream connection failed: {}", e);
247 send_error(stream, 502, "Bad Gateway").await?;
248 audit::log_denied(
249 ctx.audit_log,
250 audit::ProxyMode::Reverse,
251 &service,
252 0,
253 &e.to_string(),
254 );
255 return Ok(());
256 }
257 };
258
259 let mut request = Zeroizing::new(format!(
263 "{} {} {}\r\nHost: {}\r\n",
264 method, upstream_path_full, version, upstream_host
265 ));
266
267 if let Some(cred) = cred {
269 inject_credential_for_mode(cred, &mut request);
270 }
271
272 for (name, value) in &filtered_headers {
276 request.push_str(&format!("{}: {}\r\n", name, value));
277 }
278
279 if !body.is_empty() {
281 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
282 }
283 request.push_str("\r\n");
284
285 tls_stream.write_all(request.as_bytes()).await?;
286 if !body.is_empty() {
287 tls_stream.write_all(&body).await?;
288 }
289 tls_stream.flush().await?;
290
291 let mut response_buf = [0u8; 8192];
294 let mut status_code: u16 = 502;
295 let mut first_chunk = true;
296
297 loop {
298 let n = match tls_stream.read(&mut response_buf).await {
299 Ok(0) => break,
300 Ok(n) => n,
301 Err(e) => {
302 debug!("Upstream read error: {}", e);
303 break;
304 }
305 };
306
307 if first_chunk {
311 status_code = parse_response_status(&response_buf[..n]);
312 first_chunk = false;
313 }
314
315 stream.write_all(&response_buf[..n]).await?;
316 stream.flush().await?;
317 }
318
319 audit::log_reverse_proxy(
320 ctx.audit_log,
321 &service,
322 &method,
323 &upstream_path,
324 status_code,
325 );
326 Ok(())
327}
328
329fn parse_request_line(line: &str) -> Result<(String, String, String)> {
331 let parts: Vec<&str> = line.split_whitespace().collect();
332 if parts.len() < 3 {
333 return Err(ProxyError::HttpParse(format!(
334 "malformed request line: {}",
335 line
336 )));
337 }
338 Ok((
339 parts[0].to_string(),
340 parts[1].to_string(),
341 parts[2].to_string(),
342 ))
343}
344
345fn parse_service_prefix(path: &str) -> Result<(String, String)> {
350 let trimmed = path.strip_prefix('/').unwrap_or(path);
351 if let Some((prefix, rest)) = trimmed.split_once('/') {
352 Ok((prefix.to_string(), format!("/{}", rest)))
353 } else {
354 Ok((trimmed.to_string(), "/".to_string()))
356 }
357}
358
359fn validate_phantom_token(
366 header_bytes: &[u8],
367 header_name: &str,
368 session_token: &Zeroizing<String>,
369) -> Result<()> {
370 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
371 let header_name_lower = header_name.to_lowercase();
372
373 for line in header_str.lines() {
374 let lower = line.to_lowercase();
375 if lower.starts_with(&format!("{}:", header_name_lower)) {
376 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
377
378 let value_lower = value.to_lowercase();
381 let token_value = if value_lower.starts_with("bearer ") {
382 value[7..].trim()
384 } else {
385 value
386 };
387
388 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
389 return Ok(());
390 }
391 warn!("Invalid phantom token in {} header", header_name);
392 return Err(ProxyError::InvalidToken);
393 }
394 }
395
396 warn!(
397 "Missing {} header for phantom token validation",
398 header_name
399 );
400 Err(ProxyError::InvalidToken)
401}
402
403fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
415 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
416 let cred_header_lower = if cred_header.is_empty() {
417 String::new()
418 } else {
419 format!("{}:", cred_header.to_lowercase())
420 };
421 let mut headers = Vec::new();
422
423 for line in header_str.lines() {
424 let lower = line.to_lowercase();
425 if lower.starts_with("host:")
426 || lower.starts_with("content-length:")
427 || lower.starts_with("proxy-authorization:")
428 || (!cred_header_lower.is_empty() && lower.starts_with(&cred_header_lower))
429 || line.trim().is_empty()
430 {
431 continue;
432 }
433 if let Some((name, value)) = line.split_once(':') {
434 headers.push((name.trim().to_string(), value.trim().to_string()));
435 }
436 }
437
438 headers
439}
440
441fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
443 let header_str = std::str::from_utf8(header_bytes).ok()?;
444 for line in header_str.lines() {
445 if line.to_lowercase().starts_with("content-length:") {
446 let value = line.split_once(':')?.1.trim();
447 return value.parse().ok();
448 }
449 }
450 None
451}
452
453fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
455 let parsed = url::Url::parse(url_str)
456 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
457
458 let scheme = parsed.scheme();
459 if scheme != "https" && scheme != "http" {
460 return Err(ProxyError::HttpParse(format!(
461 "unsupported URL scheme: {}",
462 url_str
463 )));
464 }
465
466 let host = parsed
467 .host_str()
468 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
469 .to_string();
470
471 let default_port = if scheme == "https" { 443 } else { 80 };
472 let port = parsed.port().unwrap_or(default_port);
473
474 let path = parsed.path().to_string();
475 let path = if path.is_empty() {
476 "/".to_string()
477 } else {
478 path
479 };
480
481 let path_with_query = if let Some(query) = parsed.query() {
483 format!("{}?{}", path, query)
484 } else {
485 path
486 };
487
488 Ok((host, port, path_with_query))
489}
490
491async fn connect_upstream_tls(
500 host: &str,
501 port: u16,
502 resolved_addrs: &[SocketAddr],
503 connector: &TlsConnector,
504) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
505 let tcp = if resolved_addrs.is_empty() {
506 let addr = format!("{}:{}", host, port);
508 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
509 Ok(Ok(s)) => s,
510 Ok(Err(e)) => {
511 return Err(ProxyError::UpstreamConnect {
512 host: host.to_string(),
513 reason: e.to_string(),
514 });
515 }
516 Err(_) => {
517 return Err(ProxyError::UpstreamConnect {
518 host: host.to_string(),
519 reason: "connection timed out".to_string(),
520 });
521 }
522 }
523 } else {
524 connect_to_resolved(resolved_addrs, host).await?
525 };
526
527 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
528 ProxyError::UpstreamConnect {
529 host: host.to_string(),
530 reason: "invalid server name for TLS".to_string(),
531 }
532 })?;
533
534 let tls_stream =
535 connector
536 .connect(server_name, tcp)
537 .await
538 .map_err(|e| ProxyError::UpstreamConnect {
539 host: host.to_string(),
540 reason: format!("TLS handshake failed: {}", e),
541 })?;
542
543 Ok(tls_stream)
544}
545
546async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
548 let mut last_err = None;
549 for addr in addrs {
550 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
551 Ok(Ok(stream)) => return Ok(stream),
552 Ok(Err(e)) => {
553 debug!("Connect to {} failed: {}", addr, e);
554 last_err = Some(e.to_string());
555 }
556 Err(_) => {
557 debug!("Connect to {} timed out", addr);
558 last_err = Some("connection timed out".to_string());
559 }
560 }
561 }
562 Err(ProxyError::UpstreamConnect {
563 host: host.to_string(),
564 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
565 })
566}
567
568fn parse_response_status(data: &[u8]) -> u16 {
574 let line_end = data
576 .iter()
577 .position(|&b| b == b'\r' || b == b'\n')
578 .unwrap_or(data.len());
579 let first_line = &data[..line_end.min(64)];
580
581 if let Ok(line) = std::str::from_utf8(first_line) {
582 let mut parts = line.split_whitespace();
584 if let Some(version) = parts.next() {
585 if version.starts_with("HTTP/") {
586 if let Some(code_str) = parts.next() {
587 if code_str.len() == 3 {
588 return code_str.parse().unwrap_or(502);
589 }
590 }
591 }
592 }
593 }
594 502
595}
596
597async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
599 let body = format!("{{\"error\":\"{}\"}}", reason);
600 let response = format!(
601 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
602 status,
603 reason,
604 body.len(),
605 body
606 );
607 stream.write_all(response.as_bytes()).await?;
608 stream.flush().await?;
609 Ok(())
610}
611
612fn validate_phantom_token_for_mode(
623 mode: &InjectMode,
624 header_bytes: &[u8],
625 path: &str,
626 header_name: &str,
627 path_pattern: Option<&str>,
628 query_param_name: Option<&str>,
629 session_token: &Zeroizing<String>,
630) -> Result<()> {
631 match mode {
632 InjectMode::Header | InjectMode::BasicAuth => {
633 validate_phantom_token(header_bytes, header_name, session_token)
635 }
636 InjectMode::UrlPath => {
637 let pattern = path_pattern.ok_or_else(|| {
639 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
640 })?;
641 validate_phantom_token_in_path(path, pattern, session_token)
642 }
643 InjectMode::QueryParam => {
644 let param_name = query_param_name.ok_or_else(|| {
646 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
647 })?;
648 validate_phantom_token_in_query(path, param_name, session_token)
649 }
650 }
651}
652
653fn validate_phantom_token_in_path(
658 path: &str,
659 pattern: &str,
660 session_token: &Zeroizing<String>,
661) -> Result<()> {
662 let parts: Vec<&str> = pattern.split("{}").collect();
664 if parts.len() != 2 {
665 return Err(ProxyError::HttpParse(format!(
666 "invalid path_pattern '{}': must contain exactly one {{}}",
667 pattern
668 )));
669 }
670 let (prefix, suffix) = (parts[0], parts[1]);
671
672 if let Some(start) = path.find(prefix) {
674 let after_prefix = start + prefix.len();
675
676 let end_offset = if suffix.is_empty() {
678 path[after_prefix..]
679 .find(['/', '?'])
680 .unwrap_or(path[after_prefix..].len())
681 } else {
682 match path[after_prefix..].find(suffix) {
683 Some(offset) => offset,
684 None => {
685 warn!("Missing phantom token in URL path (pattern: {})", pattern);
686 return Err(ProxyError::InvalidToken);
687 }
688 }
689 };
690
691 let token = &path[after_prefix..after_prefix + end_offset];
692 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
693 return Ok(());
694 }
695 warn!("Invalid phantom token in URL path");
696 return Err(ProxyError::InvalidToken);
697 }
698
699 warn!("Missing phantom token in URL path (pattern: {})", pattern);
700 Err(ProxyError::InvalidToken)
701}
702
703fn validate_phantom_token_in_query(
705 path: &str,
706 param_name: &str,
707 session_token: &Zeroizing<String>,
708) -> Result<()> {
709 if let Some(query_start) = path.find('?') {
711 let query = &path[query_start + 1..];
712 for pair in query.split('&') {
713 if let Some((name, value)) = pair.split_once('=') {
714 if name == param_name {
715 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
717 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
718 return Ok(());
719 }
720 warn!("Invalid phantom token in query parameter '{}'", param_name);
721 return Err(ProxyError::InvalidToken);
722 }
723 }
724 }
725 }
726
727 warn!("Missing phantom token in query parameter '{}'", param_name);
728 Err(ProxyError::InvalidToken)
729}
730
731fn transform_path_for_mode(
737 mode: &InjectMode,
738 path: &str,
739 path_pattern: Option<&str>,
740 path_replacement: Option<&str>,
741 query_param_name: Option<&str>,
742 credential: &Zeroizing<String>,
743) -> Result<String> {
744 match mode {
745 InjectMode::Header | InjectMode::BasicAuth => {
746 Ok(path.to_string())
748 }
749 InjectMode::UrlPath => {
750 let pattern = path_pattern.ok_or_else(|| {
751 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
752 })?;
753 let replacement = path_replacement.unwrap_or(pattern);
754 transform_url_path(path, pattern, replacement, credential)
755 }
756 InjectMode::QueryParam => {
757 let param_name = query_param_name.ok_or_else(|| {
758 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
759 })?;
760 transform_query_param(path, param_name, credential)
761 }
762 }
763}
764
765fn transform_url_path(
769 path: &str,
770 pattern: &str,
771 replacement: &str,
772 credential: &Zeroizing<String>,
773) -> Result<String> {
774 let parts: Vec<&str> = pattern.split("{}").collect();
776 if parts.len() != 2 {
777 return Err(ProxyError::HttpParse(format!(
778 "invalid path_pattern '{}': must contain exactly one {{}}",
779 pattern
780 )));
781 }
782 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
783
784 let repl_parts: Vec<&str> = replacement.split("{}").collect();
786 if repl_parts.len() != 2 {
787 return Err(ProxyError::HttpParse(format!(
788 "invalid path_replacement '{}': must contain exactly one {{}}",
789 replacement
790 )));
791 }
792 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
793
794 if let Some(start) = path.find(pattern_prefix) {
796 let after_prefix = start + pattern_prefix.len();
797
798 let end_offset = if pattern_suffix.is_empty() {
800 path[after_prefix..]
802 .find(['/', '?'])
803 .unwrap_or(path[after_prefix..].len())
804 } else {
805 match path[after_prefix..].find(pattern_suffix) {
807 Some(offset) => offset,
808 None => {
809 return Err(ProxyError::HttpParse(format!(
810 "path '{}' does not match pattern '{}'",
811 path, pattern
812 )));
813 }
814 }
815 };
816
817 let before = &path[..start];
818 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
819 return Ok(format!(
820 "{}{}{}{}{}",
821 before,
822 repl_prefix,
823 credential.as_str(),
824 repl_suffix,
825 after
826 ));
827 }
828
829 Err(ProxyError::HttpParse(format!(
830 "path '{}' does not match pattern '{}'",
831 path, pattern
832 )))
833}
834
835fn transform_query_param(
837 path: &str,
838 param_name: &str,
839 credential: &Zeroizing<String>,
840) -> Result<String> {
841 let encoded_value = urlencoding::encode(credential.as_str());
842
843 if let Some(query_start) = path.find('?') {
844 let base_path = &path[..query_start];
845 let query = &path[query_start + 1..];
846
847 let mut found = false;
849 let new_query: Vec<String> = query
850 .split('&')
851 .map(|pair| {
852 if let Some((name, _)) = pair.split_once('=') {
853 if name == param_name {
854 found = true;
855 return format!("{}={}", param_name, encoded_value);
856 }
857 }
858 pair.to_string()
859 })
860 .collect();
861
862 if found {
863 Ok(format!("{}?{}", base_path, new_query.join("&")))
864 } else {
865 Ok(format!(
867 "{}?{}&{}={}",
868 base_path, query, param_name, encoded_value
869 ))
870 }
871 } else {
872 Ok(format!("{}?{}={}", path, param_name, encoded_value))
874 }
875}
876
877fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
882 match cred.inject_mode {
883 InjectMode::Header | InjectMode::BasicAuth => {
884 request.push_str(&format!(
886 "{}: {}\r\n",
887 cred.header_name,
888 cred.header_value.as_str()
889 ));
890 }
891 InjectMode::UrlPath | InjectMode::QueryParam => {
892 }
895 }
896}
897
898#[cfg(test)]
899#[allow(clippy::unwrap_used)]
900mod tests {
901 use super::*;
902
903 #[test]
904 fn test_parse_request_line() {
905 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
906 assert_eq!(method, "POST");
907 assert_eq!(path, "/openai/v1/chat");
908 assert_eq!(version, "HTTP/1.1");
909 }
910
911 #[test]
912 fn test_parse_request_line_malformed() {
913 assert!(parse_request_line("GET").is_err());
914 }
915
916 #[test]
917 fn test_parse_service_prefix() {
918 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
919 assert_eq!(service, "openai");
920 assert_eq!(path, "/v1/chat/completions");
921 }
922
923 #[test]
924 fn test_parse_service_prefix_no_subpath() {
925 let (service, path) = parse_service_prefix("/anthropic").unwrap();
926 assert_eq!(service, "anthropic");
927 assert_eq!(path, "/");
928 }
929
930 #[test]
931 fn test_validate_phantom_token_bearer_valid() {
932 let token = Zeroizing::new("secret123".to_string());
933 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
934 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
935 }
936
937 #[test]
938 fn test_validate_phantom_token_bearer_invalid() {
939 let token = Zeroizing::new("secret123".to_string());
940 let header = b"Authorization: Bearer wrong\r\n\r\n";
941 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
942 }
943
944 #[test]
945 fn test_validate_phantom_token_x_api_key_valid() {
946 let token = Zeroizing::new("secret123".to_string());
947 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
948 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
949 }
950
951 #[test]
952 fn test_validate_phantom_token_x_goog_api_key_valid() {
953 let token = Zeroizing::new("secret123".to_string());
954 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
955 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
956 }
957
958 #[test]
959 fn test_validate_phantom_token_missing() {
960 let token = Zeroizing::new("secret123".to_string());
961 let header = b"Content-Type: application/json\r\n\r\n";
962 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
963 }
964
965 #[test]
966 fn test_validate_phantom_token_case_insensitive_header() {
967 let token = Zeroizing::new("secret123".to_string());
968 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
969 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
970 }
971
972 #[test]
973 fn test_filter_headers_removes_host_auth() {
974 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
975 let filtered = filter_headers(header, "Authorization");
976 assert_eq!(filtered.len(), 2);
977 assert_eq!(filtered[0].0, "Content-Type");
978 assert_eq!(filtered[1].0, "Accept");
979 }
980
981 #[test]
982 fn test_filter_headers_removes_x_api_key() {
983 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
984 let filtered = filter_headers(header, "x-api-key");
985 assert_eq!(filtered.len(), 1);
986 assert_eq!(filtered[0].0, "Content-Type");
987 }
988
989 #[test]
990 fn test_filter_headers_removes_custom_header() {
991 let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
992 let filtered = filter_headers(header, "PRIVATE-TOKEN");
993 assert_eq!(filtered.len(), 1);
994 assert_eq!(filtered[0].0, "Content-Type");
995 }
996
997 #[test]
998 fn test_extract_content_length() {
999 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
1000 assert_eq!(extract_content_length(header), Some(42));
1001 }
1002
1003 #[test]
1004 fn test_extract_content_length_missing() {
1005 let header = b"Content-Type: application/json\r\n\r\n";
1006 assert_eq!(extract_content_length(header), None);
1007 }
1008
1009 #[test]
1010 fn test_parse_upstream_url_https() {
1011 let (host, port, path) =
1012 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
1013 assert_eq!(host, "api.openai.com");
1014 assert_eq!(port, 443);
1015 assert_eq!(path, "/v1/chat/completions");
1016 }
1017
1018 #[test]
1019 fn test_parse_upstream_url_http_with_port() {
1020 let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
1021 assert_eq!(host, "localhost");
1022 assert_eq!(port, 8080);
1023 assert_eq!(path, "/api");
1024 }
1025
1026 #[test]
1027 fn test_parse_upstream_url_no_path() {
1028 let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
1029 assert_eq!(host, "api.anthropic.com");
1030 assert_eq!(port, 443);
1031 assert_eq!(path, "/");
1032 }
1033
1034 #[test]
1035 fn test_parse_upstream_url_invalid_scheme() {
1036 assert!(parse_upstream_url("ftp://example.com").is_err());
1037 }
1038
1039 #[test]
1040 fn test_parse_response_status_200() {
1041 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
1042 assert_eq!(parse_response_status(data), 200);
1043 }
1044
1045 #[test]
1046 fn test_parse_response_status_404() {
1047 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1048 assert_eq!(parse_response_status(data), 404);
1049 }
1050
1051 #[test]
1052 fn test_parse_response_status_garbage() {
1053 let data = b"not an http response";
1054 assert_eq!(parse_response_status(data), 502);
1055 }
1056
1057 #[test]
1058 fn test_parse_response_status_empty() {
1059 assert_eq!(parse_response_status(b""), 502);
1060 }
1061
1062 #[test]
1063 fn test_parse_response_status_partial() {
1064 let data = b"HTTP/1.1 ";
1065 assert_eq!(parse_response_status(data), 502);
1066 }
1067
1068 #[test]
1073 fn test_validate_phantom_token_in_path_valid() {
1074 let token = Zeroizing::new("session123".to_string());
1075 let path = "/bot/session123/getMe";
1076 let pattern = "/bot/{}/";
1077 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1078 }
1079
1080 #[test]
1081 fn test_validate_phantom_token_in_path_invalid() {
1082 let token = Zeroizing::new("session123".to_string());
1083 let path = "/bot/wrong_token/getMe";
1084 let pattern = "/bot/{}/";
1085 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1086 }
1087
1088 #[test]
1089 fn test_validate_phantom_token_in_path_missing() {
1090 let token = Zeroizing::new("session123".to_string());
1091 let path = "/api/getMe";
1092 let pattern = "/bot/{}/";
1093 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1094 }
1095
1096 #[test]
1097 fn test_transform_url_path_basic() {
1098 let credential = Zeroizing::new("real_token".to_string());
1099 let path = "/bot/phantom_token/getMe";
1100 let pattern = "/bot/{}/";
1101 let replacement = "/bot/{}/";
1102 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1103 assert_eq!(result, "/bot/real_token/getMe");
1104 }
1105
1106 #[test]
1107 fn test_transform_url_path_different_replacement() {
1108 let credential = Zeroizing::new("real_token".to_string());
1109 let path = "/api/v1/phantom_token/chat";
1110 let pattern = "/api/v1/{}/";
1111 let replacement = "/v2/bot/{}/";
1112 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1113 assert_eq!(result, "/v2/bot/real_token/chat");
1114 }
1115
1116 #[test]
1117 fn test_transform_url_path_no_trailing_slash() {
1118 let credential = Zeroizing::new("real_token".to_string());
1119 let path = "/bot/phantom_token";
1120 let pattern = "/bot/{}";
1121 let replacement = "/bot/{}";
1122 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1123 assert_eq!(result, "/bot/real_token");
1124 }
1125
1126 #[test]
1131 fn test_validate_phantom_token_in_query_valid() {
1132 let token = Zeroizing::new("session123".to_string());
1133 let path = "/api/data?api_key=session123&other=value";
1134 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1135 }
1136
1137 #[test]
1138 fn test_validate_phantom_token_in_query_invalid() {
1139 let token = Zeroizing::new("session123".to_string());
1140 let path = "/api/data?api_key=wrong_token";
1141 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1142 }
1143
1144 #[test]
1145 fn test_validate_phantom_token_in_query_missing_param() {
1146 let token = Zeroizing::new("session123".to_string());
1147 let path = "/api/data?other=value";
1148 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1149 }
1150
1151 #[test]
1152 fn test_validate_phantom_token_in_query_no_query_string() {
1153 let token = Zeroizing::new("session123".to_string());
1154 let path = "/api/data";
1155 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1156 }
1157
1158 #[test]
1159 fn test_validate_phantom_token_in_query_url_encoded() {
1160 let token = Zeroizing::new("token with spaces".to_string());
1161 let path = "/api/data?api_key=token%20with%20spaces";
1162 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1163 }
1164
1165 #[test]
1166 fn test_transform_query_param_add_to_no_query() {
1167 let credential = Zeroizing::new("real_key".to_string());
1168 let path = "/api/data";
1169 let result = transform_query_param(path, "api_key", &credential).unwrap();
1170 assert_eq!(result, "/api/data?api_key=real_key");
1171 }
1172
1173 #[test]
1174 fn test_transform_query_param_add_to_existing_query() {
1175 let credential = Zeroizing::new("real_key".to_string());
1176 let path = "/api/data?other=value";
1177 let result = transform_query_param(path, "api_key", &credential).unwrap();
1178 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1179 }
1180
1181 #[test]
1182 fn test_transform_query_param_replace_existing() {
1183 let credential = Zeroizing::new("real_key".to_string());
1184 let path = "/api/data?api_key=phantom&other=value";
1185 let result = transform_query_param(path, "api_key", &credential).unwrap();
1186 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1187 }
1188
1189 #[test]
1190 fn test_transform_query_param_url_encodes_special_chars() {
1191 let credential = Zeroizing::new("key with spaces".to_string());
1192 let path = "/api/data";
1193 let result = transform_query_param(path, "api_key", &credential).unwrap();
1194 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1195 }
1196}