1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct ReverseProxyCtx<'a> {
43 pub credential_store: &'a CredentialStore,
45 pub session_token: &'a Zeroizing<String>,
47 pub filter: &'a ProxyFilter,
49 pub tls_connector: &'a TlsConnector,
51 pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55pub async fn handle_reverse_proxy(
69 first_line: &str,
70 stream: &mut TcpStream,
71 remaining_header: &[u8],
72 ctx: &ReverseProxyCtx<'_>,
73 buffered_body: &[u8],
74) -> Result<()> {
75 let (method, path, version) = parse_request_line(first_line)?;
77 debug!("Reverse proxy: {} {}", method, path);
78
79 let (service, upstream_path) = parse_service_prefix(&path)?;
81
82 let cred = ctx
84 .credential_store
85 .get(&service)
86 .ok_or_else(|| ProxyError::UnknownService {
87 prefix: service.clone(),
88 })?;
89
90 if !cred.endpoint_rules.is_allowed(&method, &upstream_path) {
93 let reason = format!(
94 "endpoint denied: {} {} on service '{}'",
95 method, upstream_path, service
96 );
97 warn!("{}", reason);
98 audit::log_denied(
99 ctx.audit_log,
100 audit::ProxyMode::Reverse,
101 &service,
102 0,
103 &reason,
104 );
105 send_error(stream, 403, "Forbidden").await?;
106 return Ok(());
107 }
108
109 if let Err(e) = validate_phantom_token_for_mode(
114 &cred.inject_mode,
115 remaining_header,
116 &upstream_path,
117 &cred.header_name,
118 cred.path_pattern.as_deref(),
119 cred.query_param_name.as_deref(),
120 ctx.session_token,
121 ) {
122 audit::log_denied(
123 ctx.audit_log,
124 audit::ProxyMode::Reverse,
125 &service,
126 0,
127 &e.to_string(),
128 );
129 send_error(stream, 401, "Unauthorized").await?;
130 return Ok(());
131 }
132
133 let transformed_path = transform_path_for_mode(
135 &cred.inject_mode,
136 &upstream_path,
137 cred.path_pattern.as_deref(),
138 cred.path_replacement.as_deref(),
139 cred.query_param_name.as_deref(),
140 &cred.raw_credential,
141 )?;
142
143 let upstream_url = format!(
145 "{}{}",
146 cred.upstream.trim_end_matches('/'),
147 transformed_path
148 );
149 debug!("Forwarding to upstream: {} {}", method, upstream_url);
150
151 let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
152
153 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
155 if !check.result.is_allowed() {
156 let reason = check.result.reason();
157 warn!("Upstream host denied by filter: {}", reason);
158 send_error(stream, 403, "Forbidden").await?;
159 audit::log_denied(
160 ctx.audit_log,
161 audit::ProxyMode::Reverse,
162 &service,
163 0,
164 &reason,
165 );
166 return Ok(());
167 }
168
169 let filtered_headers = filter_headers(remaining_header, &cred.header_name);
171 let content_length = extract_content_length(remaining_header);
172
173 let body = if let Some(len) = content_length {
177 if len > MAX_REQUEST_BODY {
178 send_error(stream, 413, "Payload Too Large").await?;
179 return Ok(());
180 }
181 let mut buf = Vec::with_capacity(len);
182 let pre = buffered_body.len().min(len);
183 buf.extend_from_slice(&buffered_body[..pre]);
184 let remaining = len - pre;
185 if remaining > 0 {
186 let mut rest = vec![0u8; remaining];
187 stream.read_exact(&mut rest).await?;
188 buf.extend_from_slice(&rest);
189 }
190 buf
191 } else {
192 Vec::new()
193 };
194
195 let connector = cred.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
199 let upstream_result = connect_upstream_tls(
200 &upstream_host,
201 upstream_port,
202 &check.resolved_addrs,
203 connector,
204 )
205 .await;
206 let mut tls_stream = match upstream_result {
207 Ok(s) => s,
208 Err(e) => {
209 warn!("Upstream connection failed: {}", e);
210 send_error(stream, 502, "Bad Gateway").await?;
211 audit::log_denied(
212 ctx.audit_log,
213 audit::ProxyMode::Reverse,
214 &service,
215 0,
216 &e.to_string(),
217 );
218 return Ok(());
219 }
220 };
221
222 let mut request = Zeroizing::new(format!(
226 "{} {} {}\r\nHost: {}\r\n",
227 method, upstream_path_full, version, upstream_host
228 ));
229
230 inject_credential_for_mode(cred, &mut request);
232
233 let auth_header_lower = cred.header_name.to_lowercase();
235 for (name, value) in &filtered_headers {
236 if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
239 && name.to_lowercase() == auth_header_lower
240 {
241 continue;
242 }
243 request.push_str(&format!("{}: {}\r\n", name, value));
244 }
245
246 if !body.is_empty() {
248 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
249 }
250 request.push_str("\r\n");
251
252 tls_stream.write_all(request.as_bytes()).await?;
253 if !body.is_empty() {
254 tls_stream.write_all(&body).await?;
255 }
256 tls_stream.flush().await?;
257
258 let mut response_buf = [0u8; 8192];
261 let mut status_code: u16 = 502;
262 let mut first_chunk = true;
263
264 loop {
265 let n = match tls_stream.read(&mut response_buf).await {
266 Ok(0) => break,
267 Ok(n) => n,
268 Err(e) => {
269 debug!("Upstream read error: {}", e);
270 break;
271 }
272 };
273
274 if first_chunk {
278 status_code = parse_response_status(&response_buf[..n]);
279 first_chunk = false;
280 }
281
282 stream.write_all(&response_buf[..n]).await?;
283 stream.flush().await?;
284 }
285
286 audit::log_reverse_proxy(
287 ctx.audit_log,
288 &service,
289 &method,
290 &upstream_path,
291 status_code,
292 );
293 Ok(())
294}
295
296fn parse_request_line(line: &str) -> Result<(String, String, String)> {
298 let parts: Vec<&str> = line.split_whitespace().collect();
299 if parts.len() < 3 {
300 return Err(ProxyError::HttpParse(format!(
301 "malformed request line: {}",
302 line
303 )));
304 }
305 Ok((
306 parts[0].to_string(),
307 parts[1].to_string(),
308 parts[2].to_string(),
309 ))
310}
311
312fn parse_service_prefix(path: &str) -> Result<(String, String)> {
317 let trimmed = path.strip_prefix('/').unwrap_or(path);
318 if let Some((prefix, rest)) = trimmed.split_once('/') {
319 Ok((prefix.to_string(), format!("/{}", rest)))
320 } else {
321 Ok((trimmed.to_string(), "/".to_string()))
323 }
324}
325
326fn validate_phantom_token(
333 header_bytes: &[u8],
334 header_name: &str,
335 session_token: &Zeroizing<String>,
336) -> Result<()> {
337 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
338 let header_name_lower = header_name.to_lowercase();
339
340 for line in header_str.lines() {
341 let lower = line.to_lowercase();
342 if lower.starts_with(&format!("{}:", header_name_lower)) {
343 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
344
345 let value_lower = value.to_lowercase();
348 let token_value = if value_lower.starts_with("bearer ") {
349 value[7..].trim()
351 } else {
352 value
353 };
354
355 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
356 return Ok(());
357 }
358 warn!("Invalid phantom token in {} header", header_name);
359 return Err(ProxyError::InvalidToken);
360 }
361 }
362
363 warn!(
364 "Missing {} header for phantom token validation",
365 header_name
366 );
367 Err(ProxyError::InvalidToken)
368}
369
370fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
376 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
377 let cred_header_lower = format!("{}:", cred_header.to_lowercase());
378 let mut headers = Vec::new();
379
380 for line in header_str.lines() {
381 let lower = line.to_lowercase();
382 if lower.starts_with("host:")
383 || lower.starts_with("content-length:")
384 || lower.starts_with(&cred_header_lower)
385 || line.trim().is_empty()
386 {
387 continue;
388 }
389 if let Some((name, value)) = line.split_once(':') {
390 headers.push((name.trim().to_string(), value.trim().to_string()));
391 }
392 }
393
394 headers
395}
396
397fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
399 let header_str = std::str::from_utf8(header_bytes).ok()?;
400 for line in header_str.lines() {
401 if line.to_lowercase().starts_with("content-length:") {
402 let value = line.split_once(':')?.1.trim();
403 return value.parse().ok();
404 }
405 }
406 None
407}
408
409fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
411 let parsed = url::Url::parse(url_str)
412 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
413
414 let scheme = parsed.scheme();
415 if scheme != "https" && scheme != "http" {
416 return Err(ProxyError::HttpParse(format!(
417 "unsupported URL scheme: {}",
418 url_str
419 )));
420 }
421
422 let host = parsed
423 .host_str()
424 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
425 .to_string();
426
427 let default_port = if scheme == "https" { 443 } else { 80 };
428 let port = parsed.port().unwrap_or(default_port);
429
430 let path = parsed.path().to_string();
431 let path = if path.is_empty() {
432 "/".to_string()
433 } else {
434 path
435 };
436
437 let path_with_query = if let Some(query) = parsed.query() {
439 format!("{}?{}", path, query)
440 } else {
441 path
442 };
443
444 Ok((host, port, path_with_query))
445}
446
447async fn connect_upstream_tls(
456 host: &str,
457 port: u16,
458 resolved_addrs: &[SocketAddr],
459 connector: &TlsConnector,
460) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
461 let tcp = if resolved_addrs.is_empty() {
462 let addr = format!("{}:{}", host, port);
464 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
465 Ok(Ok(s)) => s,
466 Ok(Err(e)) => {
467 return Err(ProxyError::UpstreamConnect {
468 host: host.to_string(),
469 reason: e.to_string(),
470 });
471 }
472 Err(_) => {
473 return Err(ProxyError::UpstreamConnect {
474 host: host.to_string(),
475 reason: "connection timed out".to_string(),
476 });
477 }
478 }
479 } else {
480 connect_to_resolved(resolved_addrs, host).await?
481 };
482
483 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
484 ProxyError::UpstreamConnect {
485 host: host.to_string(),
486 reason: "invalid server name for TLS".to_string(),
487 }
488 })?;
489
490 let tls_stream =
491 connector
492 .connect(server_name, tcp)
493 .await
494 .map_err(|e| ProxyError::UpstreamConnect {
495 host: host.to_string(),
496 reason: format!("TLS handshake failed: {}", e),
497 })?;
498
499 Ok(tls_stream)
500}
501
502async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
504 let mut last_err = None;
505 for addr in addrs {
506 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
507 Ok(Ok(stream)) => return Ok(stream),
508 Ok(Err(e)) => {
509 debug!("Connect to {} failed: {}", addr, e);
510 last_err = Some(e.to_string());
511 }
512 Err(_) => {
513 debug!("Connect to {} timed out", addr);
514 last_err = Some("connection timed out".to_string());
515 }
516 }
517 }
518 Err(ProxyError::UpstreamConnect {
519 host: host.to_string(),
520 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
521 })
522}
523
524fn parse_response_status(data: &[u8]) -> u16 {
530 let line_end = data
532 .iter()
533 .position(|&b| b == b'\r' || b == b'\n')
534 .unwrap_or(data.len());
535 let first_line = &data[..line_end.min(64)];
536
537 if let Ok(line) = std::str::from_utf8(first_line) {
538 let mut parts = line.split_whitespace();
540 if let Some(version) = parts.next() {
541 if version.starts_with("HTTP/") {
542 if let Some(code_str) = parts.next() {
543 if code_str.len() == 3 {
544 return code_str.parse().unwrap_or(502);
545 }
546 }
547 }
548 }
549 }
550 502
551}
552
553async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
555 let body = format!("{{\"error\":\"{}\"}}", reason);
556 let response = format!(
557 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
558 status,
559 reason,
560 body.len(),
561 body
562 );
563 stream.write_all(response.as_bytes()).await?;
564 stream.flush().await?;
565 Ok(())
566}
567
568fn validate_phantom_token_for_mode(
579 mode: &InjectMode,
580 header_bytes: &[u8],
581 path: &str,
582 header_name: &str,
583 path_pattern: Option<&str>,
584 query_param_name: Option<&str>,
585 session_token: &Zeroizing<String>,
586) -> Result<()> {
587 match mode {
588 InjectMode::Header | InjectMode::BasicAuth => {
589 validate_phantom_token(header_bytes, header_name, session_token)
591 }
592 InjectMode::UrlPath => {
593 let pattern = path_pattern.ok_or_else(|| {
595 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
596 })?;
597 validate_phantom_token_in_path(path, pattern, session_token)
598 }
599 InjectMode::QueryParam => {
600 let param_name = query_param_name.ok_or_else(|| {
602 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
603 })?;
604 validate_phantom_token_in_query(path, param_name, session_token)
605 }
606 }
607}
608
609fn validate_phantom_token_in_path(
614 path: &str,
615 pattern: &str,
616 session_token: &Zeroizing<String>,
617) -> Result<()> {
618 let parts: Vec<&str> = pattern.split("{}").collect();
620 if parts.len() != 2 {
621 return Err(ProxyError::HttpParse(format!(
622 "invalid path_pattern '{}': must contain exactly one {{}}",
623 pattern
624 )));
625 }
626 let (prefix, suffix) = (parts[0], parts[1]);
627
628 if let Some(start) = path.find(prefix) {
630 let after_prefix = start + prefix.len();
631
632 let end_offset = if suffix.is_empty() {
634 path[after_prefix..]
635 .find(['/', '?'])
636 .unwrap_or(path[after_prefix..].len())
637 } else {
638 match path[after_prefix..].find(suffix) {
639 Some(offset) => offset,
640 None => {
641 warn!("Missing phantom token in URL path (pattern: {})", pattern);
642 return Err(ProxyError::InvalidToken);
643 }
644 }
645 };
646
647 let token = &path[after_prefix..after_prefix + end_offset];
648 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
649 return Ok(());
650 }
651 warn!("Invalid phantom token in URL path");
652 return Err(ProxyError::InvalidToken);
653 }
654
655 warn!("Missing phantom token in URL path (pattern: {})", pattern);
656 Err(ProxyError::InvalidToken)
657}
658
659fn validate_phantom_token_in_query(
661 path: &str,
662 param_name: &str,
663 session_token: &Zeroizing<String>,
664) -> Result<()> {
665 if let Some(query_start) = path.find('?') {
667 let query = &path[query_start + 1..];
668 for pair in query.split('&') {
669 if let Some((name, value)) = pair.split_once('=') {
670 if name == param_name {
671 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
673 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
674 return Ok(());
675 }
676 warn!("Invalid phantom token in query parameter '{}'", param_name);
677 return Err(ProxyError::InvalidToken);
678 }
679 }
680 }
681 }
682
683 warn!("Missing phantom token in query parameter '{}'", param_name);
684 Err(ProxyError::InvalidToken)
685}
686
687fn transform_path_for_mode(
693 mode: &InjectMode,
694 path: &str,
695 path_pattern: Option<&str>,
696 path_replacement: Option<&str>,
697 query_param_name: Option<&str>,
698 credential: &Zeroizing<String>,
699) -> Result<String> {
700 match mode {
701 InjectMode::Header | InjectMode::BasicAuth => {
702 Ok(path.to_string())
704 }
705 InjectMode::UrlPath => {
706 let pattern = path_pattern.ok_or_else(|| {
707 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
708 })?;
709 let replacement = path_replacement.unwrap_or(pattern);
710 transform_url_path(path, pattern, replacement, credential)
711 }
712 InjectMode::QueryParam => {
713 let param_name = query_param_name.ok_or_else(|| {
714 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
715 })?;
716 transform_query_param(path, param_name, credential)
717 }
718 }
719}
720
721fn transform_url_path(
725 path: &str,
726 pattern: &str,
727 replacement: &str,
728 credential: &Zeroizing<String>,
729) -> Result<String> {
730 let parts: Vec<&str> = pattern.split("{}").collect();
732 if parts.len() != 2 {
733 return Err(ProxyError::HttpParse(format!(
734 "invalid path_pattern '{}': must contain exactly one {{}}",
735 pattern
736 )));
737 }
738 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
739
740 let repl_parts: Vec<&str> = replacement.split("{}").collect();
742 if repl_parts.len() != 2 {
743 return Err(ProxyError::HttpParse(format!(
744 "invalid path_replacement '{}': must contain exactly one {{}}",
745 replacement
746 )));
747 }
748 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
749
750 if let Some(start) = path.find(pattern_prefix) {
752 let after_prefix = start + pattern_prefix.len();
753
754 let end_offset = if pattern_suffix.is_empty() {
756 path[after_prefix..]
758 .find(['/', '?'])
759 .unwrap_or(path[after_prefix..].len())
760 } else {
761 match path[after_prefix..].find(pattern_suffix) {
763 Some(offset) => offset,
764 None => {
765 return Err(ProxyError::HttpParse(format!(
766 "path '{}' does not match pattern '{}'",
767 path, pattern
768 )));
769 }
770 }
771 };
772
773 let before = &path[..start];
774 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
775 return Ok(format!(
776 "{}{}{}{}{}",
777 before,
778 repl_prefix,
779 credential.as_str(),
780 repl_suffix,
781 after
782 ));
783 }
784
785 Err(ProxyError::HttpParse(format!(
786 "path '{}' does not match pattern '{}'",
787 path, pattern
788 )))
789}
790
791fn transform_query_param(
793 path: &str,
794 param_name: &str,
795 credential: &Zeroizing<String>,
796) -> Result<String> {
797 let encoded_value = urlencoding::encode(credential.as_str());
798
799 if let Some(query_start) = path.find('?') {
800 let base_path = &path[..query_start];
801 let query = &path[query_start + 1..];
802
803 let mut found = false;
805 let new_query: Vec<String> = query
806 .split('&')
807 .map(|pair| {
808 if let Some((name, _)) = pair.split_once('=') {
809 if name == param_name {
810 found = true;
811 return format!("{}={}", param_name, encoded_value);
812 }
813 }
814 pair.to_string()
815 })
816 .collect();
817
818 if found {
819 Ok(format!("{}?{}", base_path, new_query.join("&")))
820 } else {
821 Ok(format!(
823 "{}?{}&{}={}",
824 base_path, query, param_name, encoded_value
825 ))
826 }
827 } else {
828 Ok(format!("{}?{}={}", path, param_name, encoded_value))
830 }
831}
832
833fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
838 match cred.inject_mode {
839 InjectMode::Header | InjectMode::BasicAuth => {
840 request.push_str(&format!(
842 "{}: {}\r\n",
843 cred.header_name,
844 cred.header_value.as_str()
845 ));
846 }
847 InjectMode::UrlPath | InjectMode::QueryParam => {
848 }
851 }
852}
853
854#[cfg(test)]
855#[allow(clippy::unwrap_used)]
856mod tests {
857 use super::*;
858
859 #[test]
860 fn test_parse_request_line() {
861 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
862 assert_eq!(method, "POST");
863 assert_eq!(path, "/openai/v1/chat");
864 assert_eq!(version, "HTTP/1.1");
865 }
866
867 #[test]
868 fn test_parse_request_line_malformed() {
869 assert!(parse_request_line("GET").is_err());
870 }
871
872 #[test]
873 fn test_parse_service_prefix() {
874 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
875 assert_eq!(service, "openai");
876 assert_eq!(path, "/v1/chat/completions");
877 }
878
879 #[test]
880 fn test_parse_service_prefix_no_subpath() {
881 let (service, path) = parse_service_prefix("/anthropic").unwrap();
882 assert_eq!(service, "anthropic");
883 assert_eq!(path, "/");
884 }
885
886 #[test]
887 fn test_validate_phantom_token_bearer_valid() {
888 let token = Zeroizing::new("secret123".to_string());
889 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
890 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
891 }
892
893 #[test]
894 fn test_validate_phantom_token_bearer_invalid() {
895 let token = Zeroizing::new("secret123".to_string());
896 let header = b"Authorization: Bearer wrong\r\n\r\n";
897 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
898 }
899
900 #[test]
901 fn test_validate_phantom_token_x_api_key_valid() {
902 let token = Zeroizing::new("secret123".to_string());
903 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
904 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
905 }
906
907 #[test]
908 fn test_validate_phantom_token_x_goog_api_key_valid() {
909 let token = Zeroizing::new("secret123".to_string());
910 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
911 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
912 }
913
914 #[test]
915 fn test_validate_phantom_token_missing() {
916 let token = Zeroizing::new("secret123".to_string());
917 let header = b"Content-Type: application/json\r\n\r\n";
918 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
919 }
920
921 #[test]
922 fn test_validate_phantom_token_case_insensitive_header() {
923 let token = Zeroizing::new("secret123".to_string());
924 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
925 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
926 }
927
928 #[test]
929 fn test_filter_headers_removes_host_auth() {
930 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
931 let filtered = filter_headers(header, "Authorization");
932 assert_eq!(filtered.len(), 2);
933 assert_eq!(filtered[0].0, "Content-Type");
934 assert_eq!(filtered[1].0, "Accept");
935 }
936
937 #[test]
938 fn test_filter_headers_removes_x_api_key() {
939 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
940 let filtered = filter_headers(header, "x-api-key");
941 assert_eq!(filtered.len(), 1);
942 assert_eq!(filtered[0].0, "Content-Type");
943 }
944
945 #[test]
946 fn test_filter_headers_removes_custom_header() {
947 let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
948 let filtered = filter_headers(header, "PRIVATE-TOKEN");
949 assert_eq!(filtered.len(), 1);
950 assert_eq!(filtered[0].0, "Content-Type");
951 }
952
953 #[test]
954 fn test_extract_content_length() {
955 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
956 assert_eq!(extract_content_length(header), Some(42));
957 }
958
959 #[test]
960 fn test_extract_content_length_missing() {
961 let header = b"Content-Type: application/json\r\n\r\n";
962 assert_eq!(extract_content_length(header), None);
963 }
964
965 #[test]
966 fn test_parse_upstream_url_https() {
967 let (host, port, path) =
968 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
969 assert_eq!(host, "api.openai.com");
970 assert_eq!(port, 443);
971 assert_eq!(path, "/v1/chat/completions");
972 }
973
974 #[test]
975 fn test_parse_upstream_url_http_with_port() {
976 let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
977 assert_eq!(host, "localhost");
978 assert_eq!(port, 8080);
979 assert_eq!(path, "/api");
980 }
981
982 #[test]
983 fn test_parse_upstream_url_no_path() {
984 let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
985 assert_eq!(host, "api.anthropic.com");
986 assert_eq!(port, 443);
987 assert_eq!(path, "/");
988 }
989
990 #[test]
991 fn test_parse_upstream_url_invalid_scheme() {
992 assert!(parse_upstream_url("ftp://example.com").is_err());
993 }
994
995 #[test]
996 fn test_parse_response_status_200() {
997 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
998 assert_eq!(parse_response_status(data), 200);
999 }
1000
1001 #[test]
1002 fn test_parse_response_status_404() {
1003 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1004 assert_eq!(parse_response_status(data), 404);
1005 }
1006
1007 #[test]
1008 fn test_parse_response_status_garbage() {
1009 let data = b"not an http response";
1010 assert_eq!(parse_response_status(data), 502);
1011 }
1012
1013 #[test]
1014 fn test_parse_response_status_empty() {
1015 assert_eq!(parse_response_status(b""), 502);
1016 }
1017
1018 #[test]
1019 fn test_parse_response_status_partial() {
1020 let data = b"HTTP/1.1 ";
1021 assert_eq!(parse_response_status(data), 502);
1022 }
1023
1024 #[test]
1029 fn test_validate_phantom_token_in_path_valid() {
1030 let token = Zeroizing::new("session123".to_string());
1031 let path = "/bot/session123/getMe";
1032 let pattern = "/bot/{}/";
1033 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1034 }
1035
1036 #[test]
1037 fn test_validate_phantom_token_in_path_invalid() {
1038 let token = Zeroizing::new("session123".to_string());
1039 let path = "/bot/wrong_token/getMe";
1040 let pattern = "/bot/{}/";
1041 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1042 }
1043
1044 #[test]
1045 fn test_validate_phantom_token_in_path_missing() {
1046 let token = Zeroizing::new("session123".to_string());
1047 let path = "/api/getMe";
1048 let pattern = "/bot/{}/";
1049 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1050 }
1051
1052 #[test]
1053 fn test_transform_url_path_basic() {
1054 let credential = Zeroizing::new("real_token".to_string());
1055 let path = "/bot/phantom_token/getMe";
1056 let pattern = "/bot/{}/";
1057 let replacement = "/bot/{}/";
1058 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1059 assert_eq!(result, "/bot/real_token/getMe");
1060 }
1061
1062 #[test]
1063 fn test_transform_url_path_different_replacement() {
1064 let credential = Zeroizing::new("real_token".to_string());
1065 let path = "/api/v1/phantom_token/chat";
1066 let pattern = "/api/v1/{}/";
1067 let replacement = "/v2/bot/{}/";
1068 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1069 assert_eq!(result, "/v2/bot/real_token/chat");
1070 }
1071
1072 #[test]
1073 fn test_transform_url_path_no_trailing_slash() {
1074 let credential = Zeroizing::new("real_token".to_string());
1075 let path = "/bot/phantom_token";
1076 let pattern = "/bot/{}";
1077 let replacement = "/bot/{}";
1078 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1079 assert_eq!(result, "/bot/real_token");
1080 }
1081
1082 #[test]
1087 fn test_validate_phantom_token_in_query_valid() {
1088 let token = Zeroizing::new("session123".to_string());
1089 let path = "/api/data?api_key=session123&other=value";
1090 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1091 }
1092
1093 #[test]
1094 fn test_validate_phantom_token_in_query_invalid() {
1095 let token = Zeroizing::new("session123".to_string());
1096 let path = "/api/data?api_key=wrong_token";
1097 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1098 }
1099
1100 #[test]
1101 fn test_validate_phantom_token_in_query_missing_param() {
1102 let token = Zeroizing::new("session123".to_string());
1103 let path = "/api/data?other=value";
1104 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1105 }
1106
1107 #[test]
1108 fn test_validate_phantom_token_in_query_no_query_string() {
1109 let token = Zeroizing::new("session123".to_string());
1110 let path = "/api/data";
1111 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1112 }
1113
1114 #[test]
1115 fn test_validate_phantom_token_in_query_url_encoded() {
1116 let token = Zeroizing::new("token with spaces".to_string());
1117 let path = "/api/data?api_key=token%20with%20spaces";
1118 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1119 }
1120
1121 #[test]
1122 fn test_transform_query_param_add_to_no_query() {
1123 let credential = Zeroizing::new("real_key".to_string());
1124 let path = "/api/data";
1125 let result = transform_query_param(path, "api_key", &credential).unwrap();
1126 assert_eq!(result, "/api/data?api_key=real_key");
1127 }
1128
1129 #[test]
1130 fn test_transform_query_param_add_to_existing_query() {
1131 let credential = Zeroizing::new("real_key".to_string());
1132 let path = "/api/data?other=value";
1133 let result = transform_query_param(path, "api_key", &credential).unwrap();
1134 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1135 }
1136
1137 #[test]
1138 fn test_transform_query_param_replace_existing() {
1139 let credential = Zeroizing::new("real_key".to_string());
1140 let path = "/api/data?api_key=phantom&other=value";
1141 let result = transform_query_param(path, "api_key", &credential).unwrap();
1142 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1143 }
1144
1145 #[test]
1146 fn test_transform_query_param_url_encodes_special_chars() {
1147 let credential = Zeroizing::new("key with spaces".to_string());
1148 let path = "/api/data";
1149 let result = transform_query_param(path, "api_key", &credential).unwrap();
1150 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1151 }
1152}