1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct ReverseProxyCtx<'a> {
43 pub credential_store: &'a CredentialStore,
45 pub session_token: &'a Zeroizing<String>,
47 pub filter: &'a ProxyFilter,
49 pub tls_connector: &'a TlsConnector,
51}
52
53pub async fn handle_reverse_proxy(
67 first_line: &str,
68 stream: &mut TcpStream,
69 remaining_header: &[u8],
70 ctx: &ReverseProxyCtx<'_>,
71 buffered_body: &[u8],
72) -> Result<()> {
73 let (method, path, version) = parse_request_line(first_line)?;
75 debug!("Reverse proxy: {} {}", method, path);
76
77 let (service, upstream_path) = parse_service_prefix(&path)?;
79
80 let cred = ctx
82 .credential_store
83 .get(&service)
84 .ok_or_else(|| ProxyError::UnknownService {
85 prefix: service.clone(),
86 })?;
87
88 if let Err(e) = validate_phantom_token_for_mode(
93 &cred.inject_mode,
94 remaining_header,
95 &upstream_path,
96 &cred.header_name,
97 cred.path_pattern.as_deref(),
98 cred.query_param_name.as_deref(),
99 ctx.session_token,
100 ) {
101 audit::log_denied(audit::ProxyMode::Reverse, &service, 0, &e.to_string());
102 send_error(stream, 401, "Unauthorized").await?;
103 return Ok(());
104 }
105
106 let transformed_path = transform_path_for_mode(
108 &cred.inject_mode,
109 &upstream_path,
110 cred.path_pattern.as_deref(),
111 cred.path_replacement.as_deref(),
112 cred.query_param_name.as_deref(),
113 &cred.raw_credential,
114 )?;
115
116 let upstream_url = format!(
118 "{}{}",
119 cred.upstream.trim_end_matches('/'),
120 transformed_path
121 );
122 debug!("Forwarding to upstream: {} {}", method, upstream_url);
123
124 let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
125
126 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
128 if !check.result.is_allowed() {
129 let reason = check.result.reason();
130 warn!("Upstream host denied by filter: {}", reason);
131 send_error(stream, 403, "Forbidden").await?;
132 audit::log_denied(audit::ProxyMode::Reverse, &service, 0, &reason);
133 return Ok(());
134 }
135
136 let filtered_headers = filter_headers(remaining_header);
138 let content_length = extract_content_length(remaining_header);
139
140 let body = if let Some(len) = content_length {
144 if len > MAX_REQUEST_BODY {
145 send_error(stream, 413, "Payload Too Large").await?;
146 return Ok(());
147 }
148 let mut buf = Vec::with_capacity(len);
149 let pre = buffered_body.len().min(len);
150 buf.extend_from_slice(&buffered_body[..pre]);
151 let remaining = len - pre;
152 if remaining > 0 {
153 let mut rest = vec![0u8; remaining];
154 stream.read_exact(&mut rest).await?;
155 buf.extend_from_slice(&rest);
156 }
157 buf
158 } else {
159 Vec::new()
160 };
161
162 let upstream_result = connect_upstream_tls(
164 &upstream_host,
165 upstream_port,
166 &check.resolved_addrs,
167 ctx.tls_connector,
168 )
169 .await;
170 let mut tls_stream = match upstream_result {
171 Ok(s) => s,
172 Err(e) => {
173 warn!("Upstream connection failed: {}", e);
174 send_error(stream, 502, "Bad Gateway").await?;
175 audit::log_denied(audit::ProxyMode::Reverse, &service, 0, &e.to_string());
176 return Ok(());
177 }
178 };
179
180 let mut request = Zeroizing::new(format!(
184 "{} {} {}\r\nHost: {}\r\n",
185 method, upstream_path_full, version, upstream_host
186 ));
187
188 inject_credential_for_mode(cred, &mut request);
190
191 let auth_header_lower = cred.header_name.to_lowercase();
193 for (name, value) in &filtered_headers {
194 if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
197 && name.to_lowercase() == auth_header_lower
198 {
199 continue;
200 }
201 request.push_str(&format!("{}: {}\r\n", name, value));
202 }
203
204 if !body.is_empty() {
206 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
207 }
208 request.push_str("\r\n");
209
210 tls_stream.write_all(request.as_bytes()).await?;
211 if !body.is_empty() {
212 tls_stream.write_all(&body).await?;
213 }
214 tls_stream.flush().await?;
215
216 let mut response_buf = [0u8; 8192];
219 let mut status_code: u16 = 502;
220 let mut first_chunk = true;
221
222 loop {
223 let n = match tls_stream.read(&mut response_buf).await {
224 Ok(0) => break,
225 Ok(n) => n,
226 Err(e) => {
227 debug!("Upstream read error: {}", e);
228 break;
229 }
230 };
231
232 if first_chunk {
236 status_code = parse_response_status(&response_buf[..n]);
237 first_chunk = false;
238 }
239
240 stream.write_all(&response_buf[..n]).await?;
241 stream.flush().await?;
242 }
243
244 audit::log_reverse_proxy(&service, &method, &upstream_path, status_code);
245 Ok(())
246}
247
248fn parse_request_line(line: &str) -> Result<(String, String, String)> {
250 let parts: Vec<&str> = line.split_whitespace().collect();
251 if parts.len() < 3 {
252 return Err(ProxyError::HttpParse(format!(
253 "malformed request line: {}",
254 line
255 )));
256 }
257 Ok((
258 parts[0].to_string(),
259 parts[1].to_string(),
260 parts[2].to_string(),
261 ))
262}
263
264fn parse_service_prefix(path: &str) -> Result<(String, String)> {
269 let trimmed = path.strip_prefix('/').unwrap_or(path);
270 if let Some((prefix, rest)) = trimmed.split_once('/') {
271 Ok((prefix.to_string(), format!("/{}", rest)))
272 } else {
273 Ok((trimmed.to_string(), "/".to_string()))
275 }
276}
277
278fn validate_phantom_token(
285 header_bytes: &[u8],
286 header_name: &str,
287 session_token: &Zeroizing<String>,
288) -> Result<()> {
289 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
290 let header_name_lower = header_name.to_lowercase();
291
292 for line in header_str.lines() {
293 let lower = line.to_lowercase();
294 if lower.starts_with(&format!("{}:", header_name_lower)) {
295 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
296
297 let value_lower = value.to_lowercase();
300 let token_value = if value_lower.starts_with("bearer ") {
301 value[7..].trim()
303 } else {
304 value
305 };
306
307 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
308 return Ok(());
309 }
310 warn!("Invalid phantom token in {} header", header_name);
311 return Err(ProxyError::InvalidToken);
312 }
313 }
314
315 warn!(
316 "Missing {} header for phantom token validation",
317 header_name
318 );
319 Err(ProxyError::InvalidToken)
320}
321
322fn filter_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
328 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
329 let mut headers = Vec::new();
330
331 for line in header_str.lines() {
332 let lower = line.to_lowercase();
333 if lower.starts_with("host:")
334 || lower.starts_with("content-length:")
335 || lower.starts_with("authorization:")
336 || lower.starts_with("x-api-key:")
337 || lower.starts_with("x-goog-api-key:")
338 || line.trim().is_empty()
339 {
340 continue;
341 }
342 if let Some((name, value)) = line.split_once(':') {
343 headers.push((name.trim().to_string(), value.trim().to_string()));
344 }
345 }
346
347 headers
348}
349
350fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
352 let header_str = std::str::from_utf8(header_bytes).ok()?;
353 for line in header_str.lines() {
354 if line.to_lowercase().starts_with("content-length:") {
355 let value = line.split_once(':')?.1.trim();
356 return value.parse().ok();
357 }
358 }
359 None
360}
361
362fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
364 let parsed = url::Url::parse(url_str)
365 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
366
367 let scheme = parsed.scheme();
368 if scheme != "https" && scheme != "http" {
369 return Err(ProxyError::HttpParse(format!(
370 "unsupported URL scheme: {}",
371 url_str
372 )));
373 }
374
375 let host = parsed
376 .host_str()
377 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
378 .to_string();
379
380 let default_port = if scheme == "https" { 443 } else { 80 };
381 let port = parsed.port().unwrap_or(default_port);
382
383 let path = parsed.path().to_string();
384 let path = if path.is_empty() {
385 "/".to_string()
386 } else {
387 path
388 };
389
390 let path_with_query = if let Some(query) = parsed.query() {
392 format!("{}?{}", path, query)
393 } else {
394 path
395 };
396
397 Ok((host, port, path_with_query))
398}
399
400async fn connect_upstream_tls(
409 host: &str,
410 port: u16,
411 resolved_addrs: &[SocketAddr],
412 connector: &TlsConnector,
413) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
414 let tcp = if resolved_addrs.is_empty() {
415 let addr = format!("{}:{}", host, port);
417 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
418 Ok(Ok(s)) => s,
419 Ok(Err(e)) => {
420 return Err(ProxyError::UpstreamConnect {
421 host: host.to_string(),
422 reason: e.to_string(),
423 });
424 }
425 Err(_) => {
426 return Err(ProxyError::UpstreamConnect {
427 host: host.to_string(),
428 reason: "connection timed out".to_string(),
429 });
430 }
431 }
432 } else {
433 connect_to_resolved(resolved_addrs, host).await?
434 };
435
436 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
437 ProxyError::UpstreamConnect {
438 host: host.to_string(),
439 reason: "invalid server name for TLS".to_string(),
440 }
441 })?;
442
443 let tls_stream =
444 connector
445 .connect(server_name, tcp)
446 .await
447 .map_err(|e| ProxyError::UpstreamConnect {
448 host: host.to_string(),
449 reason: format!("TLS handshake failed: {}", e),
450 })?;
451
452 Ok(tls_stream)
453}
454
455async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
457 let mut last_err = None;
458 for addr in addrs {
459 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
460 Ok(Ok(stream)) => return Ok(stream),
461 Ok(Err(e)) => {
462 debug!("Connect to {} failed: {}", addr, e);
463 last_err = Some(e.to_string());
464 }
465 Err(_) => {
466 debug!("Connect to {} timed out", addr);
467 last_err = Some("connection timed out".to_string());
468 }
469 }
470 }
471 Err(ProxyError::UpstreamConnect {
472 host: host.to_string(),
473 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
474 })
475}
476
477fn parse_response_status(data: &[u8]) -> u16 {
483 let line_end = data
485 .iter()
486 .position(|&b| b == b'\r' || b == b'\n')
487 .unwrap_or(data.len());
488 let first_line = &data[..line_end.min(64)];
489
490 if let Ok(line) = std::str::from_utf8(first_line) {
491 let mut parts = line.split_whitespace();
493 if let Some(version) = parts.next() {
494 if version.starts_with("HTTP/") {
495 if let Some(code_str) = parts.next() {
496 if code_str.len() == 3 {
497 return code_str.parse().unwrap_or(502);
498 }
499 }
500 }
501 }
502 }
503 502
504}
505
506async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
508 let body = format!("{{\"error\":\"{}\"}}", reason);
509 let response = format!(
510 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
511 status,
512 reason,
513 body.len(),
514 body
515 );
516 stream.write_all(response.as_bytes()).await?;
517 stream.flush().await?;
518 Ok(())
519}
520
521fn validate_phantom_token_for_mode(
532 mode: &InjectMode,
533 header_bytes: &[u8],
534 path: &str,
535 header_name: &str,
536 path_pattern: Option<&str>,
537 query_param_name: Option<&str>,
538 session_token: &Zeroizing<String>,
539) -> Result<()> {
540 match mode {
541 InjectMode::Header | InjectMode::BasicAuth => {
542 validate_phantom_token(header_bytes, header_name, session_token)
544 }
545 InjectMode::UrlPath => {
546 let pattern = path_pattern.ok_or_else(|| {
548 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
549 })?;
550 validate_phantom_token_in_path(path, pattern, session_token)
551 }
552 InjectMode::QueryParam => {
553 let param_name = query_param_name.ok_or_else(|| {
555 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
556 })?;
557 validate_phantom_token_in_query(path, param_name, session_token)
558 }
559 }
560}
561
562fn validate_phantom_token_in_path(
567 path: &str,
568 pattern: &str,
569 session_token: &Zeroizing<String>,
570) -> Result<()> {
571 let parts: Vec<&str> = pattern.split("{}").collect();
573 if parts.len() != 2 {
574 return Err(ProxyError::HttpParse(format!(
575 "invalid path_pattern '{}': must contain exactly one {{}}",
576 pattern
577 )));
578 }
579 let (prefix, suffix) = (parts[0], parts[1]);
580
581 if let Some(start) = path.find(prefix) {
583 let after_prefix = start + prefix.len();
584
585 let end_offset = if suffix.is_empty() {
587 path[after_prefix..]
588 .find(['/', '?'])
589 .unwrap_or(path[after_prefix..].len())
590 } else {
591 match path[after_prefix..].find(suffix) {
592 Some(offset) => offset,
593 None => {
594 warn!("Missing phantom token in URL path (pattern: {})", pattern);
595 return Err(ProxyError::InvalidToken);
596 }
597 }
598 };
599
600 let token = &path[after_prefix..after_prefix + end_offset];
601 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
602 return Ok(());
603 }
604 warn!("Invalid phantom token in URL path");
605 return Err(ProxyError::InvalidToken);
606 }
607
608 warn!("Missing phantom token in URL path (pattern: {})", pattern);
609 Err(ProxyError::InvalidToken)
610}
611
612fn validate_phantom_token_in_query(
614 path: &str,
615 param_name: &str,
616 session_token: &Zeroizing<String>,
617) -> Result<()> {
618 if let Some(query_start) = path.find('?') {
620 let query = &path[query_start + 1..];
621 for pair in query.split('&') {
622 if let Some((name, value)) = pair.split_once('=') {
623 if name == param_name {
624 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
626 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
627 return Ok(());
628 }
629 warn!("Invalid phantom token in query parameter '{}'", param_name);
630 return Err(ProxyError::InvalidToken);
631 }
632 }
633 }
634 }
635
636 warn!("Missing phantom token in query parameter '{}'", param_name);
637 Err(ProxyError::InvalidToken)
638}
639
640fn transform_path_for_mode(
646 mode: &InjectMode,
647 path: &str,
648 path_pattern: Option<&str>,
649 path_replacement: Option<&str>,
650 query_param_name: Option<&str>,
651 credential: &Zeroizing<String>,
652) -> Result<String> {
653 match mode {
654 InjectMode::Header | InjectMode::BasicAuth => {
655 Ok(path.to_string())
657 }
658 InjectMode::UrlPath => {
659 let pattern = path_pattern.ok_or_else(|| {
660 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
661 })?;
662 let replacement = path_replacement.unwrap_or(pattern);
663 transform_url_path(path, pattern, replacement, credential)
664 }
665 InjectMode::QueryParam => {
666 let param_name = query_param_name.ok_or_else(|| {
667 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
668 })?;
669 transform_query_param(path, param_name, credential)
670 }
671 }
672}
673
674fn transform_url_path(
678 path: &str,
679 pattern: &str,
680 replacement: &str,
681 credential: &Zeroizing<String>,
682) -> Result<String> {
683 let parts: Vec<&str> = pattern.split("{}").collect();
685 if parts.len() != 2 {
686 return Err(ProxyError::HttpParse(format!(
687 "invalid path_pattern '{}': must contain exactly one {{}}",
688 pattern
689 )));
690 }
691 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
692
693 let repl_parts: Vec<&str> = replacement.split("{}").collect();
695 if repl_parts.len() != 2 {
696 return Err(ProxyError::HttpParse(format!(
697 "invalid path_replacement '{}': must contain exactly one {{}}",
698 replacement
699 )));
700 }
701 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
702
703 if let Some(start) = path.find(pattern_prefix) {
705 let after_prefix = start + pattern_prefix.len();
706
707 let end_offset = if pattern_suffix.is_empty() {
709 path[after_prefix..]
711 .find(['/', '?'])
712 .unwrap_or(path[after_prefix..].len())
713 } else {
714 match path[after_prefix..].find(pattern_suffix) {
716 Some(offset) => offset,
717 None => {
718 return Err(ProxyError::HttpParse(format!(
719 "path '{}' does not match pattern '{}'",
720 path, pattern
721 )));
722 }
723 }
724 };
725
726 let before = &path[..start];
727 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
728 return Ok(format!(
729 "{}{}{}{}{}",
730 before,
731 repl_prefix,
732 credential.as_str(),
733 repl_suffix,
734 after
735 ));
736 }
737
738 Err(ProxyError::HttpParse(format!(
739 "path '{}' does not match pattern '{}'",
740 path, pattern
741 )))
742}
743
744fn transform_query_param(
746 path: &str,
747 param_name: &str,
748 credential: &Zeroizing<String>,
749) -> Result<String> {
750 let encoded_value = urlencoding::encode(credential.as_str());
751
752 if let Some(query_start) = path.find('?') {
753 let base_path = &path[..query_start];
754 let query = &path[query_start + 1..];
755
756 let mut found = false;
758 let new_query: Vec<String> = query
759 .split('&')
760 .map(|pair| {
761 if let Some((name, _)) = pair.split_once('=') {
762 if name == param_name {
763 found = true;
764 return format!("{}={}", param_name, encoded_value);
765 }
766 }
767 pair.to_string()
768 })
769 .collect();
770
771 if found {
772 Ok(format!("{}?{}", base_path, new_query.join("&")))
773 } else {
774 Ok(format!(
776 "{}?{}&{}={}",
777 base_path, query, param_name, encoded_value
778 ))
779 }
780 } else {
781 Ok(format!("{}?{}={}", path, param_name, encoded_value))
783 }
784}
785
786fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
791 match cred.inject_mode {
792 InjectMode::Header | InjectMode::BasicAuth => {
793 request.push_str(&format!(
795 "{}: {}\r\n",
796 cred.header_name,
797 cred.header_value.as_str()
798 ));
799 }
800 InjectMode::UrlPath | InjectMode::QueryParam => {
801 }
804 }
805}
806
807#[cfg(test)]
808#[allow(clippy::unwrap_used)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn test_parse_request_line() {
814 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
815 assert_eq!(method, "POST");
816 assert_eq!(path, "/openai/v1/chat");
817 assert_eq!(version, "HTTP/1.1");
818 }
819
820 #[test]
821 fn test_parse_request_line_malformed() {
822 assert!(parse_request_line("GET").is_err());
823 }
824
825 #[test]
826 fn test_parse_service_prefix() {
827 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
828 assert_eq!(service, "openai");
829 assert_eq!(path, "/v1/chat/completions");
830 }
831
832 #[test]
833 fn test_parse_service_prefix_no_subpath() {
834 let (service, path) = parse_service_prefix("/anthropic").unwrap();
835 assert_eq!(service, "anthropic");
836 assert_eq!(path, "/");
837 }
838
839 #[test]
840 fn test_validate_phantom_token_bearer_valid() {
841 let token = Zeroizing::new("secret123".to_string());
842 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
843 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
844 }
845
846 #[test]
847 fn test_validate_phantom_token_bearer_invalid() {
848 let token = Zeroizing::new("secret123".to_string());
849 let header = b"Authorization: Bearer wrong\r\n\r\n";
850 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
851 }
852
853 #[test]
854 fn test_validate_phantom_token_x_api_key_valid() {
855 let token = Zeroizing::new("secret123".to_string());
856 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
857 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
858 }
859
860 #[test]
861 fn test_validate_phantom_token_x_goog_api_key_valid() {
862 let token = Zeroizing::new("secret123".to_string());
863 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
864 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
865 }
866
867 #[test]
868 fn test_validate_phantom_token_missing() {
869 let token = Zeroizing::new("secret123".to_string());
870 let header = b"Content-Type: application/json\r\n\r\n";
871 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
872 }
873
874 #[test]
875 fn test_validate_phantom_token_case_insensitive_header() {
876 let token = Zeroizing::new("secret123".to_string());
877 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
878 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
879 }
880
881 #[test]
882 fn test_filter_headers_removes_host_auth() {
883 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
884 let filtered = filter_headers(header);
885 assert_eq!(filtered.len(), 2);
886 assert_eq!(filtered[0].0, "Content-Type");
887 assert_eq!(filtered[1].0, "Accept");
888 }
889
890 #[test]
891 fn test_filter_headers_removes_x_api_key() {
892 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
893 let filtered = filter_headers(header);
894 assert_eq!(filtered.len(), 1);
895 assert_eq!(filtered[0].0, "Content-Type");
896 }
897
898 #[test]
899 fn test_filter_headers_removes_x_goog_api_key() {
900 let header = b"x-goog-api-key: gemini-key\r\nContent-Type: application/json\r\n\r\n";
901 let filtered = filter_headers(header);
902 assert_eq!(filtered.len(), 1);
903 assert_eq!(filtered[0].0, "Content-Type");
904 }
905
906 #[test]
907 fn test_extract_content_length() {
908 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
909 assert_eq!(extract_content_length(header), Some(42));
910 }
911
912 #[test]
913 fn test_extract_content_length_missing() {
914 let header = b"Content-Type: application/json\r\n\r\n";
915 assert_eq!(extract_content_length(header), None);
916 }
917
918 #[test]
919 fn test_parse_upstream_url_https() {
920 let (host, port, path) =
921 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
922 assert_eq!(host, "api.openai.com");
923 assert_eq!(port, 443);
924 assert_eq!(path, "/v1/chat/completions");
925 }
926
927 #[test]
928 fn test_parse_upstream_url_http_with_port() {
929 let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
930 assert_eq!(host, "localhost");
931 assert_eq!(port, 8080);
932 assert_eq!(path, "/api");
933 }
934
935 #[test]
936 fn test_parse_upstream_url_no_path() {
937 let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
938 assert_eq!(host, "api.anthropic.com");
939 assert_eq!(port, 443);
940 assert_eq!(path, "/");
941 }
942
943 #[test]
944 fn test_parse_upstream_url_invalid_scheme() {
945 assert!(parse_upstream_url("ftp://example.com").is_err());
946 }
947
948 #[test]
949 fn test_parse_response_status_200() {
950 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
951 assert_eq!(parse_response_status(data), 200);
952 }
953
954 #[test]
955 fn test_parse_response_status_404() {
956 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
957 assert_eq!(parse_response_status(data), 404);
958 }
959
960 #[test]
961 fn test_parse_response_status_garbage() {
962 let data = b"not an http response";
963 assert_eq!(parse_response_status(data), 502);
964 }
965
966 #[test]
967 fn test_parse_response_status_empty() {
968 assert_eq!(parse_response_status(b""), 502);
969 }
970
971 #[test]
972 fn test_parse_response_status_partial() {
973 let data = b"HTTP/1.1 ";
974 assert_eq!(parse_response_status(data), 502);
975 }
976
977 #[test]
982 fn test_validate_phantom_token_in_path_valid() {
983 let token = Zeroizing::new("session123".to_string());
984 let path = "/bot/session123/getMe";
985 let pattern = "/bot/{}/";
986 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
987 }
988
989 #[test]
990 fn test_validate_phantom_token_in_path_invalid() {
991 let token = Zeroizing::new("session123".to_string());
992 let path = "/bot/wrong_token/getMe";
993 let pattern = "/bot/{}/";
994 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
995 }
996
997 #[test]
998 fn test_validate_phantom_token_in_path_missing() {
999 let token = Zeroizing::new("session123".to_string());
1000 let path = "/api/getMe";
1001 let pattern = "/bot/{}/";
1002 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1003 }
1004
1005 #[test]
1006 fn test_transform_url_path_basic() {
1007 let credential = Zeroizing::new("real_token".to_string());
1008 let path = "/bot/phantom_token/getMe";
1009 let pattern = "/bot/{}/";
1010 let replacement = "/bot/{}/";
1011 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1012 assert_eq!(result, "/bot/real_token/getMe");
1013 }
1014
1015 #[test]
1016 fn test_transform_url_path_different_replacement() {
1017 let credential = Zeroizing::new("real_token".to_string());
1018 let path = "/api/v1/phantom_token/chat";
1019 let pattern = "/api/v1/{}/";
1020 let replacement = "/v2/bot/{}/";
1021 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1022 assert_eq!(result, "/v2/bot/real_token/chat");
1023 }
1024
1025 #[test]
1026 fn test_transform_url_path_no_trailing_slash() {
1027 let credential = Zeroizing::new("real_token".to_string());
1028 let path = "/bot/phantom_token";
1029 let pattern = "/bot/{}";
1030 let replacement = "/bot/{}";
1031 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1032 assert_eq!(result, "/bot/real_token");
1033 }
1034
1035 #[test]
1040 fn test_validate_phantom_token_in_query_valid() {
1041 let token = Zeroizing::new("session123".to_string());
1042 let path = "/api/data?api_key=session123&other=value";
1043 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1044 }
1045
1046 #[test]
1047 fn test_validate_phantom_token_in_query_invalid() {
1048 let token = Zeroizing::new("session123".to_string());
1049 let path = "/api/data?api_key=wrong_token";
1050 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1051 }
1052
1053 #[test]
1054 fn test_validate_phantom_token_in_query_missing_param() {
1055 let token = Zeroizing::new("session123".to_string());
1056 let path = "/api/data?other=value";
1057 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1058 }
1059
1060 #[test]
1061 fn test_validate_phantom_token_in_query_no_query_string() {
1062 let token = Zeroizing::new("session123".to_string());
1063 let path = "/api/data";
1064 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1065 }
1066
1067 #[test]
1068 fn test_validate_phantom_token_in_query_url_encoded() {
1069 let token = Zeroizing::new("token with spaces".to_string());
1070 let path = "/api/data?api_key=token%20with%20spaces";
1071 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1072 }
1073
1074 #[test]
1075 fn test_transform_query_param_add_to_no_query() {
1076 let credential = Zeroizing::new("real_key".to_string());
1077 let path = "/api/data";
1078 let result = transform_query_param(path, "api_key", &credential).unwrap();
1079 assert_eq!(result, "/api/data?api_key=real_key");
1080 }
1081
1082 #[test]
1083 fn test_transform_query_param_add_to_existing_query() {
1084 let credential = Zeroizing::new("real_key".to_string());
1085 let path = "/api/data?other=value";
1086 let result = transform_query_param(path, "api_key", &credential).unwrap();
1087 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1088 }
1089
1090 #[test]
1091 fn test_transform_query_param_replace_existing() {
1092 let credential = Zeroizing::new("real_key".to_string());
1093 let path = "/api/data?api_key=phantom&other=value";
1094 let result = transform_query_param(path, "api_key", &credential).unwrap();
1095 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1096 }
1097
1098 #[test]
1099 fn test_transform_query_param_url_encodes_special_chars() {
1100 let credential = Zeroizing::new("key with spaces".to_string());
1101 let path = "/api/data";
1102 let result = transform_query_param(path, "api_key", &credential).unwrap();
1103 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1104 }
1105}