1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct ReverseProxyCtx<'a> {
43 pub credential_store: &'a CredentialStore,
45 pub session_token: &'a Zeroizing<String>,
47 pub filter: &'a ProxyFilter,
49 pub tls_connector: &'a TlsConnector,
51 pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55pub async fn handle_reverse_proxy(
69 first_line: &str,
70 stream: &mut TcpStream,
71 remaining_header: &[u8],
72 ctx: &ReverseProxyCtx<'_>,
73 buffered_body: &[u8],
74) -> Result<()> {
75 let (method, path, version) = parse_request_line(first_line)?;
77 debug!("Reverse proxy: {} {}", method, path);
78
79 let (service, upstream_path) = parse_service_prefix(&path)?;
81
82 let cred = ctx
84 .credential_store
85 .get(&service)
86 .ok_or_else(|| ProxyError::UnknownService {
87 prefix: service.clone(),
88 })?;
89
90 if !cred.endpoint_rules.is_allowed(&method, &upstream_path) {
93 let reason = format!(
94 "endpoint denied: {} {} on service '{}'",
95 method, upstream_path, service
96 );
97 warn!("{}", reason);
98 audit::log_denied(
99 ctx.audit_log,
100 audit::ProxyMode::Reverse,
101 &service,
102 0,
103 &reason,
104 );
105 send_error(stream, 403, "Forbidden").await?;
106 return Ok(());
107 }
108
109 if let Err(e) = validate_phantom_token_for_mode(
114 &cred.inject_mode,
115 remaining_header,
116 &upstream_path,
117 &cred.header_name,
118 cred.path_pattern.as_deref(),
119 cred.query_param_name.as_deref(),
120 ctx.session_token,
121 ) {
122 audit::log_denied(
123 ctx.audit_log,
124 audit::ProxyMode::Reverse,
125 &service,
126 0,
127 &e.to_string(),
128 );
129 send_error(stream, 401, "Unauthorized").await?;
130 return Ok(());
131 }
132
133 let transformed_path = transform_path_for_mode(
135 &cred.inject_mode,
136 &upstream_path,
137 cred.path_pattern.as_deref(),
138 cred.path_replacement.as_deref(),
139 cred.query_param_name.as_deref(),
140 &cred.raw_credential,
141 )?;
142
143 let upstream_url = format!(
145 "{}{}",
146 cred.upstream.trim_end_matches('/'),
147 transformed_path
148 );
149 debug!("Forwarding to upstream: {} {}", method, upstream_url);
150
151 let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
152
153 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
155 if !check.result.is_allowed() {
156 let reason = check.result.reason();
157 warn!("Upstream host denied by filter: {}", reason);
158 send_error(stream, 403, "Forbidden").await?;
159 audit::log_denied(
160 ctx.audit_log,
161 audit::ProxyMode::Reverse,
162 &service,
163 0,
164 &reason,
165 );
166 return Ok(());
167 }
168
169 let filtered_headers = filter_headers(remaining_header);
171 let content_length = extract_content_length(remaining_header);
172
173 let body = if let Some(len) = content_length {
177 if len > MAX_REQUEST_BODY {
178 send_error(stream, 413, "Payload Too Large").await?;
179 return Ok(());
180 }
181 let mut buf = Vec::with_capacity(len);
182 let pre = buffered_body.len().min(len);
183 buf.extend_from_slice(&buffered_body[..pre]);
184 let remaining = len - pre;
185 if remaining > 0 {
186 let mut rest = vec![0u8; remaining];
187 stream.read_exact(&mut rest).await?;
188 buf.extend_from_slice(&rest);
189 }
190 buf
191 } else {
192 Vec::new()
193 };
194
195 let upstream_result = connect_upstream_tls(
197 &upstream_host,
198 upstream_port,
199 &check.resolved_addrs,
200 ctx.tls_connector,
201 )
202 .await;
203 let mut tls_stream = match upstream_result {
204 Ok(s) => s,
205 Err(e) => {
206 warn!("Upstream connection failed: {}", e);
207 send_error(stream, 502, "Bad Gateway").await?;
208 audit::log_denied(
209 ctx.audit_log,
210 audit::ProxyMode::Reverse,
211 &service,
212 0,
213 &e.to_string(),
214 );
215 return Ok(());
216 }
217 };
218
219 let mut request = Zeroizing::new(format!(
223 "{} {} {}\r\nHost: {}\r\n",
224 method, upstream_path_full, version, upstream_host
225 ));
226
227 inject_credential_for_mode(cred, &mut request);
229
230 let auth_header_lower = cred.header_name.to_lowercase();
232 for (name, value) in &filtered_headers {
233 if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
236 && name.to_lowercase() == auth_header_lower
237 {
238 continue;
239 }
240 request.push_str(&format!("{}: {}\r\n", name, value));
241 }
242
243 if !body.is_empty() {
245 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
246 }
247 request.push_str("\r\n");
248
249 tls_stream.write_all(request.as_bytes()).await?;
250 if !body.is_empty() {
251 tls_stream.write_all(&body).await?;
252 }
253 tls_stream.flush().await?;
254
255 let mut response_buf = [0u8; 8192];
258 let mut status_code: u16 = 502;
259 let mut first_chunk = true;
260
261 loop {
262 let n = match tls_stream.read(&mut response_buf).await {
263 Ok(0) => break,
264 Ok(n) => n,
265 Err(e) => {
266 debug!("Upstream read error: {}", e);
267 break;
268 }
269 };
270
271 if first_chunk {
275 status_code = parse_response_status(&response_buf[..n]);
276 first_chunk = false;
277 }
278
279 stream.write_all(&response_buf[..n]).await?;
280 stream.flush().await?;
281 }
282
283 audit::log_reverse_proxy(
284 ctx.audit_log,
285 &service,
286 &method,
287 &upstream_path,
288 status_code,
289 );
290 Ok(())
291}
292
293fn parse_request_line(line: &str) -> Result<(String, String, String)> {
295 let parts: Vec<&str> = line.split_whitespace().collect();
296 if parts.len() < 3 {
297 return Err(ProxyError::HttpParse(format!(
298 "malformed request line: {}",
299 line
300 )));
301 }
302 Ok((
303 parts[0].to_string(),
304 parts[1].to_string(),
305 parts[2].to_string(),
306 ))
307}
308
309fn parse_service_prefix(path: &str) -> Result<(String, String)> {
314 let trimmed = path.strip_prefix('/').unwrap_or(path);
315 if let Some((prefix, rest)) = trimmed.split_once('/') {
316 Ok((prefix.to_string(), format!("/{}", rest)))
317 } else {
318 Ok((trimmed.to_string(), "/".to_string()))
320 }
321}
322
323fn validate_phantom_token(
330 header_bytes: &[u8],
331 header_name: &str,
332 session_token: &Zeroizing<String>,
333) -> Result<()> {
334 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
335 let header_name_lower = header_name.to_lowercase();
336
337 for line in header_str.lines() {
338 let lower = line.to_lowercase();
339 if lower.starts_with(&format!("{}:", header_name_lower)) {
340 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
341
342 let value_lower = value.to_lowercase();
345 let token_value = if value_lower.starts_with("bearer ") {
346 value[7..].trim()
348 } else {
349 value
350 };
351
352 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
353 return Ok(());
354 }
355 warn!("Invalid phantom token in {} header", header_name);
356 return Err(ProxyError::InvalidToken);
357 }
358 }
359
360 warn!(
361 "Missing {} header for phantom token validation",
362 header_name
363 );
364 Err(ProxyError::InvalidToken)
365}
366
367fn filter_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
373 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
374 let mut headers = Vec::new();
375
376 for line in header_str.lines() {
377 let lower = line.to_lowercase();
378 if lower.starts_with("host:")
379 || lower.starts_with("content-length:")
380 || lower.starts_with("authorization:")
381 || lower.starts_with("x-api-key:")
382 || lower.starts_with("x-goog-api-key:")
383 || line.trim().is_empty()
384 {
385 continue;
386 }
387 if let Some((name, value)) = line.split_once(':') {
388 headers.push((name.trim().to_string(), value.trim().to_string()));
389 }
390 }
391
392 headers
393}
394
395fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
397 let header_str = std::str::from_utf8(header_bytes).ok()?;
398 for line in header_str.lines() {
399 if line.to_lowercase().starts_with("content-length:") {
400 let value = line.split_once(':')?.1.trim();
401 return value.parse().ok();
402 }
403 }
404 None
405}
406
407fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
409 let parsed = url::Url::parse(url_str)
410 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
411
412 let scheme = parsed.scheme();
413 if scheme != "https" && scheme != "http" {
414 return Err(ProxyError::HttpParse(format!(
415 "unsupported URL scheme: {}",
416 url_str
417 )));
418 }
419
420 let host = parsed
421 .host_str()
422 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
423 .to_string();
424
425 let default_port = if scheme == "https" { 443 } else { 80 };
426 let port = parsed.port().unwrap_or(default_port);
427
428 let path = parsed.path().to_string();
429 let path = if path.is_empty() {
430 "/".to_string()
431 } else {
432 path
433 };
434
435 let path_with_query = if let Some(query) = parsed.query() {
437 format!("{}?{}", path, query)
438 } else {
439 path
440 };
441
442 Ok((host, port, path_with_query))
443}
444
445async fn connect_upstream_tls(
454 host: &str,
455 port: u16,
456 resolved_addrs: &[SocketAddr],
457 connector: &TlsConnector,
458) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
459 let tcp = if resolved_addrs.is_empty() {
460 let addr = format!("{}:{}", host, port);
462 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
463 Ok(Ok(s)) => s,
464 Ok(Err(e)) => {
465 return Err(ProxyError::UpstreamConnect {
466 host: host.to_string(),
467 reason: e.to_string(),
468 });
469 }
470 Err(_) => {
471 return Err(ProxyError::UpstreamConnect {
472 host: host.to_string(),
473 reason: "connection timed out".to_string(),
474 });
475 }
476 }
477 } else {
478 connect_to_resolved(resolved_addrs, host).await?
479 };
480
481 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
482 ProxyError::UpstreamConnect {
483 host: host.to_string(),
484 reason: "invalid server name for TLS".to_string(),
485 }
486 })?;
487
488 let tls_stream =
489 connector
490 .connect(server_name, tcp)
491 .await
492 .map_err(|e| ProxyError::UpstreamConnect {
493 host: host.to_string(),
494 reason: format!("TLS handshake failed: {}", e),
495 })?;
496
497 Ok(tls_stream)
498}
499
500async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
502 let mut last_err = None;
503 for addr in addrs {
504 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
505 Ok(Ok(stream)) => return Ok(stream),
506 Ok(Err(e)) => {
507 debug!("Connect to {} failed: {}", addr, e);
508 last_err = Some(e.to_string());
509 }
510 Err(_) => {
511 debug!("Connect to {} timed out", addr);
512 last_err = Some("connection timed out".to_string());
513 }
514 }
515 }
516 Err(ProxyError::UpstreamConnect {
517 host: host.to_string(),
518 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
519 })
520}
521
522fn parse_response_status(data: &[u8]) -> u16 {
528 let line_end = data
530 .iter()
531 .position(|&b| b == b'\r' || b == b'\n')
532 .unwrap_or(data.len());
533 let first_line = &data[..line_end.min(64)];
534
535 if let Ok(line) = std::str::from_utf8(first_line) {
536 let mut parts = line.split_whitespace();
538 if let Some(version) = parts.next() {
539 if version.starts_with("HTTP/") {
540 if let Some(code_str) = parts.next() {
541 if code_str.len() == 3 {
542 return code_str.parse().unwrap_or(502);
543 }
544 }
545 }
546 }
547 }
548 502
549}
550
551async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
553 let body = format!("{{\"error\":\"{}\"}}", reason);
554 let response = format!(
555 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
556 status,
557 reason,
558 body.len(),
559 body
560 );
561 stream.write_all(response.as_bytes()).await?;
562 stream.flush().await?;
563 Ok(())
564}
565
566fn validate_phantom_token_for_mode(
577 mode: &InjectMode,
578 header_bytes: &[u8],
579 path: &str,
580 header_name: &str,
581 path_pattern: Option<&str>,
582 query_param_name: Option<&str>,
583 session_token: &Zeroizing<String>,
584) -> Result<()> {
585 match mode {
586 InjectMode::Header | InjectMode::BasicAuth => {
587 validate_phantom_token(header_bytes, header_name, session_token)
589 }
590 InjectMode::UrlPath => {
591 let pattern = path_pattern.ok_or_else(|| {
593 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
594 })?;
595 validate_phantom_token_in_path(path, pattern, session_token)
596 }
597 InjectMode::QueryParam => {
598 let param_name = query_param_name.ok_or_else(|| {
600 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
601 })?;
602 validate_phantom_token_in_query(path, param_name, session_token)
603 }
604 }
605}
606
607fn validate_phantom_token_in_path(
612 path: &str,
613 pattern: &str,
614 session_token: &Zeroizing<String>,
615) -> Result<()> {
616 let parts: Vec<&str> = pattern.split("{}").collect();
618 if parts.len() != 2 {
619 return Err(ProxyError::HttpParse(format!(
620 "invalid path_pattern '{}': must contain exactly one {{}}",
621 pattern
622 )));
623 }
624 let (prefix, suffix) = (parts[0], parts[1]);
625
626 if let Some(start) = path.find(prefix) {
628 let after_prefix = start + prefix.len();
629
630 let end_offset = if suffix.is_empty() {
632 path[after_prefix..]
633 .find(['/', '?'])
634 .unwrap_or(path[after_prefix..].len())
635 } else {
636 match path[after_prefix..].find(suffix) {
637 Some(offset) => offset,
638 None => {
639 warn!("Missing phantom token in URL path (pattern: {})", pattern);
640 return Err(ProxyError::InvalidToken);
641 }
642 }
643 };
644
645 let token = &path[after_prefix..after_prefix + end_offset];
646 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
647 return Ok(());
648 }
649 warn!("Invalid phantom token in URL path");
650 return Err(ProxyError::InvalidToken);
651 }
652
653 warn!("Missing phantom token in URL path (pattern: {})", pattern);
654 Err(ProxyError::InvalidToken)
655}
656
657fn validate_phantom_token_in_query(
659 path: &str,
660 param_name: &str,
661 session_token: &Zeroizing<String>,
662) -> Result<()> {
663 if let Some(query_start) = path.find('?') {
665 let query = &path[query_start + 1..];
666 for pair in query.split('&') {
667 if let Some((name, value)) = pair.split_once('=') {
668 if name == param_name {
669 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
671 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
672 return Ok(());
673 }
674 warn!("Invalid phantom token in query parameter '{}'", param_name);
675 return Err(ProxyError::InvalidToken);
676 }
677 }
678 }
679 }
680
681 warn!("Missing phantom token in query parameter '{}'", param_name);
682 Err(ProxyError::InvalidToken)
683}
684
685fn transform_path_for_mode(
691 mode: &InjectMode,
692 path: &str,
693 path_pattern: Option<&str>,
694 path_replacement: Option<&str>,
695 query_param_name: Option<&str>,
696 credential: &Zeroizing<String>,
697) -> Result<String> {
698 match mode {
699 InjectMode::Header | InjectMode::BasicAuth => {
700 Ok(path.to_string())
702 }
703 InjectMode::UrlPath => {
704 let pattern = path_pattern.ok_or_else(|| {
705 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
706 })?;
707 let replacement = path_replacement.unwrap_or(pattern);
708 transform_url_path(path, pattern, replacement, credential)
709 }
710 InjectMode::QueryParam => {
711 let param_name = query_param_name.ok_or_else(|| {
712 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
713 })?;
714 transform_query_param(path, param_name, credential)
715 }
716 }
717}
718
719fn transform_url_path(
723 path: &str,
724 pattern: &str,
725 replacement: &str,
726 credential: &Zeroizing<String>,
727) -> Result<String> {
728 let parts: Vec<&str> = pattern.split("{}").collect();
730 if parts.len() != 2 {
731 return Err(ProxyError::HttpParse(format!(
732 "invalid path_pattern '{}': must contain exactly one {{}}",
733 pattern
734 )));
735 }
736 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
737
738 let repl_parts: Vec<&str> = replacement.split("{}").collect();
740 if repl_parts.len() != 2 {
741 return Err(ProxyError::HttpParse(format!(
742 "invalid path_replacement '{}': must contain exactly one {{}}",
743 replacement
744 )));
745 }
746 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
747
748 if let Some(start) = path.find(pattern_prefix) {
750 let after_prefix = start + pattern_prefix.len();
751
752 let end_offset = if pattern_suffix.is_empty() {
754 path[after_prefix..]
756 .find(['/', '?'])
757 .unwrap_or(path[after_prefix..].len())
758 } else {
759 match path[after_prefix..].find(pattern_suffix) {
761 Some(offset) => offset,
762 None => {
763 return Err(ProxyError::HttpParse(format!(
764 "path '{}' does not match pattern '{}'",
765 path, pattern
766 )));
767 }
768 }
769 };
770
771 let before = &path[..start];
772 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
773 return Ok(format!(
774 "{}{}{}{}{}",
775 before,
776 repl_prefix,
777 credential.as_str(),
778 repl_suffix,
779 after
780 ));
781 }
782
783 Err(ProxyError::HttpParse(format!(
784 "path '{}' does not match pattern '{}'",
785 path, pattern
786 )))
787}
788
789fn transform_query_param(
791 path: &str,
792 param_name: &str,
793 credential: &Zeroizing<String>,
794) -> Result<String> {
795 let encoded_value = urlencoding::encode(credential.as_str());
796
797 if let Some(query_start) = path.find('?') {
798 let base_path = &path[..query_start];
799 let query = &path[query_start + 1..];
800
801 let mut found = false;
803 let new_query: Vec<String> = query
804 .split('&')
805 .map(|pair| {
806 if let Some((name, _)) = pair.split_once('=') {
807 if name == param_name {
808 found = true;
809 return format!("{}={}", param_name, encoded_value);
810 }
811 }
812 pair.to_string()
813 })
814 .collect();
815
816 if found {
817 Ok(format!("{}?{}", base_path, new_query.join("&")))
818 } else {
819 Ok(format!(
821 "{}?{}&{}={}",
822 base_path, query, param_name, encoded_value
823 ))
824 }
825 } else {
826 Ok(format!("{}?{}={}", path, param_name, encoded_value))
828 }
829}
830
831fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
836 match cred.inject_mode {
837 InjectMode::Header | InjectMode::BasicAuth => {
838 request.push_str(&format!(
840 "{}: {}\r\n",
841 cred.header_name,
842 cred.header_value.as_str()
843 ));
844 }
845 InjectMode::UrlPath | InjectMode::QueryParam => {
846 }
849 }
850}
851
852#[cfg(test)]
853#[allow(clippy::unwrap_used)]
854mod tests {
855 use super::*;
856
857 #[test]
858 fn test_parse_request_line() {
859 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
860 assert_eq!(method, "POST");
861 assert_eq!(path, "/openai/v1/chat");
862 assert_eq!(version, "HTTP/1.1");
863 }
864
865 #[test]
866 fn test_parse_request_line_malformed() {
867 assert!(parse_request_line("GET").is_err());
868 }
869
870 #[test]
871 fn test_parse_service_prefix() {
872 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
873 assert_eq!(service, "openai");
874 assert_eq!(path, "/v1/chat/completions");
875 }
876
877 #[test]
878 fn test_parse_service_prefix_no_subpath() {
879 let (service, path) = parse_service_prefix("/anthropic").unwrap();
880 assert_eq!(service, "anthropic");
881 assert_eq!(path, "/");
882 }
883
884 #[test]
885 fn test_validate_phantom_token_bearer_valid() {
886 let token = Zeroizing::new("secret123".to_string());
887 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
888 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
889 }
890
891 #[test]
892 fn test_validate_phantom_token_bearer_invalid() {
893 let token = Zeroizing::new("secret123".to_string());
894 let header = b"Authorization: Bearer wrong\r\n\r\n";
895 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
896 }
897
898 #[test]
899 fn test_validate_phantom_token_x_api_key_valid() {
900 let token = Zeroizing::new("secret123".to_string());
901 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
902 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
903 }
904
905 #[test]
906 fn test_validate_phantom_token_x_goog_api_key_valid() {
907 let token = Zeroizing::new("secret123".to_string());
908 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
909 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
910 }
911
912 #[test]
913 fn test_validate_phantom_token_missing() {
914 let token = Zeroizing::new("secret123".to_string());
915 let header = b"Content-Type: application/json\r\n\r\n";
916 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
917 }
918
919 #[test]
920 fn test_validate_phantom_token_case_insensitive_header() {
921 let token = Zeroizing::new("secret123".to_string());
922 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
923 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
924 }
925
926 #[test]
927 fn test_filter_headers_removes_host_auth() {
928 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
929 let filtered = filter_headers(header);
930 assert_eq!(filtered.len(), 2);
931 assert_eq!(filtered[0].0, "Content-Type");
932 assert_eq!(filtered[1].0, "Accept");
933 }
934
935 #[test]
936 fn test_filter_headers_removes_x_api_key() {
937 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
938 let filtered = filter_headers(header);
939 assert_eq!(filtered.len(), 1);
940 assert_eq!(filtered[0].0, "Content-Type");
941 }
942
943 #[test]
944 fn test_filter_headers_removes_x_goog_api_key() {
945 let header = b"x-goog-api-key: gemini-key\r\nContent-Type: application/json\r\n\r\n";
946 let filtered = filter_headers(header);
947 assert_eq!(filtered.len(), 1);
948 assert_eq!(filtered[0].0, "Content-Type");
949 }
950
951 #[test]
952 fn test_extract_content_length() {
953 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
954 assert_eq!(extract_content_length(header), Some(42));
955 }
956
957 #[test]
958 fn test_extract_content_length_missing() {
959 let header = b"Content-Type: application/json\r\n\r\n";
960 assert_eq!(extract_content_length(header), None);
961 }
962
963 #[test]
964 fn test_parse_upstream_url_https() {
965 let (host, port, path) =
966 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
967 assert_eq!(host, "api.openai.com");
968 assert_eq!(port, 443);
969 assert_eq!(path, "/v1/chat/completions");
970 }
971
972 #[test]
973 fn test_parse_upstream_url_http_with_port() {
974 let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
975 assert_eq!(host, "localhost");
976 assert_eq!(port, 8080);
977 assert_eq!(path, "/api");
978 }
979
980 #[test]
981 fn test_parse_upstream_url_no_path() {
982 let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
983 assert_eq!(host, "api.anthropic.com");
984 assert_eq!(port, 443);
985 assert_eq!(path, "/");
986 }
987
988 #[test]
989 fn test_parse_upstream_url_invalid_scheme() {
990 assert!(parse_upstream_url("ftp://example.com").is_err());
991 }
992
993 #[test]
994 fn test_parse_response_status_200() {
995 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
996 assert_eq!(parse_response_status(data), 200);
997 }
998
999 #[test]
1000 fn test_parse_response_status_404() {
1001 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1002 assert_eq!(parse_response_status(data), 404);
1003 }
1004
1005 #[test]
1006 fn test_parse_response_status_garbage() {
1007 let data = b"not an http response";
1008 assert_eq!(parse_response_status(data), 502);
1009 }
1010
1011 #[test]
1012 fn test_parse_response_status_empty() {
1013 assert_eq!(parse_response_status(b""), 502);
1014 }
1015
1016 #[test]
1017 fn test_parse_response_status_partial() {
1018 let data = b"HTTP/1.1 ";
1019 assert_eq!(parse_response_status(data), 502);
1020 }
1021
1022 #[test]
1027 fn test_validate_phantom_token_in_path_valid() {
1028 let token = Zeroizing::new("session123".to_string());
1029 let path = "/bot/session123/getMe";
1030 let pattern = "/bot/{}/";
1031 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1032 }
1033
1034 #[test]
1035 fn test_validate_phantom_token_in_path_invalid() {
1036 let token = Zeroizing::new("session123".to_string());
1037 let path = "/bot/wrong_token/getMe";
1038 let pattern = "/bot/{}/";
1039 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1040 }
1041
1042 #[test]
1043 fn test_validate_phantom_token_in_path_missing() {
1044 let token = Zeroizing::new("session123".to_string());
1045 let path = "/api/getMe";
1046 let pattern = "/bot/{}/";
1047 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1048 }
1049
1050 #[test]
1051 fn test_transform_url_path_basic() {
1052 let credential = Zeroizing::new("real_token".to_string());
1053 let path = "/bot/phantom_token/getMe";
1054 let pattern = "/bot/{}/";
1055 let replacement = "/bot/{}/";
1056 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1057 assert_eq!(result, "/bot/real_token/getMe");
1058 }
1059
1060 #[test]
1061 fn test_transform_url_path_different_replacement() {
1062 let credential = Zeroizing::new("real_token".to_string());
1063 let path = "/api/v1/phantom_token/chat";
1064 let pattern = "/api/v1/{}/";
1065 let replacement = "/v2/bot/{}/";
1066 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1067 assert_eq!(result, "/v2/bot/real_token/chat");
1068 }
1069
1070 #[test]
1071 fn test_transform_url_path_no_trailing_slash() {
1072 let credential = Zeroizing::new("real_token".to_string());
1073 let path = "/bot/phantom_token";
1074 let pattern = "/bot/{}";
1075 let replacement = "/bot/{}";
1076 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1077 assert_eq!(result, "/bot/real_token");
1078 }
1079
1080 #[test]
1085 fn test_validate_phantom_token_in_query_valid() {
1086 let token = Zeroizing::new("session123".to_string());
1087 let path = "/api/data?api_key=session123&other=value";
1088 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1089 }
1090
1091 #[test]
1092 fn test_validate_phantom_token_in_query_invalid() {
1093 let token = Zeroizing::new("session123".to_string());
1094 let path = "/api/data?api_key=wrong_token";
1095 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1096 }
1097
1098 #[test]
1099 fn test_validate_phantom_token_in_query_missing_param() {
1100 let token = Zeroizing::new("session123".to_string());
1101 let path = "/api/data?other=value";
1102 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1103 }
1104
1105 #[test]
1106 fn test_validate_phantom_token_in_query_no_query_string() {
1107 let token = Zeroizing::new("session123".to_string());
1108 let path = "/api/data";
1109 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1110 }
1111
1112 #[test]
1113 fn test_validate_phantom_token_in_query_url_encoded() {
1114 let token = Zeroizing::new("token with spaces".to_string());
1115 let path = "/api/data?api_key=token%20with%20spaces";
1116 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1117 }
1118
1119 #[test]
1120 fn test_transform_query_param_add_to_no_query() {
1121 let credential = Zeroizing::new("real_key".to_string());
1122 let path = "/api/data";
1123 let result = transform_query_param(path, "api_key", &credential).unwrap();
1124 assert_eq!(result, "/api/data?api_key=real_key");
1125 }
1126
1127 #[test]
1128 fn test_transform_query_param_add_to_existing_query() {
1129 let credential = Zeroizing::new("real_key".to_string());
1130 let path = "/api/data?other=value";
1131 let result = transform_query_param(path, "api_key", &credential).unwrap();
1132 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1133 }
1134
1135 #[test]
1136 fn test_transform_query_param_replace_existing() {
1137 let credential = Zeroizing::new("real_key".to_string());
1138 let path = "/api/data?api_key=phantom&other=value";
1139 let result = transform_query_param(path, "api_key", &credential).unwrap();
1140 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1141 }
1142
1143 #[test]
1144 fn test_transform_query_param_url_encodes_special_chars() {
1145 let credential = Zeroizing::new("key with spaces".to_string());
1146 let path = "/api/data";
1147 let result = transform_query_param(path, "api_key", &credential).unwrap();
1148 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1149 }
1150}