1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct ReverseProxyCtx<'a> {
43 pub credential_store: &'a CredentialStore,
45 pub session_token: &'a Zeroizing<String>,
47 pub filter: &'a ProxyFilter,
49 pub tls_connector: &'a TlsConnector,
51 pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55pub async fn handle_reverse_proxy(
69 first_line: &str,
70 stream: &mut TcpStream,
71 remaining_header: &[u8],
72 ctx: &ReverseProxyCtx<'_>,
73 buffered_body: &[u8],
74) -> Result<()> {
75 let (method, path, version) = parse_request_line(first_line)?;
77 debug!("Reverse proxy: {} {}", method, path);
78
79 let (service, upstream_path) = parse_service_prefix(&path)?;
81
82 let cred = ctx
84 .credential_store
85 .get(&service)
86 .ok_or_else(|| ProxyError::UnknownService {
87 prefix: service.clone(),
88 })?;
89
90 if !cred.endpoint_rules.is_allowed(&method, &upstream_path) {
93 let reason = format!(
94 "endpoint denied: {} {} on service '{}'",
95 method, upstream_path, service
96 );
97 warn!("{}", reason);
98 audit::log_denied(
99 ctx.audit_log,
100 audit::ProxyMode::Reverse,
101 &service,
102 0,
103 &reason,
104 );
105 send_error(stream, 403, "Forbidden").await?;
106 return Ok(());
107 }
108
109 if let Err(e) = validate_phantom_token_for_mode(
114 &cred.inject_mode,
115 remaining_header,
116 &upstream_path,
117 &cred.header_name,
118 cred.path_pattern.as_deref(),
119 cred.query_param_name.as_deref(),
120 ctx.session_token,
121 ) {
122 audit::log_denied(
123 ctx.audit_log,
124 audit::ProxyMode::Reverse,
125 &service,
126 0,
127 &e.to_string(),
128 );
129 send_error(stream, 401, "Unauthorized").await?;
130 return Ok(());
131 }
132
133 let transformed_path = transform_path_for_mode(
135 &cred.inject_mode,
136 &upstream_path,
137 cred.path_pattern.as_deref(),
138 cred.path_replacement.as_deref(),
139 cred.query_param_name.as_deref(),
140 &cred.raw_credential,
141 )?;
142
143 let upstream_url = format!(
145 "{}{}",
146 cred.upstream.trim_end_matches('/'),
147 transformed_path
148 );
149 debug!("Forwarding to upstream: {} {}", method, upstream_url);
150
151 let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
152
153 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
155 if !check.result.is_allowed() {
156 let reason = check.result.reason();
157 warn!("Upstream host denied by filter: {}", reason);
158 send_error(stream, 403, "Forbidden").await?;
159 audit::log_denied(
160 ctx.audit_log,
161 audit::ProxyMode::Reverse,
162 &service,
163 0,
164 &reason,
165 );
166 return Ok(());
167 }
168
169 let filtered_headers = filter_headers(remaining_header, &cred.header_name);
171 let content_length = extract_content_length(remaining_header);
172
173 let body = if let Some(len) = content_length {
177 if len > MAX_REQUEST_BODY {
178 send_error(stream, 413, "Payload Too Large").await?;
179 return Ok(());
180 }
181 let mut buf = Vec::with_capacity(len);
182 let pre = buffered_body.len().min(len);
183 buf.extend_from_slice(&buffered_body[..pre]);
184 let remaining = len - pre;
185 if remaining > 0 {
186 let mut rest = vec![0u8; remaining];
187 stream.read_exact(&mut rest).await?;
188 buf.extend_from_slice(&rest);
189 }
190 buf
191 } else {
192 Vec::new()
193 };
194
195 let upstream_result = connect_upstream_tls(
197 &upstream_host,
198 upstream_port,
199 &check.resolved_addrs,
200 ctx.tls_connector,
201 )
202 .await;
203 let mut tls_stream = match upstream_result {
204 Ok(s) => s,
205 Err(e) => {
206 warn!("Upstream connection failed: {}", e);
207 send_error(stream, 502, "Bad Gateway").await?;
208 audit::log_denied(
209 ctx.audit_log,
210 audit::ProxyMode::Reverse,
211 &service,
212 0,
213 &e.to_string(),
214 );
215 return Ok(());
216 }
217 };
218
219 let mut request = Zeroizing::new(format!(
223 "{} {} {}\r\nHost: {}\r\n",
224 method, upstream_path_full, version, upstream_host
225 ));
226
227 inject_credential_for_mode(cred, &mut request);
229
230 let auth_header_lower = cred.header_name.to_lowercase();
232 for (name, value) in &filtered_headers {
233 if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
236 && name.to_lowercase() == auth_header_lower
237 {
238 continue;
239 }
240 request.push_str(&format!("{}: {}\r\n", name, value));
241 }
242
243 if !body.is_empty() {
245 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
246 }
247 request.push_str("\r\n");
248
249 tls_stream.write_all(request.as_bytes()).await?;
250 if !body.is_empty() {
251 tls_stream.write_all(&body).await?;
252 }
253 tls_stream.flush().await?;
254
255 let mut response_buf = [0u8; 8192];
258 let mut status_code: u16 = 502;
259 let mut first_chunk = true;
260
261 loop {
262 let n = match tls_stream.read(&mut response_buf).await {
263 Ok(0) => break,
264 Ok(n) => n,
265 Err(e) => {
266 debug!("Upstream read error: {}", e);
267 break;
268 }
269 };
270
271 if first_chunk {
275 status_code = parse_response_status(&response_buf[..n]);
276 first_chunk = false;
277 }
278
279 stream.write_all(&response_buf[..n]).await?;
280 stream.flush().await?;
281 }
282
283 audit::log_reverse_proxy(
284 ctx.audit_log,
285 &service,
286 &method,
287 &upstream_path,
288 status_code,
289 );
290 Ok(())
291}
292
293fn parse_request_line(line: &str) -> Result<(String, String, String)> {
295 let parts: Vec<&str> = line.split_whitespace().collect();
296 if parts.len() < 3 {
297 return Err(ProxyError::HttpParse(format!(
298 "malformed request line: {}",
299 line
300 )));
301 }
302 Ok((
303 parts[0].to_string(),
304 parts[1].to_string(),
305 parts[2].to_string(),
306 ))
307}
308
309fn parse_service_prefix(path: &str) -> Result<(String, String)> {
314 let trimmed = path.strip_prefix('/').unwrap_or(path);
315 if let Some((prefix, rest)) = trimmed.split_once('/') {
316 Ok((prefix.to_string(), format!("/{}", rest)))
317 } else {
318 Ok((trimmed.to_string(), "/".to_string()))
320 }
321}
322
323fn validate_phantom_token(
330 header_bytes: &[u8],
331 header_name: &str,
332 session_token: &Zeroizing<String>,
333) -> Result<()> {
334 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
335 let header_name_lower = header_name.to_lowercase();
336
337 for line in header_str.lines() {
338 let lower = line.to_lowercase();
339 if lower.starts_with(&format!("{}:", header_name_lower)) {
340 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
341
342 let value_lower = value.to_lowercase();
345 let token_value = if value_lower.starts_with("bearer ") {
346 value[7..].trim()
348 } else {
349 value
350 };
351
352 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
353 return Ok(());
354 }
355 warn!("Invalid phantom token in {} header", header_name);
356 return Err(ProxyError::InvalidToken);
357 }
358 }
359
360 warn!(
361 "Missing {} header for phantom token validation",
362 header_name
363 );
364 Err(ProxyError::InvalidToken)
365}
366
367fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
373 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
374 let cred_header_lower = format!("{}:", cred_header.to_lowercase());
375 let mut headers = Vec::new();
376
377 for line in header_str.lines() {
378 let lower = line.to_lowercase();
379 if lower.starts_with("host:")
380 || lower.starts_with("content-length:")
381 || lower.starts_with(&cred_header_lower)
382 || line.trim().is_empty()
383 {
384 continue;
385 }
386 if let Some((name, value)) = line.split_once(':') {
387 headers.push((name.trim().to_string(), value.trim().to_string()));
388 }
389 }
390
391 headers
392}
393
394fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
396 let header_str = std::str::from_utf8(header_bytes).ok()?;
397 for line in header_str.lines() {
398 if line.to_lowercase().starts_with("content-length:") {
399 let value = line.split_once(':')?.1.trim();
400 return value.parse().ok();
401 }
402 }
403 None
404}
405
406fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
408 let parsed = url::Url::parse(url_str)
409 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
410
411 let scheme = parsed.scheme();
412 if scheme != "https" && scheme != "http" {
413 return Err(ProxyError::HttpParse(format!(
414 "unsupported URL scheme: {}",
415 url_str
416 )));
417 }
418
419 let host = parsed
420 .host_str()
421 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
422 .to_string();
423
424 let default_port = if scheme == "https" { 443 } else { 80 };
425 let port = parsed.port().unwrap_or(default_port);
426
427 let path = parsed.path().to_string();
428 let path = if path.is_empty() {
429 "/".to_string()
430 } else {
431 path
432 };
433
434 let path_with_query = if let Some(query) = parsed.query() {
436 format!("{}?{}", path, query)
437 } else {
438 path
439 };
440
441 Ok((host, port, path_with_query))
442}
443
444async fn connect_upstream_tls(
453 host: &str,
454 port: u16,
455 resolved_addrs: &[SocketAddr],
456 connector: &TlsConnector,
457) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
458 let tcp = if resolved_addrs.is_empty() {
459 let addr = format!("{}:{}", host, port);
461 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
462 Ok(Ok(s)) => s,
463 Ok(Err(e)) => {
464 return Err(ProxyError::UpstreamConnect {
465 host: host.to_string(),
466 reason: e.to_string(),
467 });
468 }
469 Err(_) => {
470 return Err(ProxyError::UpstreamConnect {
471 host: host.to_string(),
472 reason: "connection timed out".to_string(),
473 });
474 }
475 }
476 } else {
477 connect_to_resolved(resolved_addrs, host).await?
478 };
479
480 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
481 ProxyError::UpstreamConnect {
482 host: host.to_string(),
483 reason: "invalid server name for TLS".to_string(),
484 }
485 })?;
486
487 let tls_stream =
488 connector
489 .connect(server_name, tcp)
490 .await
491 .map_err(|e| ProxyError::UpstreamConnect {
492 host: host.to_string(),
493 reason: format!("TLS handshake failed: {}", e),
494 })?;
495
496 Ok(tls_stream)
497}
498
499async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
501 let mut last_err = None;
502 for addr in addrs {
503 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
504 Ok(Ok(stream)) => return Ok(stream),
505 Ok(Err(e)) => {
506 debug!("Connect to {} failed: {}", addr, e);
507 last_err = Some(e.to_string());
508 }
509 Err(_) => {
510 debug!("Connect to {} timed out", addr);
511 last_err = Some("connection timed out".to_string());
512 }
513 }
514 }
515 Err(ProxyError::UpstreamConnect {
516 host: host.to_string(),
517 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
518 })
519}
520
521fn parse_response_status(data: &[u8]) -> u16 {
527 let line_end = data
529 .iter()
530 .position(|&b| b == b'\r' || b == b'\n')
531 .unwrap_or(data.len());
532 let first_line = &data[..line_end.min(64)];
533
534 if let Ok(line) = std::str::from_utf8(first_line) {
535 let mut parts = line.split_whitespace();
537 if let Some(version) = parts.next() {
538 if version.starts_with("HTTP/") {
539 if let Some(code_str) = parts.next() {
540 if code_str.len() == 3 {
541 return code_str.parse().unwrap_or(502);
542 }
543 }
544 }
545 }
546 }
547 502
548}
549
550async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
552 let body = format!("{{\"error\":\"{}\"}}", reason);
553 let response = format!(
554 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
555 status,
556 reason,
557 body.len(),
558 body
559 );
560 stream.write_all(response.as_bytes()).await?;
561 stream.flush().await?;
562 Ok(())
563}
564
565fn validate_phantom_token_for_mode(
576 mode: &InjectMode,
577 header_bytes: &[u8],
578 path: &str,
579 header_name: &str,
580 path_pattern: Option<&str>,
581 query_param_name: Option<&str>,
582 session_token: &Zeroizing<String>,
583) -> Result<()> {
584 match mode {
585 InjectMode::Header | InjectMode::BasicAuth => {
586 validate_phantom_token(header_bytes, header_name, session_token)
588 }
589 InjectMode::UrlPath => {
590 let pattern = path_pattern.ok_or_else(|| {
592 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
593 })?;
594 validate_phantom_token_in_path(path, pattern, session_token)
595 }
596 InjectMode::QueryParam => {
597 let param_name = query_param_name.ok_or_else(|| {
599 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
600 })?;
601 validate_phantom_token_in_query(path, param_name, session_token)
602 }
603 }
604}
605
606fn validate_phantom_token_in_path(
611 path: &str,
612 pattern: &str,
613 session_token: &Zeroizing<String>,
614) -> Result<()> {
615 let parts: Vec<&str> = pattern.split("{}").collect();
617 if parts.len() != 2 {
618 return Err(ProxyError::HttpParse(format!(
619 "invalid path_pattern '{}': must contain exactly one {{}}",
620 pattern
621 )));
622 }
623 let (prefix, suffix) = (parts[0], parts[1]);
624
625 if let Some(start) = path.find(prefix) {
627 let after_prefix = start + prefix.len();
628
629 let end_offset = if suffix.is_empty() {
631 path[after_prefix..]
632 .find(['/', '?'])
633 .unwrap_or(path[after_prefix..].len())
634 } else {
635 match path[after_prefix..].find(suffix) {
636 Some(offset) => offset,
637 None => {
638 warn!("Missing phantom token in URL path (pattern: {})", pattern);
639 return Err(ProxyError::InvalidToken);
640 }
641 }
642 };
643
644 let token = &path[after_prefix..after_prefix + end_offset];
645 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
646 return Ok(());
647 }
648 warn!("Invalid phantom token in URL path");
649 return Err(ProxyError::InvalidToken);
650 }
651
652 warn!("Missing phantom token in URL path (pattern: {})", pattern);
653 Err(ProxyError::InvalidToken)
654}
655
656fn validate_phantom_token_in_query(
658 path: &str,
659 param_name: &str,
660 session_token: &Zeroizing<String>,
661) -> Result<()> {
662 if let Some(query_start) = path.find('?') {
664 let query = &path[query_start + 1..];
665 for pair in query.split('&') {
666 if let Some((name, value)) = pair.split_once('=') {
667 if name == param_name {
668 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
670 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
671 return Ok(());
672 }
673 warn!("Invalid phantom token in query parameter '{}'", param_name);
674 return Err(ProxyError::InvalidToken);
675 }
676 }
677 }
678 }
679
680 warn!("Missing phantom token in query parameter '{}'", param_name);
681 Err(ProxyError::InvalidToken)
682}
683
684fn transform_path_for_mode(
690 mode: &InjectMode,
691 path: &str,
692 path_pattern: Option<&str>,
693 path_replacement: Option<&str>,
694 query_param_name: Option<&str>,
695 credential: &Zeroizing<String>,
696) -> Result<String> {
697 match mode {
698 InjectMode::Header | InjectMode::BasicAuth => {
699 Ok(path.to_string())
701 }
702 InjectMode::UrlPath => {
703 let pattern = path_pattern.ok_or_else(|| {
704 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
705 })?;
706 let replacement = path_replacement.unwrap_or(pattern);
707 transform_url_path(path, pattern, replacement, credential)
708 }
709 InjectMode::QueryParam => {
710 let param_name = query_param_name.ok_or_else(|| {
711 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
712 })?;
713 transform_query_param(path, param_name, credential)
714 }
715 }
716}
717
718fn transform_url_path(
722 path: &str,
723 pattern: &str,
724 replacement: &str,
725 credential: &Zeroizing<String>,
726) -> Result<String> {
727 let parts: Vec<&str> = pattern.split("{}").collect();
729 if parts.len() != 2 {
730 return Err(ProxyError::HttpParse(format!(
731 "invalid path_pattern '{}': must contain exactly one {{}}",
732 pattern
733 )));
734 }
735 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
736
737 let repl_parts: Vec<&str> = replacement.split("{}").collect();
739 if repl_parts.len() != 2 {
740 return Err(ProxyError::HttpParse(format!(
741 "invalid path_replacement '{}': must contain exactly one {{}}",
742 replacement
743 )));
744 }
745 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
746
747 if let Some(start) = path.find(pattern_prefix) {
749 let after_prefix = start + pattern_prefix.len();
750
751 let end_offset = if pattern_suffix.is_empty() {
753 path[after_prefix..]
755 .find(['/', '?'])
756 .unwrap_or(path[after_prefix..].len())
757 } else {
758 match path[after_prefix..].find(pattern_suffix) {
760 Some(offset) => offset,
761 None => {
762 return Err(ProxyError::HttpParse(format!(
763 "path '{}' does not match pattern '{}'",
764 path, pattern
765 )));
766 }
767 }
768 };
769
770 let before = &path[..start];
771 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
772 return Ok(format!(
773 "{}{}{}{}{}",
774 before,
775 repl_prefix,
776 credential.as_str(),
777 repl_suffix,
778 after
779 ));
780 }
781
782 Err(ProxyError::HttpParse(format!(
783 "path '{}' does not match pattern '{}'",
784 path, pattern
785 )))
786}
787
788fn transform_query_param(
790 path: &str,
791 param_name: &str,
792 credential: &Zeroizing<String>,
793) -> Result<String> {
794 let encoded_value = urlencoding::encode(credential.as_str());
795
796 if let Some(query_start) = path.find('?') {
797 let base_path = &path[..query_start];
798 let query = &path[query_start + 1..];
799
800 let mut found = false;
802 let new_query: Vec<String> = query
803 .split('&')
804 .map(|pair| {
805 if let Some((name, _)) = pair.split_once('=') {
806 if name == param_name {
807 found = true;
808 return format!("{}={}", param_name, encoded_value);
809 }
810 }
811 pair.to_string()
812 })
813 .collect();
814
815 if found {
816 Ok(format!("{}?{}", base_path, new_query.join("&")))
817 } else {
818 Ok(format!(
820 "{}?{}&{}={}",
821 base_path, query, param_name, encoded_value
822 ))
823 }
824 } else {
825 Ok(format!("{}?{}={}", path, param_name, encoded_value))
827 }
828}
829
830fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
835 match cred.inject_mode {
836 InjectMode::Header | InjectMode::BasicAuth => {
837 request.push_str(&format!(
839 "{}: {}\r\n",
840 cred.header_name,
841 cred.header_value.as_str()
842 ));
843 }
844 InjectMode::UrlPath | InjectMode::QueryParam => {
845 }
848 }
849}
850
851#[cfg(test)]
852#[allow(clippy::unwrap_used)]
853mod tests {
854 use super::*;
855
856 #[test]
857 fn test_parse_request_line() {
858 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
859 assert_eq!(method, "POST");
860 assert_eq!(path, "/openai/v1/chat");
861 assert_eq!(version, "HTTP/1.1");
862 }
863
864 #[test]
865 fn test_parse_request_line_malformed() {
866 assert!(parse_request_line("GET").is_err());
867 }
868
869 #[test]
870 fn test_parse_service_prefix() {
871 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
872 assert_eq!(service, "openai");
873 assert_eq!(path, "/v1/chat/completions");
874 }
875
876 #[test]
877 fn test_parse_service_prefix_no_subpath() {
878 let (service, path) = parse_service_prefix("/anthropic").unwrap();
879 assert_eq!(service, "anthropic");
880 assert_eq!(path, "/");
881 }
882
883 #[test]
884 fn test_validate_phantom_token_bearer_valid() {
885 let token = Zeroizing::new("secret123".to_string());
886 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
887 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
888 }
889
890 #[test]
891 fn test_validate_phantom_token_bearer_invalid() {
892 let token = Zeroizing::new("secret123".to_string());
893 let header = b"Authorization: Bearer wrong\r\n\r\n";
894 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
895 }
896
897 #[test]
898 fn test_validate_phantom_token_x_api_key_valid() {
899 let token = Zeroizing::new("secret123".to_string());
900 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
901 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
902 }
903
904 #[test]
905 fn test_validate_phantom_token_x_goog_api_key_valid() {
906 let token = Zeroizing::new("secret123".to_string());
907 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
908 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
909 }
910
911 #[test]
912 fn test_validate_phantom_token_missing() {
913 let token = Zeroizing::new("secret123".to_string());
914 let header = b"Content-Type: application/json\r\n\r\n";
915 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
916 }
917
918 #[test]
919 fn test_validate_phantom_token_case_insensitive_header() {
920 let token = Zeroizing::new("secret123".to_string());
921 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
922 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
923 }
924
925 #[test]
926 fn test_filter_headers_removes_host_auth() {
927 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
928 let filtered = filter_headers(header, "Authorization");
929 assert_eq!(filtered.len(), 2);
930 assert_eq!(filtered[0].0, "Content-Type");
931 assert_eq!(filtered[1].0, "Accept");
932 }
933
934 #[test]
935 fn test_filter_headers_removes_x_api_key() {
936 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
937 let filtered = filter_headers(header, "x-api-key");
938 assert_eq!(filtered.len(), 1);
939 assert_eq!(filtered[0].0, "Content-Type");
940 }
941
942 #[test]
943 fn test_filter_headers_removes_custom_header() {
944 let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
945 let filtered = filter_headers(header, "PRIVATE-TOKEN");
946 assert_eq!(filtered.len(), 1);
947 assert_eq!(filtered[0].0, "Content-Type");
948 }
949
950 #[test]
951 fn test_extract_content_length() {
952 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
953 assert_eq!(extract_content_length(header), Some(42));
954 }
955
956 #[test]
957 fn test_extract_content_length_missing() {
958 let header = b"Content-Type: application/json\r\n\r\n";
959 assert_eq!(extract_content_length(header), None);
960 }
961
962 #[test]
963 fn test_parse_upstream_url_https() {
964 let (host, port, path) =
965 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
966 assert_eq!(host, "api.openai.com");
967 assert_eq!(port, 443);
968 assert_eq!(path, "/v1/chat/completions");
969 }
970
971 #[test]
972 fn test_parse_upstream_url_http_with_port() {
973 let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
974 assert_eq!(host, "localhost");
975 assert_eq!(port, 8080);
976 assert_eq!(path, "/api");
977 }
978
979 #[test]
980 fn test_parse_upstream_url_no_path() {
981 let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
982 assert_eq!(host, "api.anthropic.com");
983 assert_eq!(port, 443);
984 assert_eq!(path, "/");
985 }
986
987 #[test]
988 fn test_parse_upstream_url_invalid_scheme() {
989 assert!(parse_upstream_url("ftp://example.com").is_err());
990 }
991
992 #[test]
993 fn test_parse_response_status_200() {
994 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
995 assert_eq!(parse_response_status(data), 200);
996 }
997
998 #[test]
999 fn test_parse_response_status_404() {
1000 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1001 assert_eq!(parse_response_status(data), 404);
1002 }
1003
1004 #[test]
1005 fn test_parse_response_status_garbage() {
1006 let data = b"not an http response";
1007 assert_eq!(parse_response_status(data), 502);
1008 }
1009
1010 #[test]
1011 fn test_parse_response_status_empty() {
1012 assert_eq!(parse_response_status(b""), 502);
1013 }
1014
1015 #[test]
1016 fn test_parse_response_status_partial() {
1017 let data = b"HTTP/1.1 ";
1018 assert_eq!(parse_response_status(data), 502);
1019 }
1020
1021 #[test]
1026 fn test_validate_phantom_token_in_path_valid() {
1027 let token = Zeroizing::new("session123".to_string());
1028 let path = "/bot/session123/getMe";
1029 let pattern = "/bot/{}/";
1030 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1031 }
1032
1033 #[test]
1034 fn test_validate_phantom_token_in_path_invalid() {
1035 let token = Zeroizing::new("session123".to_string());
1036 let path = "/bot/wrong_token/getMe";
1037 let pattern = "/bot/{}/";
1038 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1039 }
1040
1041 #[test]
1042 fn test_validate_phantom_token_in_path_missing() {
1043 let token = Zeroizing::new("session123".to_string());
1044 let path = "/api/getMe";
1045 let pattern = "/bot/{}/";
1046 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1047 }
1048
1049 #[test]
1050 fn test_transform_url_path_basic() {
1051 let credential = Zeroizing::new("real_token".to_string());
1052 let path = "/bot/phantom_token/getMe";
1053 let pattern = "/bot/{}/";
1054 let replacement = "/bot/{}/";
1055 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1056 assert_eq!(result, "/bot/real_token/getMe");
1057 }
1058
1059 #[test]
1060 fn test_transform_url_path_different_replacement() {
1061 let credential = Zeroizing::new("real_token".to_string());
1062 let path = "/api/v1/phantom_token/chat";
1063 let pattern = "/api/v1/{}/";
1064 let replacement = "/v2/bot/{}/";
1065 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1066 assert_eq!(result, "/v2/bot/real_token/chat");
1067 }
1068
1069 #[test]
1070 fn test_transform_url_path_no_trailing_slash() {
1071 let credential = Zeroizing::new("real_token".to_string());
1072 let path = "/bot/phantom_token";
1073 let pattern = "/bot/{}";
1074 let replacement = "/bot/{}";
1075 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1076 assert_eq!(result, "/bot/real_token");
1077 }
1078
1079 #[test]
1084 fn test_validate_phantom_token_in_query_valid() {
1085 let token = Zeroizing::new("session123".to_string());
1086 let path = "/api/data?api_key=session123&other=value";
1087 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1088 }
1089
1090 #[test]
1091 fn test_validate_phantom_token_in_query_invalid() {
1092 let token = Zeroizing::new("session123".to_string());
1093 let path = "/api/data?api_key=wrong_token";
1094 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1095 }
1096
1097 #[test]
1098 fn test_validate_phantom_token_in_query_missing_param() {
1099 let token = Zeroizing::new("session123".to_string());
1100 let path = "/api/data?other=value";
1101 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1102 }
1103
1104 #[test]
1105 fn test_validate_phantom_token_in_query_no_query_string() {
1106 let token = Zeroizing::new("session123".to_string());
1107 let path = "/api/data";
1108 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1109 }
1110
1111 #[test]
1112 fn test_validate_phantom_token_in_query_url_encoded() {
1113 let token = Zeroizing::new("token with spaces".to_string());
1114 let path = "/api/data?api_key=token%20with%20spaces";
1115 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1116 }
1117
1118 #[test]
1119 fn test_transform_query_param_add_to_no_query() {
1120 let credential = Zeroizing::new("real_key".to_string());
1121 let path = "/api/data";
1122 let result = transform_query_param(path, "api_key", &credential).unwrap();
1123 assert_eq!(result, "/api/data?api_key=real_key");
1124 }
1125
1126 #[test]
1127 fn test_transform_query_param_add_to_existing_query() {
1128 let credential = Zeroizing::new("real_key".to_string());
1129 let path = "/api/data?other=value";
1130 let result = transform_query_param(path, "api_key", &credential).unwrap();
1131 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1132 }
1133
1134 #[test]
1135 fn test_transform_query_param_replace_existing() {
1136 let credential = Zeroizing::new("real_key".to_string());
1137 let path = "/api/data?api_key=phantom&other=value";
1138 let result = transform_query_param(path, "api_key", &credential).unwrap();
1139 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1140 }
1141
1142 #[test]
1143 fn test_transform_query_param_url_encodes_special_chars() {
1144 let credential = Zeroizing::new("key with spaces".to_string());
1145 let path = "/api/data";
1146 let result = transform_query_param(path, "api_key", &credential).unwrap();
1147 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1148 }
1149}