1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::route::RouteStore;
23use crate::token;
24use std::net::SocketAddr;
25use std::time::Duration;
26use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
27use tokio::net::TcpStream;
28use tokio_rustls::TlsConnector;
29use tracing::{debug, warn};
30use zeroize::Zeroizing;
31
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
34
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
37
38pub struct ReverseProxyCtx<'a> {
44 pub route_store: &'a RouteStore,
46 pub credential_store: &'a CredentialStore,
48 pub session_token: &'a Zeroizing<String>,
50 pub filter: &'a ProxyFilter,
52 pub tls_connector: &'a TlsConnector,
54 pub audit_log: Option<&'a audit::SharedAuditLog>,
56}
57
58pub async fn handle_reverse_proxy(
72 first_line: &str,
73 stream: &mut TcpStream,
74 remaining_header: &[u8],
75 ctx: &ReverseProxyCtx<'_>,
76 buffered_body: &[u8],
77) -> Result<()> {
78 let (method, path, version) = parse_request_line(first_line)?;
80 debug!("Reverse proxy: {} {}", method, path);
81
82 let (service, upstream_path) = parse_service_prefix(&path)?;
84 let route = ctx
85 .route_store
86 .get(&service)
87 .ok_or_else(|| ProxyError::UnknownService {
88 prefix: service.clone(),
89 })?;
90 let static_cred = ctx.credential_store.get(&service);
91 let oauth2_route = ctx.credential_store.get_oauth2(&service);
92
93 if !route.endpoint_rules.is_allowed(&method, &upstream_path) {
96 let reason = format!(
97 "endpoint denied: {} {} on service '{}'",
98 method, upstream_path, service
99 );
100 warn!("{}", reason);
101 audit::log_denied(
102 ctx.audit_log,
103 audit::ProxyMode::Reverse,
104 &service,
105 0,
106 &reason,
107 );
108 send_error(stream, 403, "Forbidden").await?;
109 return Ok(());
110 }
111
112 if let Some(oauth2_route) = oauth2_route {
113 return handle_oauth2_credential(
114 oauth2_route,
115 route,
116 &service,
117 &upstream_path,
118 &method,
119 &version,
120 stream,
121 remaining_header,
122 buffered_body,
123 ctx,
124 )
125 .await;
126 }
127
128 let cred = static_cred;
129
130 if let Some(cred) = cred {
134 if let Err(e) = validate_phantom_token_for_mode(
135 &cred.proxy_inject_mode,
136 remaining_header,
137 &upstream_path,
138 &cred.proxy_header_name,
139 cred.proxy_path_pattern.as_deref(),
140 cred.proxy_query_param_name.as_deref(),
141 ctx.session_token,
142 ) {
143 audit::log_denied(
144 ctx.audit_log,
145 audit::ProxyMode::Reverse,
146 &service,
147 0,
148 &e.to_string(),
149 );
150 send_error(stream, 401, "Unauthorized").await?;
151 return Ok(());
152 }
153 } else if let Err(e) = token::validate_proxy_auth(remaining_header, ctx.session_token) {
154 audit::log_denied(
155 ctx.audit_log,
156 audit::ProxyMode::Reverse,
157 &service,
158 0,
159 &e.to_string(),
160 );
161 send_error(stream, 407, "Proxy Authentication Required").await?;
162 return Ok(());
163 }
164
165 let transformed_path = if let Some(cred) = cred {
166 let cleaned_path = strip_proxy_artifacts(
167 &upstream_path,
168 &cred.proxy_inject_mode,
169 &cred.inject_mode,
170 cred.proxy_path_pattern.as_deref(),
171 cred.proxy_query_param_name.as_deref(),
172 );
173 transform_path_for_mode(
174 &cred.inject_mode,
175 &cleaned_path,
176 cred.path_pattern.as_deref(),
177 cred.path_replacement.as_deref(),
178 cred.query_param_name.as_deref(),
179 &cred.raw_credential,
180 )?
181 } else {
182 upstream_path.clone()
183 };
184
185 let upstream_url = format!(
186 "{}{}",
187 route.upstream.trim_end_matches('/'),
188 transformed_path
189 );
190 debug!("Forwarding to upstream: {} {}", method, upstream_url);
191
192 let (upstream_scheme, upstream_host, upstream_port, upstream_path_full) =
193 parse_upstream_url(&upstream_url)?;
194 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
195 if !check.result.is_allowed() {
196 let reason = check.result.reason();
197 warn!("Upstream host denied by filter: {}", reason);
198 send_error(stream, 403, "Forbidden").await?;
199 audit::log_denied(
200 ctx.audit_log,
201 audit::ProxyMode::Reverse,
202 &service,
203 0,
204 &reason,
205 );
206 return Ok(());
207 }
208 if let Err(reason) =
209 validate_http_upstream_target(upstream_scheme, &upstream_host, &check.resolved_addrs)
210 {
211 warn!("{}", reason);
212 send_error(stream, 502, "Bad Gateway").await?;
213 audit::log_denied(
214 ctx.audit_log,
215 audit::ProxyMode::Reverse,
216 &service,
217 0,
218 &reason,
219 );
220 return Ok(());
221 }
222
223 let strip_header = cred.map(|c| c.proxy_header_name.as_str()).unwrap_or("");
224 let filtered_headers = filter_headers(remaining_header, strip_header);
225 let content_length = extract_content_length(remaining_header);
226 let body = match read_request_body(stream, content_length, buffered_body).await? {
227 Some(body) => body,
228 None => return Ok(()),
229 };
230
231 let upstream_authority = format_host_header(upstream_scheme, &upstream_host, upstream_port);
232 let mut request = Zeroizing::new(format!(
233 "{} {} {}\r\nHost: {}\r\n",
234 method, upstream_path_full, version, upstream_authority
235 ));
236
237 if let Some(cred) = cred {
238 inject_credential_for_mode(cred, &mut request);
239 }
240
241 let auth_header_lower = cred.map(|c| c.header_name.to_lowercase());
242 for (name, value) in &filtered_headers {
243 if let (Some(cred), Some(header_lower)) = (cred, auth_header_lower.as_ref()) {
244 if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
245 && name.to_lowercase() == *header_lower
246 {
247 continue;
248 }
249 }
250 request.push_str(&format!("{}: {}\r\n", name, value));
251 }
252
253 request.push_str("Connection: close\r\n");
254 if !body.is_empty() {
255 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
256 }
257 request.push_str("\r\n");
258
259 let status_code = match upstream_scheme {
260 UpstreamScheme::Https => {
261 let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
262 let mut tls_stream = match connect_upstream_tls(
263 &upstream_host,
264 upstream_port,
265 &check.resolved_addrs,
266 connector,
267 )
268 .await
269 {
270 Ok(s) => s,
271 Err(e) => {
272 warn!("Upstream connection failed: {}", e);
273 send_error(stream, 502, "Bad Gateway").await?;
274 audit::log_denied(
275 ctx.audit_log,
276 audit::ProxyMode::Reverse,
277 &service,
278 0,
279 &e.to_string(),
280 );
281 return Ok(());
282 }
283 };
284
285 write_upstream_request(&mut tls_stream, &request, &body).await?;
286 stream_response(&mut tls_stream, stream).await?
287 }
288 UpstreamScheme::Http => {
289 let mut upstream_stream =
290 match connect_upstream_tcp(&upstream_host, upstream_port, &check.resolved_addrs)
291 .await
292 {
293 Ok(s) => s,
294 Err(e) => {
295 warn!("Upstream connection failed: {}", e);
296 send_error(stream, 502, "Bad Gateway").await?;
297 audit::log_denied(
298 ctx.audit_log,
299 audit::ProxyMode::Reverse,
300 &service,
301 0,
302 &e.to_string(),
303 );
304 return Ok(());
305 }
306 };
307
308 write_upstream_request(&mut upstream_stream, &request, &body).await?;
309 stream_response(&mut upstream_stream, stream).await?
310 }
311 };
312 audit::log_reverse_proxy(
313 ctx.audit_log,
314 &service,
315 &method,
316 &upstream_path,
317 status_code,
318 );
319 Ok(())
320}
321
322#[allow(clippy::too_many_arguments)]
329async fn handle_oauth2_credential(
330 oauth2_route: &crate::credential::OAuth2Route,
331 route: &crate::route::LoadedRoute,
332 service: &str,
333 upstream_path: &str,
334 method: &str,
335 version: &str,
336 stream: &mut TcpStream,
337 remaining_header: &[u8],
338 buffered_body: &[u8],
339 ctx: &ReverseProxyCtx<'_>,
340) -> Result<()> {
341 let access_token = oauth2_route.cache.get_or_refresh().await;
343
344 if let Err(e) = validate_phantom_token(remaining_header, "Authorization", ctx.session_token) {
348 audit::log_denied(
349 ctx.audit_log,
350 audit::ProxyMode::Reverse,
351 service,
352 0,
353 &e.to_string(),
354 );
355 send_error(stream, 401, "Unauthorized").await?;
356 return Ok(());
357 }
358
359 let upstream_url = format!(
360 "{}{}",
361 oauth2_route.upstream.trim_end_matches('/'),
362 upstream_path
363 );
364 debug!("OAuth2 forwarding to upstream: {} {}", method, upstream_url);
365
366 let (upstream_scheme, upstream_host, upstream_port, upstream_path_full) =
367 parse_upstream_url(&upstream_url)?;
368 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
370 if !check.result.is_allowed() {
371 let reason = check.result.reason();
372 warn!("Upstream host denied by filter: {}", reason);
373 send_error(stream, 403, "Forbidden").await?;
374 audit::log_denied(
375 ctx.audit_log,
376 audit::ProxyMode::Reverse,
377 service,
378 0,
379 &reason,
380 );
381 return Ok(());
382 }
383 if let Err(reason) =
384 validate_http_upstream_target(upstream_scheme, &upstream_host, &check.resolved_addrs)
385 {
386 warn!("{}", reason);
387 send_error(stream, 502, "Bad Gateway").await?;
388 audit::log_denied(
389 ctx.audit_log,
390 audit::ProxyMode::Reverse,
391 service,
392 0,
393 &reason,
394 );
395 return Ok(());
396 }
397
398 let filtered_headers = filter_headers(remaining_header, "Authorization");
401 let content_length = extract_content_length(remaining_header);
402
403 let body = match read_request_body(stream, content_length, buffered_body).await? {
405 Some(body) => body,
406 None => return Ok(()),
407 };
408
409 let upstream_authority = format_host_header(upstream_scheme, &upstream_host, upstream_port);
411 let mut request = Zeroizing::new(format!(
412 "{} {} {}\r\nHost: {}\r\n",
413 method, upstream_path_full, version, upstream_authority
414 ));
415
416 request.push_str(&format!(
418 "Authorization: Bearer {}\r\n",
419 access_token.as_str()
420 ));
421
422 for (name, value) in &filtered_headers {
424 request.push_str(&format!("{}: {}\r\n", name, value));
425 }
426
427 if !body.is_empty() {
428 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
429 }
430 request.push_str("\r\n");
431
432 let status_code = match upstream_scheme {
433 UpstreamScheme::Https => {
434 let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
435 let mut tls_stream = match connect_upstream_tls(
436 &upstream_host,
437 upstream_port,
438 &check.resolved_addrs,
439 connector,
440 )
441 .await
442 {
443 Ok(s) => s,
444 Err(e) => {
445 warn!("Upstream connection failed: {}", e);
446 send_error(stream, 502, "Bad Gateway").await?;
447 audit::log_denied(
448 ctx.audit_log,
449 audit::ProxyMode::Reverse,
450 service,
451 0,
452 &e.to_string(),
453 );
454 return Ok(());
455 }
456 };
457
458 write_upstream_request(&mut tls_stream, &request, &body).await?;
459 stream_response(&mut tls_stream, stream).await?
460 }
461 UpstreamScheme::Http => {
462 let mut upstream_stream =
463 match connect_upstream_tcp(&upstream_host, upstream_port, &check.resolved_addrs)
464 .await
465 {
466 Ok(s) => s,
467 Err(e) => {
468 warn!("Upstream connection failed: {}", e);
469 send_error(stream, 502, "Bad Gateway").await?;
470 audit::log_denied(
471 ctx.audit_log,
472 audit::ProxyMode::Reverse,
473 service,
474 0,
475 &e.to_string(),
476 );
477 return Ok(());
478 }
479 };
480
481 write_upstream_request(&mut upstream_stream, &request, &body).await?;
482 stream_response(&mut upstream_stream, stream).await?
483 }
484 };
485
486 audit::log_reverse_proxy(ctx.audit_log, service, method, upstream_path, status_code);
487 Ok(())
488}
489
490async fn write_upstream_request<S>(stream: &mut S, request: &str, body: &[u8]) -> Result<()>
491where
492 S: AsyncWrite + Unpin,
493{
494 stream.write_all(request.as_bytes()).await?;
495 if !body.is_empty() {
496 stream.write_all(body).await?;
497 }
498 stream.flush().await?;
499 Ok(())
500}
501
502async fn read_request_body(
506 stream: &mut TcpStream,
507 content_length: Option<usize>,
508 buffered_body: &[u8],
509) -> Result<Option<Vec<u8>>> {
510 if let Some(len) = content_length {
511 if len > MAX_REQUEST_BODY {
512 send_error(stream, 413, "Payload Too Large").await?;
513 return Ok(None);
514 }
515 let mut buf = Vec::with_capacity(len);
516 let pre = buffered_body.len().min(len);
517 buf.extend_from_slice(&buffered_body[..pre]);
518 let remaining = len - pre;
519 if remaining > 0 {
520 let mut rest = vec![0u8; remaining];
521 stream.read_exact(&mut rest).await?;
522 buf.extend_from_slice(&rest);
523 }
524 Ok(Some(buf))
525 } else {
526 Ok(Some(Vec::new()))
527 }
528}
529
530async fn stream_response<S>(tls_stream: &mut S, stream: &mut TcpStream) -> Result<u16>
534where
535 S: AsyncRead + AsyncWrite + Unpin,
536{
537 let mut response_buf = [0u8; 8192];
538 let mut status_code: u16 = 502;
539 let mut first_chunk = true;
540
541 loop {
542 let n = match tls_stream.read(&mut response_buf).await {
543 Ok(0) => break,
544 Ok(n) => n,
545 Err(e) => {
546 debug!("Upstream read error: {}", e);
547 break;
548 }
549 };
550
551 if first_chunk {
552 status_code = parse_response_status(&response_buf[..n]);
553 first_chunk = false;
554 }
555
556 stream.write_all(&response_buf[..n]).await?;
557 stream.flush().await?;
558 }
559
560 Ok(status_code)
561}
562
563fn parse_request_line(line: &str) -> Result<(String, String, String)> {
565 let parts: Vec<&str> = line.split_whitespace().collect();
566 if parts.len() < 3 {
567 return Err(ProxyError::HttpParse(format!(
568 "malformed request line: {}",
569 line
570 )));
571 }
572 Ok((
573 parts[0].to_string(),
574 parts[1].to_string(),
575 parts[2].to_string(),
576 ))
577}
578
579fn parse_service_prefix(path: &str) -> Result<(String, String)> {
584 let trimmed = path.strip_prefix('/').unwrap_or(path);
585 if let Some((prefix, rest)) = trimmed.split_once('/') {
586 Ok((prefix.to_string(), format!("/{}", rest)))
587 } else {
588 Ok((trimmed.to_string(), "/".to_string()))
590 }
591}
592
593fn validate_phantom_token(
600 header_bytes: &[u8],
601 header_name: &str,
602 session_token: &Zeroizing<String>,
603) -> Result<()> {
604 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
605 let header_name_lower = header_name.to_lowercase();
606
607 for line in header_str.lines() {
608 let lower = line.to_lowercase();
609 if lower.starts_with(&format!("{}:", header_name_lower)) {
610 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
611
612 let value_lower = value.to_lowercase();
615 let token_value = if value_lower.starts_with("bearer ") {
616 value[7..].trim()
618 } else {
619 value
620 };
621
622 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
623 return Ok(());
624 }
625 warn!("Invalid phantom token in {} header", header_name);
626 return Err(ProxyError::InvalidToken);
627 }
628 }
629
630 warn!(
631 "Missing {} header for phantom token validation",
632 header_name
633 );
634 Err(ProxyError::InvalidToken)
635}
636
637fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
649 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
650 let cred_header_lower = if cred_header.is_empty() {
651 String::new()
652 } else {
653 format!("{}:", cred_header.to_lowercase())
654 };
655 let mut headers = Vec::new();
656
657 for line in header_str.lines() {
658 let lower = line.to_lowercase();
659 if lower.starts_with("host:")
660 || lower.starts_with("content-length:")
661 || lower.starts_with("connection:")
662 || lower.starts_with("proxy-authorization:")
663 || (!cred_header_lower.is_empty() && lower.starts_with(&cred_header_lower))
664 || line.trim().is_empty()
665 {
666 continue;
667 }
668 if let Some((name, value)) = line.split_once(':') {
669 headers.push((name.trim().to_string(), value.trim().to_string()));
670 }
671 }
672
673 headers
674}
675
676fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
678 let header_str = std::str::from_utf8(header_bytes).ok()?;
679 for line in header_str.lines() {
680 if line.to_lowercase().starts_with("content-length:") {
681 let value = line.split_once(':')?.1.trim();
682 return value.parse().ok();
683 }
684 }
685 None
686}
687
688#[derive(Debug, Clone, Copy, PartialEq, Eq)]
690enum UpstreamScheme {
691 Http,
692 Https,
693}
694
695fn validate_http_upstream_target(
696 scheme: UpstreamScheme,
697 host: &str,
698 resolved_addrs: &[SocketAddr],
699) -> std::result::Result<(), String> {
700 if matches!(scheme, UpstreamScheme::Https) {
701 return Ok(());
702 }
703
704 if is_local_only_target(host, resolved_addrs) {
705 Ok(())
706 } else {
707 Err(format!(
708 "refusing insecure http upstream for non-local host '{}'; http is only allowed for loopback addresses",
709 host
710 ))
711 }
712}
713
714fn is_local_only_target(host: &str, resolved_addrs: &[SocketAddr]) -> bool {
715 if !resolved_addrs.is_empty() {
716 return resolved_addrs.iter().all(|addr| addr.ip().is_loopback());
717 }
718
719 match host.parse::<std::net::IpAddr>() {
720 Ok(std::net::IpAddr::V4(ip)) => ip.is_loopback(),
721 Ok(std::net::IpAddr::V6(ip)) => ip.is_loopback(),
722 Err(_) => false,
723 }
724}
725
726fn format_host_header(scheme: UpstreamScheme, host: &str, port: u16) -> String {
727 let default_port = match scheme {
728 UpstreamScheme::Http => 80,
729 UpstreamScheme::Https => 443,
730 };
731 let bracketed_host = if host.contains(':') && !host.starts_with('[') {
732 format!("[{}]", host)
733 } else {
734 host.to_string()
735 };
736
737 if port == default_port {
738 bracketed_host
739 } else {
740 format!("{}:{}", bracketed_host, port)
741 }
742}
743
744fn parse_upstream_url(url_str: &str) -> Result<(UpstreamScheme, String, u16, String)> {
745 let parsed = url::Url::parse(url_str)
746 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
747
748 let scheme = match parsed.scheme() {
749 "https" => UpstreamScheme::Https,
750 "http" => UpstreamScheme::Http,
751 _ => {
752 return Err(ProxyError::HttpParse(format!(
753 "unsupported URL scheme: {}",
754 url_str
755 )));
756 }
757 };
758
759 let host = parsed
760 .host_str()
761 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
762 .to_string();
763
764 let default_port = if matches!(scheme, UpstreamScheme::Https) {
765 443
766 } else {
767 80
768 };
769 let port = parsed.port().unwrap_or(default_port);
770
771 let path = parsed.path().to_string();
772 let path = if path.is_empty() {
773 "/".to_string()
774 } else {
775 path
776 };
777
778 let path_with_query = if let Some(query) = parsed.query() {
780 format!("{}?{}", path, query)
781 } else {
782 path
783 };
784
785 Ok((scheme, host, port, path_with_query))
786}
787
788async fn connect_upstream_tls(
797 host: &str,
798 port: u16,
799 resolved_addrs: &[SocketAddr],
800 connector: &TlsConnector,
801) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
802 let tcp = if resolved_addrs.is_empty() {
803 let addr = format!("{}:{}", host, port);
805 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
806 Ok(Ok(s)) => s,
807 Ok(Err(e)) => {
808 return Err(ProxyError::UpstreamConnect {
809 host: host.to_string(),
810 reason: e.to_string(),
811 });
812 }
813 Err(_) => {
814 return Err(ProxyError::UpstreamConnect {
815 host: host.to_string(),
816 reason: "connection timed out".to_string(),
817 });
818 }
819 }
820 } else {
821 connect_to_resolved(resolved_addrs, host).await?
822 };
823
824 let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
825 ProxyError::UpstreamConnect {
826 host: host.to_string(),
827 reason: "invalid server name for TLS".to_string(),
828 }
829 })?;
830
831 let tls_stream =
832 connector
833 .connect(server_name, tcp)
834 .await
835 .map_err(|e| ProxyError::UpstreamConnect {
836 host: host.to_string(),
837 reason: format!("TLS handshake failed: {}", e),
838 })?;
839
840 Ok(tls_stream)
841}
842
843async fn connect_upstream_tcp(
844 host: &str,
845 port: u16,
846 resolved_addrs: &[SocketAddr],
847) -> Result<TcpStream> {
848 if resolved_addrs.is_empty() {
849 let addr = format!("{}:{}", host, port);
850 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
851 Ok(Ok(s)) => Ok(s),
852 Ok(Err(e)) => Err(ProxyError::UpstreamConnect {
853 host: host.to_string(),
854 reason: e.to_string(),
855 }),
856 Err(_) => Err(ProxyError::UpstreamConnect {
857 host: host.to_string(),
858 reason: "connection timed out".to_string(),
859 }),
860 }
861 } else {
862 connect_to_resolved(resolved_addrs, host).await
863 }
864}
865
866async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
868 let mut last_err = None;
869 for addr in addrs {
870 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
871 Ok(Ok(stream)) => return Ok(stream),
872 Ok(Err(e)) => {
873 debug!("Connect to {} failed: {}", addr, e);
874 last_err = Some(e.to_string());
875 }
876 Err(_) => {
877 debug!("Connect to {} timed out", addr);
878 last_err = Some("connection timed out".to_string());
879 }
880 }
881 }
882 Err(ProxyError::UpstreamConnect {
883 host: host.to_string(),
884 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
885 })
886}
887
888fn parse_response_status(data: &[u8]) -> u16 {
894 let line_end = data
896 .iter()
897 .position(|&b| b == b'\r' || b == b'\n')
898 .unwrap_or(data.len());
899 let first_line = &data[..line_end.min(64)];
900
901 if let Ok(line) = std::str::from_utf8(first_line) {
902 let mut parts = line.split_whitespace();
904 if let Some(version) = parts.next() {
905 if version.starts_with("HTTP/") {
906 if let Some(code_str) = parts.next() {
907 if code_str.len() == 3 {
908 return code_str.parse().unwrap_or(502);
909 }
910 }
911 }
912 }
913 }
914 502
915}
916
917async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
919 let body = format!("{{\"error\":\"{}\"}}", reason);
920 let response = format!(
921 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
922 status,
923 reason,
924 body.len(),
925 body
926 );
927 stream.write_all(response.as_bytes()).await?;
928 stream.flush().await?;
929 Ok(())
930}
931
932fn validate_phantom_token_for_mode(
943 mode: &InjectMode,
944 header_bytes: &[u8],
945 path: &str,
946 header_name: &str,
947 path_pattern: Option<&str>,
948 query_param_name: Option<&str>,
949 session_token: &Zeroizing<String>,
950) -> Result<()> {
951 match mode {
952 InjectMode::Header | InjectMode::BasicAuth => {
953 validate_phantom_token(header_bytes, header_name, session_token)
955 }
956 InjectMode::UrlPath => {
957 let pattern = path_pattern.ok_or_else(|| {
959 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
960 })?;
961 validate_phantom_token_in_path(path, pattern, session_token)
962 }
963 InjectMode::QueryParam => {
964 let param_name = query_param_name.ok_or_else(|| {
966 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
967 })?;
968 validate_phantom_token_in_query(path, param_name, session_token)
969 }
970 }
971}
972
973fn validate_phantom_token_in_path(
978 path: &str,
979 pattern: &str,
980 session_token: &Zeroizing<String>,
981) -> Result<()> {
982 let parts: Vec<&str> = pattern.split("{}").collect();
984 if parts.len() != 2 {
985 return Err(ProxyError::HttpParse(format!(
986 "invalid path_pattern '{}': must contain exactly one {{}}",
987 pattern
988 )));
989 }
990 let (prefix, suffix) = (parts[0], parts[1]);
991
992 if let Some(start) = path.find(prefix) {
994 let after_prefix = start + prefix.len();
995
996 let end_offset = if suffix.is_empty() {
998 path[after_prefix..]
999 .find(['/', '?'])
1000 .unwrap_or(path[after_prefix..].len())
1001 } else {
1002 match path[after_prefix..].find(suffix) {
1003 Some(offset) => offset,
1004 None => {
1005 warn!("Missing phantom token in URL path (pattern: {})", pattern);
1006 return Err(ProxyError::InvalidToken);
1007 }
1008 }
1009 };
1010
1011 let token = &path[after_prefix..after_prefix + end_offset];
1012 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
1013 return Ok(());
1014 }
1015 warn!("Invalid phantom token in URL path");
1016 return Err(ProxyError::InvalidToken);
1017 }
1018
1019 warn!("Missing phantom token in URL path (pattern: {})", pattern);
1020 Err(ProxyError::InvalidToken)
1021}
1022
1023fn validate_phantom_token_in_query(
1025 path: &str,
1026 param_name: &str,
1027 session_token: &Zeroizing<String>,
1028) -> Result<()> {
1029 if let Some(query_start) = path.find('?') {
1031 let query = &path[query_start + 1..];
1032 for pair in query.split('&') {
1033 if let Some((name, value)) = pair.split_once('=') {
1034 if name == param_name {
1035 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
1037 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
1038 return Ok(());
1039 }
1040 warn!("Invalid phantom token in query parameter '{}'", param_name);
1041 return Err(ProxyError::InvalidToken);
1042 }
1043 }
1044 }
1045 }
1046
1047 warn!("Missing phantom token in query parameter '{}'", param_name);
1048 Err(ProxyError::InvalidToken)
1049}
1050
1051fn transform_path_for_mode(
1057 mode: &InjectMode,
1058 path: &str,
1059 path_pattern: Option<&str>,
1060 path_replacement: Option<&str>,
1061 query_param_name: Option<&str>,
1062 credential: &Zeroizing<String>,
1063) -> Result<String> {
1064 match mode {
1065 InjectMode::Header | InjectMode::BasicAuth => {
1066 Ok(path.to_string())
1068 }
1069 InjectMode::UrlPath => {
1070 let pattern = path_pattern.ok_or_else(|| {
1071 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
1072 })?;
1073 let replacement = path_replacement.unwrap_or(pattern);
1074 transform_url_path(path, pattern, replacement, credential)
1075 }
1076 InjectMode::QueryParam => {
1077 let param_name = query_param_name.ok_or_else(|| {
1078 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
1079 })?;
1080 transform_query_param(path, param_name, credential)
1081 }
1082 }
1083}
1084
1085fn transform_url_path(
1089 path: &str,
1090 pattern: &str,
1091 replacement: &str,
1092 credential: &Zeroizing<String>,
1093) -> Result<String> {
1094 let parts: Vec<&str> = pattern.split("{}").collect();
1096 if parts.len() != 2 {
1097 return Err(ProxyError::HttpParse(format!(
1098 "invalid path_pattern '{}': must contain exactly one {{}}",
1099 pattern
1100 )));
1101 }
1102 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
1103
1104 let repl_parts: Vec<&str> = replacement.split("{}").collect();
1106 if repl_parts.len() != 2 {
1107 return Err(ProxyError::HttpParse(format!(
1108 "invalid path_replacement '{}': must contain exactly one {{}}",
1109 replacement
1110 )));
1111 }
1112 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
1113
1114 if let Some(start) = path.find(pattern_prefix) {
1116 let after_prefix = start + pattern_prefix.len();
1117
1118 let end_offset = if pattern_suffix.is_empty() {
1120 path[after_prefix..]
1122 .find(['/', '?'])
1123 .unwrap_or(path[after_prefix..].len())
1124 } else {
1125 match path[after_prefix..].find(pattern_suffix) {
1127 Some(offset) => offset,
1128 None => {
1129 return Err(ProxyError::HttpParse(format!(
1130 "path '{}' does not match pattern '{}'",
1131 path, pattern
1132 )));
1133 }
1134 }
1135 };
1136
1137 let before = &path[..start];
1138 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
1139 return Ok(format!(
1140 "{}{}{}{}{}",
1141 before,
1142 repl_prefix,
1143 credential.as_str(),
1144 repl_suffix,
1145 after
1146 ));
1147 }
1148
1149 Err(ProxyError::HttpParse(format!(
1150 "path '{}' does not match pattern '{}'",
1151 path, pattern
1152 )))
1153}
1154
1155fn transform_query_param(
1157 path: &str,
1158 param_name: &str,
1159 credential: &Zeroizing<String>,
1160) -> Result<String> {
1161 let encoded_value = urlencoding::encode(credential.as_str());
1162
1163 if let Some(query_start) = path.find('?') {
1164 let base_path = &path[..query_start];
1165 let query = &path[query_start + 1..];
1166
1167 let mut found = false;
1169 let new_query: Vec<String> = query
1170 .split('&')
1171 .map(|pair| {
1172 if let Some((name, _)) = pair.split_once('=') {
1173 if name == param_name {
1174 found = true;
1175 return format!("{}={}", param_name, encoded_value);
1176 }
1177 }
1178 pair.to_string()
1179 })
1180 .collect();
1181
1182 if found {
1183 Ok(format!("{}?{}", base_path, new_query.join("&")))
1184 } else {
1185 Ok(format!(
1187 "{}?{}&{}={}",
1188 base_path, query, param_name, encoded_value
1189 ))
1190 }
1191 } else {
1192 Ok(format!("{}?{}={}", path, param_name, encoded_value))
1194 }
1195}
1196
1197fn strip_proxy_artifacts(
1208 path: &str,
1209 proxy_mode: &InjectMode,
1210 upstream_mode: &InjectMode,
1211 proxy_path_pattern: Option<&str>,
1212 proxy_query_param_name: Option<&str>,
1213) -> String {
1214 if proxy_mode == upstream_mode {
1217 return path.to_string();
1218 }
1219
1220 match proxy_mode {
1221 InjectMode::UrlPath => {
1222 if let Some(pattern) = proxy_path_pattern {
1223 strip_proxy_path_token(path, pattern)
1224 } else {
1225 path.to_string()
1226 }
1227 }
1228 InjectMode::QueryParam => {
1229 if let Some(param_name) = proxy_query_param_name {
1230 strip_proxy_query_param(path, param_name)
1231 } else {
1232 path.to_string()
1233 }
1234 }
1235 InjectMode::Header | InjectMode::BasicAuth => path.to_string(),
1237 }
1238}
1239
1240fn strip_proxy_path_token(path: &str, pattern: &str) -> String {
1244 let parts: Vec<&str> = pattern.split("{}").collect();
1245 if parts.len() != 2 {
1246 return path.to_string();
1247 }
1248 let (prefix, suffix) = (parts[0], parts[1]);
1249
1250 let start = if path.starts_with(prefix) {
1254 Some(0)
1255 } else {
1256 path.find(prefix)
1257 };
1258
1259 if let Some(start) = start {
1260 let after_prefix = start + prefix.len();
1261 let end_offset = if suffix.is_empty() {
1262 path[after_prefix..]
1263 .find(['/', '?'])
1264 .unwrap_or(path[after_prefix..].len())
1265 } else {
1266 match path[after_prefix..].find(suffix) {
1267 Some(offset) => offset,
1268 None => return path.to_string(),
1269 }
1270 };
1271
1272 let before = &path[..start];
1273 let after = &path[after_prefix + end_offset + suffix.len()..];
1274
1275 let joined = match (before.ends_with('/'), after.starts_with('/')) {
1279 (true, true) => format!("{}{}", before, &after[1..]),
1280 (false, false) if !before.is_empty() && !after.is_empty() => {
1281 format!("{}/{}", before, after)
1282 }
1283 _ => format!("{}{}", before, after),
1284 };
1285
1286 if joined.is_empty() || !joined.starts_with('/') {
1287 format!("/{}", joined)
1288 } else {
1289 joined
1290 }
1291 } else {
1292 path.to_string()
1293 }
1294}
1295
1296fn strip_proxy_query_param(path: &str, param_name: &str) -> String {
1300 if let Some(query_start) = path.find('?') {
1301 let base_path = &path[..query_start];
1302 let query = &path[query_start + 1..];
1303
1304 let remaining: Vec<&str> = query
1305 .split('&')
1306 .filter(|pair| {
1307 pair.split_once('=')
1308 .map(|(name, _)| name != param_name)
1309 .unwrap_or(true)
1310 })
1311 .collect();
1312
1313 if remaining.is_empty() {
1314 base_path.to_string()
1315 } else {
1316 format!("{}?{}", base_path, remaining.join("&"))
1317 }
1318 } else {
1319 path.to_string()
1320 }
1321}
1322
1323fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
1328 match cred.inject_mode {
1329 InjectMode::Header | InjectMode::BasicAuth => {
1330 request.push_str(&format!(
1332 "{}: {}\r\n",
1333 cred.header_name,
1334 cred.header_value.as_str()
1335 ));
1336 }
1337 InjectMode::UrlPath | InjectMode::QueryParam => {
1338 }
1341 }
1342}
1343
1344#[cfg(test)]
1345#[allow(clippy::unwrap_used)]
1346mod tests {
1347 use super::*;
1348
1349 #[test]
1350 fn test_parse_request_line() {
1351 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
1352 assert_eq!(method, "POST");
1353 assert_eq!(path, "/openai/v1/chat");
1354 assert_eq!(version, "HTTP/1.1");
1355 }
1356
1357 #[test]
1358 fn test_parse_request_line_malformed() {
1359 assert!(parse_request_line("GET").is_err());
1360 }
1361
1362 #[test]
1363 fn test_parse_service_prefix() {
1364 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
1365 assert_eq!(service, "openai");
1366 assert_eq!(path, "/v1/chat/completions");
1367 }
1368
1369 #[test]
1370 fn test_parse_service_prefix_no_subpath() {
1371 let (service, path) = parse_service_prefix("/anthropic").unwrap();
1372 assert_eq!(service, "anthropic");
1373 assert_eq!(path, "/");
1374 }
1375
1376 #[test]
1377 fn test_validate_phantom_token_bearer_valid() {
1378 let token = Zeroizing::new("secret123".to_string());
1379 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
1380 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
1381 }
1382
1383 #[test]
1384 fn test_validate_phantom_token_bearer_invalid() {
1385 let token = Zeroizing::new("secret123".to_string());
1386 let header = b"Authorization: Bearer wrong\r\n\r\n";
1387 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
1388 }
1389
1390 #[test]
1391 fn test_validate_phantom_token_x_api_key_valid() {
1392 let token = Zeroizing::new("secret123".to_string());
1393 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
1394 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
1395 }
1396
1397 #[test]
1398 fn test_validate_phantom_token_x_goog_api_key_valid() {
1399 let token = Zeroizing::new("secret123".to_string());
1400 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
1401 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
1402 }
1403
1404 #[test]
1405 fn test_validate_phantom_token_missing() {
1406 let token = Zeroizing::new("secret123".to_string());
1407 let header = b"Content-Type: application/json\r\n\r\n";
1408 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
1409 }
1410
1411 #[test]
1412 fn test_validate_phantom_token_case_insensitive_header() {
1413 let token = Zeroizing::new("secret123".to_string());
1414 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
1415 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
1416 }
1417
1418 #[test]
1419 fn test_filter_headers_removes_host_auth() {
1420 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
1421 let filtered = filter_headers(header, "Authorization");
1422 assert_eq!(filtered.len(), 2);
1423 assert_eq!(filtered[0].0, "Content-Type");
1424 assert_eq!(filtered[1].0, "Accept");
1425 }
1426
1427 #[test]
1428 fn test_filter_headers_removes_x_api_key() {
1429 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
1430 let filtered = filter_headers(header, "x-api-key");
1431 assert_eq!(filtered.len(), 1);
1432 assert_eq!(filtered[0].0, "Content-Type");
1433 }
1434
1435 #[test]
1436 fn test_filter_headers_removes_custom_header() {
1437 let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
1438 let filtered = filter_headers(header, "PRIVATE-TOKEN");
1439 assert_eq!(filtered.len(), 1);
1440 assert_eq!(filtered[0].0, "Content-Type");
1441 }
1442
1443 #[test]
1444 fn test_extract_content_length() {
1445 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
1446 assert_eq!(extract_content_length(header), Some(42));
1447 }
1448
1449 #[test]
1450 fn test_extract_content_length_missing() {
1451 let header = b"Content-Type: application/json\r\n\r\n";
1452 assert_eq!(extract_content_length(header), None);
1453 }
1454
1455 #[test]
1456 fn test_parse_upstream_url_https() {
1457 let (scheme, host, port, path) =
1458 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
1459 assert_eq!(scheme, UpstreamScheme::Https);
1460 assert_eq!(host, "api.openai.com");
1461 assert_eq!(port, 443);
1462 assert_eq!(path, "/v1/chat/completions");
1463 }
1464
1465 #[test]
1466 fn test_parse_upstream_url_http_with_port() {
1467 let (scheme, host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
1468 assert_eq!(scheme, UpstreamScheme::Http);
1469 assert_eq!(host, "localhost");
1470 assert_eq!(port, 8080);
1471 assert_eq!(path, "/api");
1472 }
1473
1474 #[test]
1475 fn test_parse_upstream_url_no_path() {
1476 let (scheme, host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
1477 assert_eq!(scheme, UpstreamScheme::Https);
1478 assert_eq!(host, "api.anthropic.com");
1479 assert_eq!(port, 443);
1480 assert_eq!(path, "/");
1481 }
1482
1483 #[test]
1484 fn test_parse_upstream_url_invalid_scheme() {
1485 assert!(parse_upstream_url("ftp://example.com").is_err());
1486 }
1487
1488 #[test]
1489 fn test_validate_http_upstream_target_rejects_non_local_host() {
1490 let err = validate_http_upstream_target(UpstreamScheme::Http, "api.example.com", &[])
1491 .expect_err("non-local http upstream should be rejected");
1492 assert!(err.contains("refusing insecure http upstream"));
1493 }
1494
1495 #[test]
1496 fn test_validate_http_upstream_target_allows_loopback() {
1497 let loopback = [SocketAddr::from(([127, 0, 0, 1], 8080))];
1498 assert!(validate_http_upstream_target(UpstreamScheme::Http, "127.0.0.1", &[]).is_ok());
1499 assert!(validate_http_upstream_target(UpstreamScheme::Http, "::1", &[]).is_ok());
1500 assert!(
1501 validate_http_upstream_target(UpstreamScheme::Http, "localhost", &loopback).is_ok()
1502 );
1503 }
1504
1505 #[test]
1506 fn test_validate_http_upstream_target_rejects_unspecified_addresses() {
1507 let unspecified = [SocketAddr::from(([0, 0, 0, 0], 8080))];
1508 let err = validate_http_upstream_target(UpstreamScheme::Http, "0.0.0.0", &[])
1509 .expect_err("unspecified http upstream should be rejected");
1510 assert!(err.contains("loopback addresses"));
1511
1512 let err = validate_http_upstream_target(UpstreamScheme::Http, "localhost", &unspecified)
1513 .expect_err("localhost resolving to unspecified should be rejected");
1514 assert!(err.contains("loopback addresses"));
1515 }
1516
1517 #[test]
1518 fn test_validate_http_upstream_target_rejects_localhost_resolving_non_loopback() {
1519 let poisoned = [SocketAddr::from(([203, 0, 113, 10], 8080))];
1520 let err = validate_http_upstream_target(UpstreamScheme::Http, "localhost", &poisoned)
1521 .expect_err("localhost resolving off-host should be rejected");
1522 assert!(err.contains("refusing insecure http upstream"));
1523 }
1524
1525 #[test]
1526 fn test_format_host_header_uses_port_for_non_default_http() {
1527 assert_eq!(
1528 format_host_header(UpstreamScheme::Http, "localhost", 8080),
1529 "localhost:8080"
1530 );
1531 }
1532
1533 #[test]
1534 fn test_format_host_header_omits_default_https_port() {
1535 assert_eq!(
1536 format_host_header(UpstreamScheme::Https, "api.openai.com", 443),
1537 "api.openai.com"
1538 );
1539 }
1540
1541 #[test]
1542 fn test_format_host_header_brackets_ipv6() {
1543 assert_eq!(
1544 format_host_header(UpstreamScheme::Http, "::1", 8080),
1545 "[::1]:8080"
1546 );
1547 }
1548
1549 #[test]
1550 fn test_parse_response_status_200() {
1551 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
1552 assert_eq!(parse_response_status(data), 200);
1553 }
1554
1555 #[test]
1556 fn test_parse_response_status_404() {
1557 let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1558 assert_eq!(parse_response_status(data), 404);
1559 }
1560
1561 #[test]
1562 fn test_parse_response_status_garbage() {
1563 let data = b"not an http response";
1564 assert_eq!(parse_response_status(data), 502);
1565 }
1566
1567 #[test]
1568 fn test_parse_response_status_empty() {
1569 assert_eq!(parse_response_status(b""), 502);
1570 }
1571
1572 #[test]
1573 fn test_parse_response_status_partial() {
1574 let data = b"HTTP/1.1 ";
1575 assert_eq!(parse_response_status(data), 502);
1576 }
1577
1578 #[test]
1583 fn test_validate_phantom_token_in_path_valid() {
1584 let token = Zeroizing::new("session123".to_string());
1585 let path = "/bot/session123/getMe";
1586 let pattern = "/bot/{}/";
1587 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1588 }
1589
1590 #[test]
1591 fn test_validate_phantom_token_in_path_invalid() {
1592 let token = Zeroizing::new("session123".to_string());
1593 let path = "/bot/wrong_token/getMe";
1594 let pattern = "/bot/{}/";
1595 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1596 }
1597
1598 #[test]
1599 fn test_validate_phantom_token_in_path_missing() {
1600 let token = Zeroizing::new("session123".to_string());
1601 let path = "/api/getMe";
1602 let pattern = "/bot/{}/";
1603 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1604 }
1605
1606 #[test]
1607 fn test_transform_url_path_basic() {
1608 let credential = Zeroizing::new("real_token".to_string());
1609 let path = "/bot/phantom_token/getMe";
1610 let pattern = "/bot/{}/";
1611 let replacement = "/bot/{}/";
1612 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1613 assert_eq!(result, "/bot/real_token/getMe");
1614 }
1615
1616 #[test]
1617 fn test_transform_url_path_different_replacement() {
1618 let credential = Zeroizing::new("real_token".to_string());
1619 let path = "/api/v1/phantom_token/chat";
1620 let pattern = "/api/v1/{}/";
1621 let replacement = "/v2/bot/{}/";
1622 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1623 assert_eq!(result, "/v2/bot/real_token/chat");
1624 }
1625
1626 #[test]
1627 fn test_transform_url_path_no_trailing_slash() {
1628 let credential = Zeroizing::new("real_token".to_string());
1629 let path = "/bot/phantom_token";
1630 let pattern = "/bot/{}";
1631 let replacement = "/bot/{}";
1632 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1633 assert_eq!(result, "/bot/real_token");
1634 }
1635
1636 #[test]
1641 fn test_validate_phantom_token_in_query_valid() {
1642 let token = Zeroizing::new("session123".to_string());
1643 let path = "/api/data?api_key=session123&other=value";
1644 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1645 }
1646
1647 #[test]
1648 fn test_validate_phantom_token_in_query_invalid() {
1649 let token = Zeroizing::new("session123".to_string());
1650 let path = "/api/data?api_key=wrong_token";
1651 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1652 }
1653
1654 #[test]
1655 fn test_validate_phantom_token_in_query_missing_param() {
1656 let token = Zeroizing::new("session123".to_string());
1657 let path = "/api/data?other=value";
1658 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1659 }
1660
1661 #[test]
1662 fn test_validate_phantom_token_in_query_no_query_string() {
1663 let token = Zeroizing::new("session123".to_string());
1664 let path = "/api/data";
1665 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1666 }
1667
1668 #[test]
1669 fn test_validate_phantom_token_in_query_url_encoded() {
1670 let token = Zeroizing::new("token with spaces".to_string());
1671 let path = "/api/data?api_key=token%20with%20spaces";
1672 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1673 }
1674
1675 #[test]
1676 fn test_transform_query_param_add_to_no_query() {
1677 let credential = Zeroizing::new("real_key".to_string());
1678 let path = "/api/data";
1679 let result = transform_query_param(path, "api_key", &credential).unwrap();
1680 assert_eq!(result, "/api/data?api_key=real_key");
1681 }
1682
1683 #[test]
1684 fn test_transform_query_param_add_to_existing_query() {
1685 let credential = Zeroizing::new("real_key".to_string());
1686 let path = "/api/data?other=value";
1687 let result = transform_query_param(path, "api_key", &credential).unwrap();
1688 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1689 }
1690
1691 #[test]
1692 fn test_transform_query_param_replace_existing() {
1693 let credential = Zeroizing::new("real_key".to_string());
1694 let path = "/api/data?api_key=phantom&other=value";
1695 let result = transform_query_param(path, "api_key", &credential).unwrap();
1696 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1697 }
1698
1699 #[test]
1700 fn test_transform_query_param_url_encodes_special_chars() {
1701 let credential = Zeroizing::new("key with spaces".to_string());
1702 let path = "/api/data";
1703 let result = transform_query_param(path, "api_key", &credential).unwrap();
1704 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1705 }
1706
1707 #[test]
1708 fn test_validate_phantom_token_uses_proxy_mode_over_upstream_mode() {
1709 let token = Zeroizing::new("session123".to_string());
1710 let header = b"Authorization: Bearer session123\r\n\r\n";
1711 let path = "/api/data?api_key=wrong";
1712
1713 let result = validate_phantom_token_for_mode(
1716 &InjectMode::Header,
1717 header,
1718 path,
1719 "Authorization",
1720 None,
1721 Some("api_key"),
1722 &token,
1723 );
1724
1725 assert!(result.is_ok());
1726 }
1727
1728 #[test]
1729 fn test_transform_path_uses_upstream_mode_independently() {
1730 let credential = Zeroizing::new("real_key".to_string());
1731 let path = "/api/data?api_key=phantom";
1732
1733 let transformed = transform_path_for_mode(
1735 &InjectMode::QueryParam,
1736 path,
1737 None,
1738 None,
1739 Some("api_key"),
1740 &credential,
1741 )
1742 .expect("query-param transform should succeed");
1743
1744 assert_eq!(transformed, "/api/data?api_key=real_key");
1745 }
1746
1747 #[test]
1752 fn test_strip_proxy_path_token_basic() {
1753 let result = strip_proxy_path_token("/PHANTOM123/api/v1/pods", "/{}/");
1755 assert_eq!(result, "/api/v1/pods");
1756 }
1757
1758 #[test]
1759 fn test_strip_proxy_path_token_nested_pattern() {
1760 let result = strip_proxy_path_token("/auth/PHANTOM123/api/v1/pods", "/auth/{}/");
1762 assert_eq!(result, "/api/v1/pods");
1763 }
1764
1765 #[test]
1766 fn test_strip_proxy_path_token_no_trailing_slash() {
1767 let result = strip_proxy_path_token("/PHANTOM123", "/{}");
1769 assert_eq!(result, "/");
1770 }
1771
1772 #[test]
1773 fn test_strip_proxy_path_token_preserves_query() {
1774 let result = strip_proxy_path_token("/PHANTOM123/api?limit=10", "/{}/");
1776 assert_eq!(result, "/api?limit=10");
1777 }
1778
1779 #[test]
1780 fn test_strip_proxy_path_token_no_match() {
1781 let result = strip_proxy_path_token("/api/v1/pods", "/auth/{}/");
1783 assert_eq!(result, "/api/v1/pods");
1784 }
1785
1786 #[test]
1787 fn test_strip_proxy_path_token_mid_path_slash_join() {
1788 let result = strip_proxy_path_token("/api/k8s/PHANTOM/data", "/k8s/{}/");
1790 assert_eq!(result, "/api/data");
1791 }
1792
1793 #[test]
1794 fn test_strip_proxy_path_token_no_double_slash() {
1795 let result = strip_proxy_path_token("/prefix/PHANTOM//suffix", "/prefix/{}/");
1797 assert_eq!(result, "/suffix");
1798 }
1799
1800 #[test]
1801 fn test_strip_proxy_query_param_only_param() {
1802 let result = strip_proxy_query_param("/api/v1/pods?token=PHANTOM123", "token");
1803 assert_eq!(result, "/api/v1/pods");
1804 }
1805
1806 #[test]
1807 fn test_strip_proxy_query_param_with_other_params() {
1808 let result = strip_proxy_query_param("/api/v1/pods?token=PHANTOM123&limit=10", "token");
1809 assert_eq!(result, "/api/v1/pods?limit=10");
1810 }
1811
1812 #[test]
1813 fn test_strip_proxy_query_param_middle() {
1814 let result =
1815 strip_proxy_query_param("/api/v1/pods?limit=10&token=PHANTOM123&watch=true", "token");
1816 assert_eq!(result, "/api/v1/pods?limit=10&watch=true");
1817 }
1818
1819 #[test]
1820 fn test_strip_proxy_query_param_no_match() {
1821 let result = strip_proxy_query_param("/api/v1/pods?limit=10", "token");
1822 assert_eq!(result, "/api/v1/pods?limit=10");
1823 }
1824
1825 #[test]
1826 fn test_strip_proxy_query_param_no_query_string() {
1827 let result = strip_proxy_query_param("/api/v1/pods", "token");
1828 assert_eq!(result, "/api/v1/pods");
1829 }
1830
1831 #[test]
1832 fn test_strip_proxy_artifacts_same_mode_noop() {
1833 let path = "/PHANTOM123/api/v1/pods";
1835 let result = strip_proxy_artifacts(
1836 path,
1837 &InjectMode::UrlPath,
1838 &InjectMode::UrlPath,
1839 Some("/{}/"),
1840 None,
1841 );
1842 assert_eq!(result, path);
1843 }
1844
1845 #[test]
1846 fn test_strip_proxy_artifacts_url_path_to_header() {
1847 let result = strip_proxy_artifacts(
1849 "/PHANTOM123/api/v1/pods",
1850 &InjectMode::UrlPath,
1851 &InjectMode::Header,
1852 Some("/{}/"),
1853 None,
1854 );
1855 assert_eq!(result, "/api/v1/pods");
1856 }
1857
1858 #[test]
1859 fn test_strip_proxy_artifacts_query_param_to_header() {
1860 let result = strip_proxy_artifacts(
1862 "/api/v1/pods?token=PHANTOM123",
1863 &InjectMode::QueryParam,
1864 &InjectMode::Header,
1865 None,
1866 Some("token"),
1867 );
1868 assert_eq!(result, "/api/v1/pods");
1869 }
1870
1871 #[test]
1872 fn test_strip_proxy_artifacts_header_to_query_param() {
1873 let path = "/api/v1/pods";
1875 let result = strip_proxy_artifacts(
1876 path,
1877 &InjectMode::Header,
1878 &InjectMode::QueryParam,
1879 None,
1880 None,
1881 );
1882 assert_eq!(result, path);
1883 }
1884
1885 #[test]
1886 fn test_end_to_end_url_path_proxy_header_upstream() {
1887 let token = Zeroizing::new("session456".to_string());
1890 let credential = Zeroizing::new("real_bearer_token".to_string());
1891 let path = "/session456/api/v1/namespaces";
1892
1893 assert!(validate_phantom_token_for_mode(
1895 &InjectMode::UrlPath,
1896 b"\r\n\r\n", path,
1898 "Authorization",
1899 Some("/{}/"),
1900 None,
1901 &token,
1902 )
1903 .is_ok());
1904
1905 let cleaned = strip_proxy_artifacts(
1907 path,
1908 &InjectMode::UrlPath,
1909 &InjectMode::Header,
1910 Some("/{}/"),
1911 None,
1912 );
1913 assert_eq!(cleaned, "/api/v1/namespaces");
1914
1915 let transformed =
1917 transform_path_for_mode(&InjectMode::Header, &cleaned, None, None, None, &credential)
1918 .unwrap();
1919 assert_eq!(transformed, "/api/v1/namespaces");
1920 }
1921
1922 #[test]
1923 fn test_end_to_end_query_param_proxy_header_upstream() {
1924 let token = Zeroizing::new("session789".to_string());
1926 let credential = Zeroizing::new("real_bearer_token".to_string());
1927 let path = "/api/v1/pods?token=session789&limit=100";
1928
1929 assert!(validate_phantom_token_for_mode(
1931 &InjectMode::QueryParam,
1932 b"\r\n\r\n",
1933 path,
1934 "Authorization",
1935 None,
1936 Some("token"),
1937 &token,
1938 )
1939 .is_ok());
1940
1941 let cleaned = strip_proxy_artifacts(
1943 path,
1944 &InjectMode::QueryParam,
1945 &InjectMode::Header,
1946 None,
1947 Some("token"),
1948 );
1949 assert_eq!(cleaned, "/api/v1/pods?limit=100");
1950
1951 let transformed =
1953 transform_path_for_mode(&InjectMode::Header, &cleaned, None, None, None, &credential)
1954 .unwrap();
1955 assert_eq!(transformed, "/api/v1/pods?limit=100");
1956 }
1957}