1use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::forward::{self, AuditCtx, UpstreamScheme, UpstreamSpec, UpstreamStrategy};
23use crate::route::RouteStore;
24use crate::token;
25use std::net::SocketAddr;
26use tokio::io::AsyncReadExt;
27use tokio::io::AsyncWriteExt;
28use tokio::net::TcpStream;
29use tokio_rustls::TlsConnector;
30use tracing::{debug, warn};
31use zeroize::Zeroizing;
32
33const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
35
36fn auth_mechanism_for_inject_mode(mode: &InjectMode) -> nono::undo::NetworkAuditAuthMechanism {
37 match mode {
38 InjectMode::Header | InjectMode::BasicAuth => {
39 nono::undo::NetworkAuditAuthMechanism::PhantomHeader
40 }
41 InjectMode::UrlPath => nono::undo::NetworkAuditAuthMechanism::PhantomPath,
42 InjectMode::QueryParam => nono::undo::NetworkAuditAuthMechanism::PhantomQuery,
43 }
44}
45
46fn audit_injection_mode_for_inject_mode(
47 mode: &InjectMode,
48) -> nono::undo::NetworkAuditInjectionMode {
49 match mode {
50 InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
51 InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
52 InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
53 InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
54 }
55}
56
57fn proxy_auth_event_ctx<'a>(route_id: &'a str) -> audit::EventContext<'a> {
58 audit::EventContext {
59 route_id: Some(route_id),
60 auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::ProxyAuthorization),
61 ..audit::EventContext::default()
62 }
63}
64
65fn managed_credential_event_ctx<'a>(
66 route_id: &'a str,
67 proxy_mode: &InjectMode,
68 inject_mode: nono::undo::NetworkAuditInjectionMode,
69) -> audit::EventContext<'a> {
70 audit::EventContext {
71 route_id: Some(route_id),
72 auth_mechanism: Some(auth_mechanism_for_inject_mode(proxy_mode)),
73 managed_credential_active: Some(true),
74 injection_mode: Some(inject_mode),
75 ..audit::EventContext::default()
76 }
77}
78
79pub struct ReverseProxyCtx<'a> {
85 pub route_store: &'a RouteStore,
87 pub credential_store: &'a CredentialStore,
89 pub session_token: &'a Zeroizing<String>,
91 pub filter: &'a ProxyFilter,
93 pub tls_connector: &'a TlsConnector,
95 pub audit_log: Option<&'a audit::SharedAuditLog>,
97}
98
99pub async fn handle_reverse_proxy(
113 first_line: &str,
114 stream: &mut TcpStream,
115 remaining_header: &[u8],
116 ctx: &ReverseProxyCtx<'_>,
117 buffered_body: &[u8],
118) -> Result<()> {
119 let (method, path, version) = parse_request_line(first_line)?;
121 debug!("Reverse proxy: {} {}", method, path);
122
123 let (service, upstream_path) = parse_service_prefix(&path)?;
125 let route = ctx
126 .route_store
127 .get(&service)
128 .ok_or_else(|| ProxyError::UnknownService {
129 prefix: service.clone(),
130 })?;
131 let static_cred = ctx.credential_store.get(&service);
132 let oauth2_route = ctx.credential_store.get_oauth2(&service);
133 let managed_ctx = static_cred.map(|cred| {
134 managed_credential_event_ctx(
135 &service,
136 &cred.proxy_inject_mode,
137 audit_injection_mode_for_inject_mode(&cred.inject_mode),
138 )
139 });
140 let oauth2_ctx = oauth2_route.map(|_| audit::EventContext {
141 route_id: Some(&service),
142 auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::PhantomHeader),
143 managed_credential_active: Some(true),
144 injection_mode: Some(nono::undo::NetworkAuditInjectionMode::OAuth2),
145 ..audit::EventContext::default()
146 });
147 let route_ctx = managed_ctx
148 .clone()
149 .or_else(|| oauth2_ctx.clone())
150 .unwrap_or_else(|| audit::EventContext {
151 route_id: Some(&service),
152 managed_credential_active: Some(false),
153 ..audit::EventContext::default()
154 });
155
156 if route.missing_managed_credential(static_cred.is_some(), oauth2_route.is_some()) {
157 let reason = format!(
158 "managed credential unavailable for service '{}': route is configured for proxy-supplied auth",
159 service
160 );
161 warn!("{}", reason);
162 let deny_ctx = audit::EventContext {
163 route_id: Some(&service),
164 auth_mechanism: route.managed_auth_mechanism.clone(),
165 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Failed),
166 managed_credential_active: Some(false),
167 injection_mode: route.managed_injection_mode.clone(),
168 denial_category: Some(
169 nono::undo::NetworkAuditDenialCategory::ManagedCredentialUnavailable,
170 ),
171 };
172 audit::log_denied(
173 ctx.audit_log,
174 audit::ProxyMode::Reverse,
175 &deny_ctx,
176 &service,
177 0,
178 &reason,
179 );
180 send_error(stream, 503, "Service Unavailable").await?;
181 return Ok(());
182 }
183
184 if !route.endpoint_rules.is_allowed(&method, &upstream_path) {
187 let reason = format!(
188 "endpoint denied: {} {} on service '{}'",
189 method, upstream_path, service
190 );
191 warn!("{}", reason);
192 let deny_ctx = audit::EventContext {
193 denial_category: Some(nono::undo::NetworkAuditDenialCategory::EndpointPolicy),
194 ..route_ctx.clone()
195 };
196 audit::log_denied(
197 ctx.audit_log,
198 audit::ProxyMode::Reverse,
199 &deny_ctx,
200 &service,
201 0,
202 &reason,
203 );
204 send_error(stream, 403, "Forbidden").await?;
205 return Ok(());
206 }
207
208 if let Some(oauth2_route) = oauth2_route {
209 return handle_oauth2_credential(
210 oauth2_route,
211 route,
212 &service,
213 &upstream_path,
214 &method,
215 &version,
216 stream,
217 remaining_header,
218 buffered_body,
219 ctx,
220 )
221 .await;
222 }
223
224 let cred = static_cred;
225
226 if let Some(cred) = cred {
230 if let Err(e) = validate_phantom_token_for_mode(
231 &cred.proxy_inject_mode,
232 remaining_header,
233 &upstream_path,
234 &cred.proxy_header_name,
235 cred.proxy_path_pattern.as_deref(),
236 cred.proxy_query_param_name.as_deref(),
237 ctx.session_token,
238 ) {
239 let deny_ctx = audit::EventContext {
240 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Failed),
241 denial_category: Some(nono::undo::NetworkAuditDenialCategory::AuthenticationFailed),
242 ..managed_ctx.clone().unwrap_or_else(|| route_ctx.clone())
243 };
244 audit::log_denied(
245 ctx.audit_log,
246 audit::ProxyMode::Reverse,
247 &deny_ctx,
248 &service,
249 0,
250 &e.to_string(),
251 );
252 send_error(stream, 401, "Unauthorized").await?;
253 return Ok(());
254 }
255 } else if let Err(e) = token::validate_proxy_auth(remaining_header, ctx.session_token) {
256 let deny_ctx = audit::EventContext {
257 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Failed),
258 denial_category: Some(nono::undo::NetworkAuditDenialCategory::AuthenticationFailed),
259 ..proxy_auth_event_ctx(&service)
260 };
261 audit::log_denied(
262 ctx.audit_log,
263 audit::ProxyMode::Reverse,
264 &deny_ctx,
265 &service,
266 0,
267 &e.to_string(),
268 );
269 send_error(stream, 407, "Proxy Authentication Required").await?;
270 return Ok(());
271 }
272
273 let transformed_path = if let Some(cred) = cred {
274 let cleaned_path = strip_proxy_artifacts(
275 &upstream_path,
276 &cred.proxy_inject_mode,
277 &cred.inject_mode,
278 cred.proxy_path_pattern.as_deref(),
279 cred.proxy_query_param_name.as_deref(),
280 );
281 transform_path_for_mode(
282 &cred.inject_mode,
283 &cleaned_path,
284 cred.path_pattern.as_deref(),
285 cred.path_replacement.as_deref(),
286 cred.query_param_name.as_deref(),
287 &cred.raw_credential,
288 )?
289 } else {
290 upstream_path.clone()
291 };
292
293 let upstream_url = format!(
294 "{}{}",
295 route.upstream.trim_end_matches('/'),
296 transformed_path
297 );
298 debug!("Forwarding to upstream: {} {}", method, upstream_url);
299
300 let (upstream_scheme, upstream_host, upstream_port, upstream_path_full) =
301 parse_upstream_url(&upstream_url)?;
302 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
303 if !check.result.is_allowed() {
304 let reason = check.result.reason();
305 warn!("Upstream host denied by filter: {}", reason);
306 send_error(stream, 403, "Forbidden").await?;
307 let deny_ctx = audit::EventContext {
308 denial_category: Some(nono::undo::NetworkAuditDenialCategory::HostDenied),
309 ..route_ctx.clone()
310 };
311 audit::log_denied(
312 ctx.audit_log,
313 audit::ProxyMode::Reverse,
314 &deny_ctx,
315 &service,
316 0,
317 &reason,
318 );
319 return Ok(());
320 }
321 if let Err(reason) =
322 validate_http_upstream_target(upstream_scheme, &upstream_host, &check.resolved_addrs)
323 {
324 warn!("{}", reason);
325 send_error(stream, 502, "Bad Gateway").await?;
326 let deny_ctx = audit::EventContext {
327 denial_category: Some(nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed),
328 ..route_ctx.clone()
329 };
330 audit::log_denied(
331 ctx.audit_log,
332 audit::ProxyMode::Reverse,
333 &deny_ctx,
334 &service,
335 0,
336 &reason,
337 );
338 return Ok(());
339 }
340
341 let success_ctx = if let Some(ctx) = managed_ctx.clone() {
342 audit::EventContext {
343 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
344 ..ctx
345 }
346 } else if oauth2_ctx.is_some() {
347 audit::EventContext {
348 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
349 ..oauth2_ctx.clone().unwrap_or_default()
350 }
351 } else {
352 audit::EventContext {
353 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
354 managed_credential_active: Some(false),
355 ..proxy_auth_event_ctx(&service)
356 }
357 };
358
359 let strip_header = cred.map(|c| c.proxy_header_name.as_str()).unwrap_or("");
360 let filtered_headers = filter_headers(remaining_header, strip_header);
361 let content_length = extract_content_length(remaining_header);
362 let body = match read_request_body(stream, content_length, buffered_body).await? {
363 Some(body) => body,
364 None => return Ok(()),
365 };
366
367 let upstream_authority = format_host_header(upstream_scheme, &upstream_host, upstream_port);
368 let mut request = Zeroizing::new(format!(
369 "{} {} {}\r\nHost: {}\r\n",
370 method, upstream_path_full, version, upstream_authority
371 ));
372
373 if let Some(cred) = cred {
374 inject_credential_for_mode(cred, &mut request);
375 }
376
377 let auth_header_lower = cred.map(|c| c.header_name.to_lowercase());
378 for (name, value) in &filtered_headers {
379 if let (Some(cred), Some(header_lower)) = (cred, auth_header_lower.as_ref())
380 && matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
381 && name.to_lowercase() == *header_lower
382 {
383 continue;
384 }
385 request.push_str(&format!("{}: {}\r\n", name, value));
386 }
387
388 request.push_str("Connection: close\r\n");
389 if !body.is_empty() {
390 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
391 }
392 request.push_str("\r\n");
393
394 let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
395 let upstream_spec = UpstreamSpec {
396 scheme: upstream_scheme,
397 host: &upstream_host,
398 port: upstream_port,
399 strategy: UpstreamStrategy::Direct {
400 resolved_addrs: &check.resolved_addrs,
401 },
402 tls_connector: connector,
403 };
404 let audit_ctx = AuditCtx {
405 log: ctx.audit_log,
406 mode: audit::ProxyMode::Reverse,
407 event_ctx: success_ctx.clone(),
408 target: &service,
409 method: &method,
410 path: &upstream_path,
411 };
412 if let Err(e) =
413 forward::forward_request(stream, request.as_bytes(), &body, upstream_spec, audit_ctx).await
414 {
415 warn!("Upstream connection failed: {}", e);
416 send_error(stream, 502, "Bad Gateway").await?;
417 let deny_ctx = audit::EventContext {
418 denial_category: Some(nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed),
419 ..success_ctx.clone()
420 };
421 audit::log_denied(
422 ctx.audit_log,
423 audit::ProxyMode::Reverse,
424 &deny_ctx,
425 &service,
426 0,
427 &e.to_string(),
428 );
429 }
430 Ok(())
431}
432
433#[allow(clippy::too_many_arguments)]
440async fn handle_oauth2_credential(
441 oauth2_route: &crate::credential::OAuth2Route,
442 route: &crate::route::LoadedRoute,
443 service: &str,
444 upstream_path: &str,
445 method: &str,
446 version: &str,
447 stream: &mut TcpStream,
448 remaining_header: &[u8],
449 buffered_body: &[u8],
450 ctx: &ReverseProxyCtx<'_>,
451) -> Result<()> {
452 let access_token = oauth2_route.cache.get_or_refresh().await;
454
455 if let Err(e) = validate_phantom_token(remaining_header, "Authorization", ctx.session_token) {
459 let deny_ctx = audit::EventContext {
460 route_id: Some(service),
461 auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::PhantomHeader),
462 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Failed),
463 managed_credential_active: Some(true),
464 injection_mode: Some(nono::undo::NetworkAuditInjectionMode::OAuth2),
465 denial_category: Some(nono::undo::NetworkAuditDenialCategory::AuthenticationFailed),
466 };
467 audit::log_denied(
468 ctx.audit_log,
469 audit::ProxyMode::Reverse,
470 &deny_ctx,
471 service,
472 0,
473 &e.to_string(),
474 );
475 send_error(stream, 401, "Unauthorized").await?;
476 return Ok(());
477 }
478
479 let upstream_url = format!(
480 "{}{}",
481 oauth2_route.upstream.trim_end_matches('/'),
482 upstream_path
483 );
484 debug!("OAuth2 forwarding to upstream: {} {}", method, upstream_url);
485
486 let (upstream_scheme, upstream_host, upstream_port, upstream_path_full) =
487 parse_upstream_url(&upstream_url)?;
488 let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
490 if !check.result.is_allowed() {
491 let reason = check.result.reason();
492 warn!("Upstream host denied by filter: {}", reason);
493 send_error(stream, 403, "Forbidden").await?;
494 let route_ctx = audit::EventContext {
495 route_id: Some(service),
496 managed_credential_active: Some(true),
497 injection_mode: Some(nono::undo::NetworkAuditInjectionMode::OAuth2),
498 denial_category: Some(nono::undo::NetworkAuditDenialCategory::HostDenied),
499 ..audit::EventContext::default()
500 };
501 audit::log_denied(
502 ctx.audit_log,
503 audit::ProxyMode::Reverse,
504 &route_ctx,
505 service,
506 0,
507 &reason,
508 );
509 return Ok(());
510 }
511 if let Err(reason) =
512 validate_http_upstream_target(upstream_scheme, &upstream_host, &check.resolved_addrs)
513 {
514 warn!("{}", reason);
515 send_error(stream, 502, "Bad Gateway").await?;
516 let route_ctx = audit::EventContext {
517 route_id: Some(service),
518 managed_credential_active: Some(true),
519 injection_mode: Some(nono::undo::NetworkAuditInjectionMode::OAuth2),
520 denial_category: Some(nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed),
521 ..audit::EventContext::default()
522 };
523 audit::log_denied(
524 ctx.audit_log,
525 audit::ProxyMode::Reverse,
526 &route_ctx,
527 service,
528 0,
529 &reason,
530 );
531 return Ok(());
532 }
533
534 let filtered_headers = filter_headers(remaining_header, "Authorization");
537 let content_length = extract_content_length(remaining_header);
538
539 let body = match read_request_body(stream, content_length, buffered_body).await? {
541 Some(body) => body,
542 None => return Ok(()),
543 };
544
545 let upstream_authority = format_host_header(upstream_scheme, &upstream_host, upstream_port);
547 let mut request = Zeroizing::new(format!(
548 "{} {} {}\r\nHost: {}\r\n",
549 method, upstream_path_full, version, upstream_authority
550 ));
551
552 request.push_str(&format!(
554 "Authorization: Bearer {}\r\n",
555 access_token.as_str()
556 ));
557
558 for (name, value) in &filtered_headers {
560 request.push_str(&format!("{}: {}\r\n", name, value));
561 }
562
563 if !body.is_empty() {
564 request.push_str(&format!("Content-Length: {}\r\n", body.len()));
565 }
566 request.push_str("\r\n");
567
568 let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
569 let upstream_spec = UpstreamSpec {
570 scheme: upstream_scheme,
571 host: &upstream_host,
572 port: upstream_port,
573 strategy: UpstreamStrategy::Direct {
574 resolved_addrs: &check.resolved_addrs,
575 },
576 tls_connector: connector,
577 };
578 let audit_ctx = AuditCtx {
579 log: ctx.audit_log,
580 mode: audit::ProxyMode::Reverse,
581 event_ctx: audit::EventContext {
582 route_id: Some(service),
583 auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::PhantomHeader),
584 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
585 managed_credential_active: Some(true),
586 injection_mode: Some(nono::undo::NetworkAuditInjectionMode::OAuth2),
587 denial_category: None,
588 },
589 target: service,
590 method,
591 path: upstream_path,
592 };
593 if let Err(e) =
594 forward::forward_request(stream, request.as_bytes(), &body, upstream_spec, audit_ctx).await
595 {
596 warn!("Upstream connection failed: {}", e);
597 send_error(stream, 502, "Bad Gateway").await?;
598 audit::log_denied(
599 ctx.audit_log,
600 audit::ProxyMode::Reverse,
601 &audit::EventContext {
602 route_id: Some(service),
603 auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::PhantomHeader),
604 auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
605 managed_credential_active: Some(true),
606 injection_mode: Some(nono::undo::NetworkAuditInjectionMode::OAuth2),
607 denial_category: Some(
608 nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed,
609 ),
610 },
611 service,
612 0,
613 &e.to_string(),
614 );
615 }
616 Ok(())
617}
618
619pub(crate) async fn read_request_body<S>(
625 stream: &mut S,
626 content_length: Option<usize>,
627 buffered_body: &[u8],
628) -> Result<Option<Vec<u8>>>
629where
630 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
631{
632 if let Some(len) = content_length {
633 if len > MAX_REQUEST_BODY {
634 send_error_generic(stream, 413, "Payload Too Large").await?;
635 return Ok(None);
636 }
637 let mut buf = Vec::with_capacity(len);
638 let pre = buffered_body.len().min(len);
639 buf.extend_from_slice(&buffered_body[..pre]);
640 let remaining = len - pre;
641 if remaining > 0 {
642 let mut rest = vec![0u8; remaining];
643 stream.read_exact(&mut rest).await?;
644 buf.extend_from_slice(&rest);
645 }
646 Ok(Some(buf))
647 } else {
648 Ok(Some(Vec::new()))
649 }
650}
651
652pub(crate) async fn send_error_generic<S>(stream: &mut S, status: u16, reason: &str) -> Result<()>
654where
655 S: tokio::io::AsyncWrite + Unpin,
656{
657 let body = format!("{{\"error\":\"{}\"}}", reason);
658 let response = format!(
659 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
660 status,
661 reason,
662 body.len(),
663 body
664 );
665 stream.write_all(response.as_bytes()).await?;
666 stream.flush().await?;
667 Ok(())
668}
669
670fn parse_request_line(line: &str) -> Result<(String, String, String)> {
672 let parts: Vec<&str> = line.split_whitespace().collect();
673 if parts.len() < 3 {
674 return Err(ProxyError::HttpParse(format!(
675 "malformed request line: {}",
676 line
677 )));
678 }
679 Ok((
680 parts[0].to_string(),
681 parts[1].to_string(),
682 parts[2].to_string(),
683 ))
684}
685
686fn parse_service_prefix(path: &str) -> Result<(String, String)> {
691 let trimmed = path.strip_prefix('/').unwrap_or(path);
692 if let Some((prefix, rest)) = trimmed.split_once('/') {
693 Ok((prefix.to_string(), format!("/{}", rest)))
694 } else {
695 Ok((trimmed.to_string(), "/".to_string()))
697 }
698}
699
700fn validate_phantom_token(
707 header_bytes: &[u8],
708 header_name: &str,
709 session_token: &Zeroizing<String>,
710) -> Result<()> {
711 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
712 let header_name_lower = header_name.to_lowercase();
713
714 for line in header_str.lines() {
715 let lower = line.to_lowercase();
716 if lower.starts_with(&format!("{}:", header_name_lower)) {
717 let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
718
719 let value_lower = value.to_lowercase();
722 let token_value = if value_lower.starts_with("bearer ") {
723 value[7..].trim()
725 } else {
726 value
727 };
728
729 if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
730 return Ok(());
731 }
732 warn!("Invalid phantom token in {} header", header_name);
733 return Err(ProxyError::InvalidToken);
734 }
735 }
736
737 warn!(
738 "Missing {} header for phantom token validation",
739 header_name
740 );
741 Err(ProxyError::InvalidToken)
742}
743
744pub(crate) fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
756 let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
757 let cred_header_lower = if cred_header.is_empty() {
758 String::new()
759 } else {
760 format!("{}:", cred_header.to_lowercase())
761 };
762 let mut headers = Vec::new();
763
764 for line in header_str.lines() {
765 let lower = line.to_lowercase();
766 if lower.starts_with("host:")
767 || lower.starts_with("content-length:")
768 || lower.starts_with("connection:")
769 || lower.starts_with("proxy-authorization:")
770 || (!cred_header_lower.is_empty() && lower.starts_with(&cred_header_lower))
771 || line.trim().is_empty()
772 {
773 continue;
774 }
775 if let Some((name, value)) = line.split_once(':') {
776 headers.push((name.trim().to_string(), value.trim().to_string()));
777 }
778 }
779
780 headers
781}
782
783pub(crate) fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
785 let header_str = std::str::from_utf8(header_bytes).ok()?;
786 for line in header_str.lines() {
787 if line.to_lowercase().starts_with("content-length:") {
788 let value = line.split_once(':')?.1.trim();
789 return value.parse().ok();
790 }
791 }
792 None
793}
794
795fn validate_http_upstream_target(
796 scheme: UpstreamScheme,
797 host: &str,
798 resolved_addrs: &[SocketAddr],
799) -> std::result::Result<(), String> {
800 if matches!(scheme, UpstreamScheme::Https) {
801 return Ok(());
802 }
803
804 if is_local_only_target(host, resolved_addrs) {
805 Ok(())
806 } else {
807 Err(format!(
808 "refusing insecure http upstream for non-local host '{}'; http is only allowed for loopback addresses",
809 host
810 ))
811 }
812}
813
814fn is_local_only_target(host: &str, resolved_addrs: &[SocketAddr]) -> bool {
815 if !resolved_addrs.is_empty() {
816 return resolved_addrs.iter().all(|addr| addr.ip().is_loopback());
817 }
818
819 match host.parse::<std::net::IpAddr>() {
820 Ok(std::net::IpAddr::V4(ip)) => ip.is_loopback(),
821 Ok(std::net::IpAddr::V6(ip)) => ip.is_loopback(),
822 Err(_) => false,
823 }
824}
825
826pub(crate) fn format_host_header(scheme: UpstreamScheme, host: &str, port: u16) -> String {
827 let default_port = match scheme {
828 UpstreamScheme::Http => 80,
829 UpstreamScheme::Https => 443,
830 };
831 let bracketed_host = if host.contains(':') && !host.starts_with('[') {
832 format!("[{}]", host)
833 } else {
834 host.to_string()
835 };
836
837 if port == default_port {
838 bracketed_host
839 } else {
840 format!("{}:{}", bracketed_host, port)
841 }
842}
843
844fn parse_upstream_url(url_str: &str) -> Result<(UpstreamScheme, String, u16, String)> {
845 let parsed = url::Url::parse(url_str)
846 .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
847
848 let scheme = match parsed.scheme() {
849 "https" => UpstreamScheme::Https,
850 "http" => UpstreamScheme::Http,
851 _ => {
852 return Err(ProxyError::HttpParse(format!(
853 "unsupported URL scheme: {}",
854 url_str
855 )));
856 }
857 };
858
859 let host = parsed
860 .host_str()
861 .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
862 .to_string();
863
864 let default_port = if matches!(scheme, UpstreamScheme::Https) {
865 443
866 } else {
867 80
868 };
869 let port = parsed.port().unwrap_or(default_port);
870
871 let path = parsed.path().to_string();
872 let path = if path.is_empty() {
873 "/".to_string()
874 } else {
875 path
876 };
877
878 let path_with_query = if let Some(query) = parsed.query() {
880 format!("{}?{}", path, query)
881 } else {
882 path
883 };
884
885 Ok((scheme, host, port, path_with_query))
886}
887
888async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
890 let body = format!("{{\"error\":\"{}\"}}", reason);
891 let response = format!(
892 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
893 status,
894 reason,
895 body.len(),
896 body
897 );
898 stream.write_all(response.as_bytes()).await?;
899 stream.flush().await?;
900 Ok(())
901}
902
903pub(crate) fn validate_phantom_token_for_mode(
914 mode: &InjectMode,
915 header_bytes: &[u8],
916 path: &str,
917 header_name: &str,
918 path_pattern: Option<&str>,
919 query_param_name: Option<&str>,
920 session_token: &Zeroizing<String>,
921) -> Result<()> {
922 match mode {
923 InjectMode::Header | InjectMode::BasicAuth => {
924 validate_phantom_token(header_bytes, header_name, session_token)
926 }
927 InjectMode::UrlPath => {
928 let pattern = path_pattern.ok_or_else(|| {
930 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
931 })?;
932 validate_phantom_token_in_path(path, pattern, session_token)
933 }
934 InjectMode::QueryParam => {
935 let param_name = query_param_name.ok_or_else(|| {
937 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
938 })?;
939 validate_phantom_token_in_query(path, param_name, session_token)
940 }
941 }
942}
943
944fn validate_phantom_token_in_path(
949 path: &str,
950 pattern: &str,
951 session_token: &Zeroizing<String>,
952) -> Result<()> {
953 let parts: Vec<&str> = pattern.split("{}").collect();
955 if parts.len() != 2 {
956 return Err(ProxyError::HttpParse(format!(
957 "invalid path_pattern '{}': must contain exactly one {{}}",
958 pattern
959 )));
960 }
961 let (prefix, suffix) = (parts[0], parts[1]);
962
963 if let Some(start) = path.find(prefix) {
965 let after_prefix = start + prefix.len();
966
967 let end_offset = if suffix.is_empty() {
969 path[after_prefix..]
970 .find(['/', '?'])
971 .unwrap_or(path[after_prefix..].len())
972 } else {
973 match path[after_prefix..].find(suffix) {
974 Some(offset) => offset,
975 None => {
976 warn!("Missing phantom token in URL path (pattern: {})", pattern);
977 return Err(ProxyError::InvalidToken);
978 }
979 }
980 };
981
982 let token = &path[after_prefix..after_prefix + end_offset];
983 if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
984 return Ok(());
985 }
986 warn!("Invalid phantom token in URL path");
987 return Err(ProxyError::InvalidToken);
988 }
989
990 warn!("Missing phantom token in URL path (pattern: {})", pattern);
991 Err(ProxyError::InvalidToken)
992}
993
994fn validate_phantom_token_in_query(
996 path: &str,
997 param_name: &str,
998 session_token: &Zeroizing<String>,
999) -> Result<()> {
1000 if let Some(query_start) = path.find('?') {
1002 let query = &path[query_start + 1..];
1003 for pair in query.split('&') {
1004 if let Some((name, value)) = pair.split_once('=')
1005 && name == param_name
1006 {
1007 let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
1008 if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
1009 return Ok(());
1010 }
1011 warn!("Invalid phantom token in query parameter '{}'", param_name);
1012 return Err(ProxyError::InvalidToken);
1013 }
1014 }
1015 }
1016
1017 warn!("Missing phantom token in query parameter '{}'", param_name);
1018 Err(ProxyError::InvalidToken)
1019}
1020
1021pub(crate) fn transform_path_for_mode(
1027 mode: &InjectMode,
1028 path: &str,
1029 path_pattern: Option<&str>,
1030 path_replacement: Option<&str>,
1031 query_param_name: Option<&str>,
1032 credential: &Zeroizing<String>,
1033) -> Result<String> {
1034 match mode {
1035 InjectMode::Header | InjectMode::BasicAuth => {
1036 Ok(path.to_string())
1038 }
1039 InjectMode::UrlPath => {
1040 let pattern = path_pattern.ok_or_else(|| {
1041 ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
1042 })?;
1043 let replacement = path_replacement.unwrap_or(pattern);
1044 transform_url_path(path, pattern, replacement, credential)
1045 }
1046 InjectMode::QueryParam => {
1047 let param_name = query_param_name.ok_or_else(|| {
1048 ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
1049 })?;
1050 transform_query_param(path, param_name, credential)
1051 }
1052 }
1053}
1054
1055fn transform_url_path(
1059 path: &str,
1060 pattern: &str,
1061 replacement: &str,
1062 credential: &Zeroizing<String>,
1063) -> Result<String> {
1064 let parts: Vec<&str> = pattern.split("{}").collect();
1066 if parts.len() != 2 {
1067 return Err(ProxyError::HttpParse(format!(
1068 "invalid path_pattern '{}': must contain exactly one {{}}",
1069 pattern
1070 )));
1071 }
1072 let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
1073
1074 let repl_parts: Vec<&str> = replacement.split("{}").collect();
1076 if repl_parts.len() != 2 {
1077 return Err(ProxyError::HttpParse(format!(
1078 "invalid path_replacement '{}': must contain exactly one {{}}",
1079 replacement
1080 )));
1081 }
1082 let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
1083
1084 if let Some(start) = path.find(pattern_prefix) {
1086 let after_prefix = start + pattern_prefix.len();
1087
1088 let end_offset = if pattern_suffix.is_empty() {
1090 path[after_prefix..]
1092 .find(['/', '?'])
1093 .unwrap_or(path[after_prefix..].len())
1094 } else {
1095 match path[after_prefix..].find(pattern_suffix) {
1097 Some(offset) => offset,
1098 None => {
1099 return Err(ProxyError::HttpParse(format!(
1100 "path '{}' does not match pattern '{}'",
1101 path, pattern
1102 )));
1103 }
1104 }
1105 };
1106
1107 let before = &path[..start];
1108 let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
1109 return Ok(format!(
1110 "{}{}{}{}{}",
1111 before,
1112 repl_prefix,
1113 credential.as_str(),
1114 repl_suffix,
1115 after
1116 ));
1117 }
1118
1119 Err(ProxyError::HttpParse(format!(
1120 "path '{}' does not match pattern '{}'",
1121 path, pattern
1122 )))
1123}
1124
1125fn transform_query_param(
1127 path: &str,
1128 param_name: &str,
1129 credential: &Zeroizing<String>,
1130) -> Result<String> {
1131 let encoded_value = urlencoding::encode(credential.as_str());
1132
1133 if let Some(query_start) = path.find('?') {
1134 let base_path = &path[..query_start];
1135 let query = &path[query_start + 1..];
1136
1137 let mut found = false;
1139 let new_query: Vec<String> = query
1140 .split('&')
1141 .map(|pair| {
1142 if let Some((name, _)) = pair.split_once('=')
1143 && name == param_name
1144 {
1145 found = true;
1146 return format!("{}={}", param_name, encoded_value);
1147 }
1148 pair.to_string()
1149 })
1150 .collect();
1151
1152 if found {
1153 Ok(format!("{}?{}", base_path, new_query.join("&")))
1154 } else {
1155 Ok(format!(
1157 "{}?{}&{}={}",
1158 base_path, query, param_name, encoded_value
1159 ))
1160 }
1161 } else {
1162 Ok(format!("{}?{}={}", path, param_name, encoded_value))
1164 }
1165}
1166
1167pub(crate) fn strip_proxy_artifacts(
1178 path: &str,
1179 proxy_mode: &InjectMode,
1180 upstream_mode: &InjectMode,
1181 proxy_path_pattern: Option<&str>,
1182 proxy_query_param_name: Option<&str>,
1183) -> String {
1184 if proxy_mode == upstream_mode {
1187 return path.to_string();
1188 }
1189
1190 match proxy_mode {
1191 InjectMode::UrlPath => {
1192 if let Some(pattern) = proxy_path_pattern {
1193 strip_proxy_path_token(path, pattern)
1194 } else {
1195 path.to_string()
1196 }
1197 }
1198 InjectMode::QueryParam => {
1199 if let Some(param_name) = proxy_query_param_name {
1200 strip_proxy_query_param(path, param_name)
1201 } else {
1202 path.to_string()
1203 }
1204 }
1205 InjectMode::Header | InjectMode::BasicAuth => path.to_string(),
1207 }
1208}
1209
1210fn strip_proxy_path_token(path: &str, pattern: &str) -> String {
1214 let parts: Vec<&str> = pattern.split("{}").collect();
1215 if parts.len() != 2 {
1216 return path.to_string();
1217 }
1218 let (prefix, suffix) = (parts[0], parts[1]);
1219
1220 let start = if path.starts_with(prefix) {
1224 Some(0)
1225 } else {
1226 path.find(prefix)
1227 };
1228
1229 if let Some(start) = start {
1230 let after_prefix = start + prefix.len();
1231 let end_offset = if suffix.is_empty() {
1232 path[after_prefix..]
1233 .find(['/', '?'])
1234 .unwrap_or(path[after_prefix..].len())
1235 } else {
1236 match path[after_prefix..].find(suffix) {
1237 Some(offset) => offset,
1238 None => return path.to_string(),
1239 }
1240 };
1241
1242 let before = &path[..start];
1243 let after = &path[after_prefix + end_offset + suffix.len()..];
1244
1245 let joined = match (before.ends_with('/'), after.starts_with('/')) {
1249 (true, true) => format!("{}{}", before, &after[1..]),
1250 (false, false) if !before.is_empty() && !after.is_empty() => {
1251 format!("{}/{}", before, after)
1252 }
1253 _ => format!("{}{}", before, after),
1254 };
1255
1256 if joined.is_empty() || !joined.starts_with('/') {
1257 format!("/{}", joined)
1258 } else {
1259 joined
1260 }
1261 } else {
1262 path.to_string()
1263 }
1264}
1265
1266fn strip_proxy_query_param(path: &str, param_name: &str) -> String {
1270 if let Some(query_start) = path.find('?') {
1271 let base_path = &path[..query_start];
1272 let query = &path[query_start + 1..];
1273
1274 let remaining: Vec<&str> = query
1275 .split('&')
1276 .filter(|pair| {
1277 pair.split_once('=')
1278 .map(|(name, _)| name != param_name)
1279 .unwrap_or(true)
1280 })
1281 .collect();
1282
1283 if remaining.is_empty() {
1284 base_path.to_string()
1285 } else {
1286 format!("{}?{}", base_path, remaining.join("&"))
1287 }
1288 } else {
1289 path.to_string()
1290 }
1291}
1292
1293pub(crate) fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
1298 match cred.inject_mode {
1299 InjectMode::Header | InjectMode::BasicAuth => {
1300 request.push_str(&format!(
1302 "{}: {}\r\n",
1303 cred.header_name,
1304 cred.header_value.as_str()
1305 ));
1306 }
1307 InjectMode::UrlPath | InjectMode::QueryParam => {
1308 }
1311 }
1312}
1313
1314#[cfg(test)]
1315#[allow(clippy::unwrap_used)]
1316mod tests {
1317 use super::*;
1318
1319 #[test]
1320 fn test_parse_request_line() {
1321 let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
1322 assert_eq!(method, "POST");
1323 assert_eq!(path, "/openai/v1/chat");
1324 assert_eq!(version, "HTTP/1.1");
1325 }
1326
1327 #[test]
1328 fn test_parse_request_line_malformed() {
1329 assert!(parse_request_line("GET").is_err());
1330 }
1331
1332 #[test]
1333 fn test_parse_service_prefix() {
1334 let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
1335 assert_eq!(service, "openai");
1336 assert_eq!(path, "/v1/chat/completions");
1337 }
1338
1339 #[test]
1340 fn test_parse_service_prefix_no_subpath() {
1341 let (service, path) = parse_service_prefix("/anthropic").unwrap();
1342 assert_eq!(service, "anthropic");
1343 assert_eq!(path, "/");
1344 }
1345
1346 #[test]
1347 fn test_validate_phantom_token_bearer_valid() {
1348 let token = Zeroizing::new("secret123".to_string());
1349 let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
1350 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
1351 }
1352
1353 #[test]
1354 fn test_validate_phantom_token_bearer_invalid() {
1355 let token = Zeroizing::new("secret123".to_string());
1356 let header = b"Authorization: Bearer wrong\r\n\r\n";
1357 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
1358 }
1359
1360 #[test]
1361 fn test_validate_phantom_token_x_api_key_valid() {
1362 let token = Zeroizing::new("secret123".to_string());
1363 let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
1364 assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
1365 }
1366
1367 #[test]
1368 fn test_validate_phantom_token_x_goog_api_key_valid() {
1369 let token = Zeroizing::new("secret123".to_string());
1370 let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
1371 assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
1372 }
1373
1374 #[test]
1375 fn test_validate_phantom_token_missing() {
1376 let token = Zeroizing::new("secret123".to_string());
1377 let header = b"Content-Type: application/json\r\n\r\n";
1378 assert!(validate_phantom_token(header, "Authorization", &token).is_err());
1379 }
1380
1381 #[test]
1382 fn test_validate_phantom_token_case_insensitive_header() {
1383 let token = Zeroizing::new("secret123".to_string());
1384 let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
1385 assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
1386 }
1387
1388 #[test]
1389 fn test_filter_headers_removes_host_auth() {
1390 let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
1391 let filtered = filter_headers(header, "Authorization");
1392 assert_eq!(filtered.len(), 2);
1393 assert_eq!(filtered[0].0, "Content-Type");
1394 assert_eq!(filtered[1].0, "Accept");
1395 }
1396
1397 #[test]
1398 fn test_filter_headers_removes_x_api_key() {
1399 let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
1400 let filtered = filter_headers(header, "x-api-key");
1401 assert_eq!(filtered.len(), 1);
1402 assert_eq!(filtered[0].0, "Content-Type");
1403 }
1404
1405 #[test]
1406 fn test_filter_headers_removes_custom_header() {
1407 let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
1408 let filtered = filter_headers(header, "PRIVATE-TOKEN");
1409 assert_eq!(filtered.len(), 1);
1410 assert_eq!(filtered[0].0, "Content-Type");
1411 }
1412
1413 #[test]
1414 fn test_extract_content_length() {
1415 let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
1416 assert_eq!(extract_content_length(header), Some(42));
1417 }
1418
1419 #[test]
1420 fn test_extract_content_length_missing() {
1421 let header = b"Content-Type: application/json\r\n\r\n";
1422 assert_eq!(extract_content_length(header), None);
1423 }
1424
1425 #[test]
1426 fn test_parse_upstream_url_https() {
1427 let (scheme, host, port, path) =
1428 parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
1429 assert_eq!(scheme, UpstreamScheme::Https);
1430 assert_eq!(host, "api.openai.com");
1431 assert_eq!(port, 443);
1432 assert_eq!(path, "/v1/chat/completions");
1433 }
1434
1435 #[test]
1436 fn test_parse_upstream_url_http_with_port() {
1437 let (scheme, host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
1438 assert_eq!(scheme, UpstreamScheme::Http);
1439 assert_eq!(host, "localhost");
1440 assert_eq!(port, 8080);
1441 assert_eq!(path, "/api");
1442 }
1443
1444 #[test]
1445 fn test_parse_upstream_url_no_path() {
1446 let (scheme, host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
1447 assert_eq!(scheme, UpstreamScheme::Https);
1448 assert_eq!(host, "api.anthropic.com");
1449 assert_eq!(port, 443);
1450 assert_eq!(path, "/");
1451 }
1452
1453 #[test]
1454 fn test_parse_upstream_url_invalid_scheme() {
1455 assert!(parse_upstream_url("ftp://example.com").is_err());
1456 }
1457
1458 #[test]
1459 fn test_validate_http_upstream_target_rejects_non_local_host() {
1460 let err = validate_http_upstream_target(UpstreamScheme::Http, "api.example.com", &[])
1461 .expect_err("non-local http upstream should be rejected");
1462 assert!(err.contains("refusing insecure http upstream"));
1463 }
1464
1465 #[test]
1466 fn test_validate_http_upstream_target_allows_loopback() {
1467 let loopback = [SocketAddr::from(([127, 0, 0, 1], 8080))];
1468 assert!(validate_http_upstream_target(UpstreamScheme::Http, "127.0.0.1", &[]).is_ok());
1469 assert!(validate_http_upstream_target(UpstreamScheme::Http, "::1", &[]).is_ok());
1470 assert!(
1471 validate_http_upstream_target(UpstreamScheme::Http, "localhost", &loopback).is_ok()
1472 );
1473 }
1474
1475 #[test]
1476 fn test_validate_http_upstream_target_rejects_unspecified_addresses() {
1477 let unspecified = [SocketAddr::from(([0, 0, 0, 0], 8080))];
1478 let err = validate_http_upstream_target(UpstreamScheme::Http, "0.0.0.0", &[])
1479 .expect_err("unspecified http upstream should be rejected");
1480 assert!(err.contains("loopback addresses"));
1481
1482 let err = validate_http_upstream_target(UpstreamScheme::Http, "localhost", &unspecified)
1483 .expect_err("localhost resolving to unspecified should be rejected");
1484 assert!(err.contains("loopback addresses"));
1485 }
1486
1487 #[test]
1488 fn test_validate_http_upstream_target_rejects_localhost_resolving_non_loopback() {
1489 let poisoned = [SocketAddr::from(([203, 0, 113, 10], 8080))];
1490 let err = validate_http_upstream_target(UpstreamScheme::Http, "localhost", &poisoned)
1491 .expect_err("localhost resolving off-host should be rejected");
1492 assert!(err.contains("refusing insecure http upstream"));
1493 }
1494
1495 #[test]
1496 fn test_format_host_header_uses_port_for_non_default_http() {
1497 assert_eq!(
1498 format_host_header(UpstreamScheme::Http, "localhost", 8080),
1499 "localhost:8080"
1500 );
1501 }
1502
1503 #[test]
1504 fn test_format_host_header_omits_default_https_port() {
1505 assert_eq!(
1506 format_host_header(UpstreamScheme::Https, "api.openai.com", 443),
1507 "api.openai.com"
1508 );
1509 }
1510
1511 #[test]
1512 fn test_format_host_header_brackets_ipv6() {
1513 assert_eq!(
1514 format_host_header(UpstreamScheme::Http, "::1", 8080),
1515 "[::1]:8080"
1516 );
1517 }
1518
1519 #[test]
1527 fn test_validate_phantom_token_in_path_valid() {
1528 let token = Zeroizing::new("session123".to_string());
1529 let path = "/bot/session123/getMe";
1530 let pattern = "/bot/{}/";
1531 assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1532 }
1533
1534 #[test]
1535 fn test_validate_phantom_token_in_path_invalid() {
1536 let token = Zeroizing::new("session123".to_string());
1537 let path = "/bot/wrong_token/getMe";
1538 let pattern = "/bot/{}/";
1539 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1540 }
1541
1542 #[test]
1543 fn test_validate_phantom_token_in_path_missing() {
1544 let token = Zeroizing::new("session123".to_string());
1545 let path = "/api/getMe";
1546 let pattern = "/bot/{}/";
1547 assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1548 }
1549
1550 #[test]
1551 fn test_transform_url_path_basic() {
1552 let credential = Zeroizing::new("real_token".to_string());
1553 let path = "/bot/phantom_token/getMe";
1554 let pattern = "/bot/{}/";
1555 let replacement = "/bot/{}/";
1556 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1557 assert_eq!(result, "/bot/real_token/getMe");
1558 }
1559
1560 #[test]
1561 fn test_transform_url_path_different_replacement() {
1562 let credential = Zeroizing::new("real_token".to_string());
1563 let path = "/api/v1/phantom_token/chat";
1564 let pattern = "/api/v1/{}/";
1565 let replacement = "/v2/bot/{}/";
1566 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1567 assert_eq!(result, "/v2/bot/real_token/chat");
1568 }
1569
1570 #[test]
1571 fn test_transform_url_path_no_trailing_slash() {
1572 let credential = Zeroizing::new("real_token".to_string());
1573 let path = "/bot/phantom_token";
1574 let pattern = "/bot/{}";
1575 let replacement = "/bot/{}";
1576 let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1577 assert_eq!(result, "/bot/real_token");
1578 }
1579
1580 #[test]
1585 fn test_validate_phantom_token_in_query_valid() {
1586 let token = Zeroizing::new("session123".to_string());
1587 let path = "/api/data?api_key=session123&other=value";
1588 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1589 }
1590
1591 #[test]
1592 fn test_validate_phantom_token_in_query_invalid() {
1593 let token = Zeroizing::new("session123".to_string());
1594 let path = "/api/data?api_key=wrong_token";
1595 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1596 }
1597
1598 #[test]
1599 fn test_validate_phantom_token_in_query_missing_param() {
1600 let token = Zeroizing::new("session123".to_string());
1601 let path = "/api/data?other=value";
1602 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1603 }
1604
1605 #[test]
1606 fn test_validate_phantom_token_in_query_no_query_string() {
1607 let token = Zeroizing::new("session123".to_string());
1608 let path = "/api/data";
1609 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1610 }
1611
1612 #[test]
1613 fn test_validate_phantom_token_in_query_url_encoded() {
1614 let token = Zeroizing::new("token with spaces".to_string());
1615 let path = "/api/data?api_key=token%20with%20spaces";
1616 assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1617 }
1618
1619 #[test]
1620 fn test_transform_query_param_add_to_no_query() {
1621 let credential = Zeroizing::new("real_key".to_string());
1622 let path = "/api/data";
1623 let result = transform_query_param(path, "api_key", &credential).unwrap();
1624 assert_eq!(result, "/api/data?api_key=real_key");
1625 }
1626
1627 #[test]
1628 fn test_transform_query_param_add_to_existing_query() {
1629 let credential = Zeroizing::new("real_key".to_string());
1630 let path = "/api/data?other=value";
1631 let result = transform_query_param(path, "api_key", &credential).unwrap();
1632 assert_eq!(result, "/api/data?other=value&api_key=real_key");
1633 }
1634
1635 #[test]
1636 fn test_transform_query_param_replace_existing() {
1637 let credential = Zeroizing::new("real_key".to_string());
1638 let path = "/api/data?api_key=phantom&other=value";
1639 let result = transform_query_param(path, "api_key", &credential).unwrap();
1640 assert_eq!(result, "/api/data?api_key=real_key&other=value");
1641 }
1642
1643 #[test]
1644 fn test_transform_query_param_url_encodes_special_chars() {
1645 let credential = Zeroizing::new("key with spaces".to_string());
1646 let path = "/api/data";
1647 let result = transform_query_param(path, "api_key", &credential).unwrap();
1648 assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1649 }
1650
1651 #[test]
1652 fn test_validate_phantom_token_uses_proxy_mode_over_upstream_mode() {
1653 let token = Zeroizing::new("session123".to_string());
1654 let header = b"Authorization: Bearer session123\r\n\r\n";
1655 let path = "/api/data?api_key=wrong";
1656
1657 let result = validate_phantom_token_for_mode(
1660 &InjectMode::Header,
1661 header,
1662 path,
1663 "Authorization",
1664 None,
1665 Some("api_key"),
1666 &token,
1667 );
1668
1669 assert!(result.is_ok());
1670 }
1671
1672 #[test]
1673 fn test_transform_path_uses_upstream_mode_independently() {
1674 let credential = Zeroizing::new("real_key".to_string());
1675 let path = "/api/data?api_key=phantom";
1676
1677 let transformed = transform_path_for_mode(
1679 &InjectMode::QueryParam,
1680 path,
1681 None,
1682 None,
1683 Some("api_key"),
1684 &credential,
1685 )
1686 .expect("query-param transform should succeed");
1687
1688 assert_eq!(transformed, "/api/data?api_key=real_key");
1689 }
1690
1691 #[test]
1696 fn test_strip_proxy_path_token_basic() {
1697 let result = strip_proxy_path_token("/PHANTOM123/api/v1/pods", "/{}/");
1699 assert_eq!(result, "/api/v1/pods");
1700 }
1701
1702 #[test]
1703 fn test_strip_proxy_path_token_nested_pattern() {
1704 let result = strip_proxy_path_token("/auth/PHANTOM123/api/v1/pods", "/auth/{}/");
1706 assert_eq!(result, "/api/v1/pods");
1707 }
1708
1709 #[test]
1710 fn test_strip_proxy_path_token_no_trailing_slash() {
1711 let result = strip_proxy_path_token("/PHANTOM123", "/{}");
1713 assert_eq!(result, "/");
1714 }
1715
1716 #[test]
1717 fn test_strip_proxy_path_token_preserves_query() {
1718 let result = strip_proxy_path_token("/PHANTOM123/api?limit=10", "/{}/");
1720 assert_eq!(result, "/api?limit=10");
1721 }
1722
1723 #[test]
1724 fn test_strip_proxy_path_token_no_match() {
1725 let result = strip_proxy_path_token("/api/v1/pods", "/auth/{}/");
1727 assert_eq!(result, "/api/v1/pods");
1728 }
1729
1730 #[test]
1731 fn test_strip_proxy_path_token_mid_path_slash_join() {
1732 let result = strip_proxy_path_token("/api/k8s/PHANTOM/data", "/k8s/{}/");
1734 assert_eq!(result, "/api/data");
1735 }
1736
1737 #[test]
1738 fn test_strip_proxy_path_token_no_double_slash() {
1739 let result = strip_proxy_path_token("/prefix/PHANTOM//suffix", "/prefix/{}/");
1741 assert_eq!(result, "/suffix");
1742 }
1743
1744 #[test]
1745 fn test_strip_proxy_query_param_only_param() {
1746 let result = strip_proxy_query_param("/api/v1/pods?token=PHANTOM123", "token");
1747 assert_eq!(result, "/api/v1/pods");
1748 }
1749
1750 #[test]
1751 fn test_strip_proxy_query_param_with_other_params() {
1752 let result = strip_proxy_query_param("/api/v1/pods?token=PHANTOM123&limit=10", "token");
1753 assert_eq!(result, "/api/v1/pods?limit=10");
1754 }
1755
1756 #[test]
1757 fn test_strip_proxy_query_param_middle() {
1758 let result =
1759 strip_proxy_query_param("/api/v1/pods?limit=10&token=PHANTOM123&watch=true", "token");
1760 assert_eq!(result, "/api/v1/pods?limit=10&watch=true");
1761 }
1762
1763 #[test]
1764 fn test_strip_proxy_query_param_no_match() {
1765 let result = strip_proxy_query_param("/api/v1/pods?limit=10", "token");
1766 assert_eq!(result, "/api/v1/pods?limit=10");
1767 }
1768
1769 #[test]
1770 fn test_strip_proxy_query_param_no_query_string() {
1771 let result = strip_proxy_query_param("/api/v1/pods", "token");
1772 assert_eq!(result, "/api/v1/pods");
1773 }
1774
1775 #[test]
1776 fn test_strip_proxy_artifacts_same_mode_noop() {
1777 let path = "/PHANTOM123/api/v1/pods";
1779 let result = strip_proxy_artifacts(
1780 path,
1781 &InjectMode::UrlPath,
1782 &InjectMode::UrlPath,
1783 Some("/{}/"),
1784 None,
1785 );
1786 assert_eq!(result, path);
1787 }
1788
1789 #[test]
1790 fn test_strip_proxy_artifacts_url_path_to_header() {
1791 let result = strip_proxy_artifacts(
1793 "/PHANTOM123/api/v1/pods",
1794 &InjectMode::UrlPath,
1795 &InjectMode::Header,
1796 Some("/{}/"),
1797 None,
1798 );
1799 assert_eq!(result, "/api/v1/pods");
1800 }
1801
1802 #[test]
1803 fn test_strip_proxy_artifacts_query_param_to_header() {
1804 let result = strip_proxy_artifacts(
1806 "/api/v1/pods?token=PHANTOM123",
1807 &InjectMode::QueryParam,
1808 &InjectMode::Header,
1809 None,
1810 Some("token"),
1811 );
1812 assert_eq!(result, "/api/v1/pods");
1813 }
1814
1815 #[test]
1816 fn test_strip_proxy_artifacts_header_to_query_param() {
1817 let path = "/api/v1/pods";
1819 let result = strip_proxy_artifacts(
1820 path,
1821 &InjectMode::Header,
1822 &InjectMode::QueryParam,
1823 None,
1824 None,
1825 );
1826 assert_eq!(result, path);
1827 }
1828
1829 #[test]
1830 fn test_end_to_end_url_path_proxy_header_upstream() {
1831 let token = Zeroizing::new("session456".to_string());
1834 let credential = Zeroizing::new("real_bearer_token".to_string());
1835 let path = "/session456/api/v1/namespaces";
1836
1837 assert!(
1839 validate_phantom_token_for_mode(
1840 &InjectMode::UrlPath,
1841 b"\r\n\r\n", path,
1843 "Authorization",
1844 Some("/{}/"),
1845 None,
1846 &token,
1847 )
1848 .is_ok()
1849 );
1850
1851 let cleaned = strip_proxy_artifacts(
1853 path,
1854 &InjectMode::UrlPath,
1855 &InjectMode::Header,
1856 Some("/{}/"),
1857 None,
1858 );
1859 assert_eq!(cleaned, "/api/v1/namespaces");
1860
1861 let transformed =
1863 transform_path_for_mode(&InjectMode::Header, &cleaned, None, None, None, &credential)
1864 .unwrap();
1865 assert_eq!(transformed, "/api/v1/namespaces");
1866 }
1867
1868 #[test]
1869 fn test_end_to_end_query_param_proxy_header_upstream() {
1870 let token = Zeroizing::new("session789".to_string());
1872 let credential = Zeroizing::new("real_bearer_token".to_string());
1873 let path = "/api/v1/pods?token=session789&limit=100";
1874
1875 assert!(
1877 validate_phantom_token_for_mode(
1878 &InjectMode::QueryParam,
1879 b"\r\n\r\n",
1880 path,
1881 "Authorization",
1882 None,
1883 Some("token"),
1884 &token,
1885 )
1886 .is_ok()
1887 );
1888
1889 let cleaned = strip_proxy_artifacts(
1891 path,
1892 &InjectMode::QueryParam,
1893 &InjectMode::Header,
1894 None,
1895 Some("token"),
1896 );
1897 assert_eq!(cleaned, "/api/v1/pods?limit=100");
1898
1899 let transformed =
1901 transform_path_for_mode(&InjectMode::Header, &cleaned, None, None, None, &credential)
1902 .unwrap();
1903 assert_eq!(transformed, "/api/v1/pods?limit=100");
1904 }
1905}