1use crate::acme::CertManager;
7use crate::config::ProxyConfig;
8use crate::error::{ProxyError, Result};
9use crate::lb::LoadBalancer;
10use crate::network_policy::NetworkPolicyChecker;
11use crate::routes::{transform_path, ResolvedService, ServiceRegistry};
12use bytes::Bytes;
13use http::{header, Request, Response, Uri, Version};
14use http_body_util::{BodyExt, Full};
15use hyper::body::Incoming;
16use hyper::upgrade::OnUpgrade;
17use hyper_util::client::legacy::Client;
18use hyper_util::rt::{TokioExecutor, TokioIo};
19use std::collections::VecDeque;
20use std::net::{IpAddr, SocketAddr};
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::{Duration, Instant};
24use tokio::net::TcpStream;
25use tokio::sync::Mutex;
26use tower::Service;
27use tracing::{debug, error, info, warn};
28use zlayer_spec::ExposeType;
29
30const ACTIVATE_DEADLINE: Duration = Duration::from_secs(30);
34
35const ACTIVATE_POLL_STEP: Duration = Duration::from_millis(200);
39
40const RPS_WINDOW: Duration = Duration::from_secs(10);
43
44#[async_trait::async_trait]
58pub trait Activator: Send + Sync {
59 async fn activate(&self, service: &str) -> std::result::Result<(), String>;
71}
72
73#[derive(Debug, Default)]
85pub struct RpsRegistry {
86 services: Mutex<std::collections::HashMap<String, VecDeque<Instant>>>,
88}
89
90impl RpsRegistry {
91 #[must_use]
93 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub async fn record(&self, service: &str) {
100 let now = Instant::now();
101 let cutoff = now.checked_sub(RPS_WINDOW).unwrap_or(now);
102 let mut map = self.services.lock().await;
103 let ring = map.entry(service.to_string()).or_default();
104 ring.push_back(now);
105 while ring.front().is_some_and(|t| *t < cutoff) {
106 ring.pop_front();
107 }
108 }
109
110 pub async fn rps(&self, service: &str) -> f64 {
113 let now = Instant::now();
114 let cutoff = now.checked_sub(RPS_WINDOW).unwrap_or(now);
115 let mut map = self.services.lock().await;
116 let Some(ring) = map.get_mut(service) else {
117 return 0.0;
118 };
119 while ring.front().is_some_and(|t| *t < cutoff) {
120 ring.pop_front();
121 }
122 let count = ring.len();
123 #[allow(clippy::cast_precision_loss)]
124 {
125 count as f64 / RPS_WINDOW.as_secs_f64()
126 }
127 }
128
129 pub async fn snapshot(&self) -> std::collections::HashMap<String, f64> {
132 let now = Instant::now();
133 let cutoff = now.checked_sub(RPS_WINDOW).unwrap_or(now);
134 let window_secs = RPS_WINDOW.as_secs_f64();
135 let mut map = self.services.lock().await;
136 let mut out = std::collections::HashMap::with_capacity(map.len());
137 for (name, ring) in map.iter_mut() {
138 while ring.front().is_some_and(|t| *t < cutoff) {
139 ring.pop_front();
140 }
141 #[allow(clippy::cast_precision_loss)]
142 let rps = ring.len() as f64 / window_secs;
143 out.insert(name.clone(), rps);
144 }
145 out
146 }
147}
148
149const OVERLAY_NETWORK: (u8, u8) = (10, 200); fn is_overlay_ip(ip: IpAddr) -> bool {
155 match ip {
156 IpAddr::V4(v4) => {
157 let octets = v4.octets();
158 octets[0] == OVERLAY_NETWORK.0 && octets[1] == OVERLAY_NETWORK.1
159 }
160 IpAddr::V6(_) => false,
161 }
162}
163
164pub type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
166
167#[must_use]
169pub fn empty_body() -> BoxBody {
170 http_body_util::Empty::<Bytes>::new()
171 .map_err(|never| match never {})
172 .boxed()
173}
174
175pub fn full_body(bytes: impl Into<Bytes>) -> BoxBody {
177 Full::new(bytes.into())
178 .map_err(|never| match never {})
179 .boxed()
180}
181
182#[derive(Clone)]
184pub struct ReverseProxyService {
185 registry: Arc<ServiceRegistry>,
187 load_balancer: Arc<LoadBalancer>,
189 client: Client<hyper_util::client::legacy::connect::HttpConnector, BoxBody>,
191 config: Arc<ProxyConfig>,
193 remote_addr: Option<SocketAddr>,
195 is_tls: bool,
197 cert_manager: Option<Arc<CertManager>>,
199 network_policy_checker: Option<NetworkPolicyChecker>,
201 trusted_proxies: Arc<crate::trust::TrustedProxyList>,
206 activator: Option<Arc<dyn Activator>>,
210 rps_registry: Option<Arc<RpsRegistry>>,
214}
215
216impl ReverseProxyService {
217 pub fn new(
219 registry: Arc<ServiceRegistry>,
220 load_balancer: Arc<LoadBalancer>,
221 config: Arc<ProxyConfig>,
222 ) -> Self {
223 let client = Client::builder(TokioExecutor::new())
224 .pool_max_idle_per_host(config.pool.max_idle_per_backend)
225 .pool_idle_timeout(config.pool.idle_timeout)
226 .pool_timer(hyper_util::rt::TokioTimer::new())
227 .build_http();
228
229 Self {
230 registry,
231 load_balancer,
232 client,
233 config,
234 remote_addr: None,
235 is_tls: false,
236 cert_manager: None,
237 network_policy_checker: None,
238 trusted_proxies: Arc::new(crate::trust::TrustedProxyList::localhost_only()),
239 activator: None,
240 rps_registry: None,
241 }
242 }
243
244 #[must_use]
246 pub fn with_remote_addr(mut self, addr: SocketAddr) -> Self {
247 self.remote_addr = Some(addr);
248 self
249 }
250
251 #[must_use]
253 pub fn with_tls(mut self, is_tls: bool) -> Self {
254 self.is_tls = is_tls;
255 self
256 }
257
258 #[must_use]
263 pub fn with_trusted_proxies(mut self, trusted: Arc<crate::trust::TrustedProxyList>) -> Self {
264 self.trusted_proxies = trusted;
265 self
266 }
267
268 #[must_use]
270 pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
271 self.cert_manager = Some(cm);
272 self
273 }
274
275 #[must_use]
277 pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
278 self.network_policy_checker = Some(checker);
279 self
280 }
281
282 #[must_use]
289 pub fn with_activator(mut self, activator: Arc<dyn Activator>) -> Self {
290 self.activator = Some(activator);
291 self
292 }
293
294 #[must_use]
299 pub fn with_rps_registry(mut self, rps_registry: Arc<RpsRegistry>) -> Self {
300 self.rps_registry = Some(rps_registry);
301 self
302 }
303
304 #[must_use]
306 pub fn is_tls(&self) -> bool {
307 self.is_tls
308 }
309
310 #[allow(clippy::too_many_lines)]
322 pub async fn proxy_request(&self, mut req: Request<Incoming>) -> Result<Response<BoxBody>> {
323 let start = std::time::Instant::now();
324 let method = req.method().clone();
325 let uri = req.uri().clone();
326
327 let host = req
328 .headers()
329 .get(header::HOST)
330 .and_then(|h| h.to_str().ok())
331 .or_else(|| uri.host())
332 .map(std::string::ToString::to_string);
333
334 let path = uri.path().to_string();
335
336 if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
343 if !token.is_empty() {
344 if let Some(ref cm) = self.cert_manager {
345 if let Some(auth) = cm.get_challenge_response(token) {
346 return Ok(Response::builder()
347 .status(200)
348 .header("content-type", "text/plain")
349 .body(full_body(auth))
350 .unwrap());
351 }
352 }
353 }
354 tracing::warn!(
355 token = %token,
356 cert_manager = self.cert_manager.is_some(),
357 host = host.as_deref().unwrap_or("<none>"),
358 "ACME HTTP-01 challenge token not found; returning 404"
359 );
360 return Ok(Response::builder()
361 .status(404)
362 .header("content-type", "text/plain")
363 .body(full_body("ACME challenge token not found"))
364 .unwrap());
365 }
366
367 if crate::tunnel::is_upgrade_request(&req) {
369 let resolved = self
371 .registry
372 .resolve(host.as_deref(), &path)
373 .await
374 .ok_or_else(|| ProxyError::RouteNotFound {
375 host: host.as_deref().unwrap_or("<none>").to_string(),
376 path: path.clone(),
377 })?;
378
379 if resolved.expose == ExposeType::Internal {
381 if let Some(addr) = self.remote_addr {
382 if !is_overlay_ip(addr.ip()) {
383 return Err(ProxyError::Forbidden(
384 "endpoint is internal-only".to_string(),
385 ));
386 }
387 }
388 }
389
390 if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
392 if !checker
393 .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
394 .await
395 {
396 return Err(ProxyError::Forbidden(format!(
397 "network policy denied access to service '{}'",
398 resolved.name
399 )));
400 }
401 }
402
403 let backend = self
404 .select_or_activate(&resolved.name)
405 .await
406 .ok_or_else(|| ProxyError::NoHealthyBackends {
407 service: resolved.name.clone(),
408 })?;
409 let _guard = backend.track_connection();
410 let backend_addr = backend.addr;
411
412 if let Some(rps) = &self.rps_registry {
414 rps.record(&resolved.name).await;
415 }
416
417 info!(
418 method = %method,
419 host = ?host,
420 path = %path,
421 backend = %backend_addr,
422 service = %resolved.name,
423 "Forwarding upgrade request"
424 );
425
426 let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
428
429 let original_path = req.uri().path();
431 let transformed_path =
432 transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
433 let new_uri = format!(
434 "http://{}{}{}",
435 backend_addr,
436 transformed_path,
437 req.uri()
438 .query()
439 .map(|q| format!("?{q}"))
440 .unwrap_or_default()
441 );
442
443 let (orig_parts, _body) = req.into_parts();
445 let mut backend_parts = http::request::Builder::new()
446 .method(orig_parts.method.clone())
447 .uri(
448 new_uri
449 .parse::<Uri>()
450 .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?,
451 )
452 .body(())
453 .unwrap()
454 .into_parts()
455 .0;
456
457 for (name, value) in &orig_parts.headers {
459 backend_parts.headers.insert(name.clone(), value.clone());
460 }
461
462 crate::tunnel::copy_upgrade_headers(&orig_parts, &mut backend_parts);
464
465 self.add_forwarding_headers(&mut backend_parts);
467
468 let tcp_stream = TcpStream::connect(backend_addr).await.map_err(|e| {
470 error!(error = %e, backend = %backend_addr, "Backend upgrade connect failed");
471 ProxyError::BackendConnectionFailed {
472 backend: backend_addr,
473 reason: e.to_string(),
474 }
475 })?;
476 let io = TokioIo::new(tcp_stream);
477
478 let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
480 .preserve_header_case(true)
481 .handshake(io)
482 .await
483 .map_err(|e| {
484 error!(error = %e, backend = %backend_addr, "Backend upgrade handshake failed");
485 ProxyError::BackendRequestFailed(format!("Upgrade handshake failed: {e}"))
486 })?;
487
488 tokio::spawn(async move {
490 if let Err(e) = conn.with_upgrades().await {
491 error!(error = %e, "Backend upgrade connection driver error");
492 }
493 });
494
495 let backend_req =
497 Request::from_parts(backend_parts, http_body_util::Empty::<Bytes>::new());
498 let backend_response = sender.send_request(backend_req).await.map_err(|e| {
499 error!(error = %e, backend = %backend_addr, "Backend upgrade request failed");
500 ProxyError::BackendRequestFailed(e.to_string())
501 })?;
502
503 if backend_response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
504 let server_upgrade: OnUpgrade = hyper::upgrade::on(backend_response);
506
507 let mut resp_builder =
509 Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS);
510 if let Some(upgrade_val) = orig_parts.headers.get(header::UPGRADE) {
519 resp_builder = resp_builder.header(header::UPGRADE, upgrade_val.clone());
520 }
521 resp_builder = resp_builder.header(header::CONNECTION, "upgrade");
522
523 let client_response = resp_builder.body(empty_body()).map_err(|e| {
524 ProxyError::Internal(format!("Failed to build 101 response: {e}"))
525 })?;
526
527 tokio::spawn(async move {
529 if let Err(e) =
530 crate::tunnel::proxy_upgrade(client_upgrade, server_upgrade).await
531 {
532 debug!(error = %e, "Upgrade tunnel ended");
533 }
534 });
535
536 let (mut parts, body) = client_response.into_parts();
538 if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
539 parts.headers.insert("server-timing", hv);
540 }
541
542 return Ok(Response::from_parts(parts, body));
543 }
544
545 let (mut parts, body) = backend_response.into_parts();
547 let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
548
549 if self.is_tls && self.config.headers.hsts {
551 let value = if self.config.headers.hsts_subdomains {
552 format!(
553 "max-age={}; includeSubDomains",
554 self.config.headers.hsts_max_age
555 )
556 } else {
557 format!("max-age={}", self.config.headers.hsts_max_age)
558 };
559 if let Ok(hv) = value.parse() {
560 parts.headers.insert("strict-transport-security", hv);
561 }
562 }
563
564 if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
566 parts.headers.insert("server-timing", hv);
567 }
568
569 return Ok(Response::from_parts(parts, streaming_body));
570 }
571
572 debug!(method = %method, host = ?host, path = %path, "Routing request");
573
574 let resolved = self
576 .registry
577 .resolve(host.as_deref(), &path)
578 .await
579 .ok_or_else(|| ProxyError::RouteNotFound {
580 host: host.as_deref().unwrap_or("<none>").to_string(),
581 path: path.clone(),
582 })?;
583
584 if resolved.expose == ExposeType::Internal {
586 match self.remote_addr {
587 Some(addr) if !is_overlay_ip(addr.ip()) => {
588 warn!(
589 source = %addr.ip(),
590 service = %resolved.name,
591 "Rejected non-overlay source for internal endpoint"
592 );
593 return Err(ProxyError::Forbidden(
594 "endpoint is internal-only".to_string(),
595 ));
596 }
597 None => {
598 debug!(
599 service = %resolved.name,
600 "No remote_addr available; skipping overlay source check"
601 );
602 }
603 _ => {}
604 }
605 }
606
607 if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
609 if !checker
610 .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
611 .await
612 {
613 return Err(ProxyError::Forbidden(format!(
614 "network policy denied access to service '{}'",
615 resolved.name
616 )));
617 }
618 }
619
620 let backend = self
623 .select_or_activate(&resolved.name)
624 .await
625 .ok_or_else(|| ProxyError::NoHealthyBackends {
626 service: resolved.name.clone(),
627 })?;
628 let _guard = backend.track_connection();
629 let backend_addr = backend.addr;
630
631 if let Some(rps) = &self.rps_registry {
633 rps.record(&resolved.name).await;
634 }
635
636 info!(
637 method = %method,
638 host = ?host,
639 path = %path,
640 backend = %backend_addr,
641 service = %resolved.name,
642 "Forwarding request"
643 );
644
645 let forwarded_req = self.build_forwarded_request(req, &backend_addr, &resolved)?;
647
648 let response = self.client.request(forwarded_req).await.map_err(|e| {
650 error!(error = %e, backend = %backend_addr, "Backend request failed");
651 ProxyError::BackendRequestFailed(e.to_string())
652 })?;
653
654 let (mut parts, body) = response.into_parts();
655 let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
656
657 if self.is_tls && self.config.headers.hsts {
659 let value = if self.config.headers.hsts_subdomains {
660 format!(
661 "max-age={}; includeSubDomains",
662 self.config.headers.hsts_max_age
663 )
664 } else {
665 format!("max-age={}", self.config.headers.hsts_max_age)
666 };
667 if let Ok(hv) = value.parse() {
668 parts.headers.insert("strict-transport-security", hv);
669 }
670 }
671
672 if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
674 parts.headers.insert("server-timing", hv);
675 }
676
677 Ok(Response::from_parts(parts, streaming_body))
678 }
679
680 async fn select_or_activate(&self, service: &str) -> Option<Arc<crate::lb::Backend>> {
697 if let Some(backend) = self.load_balancer.select(service) {
698 return Some(backend);
699 }
700
701 let Some(activator) = &self.activator else {
702 return None;
703 };
704
705 info!(
706 service = %service,
707 "No healthy backend; invoking activator (scale-to-zero wake-up)"
708 );
709 if let Err(e) = activator.activate(service).await {
710 warn!(service = %service, error = %e, "Activator returned an error; will still poll for a backend");
713 }
714
715 let deadline = Instant::now() + ACTIVATE_DEADLINE;
716 loop {
717 if let Some(backend) = self.load_balancer.select(service) {
718 info!(service = %service, "Backend became available after activation");
719 return Some(backend);
720 }
721 if Instant::now() >= deadline {
722 warn!(
723 service = %service,
724 "Activation deadline elapsed without a healthy backend; falling back to 503"
725 );
726 return None;
727 }
728 tokio::time::sleep(ACTIVATE_POLL_STEP).await;
729 }
730 }
731
732 fn build_forwarded_request(
733 &self,
734 req: Request<Incoming>,
735 backend: &SocketAddr,
736 resolved: &ResolvedService,
737 ) -> Result<Request<BoxBody>> {
738 let (mut parts, body) = req.into_parts();
739
740 let original_path = parts.uri.path();
742 let transformed_path =
743 transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
744
745 let new_uri = format!(
747 "http://{}{}{}",
748 backend,
749 transformed_path,
750 parts
751 .uri
752 .query()
753 .map(|q| format!("?{q}"))
754 .unwrap_or_default()
755 );
756
757 parts.uri = new_uri
758 .parse::<Uri>()
759 .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?;
760
761 self.add_forwarding_headers(&mut parts);
763
764 Self::remove_hop_by_hop_headers(&mut parts);
766
767 let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
768
769 let req = Request::from_parts(parts, streaming_body);
770 Ok(req)
771 }
772
773 fn add_forwarding_headers(&self, parts: &mut http::request::Parts) {
774 let config = &self.config.headers;
775
776 let peer_is_trusted = self
779 .remote_addr
780 .is_some_and(|addr| self.trusted_proxies.is_trusted(addr.ip()));
781
782 let effective_client_ip: Option<IpAddr> = if peer_is_trusted {
787 let cf_ip = parts
788 .headers
789 .get("cf-connecting-ip")
790 .and_then(|h| h.to_str().ok())
791 .and_then(|s| s.trim().parse::<IpAddr>().ok());
792
793 let xff_leftmost = parts
794 .headers
795 .get("x-forwarded-for")
796 .and_then(|h| h.to_str().ok())
797 .and_then(|s| s.split(',').next())
798 .and_then(|s| s.trim().parse::<IpAddr>().ok());
799
800 cf_ip
801 .or(xff_leftmost)
802 .or_else(|| self.remote_addr.map(|a| a.ip()))
803 } else {
804 self.remote_addr.map(|a| a.ip())
805 };
806
807 if config.x_forwarded_for {
809 if let Some(addr) = self.remote_addr {
810 let existing_xff = parts
811 .headers
812 .get("x-forwarded-for")
813 .and_then(|h| h.to_str().ok())
814 .map(std::string::ToString::to_string);
815
816 let new_value = if peer_is_trusted {
817 let real = effective_client_ip.unwrap_or_else(|| addr.ip()).to_string();
821 match existing_xff {
822 Some(chain) if !chain.trim().is_empty() => format!("{real}, {chain}"),
823 _ => real,
824 }
825 } else {
826 match existing_xff {
829 Some(chain) => format!("{}, {}", chain, addr.ip()),
830 None => addr.ip().to_string(),
831 }
832 };
833
834 if let Ok(value) = new_value.parse() {
835 parts.headers.insert("x-forwarded-for", value);
836 }
837 }
838 }
839
840 if config.x_forwarded_proto && parts.headers.get("x-forwarded-proto").is_none() {
842 let proto = if self.is_tls { "https" } else { "http" };
843 if let Ok(value) = proto.parse() {
844 parts.headers.insert("x-forwarded-proto", value);
845 }
846 }
847
848 if config.x_forwarded_host {
850 if let Some(host) = parts.headers.get(header::HOST).cloned() {
851 if parts.headers.get("x-forwarded-host").is_none() {
852 parts.headers.insert("x-forwarded-host", host);
853 }
854 }
855 }
856
857 if config.x_real_ip {
861 if let Some(ip) = effective_client_ip {
862 if parts.headers.get("x-real-ip").is_none() {
863 if let Ok(value) = ip.to_string().parse() {
864 parts.headers.insert("x-real-ip", value);
865 }
866 }
867 }
868 }
869
870 if config.via {
872 let proto_version = match parts.version {
873 Version::HTTP_09 => "0.9",
874 Version::HTTP_10 => "1.0",
875 Version::HTTP_2 => "2.0",
876 Version::HTTP_3 => "3.0",
877 _ => "1.1",
878 };
879
880 let via_value = format!("{} {}", proto_version, config.server_name);
881 let existing = parts
882 .headers
883 .get(header::VIA)
884 .and_then(|h| h.to_str().ok())
885 .map(|s| format!("{s}, {via_value}"))
886 .unwrap_or(via_value);
887
888 if let Ok(value) = existing.parse() {
889 parts.headers.insert(header::VIA, value);
890 }
891 }
892 }
893
894 fn remove_hop_by_hop_headers(parts: &mut http::request::Parts) {
895 const HOP_BY_HOP: &[&str] = &[
897 "connection",
898 "keep-alive",
899 "proxy-authenticate",
900 "proxy-authorization",
901 "te",
902 "trailer",
903 "transfer-encoding",
904 "upgrade",
905 ];
906
907 let connection_headers: Vec<String> = parts
909 .headers
910 .get(header::CONNECTION)
911 .and_then(|h| h.to_str().ok())
912 .map(|value| value.split(',').map(|s| s.trim().to_lowercase()).collect())
913 .unwrap_or_default();
914
915 for header_name in HOP_BY_HOP {
916 parts.headers.remove(*header_name);
917 }
918
919 for header_name in connection_headers {
921 parts.headers.remove(header_name.as_str());
922 }
923 }
924
925 pub fn error_response(error: &ProxyError) -> Response<BoxBody> {
941 let status = error.status_code();
942 let body = status.canonical_reason().map_or_else(
947 || status.as_str().to_string(),
948 |reason| format!("{} {reason}", status.as_u16()),
949 );
950
951 Response::builder()
952 .status(status)
953 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
954 .body(full_body(body))
955 .unwrap()
956 }
957}
958
959impl Service<Request<Incoming>> for ReverseProxyService {
960 type Response = Response<BoxBody>;
961 type Error = ProxyError;
962 type Future = std::pin::Pin<
963 Box<
964 dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
965 + Send,
966 >,
967 >;
968
969 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
970 Poll::Ready(Ok(()))
971 }
972
973 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
974 let this = self.clone();
975 Box::pin(async move { this.proxy_request(req).await })
976 }
977}
978
979#[cfg(test)]
980mod tests {
981 use super::*;
982
983 #[test]
984 fn test_error_response() {
985 let error = ProxyError::RouteNotFound {
986 host: "example.com".to_string(),
987 path: "/api".to_string(),
988 };
989
990 let response = ReverseProxyService::error_response(&error);
991 assert_eq!(response.status(), http::StatusCode::NOT_FOUND);
992 }
993
994 #[test]
995 fn test_hop_by_hop_headers() {
996 let mut parts = http::request::Builder::new()
997 .method("GET")
998 .uri("/test")
999 .header("connection", "keep-alive, x-custom")
1000 .header("keep-alive", "timeout=5")
1001 .header("x-custom", "value")
1002 .header("x-other", "value")
1003 .body(())
1004 .unwrap()
1005 .into_parts()
1006 .0;
1007
1008 ReverseProxyService::remove_hop_by_hop_headers(&mut parts);
1009
1010 assert!(parts.headers.get("connection").is_none());
1011 assert!(parts.headers.get("keep-alive").is_none());
1012 assert!(parts.headers.get("x-custom").is_none());
1013 assert!(parts.headers.get("x-other").is_some());
1015 }
1016
1017 #[test]
1018 fn test_is_overlay_ip_accepts_overlay_range() {
1019 assert!(is_overlay_ip("10.200.0.1".parse().unwrap()));
1021 assert!(is_overlay_ip("10.200.255.254".parse().unwrap()));
1022 assert!(is_overlay_ip("10.200.1.100".parse().unwrap()));
1023 }
1024
1025 #[test]
1026 fn test_is_overlay_ip_rejects_non_overlay() {
1027 assert!(!is_overlay_ip("192.168.1.1".parse().unwrap()));
1029 assert!(!is_overlay_ip("10.0.0.1".parse().unwrap()));
1030 assert!(!is_overlay_ip("10.201.0.1".parse().unwrap()));
1031 assert!(!is_overlay_ip("172.16.0.1".parse().unwrap()));
1032 assert!(!is_overlay_ip("8.8.8.8".parse().unwrap()));
1033 }
1034
1035 #[test]
1036 fn test_is_overlay_ip_rejects_ipv6() {
1037 assert!(!is_overlay_ip("::1".parse().unwrap()));
1038 assert!(!is_overlay_ip("fe80::1".parse().unwrap()));
1039 }
1040
1041 #[test]
1042 fn test_forbidden_error_response() {
1043 let error = ProxyError::Forbidden("endpoint 'ws' is internal-only".to_string());
1044 let response = ReverseProxyService::error_response(&error);
1045 assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1046 }
1047
1048 use crate::trust::TrustedProxyList;
1051
1052 fn build_svc(peer: SocketAddr, trusted: TrustedProxyList) -> ReverseProxyService {
1053 let registry = Arc::new(ServiceRegistry::new());
1054 let load_balancer = Arc::new(LoadBalancer::new());
1055 let config = Arc::new(ProxyConfig::default());
1056 ReverseProxyService::new(registry, load_balancer, config)
1057 .with_remote_addr(peer)
1058 .with_trusted_proxies(Arc::new(trusted))
1059 }
1060
1061 fn parts_with_headers(headers: &[(&str, &str)]) -> http::request::Parts {
1062 let mut builder = http::request::Builder::new().method("GET").uri("/");
1063 for (k, v) in headers {
1064 builder = builder.header(*k, *v);
1065 }
1066 builder.body(()).unwrap().into_parts().0
1067 }
1068
1069 #[test]
1070 fn trusted_peer_cf_connecting_ip_is_honored() {
1071 let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
1074 let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
1075 let svc = build_svc(peer, trusted);
1076
1077 let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
1078 svc.add_forwarding_headers(&mut parts);
1079
1080 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.7");
1081 let xff = parts
1082 .headers
1083 .get("x-forwarded-for")
1084 .unwrap()
1085 .to_str()
1086 .unwrap();
1087 assert!(
1088 xff.starts_with("198.51.100.7"),
1089 "XFF should start with real client IP, got {xff}"
1090 );
1091 }
1092
1093 #[test]
1094 fn trusted_peer_xff_leftmost_is_honored_when_no_cf_header() {
1095 let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
1098 let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
1099 let svc = build_svc(peer, trusted);
1100
1101 let mut parts = parts_with_headers(&[("x-forwarded-for", "198.51.100.9, 10.0.0.1")]);
1102 svc.add_forwarding_headers(&mut parts);
1103
1104 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.9");
1105 let xff = parts
1106 .headers
1107 .get("x-forwarded-for")
1108 .unwrap()
1109 .to_str()
1110 .unwrap();
1111 assert!(
1113 xff.starts_with("198.51.100.9"),
1114 "XFF should start with leftmost real client, got {xff}"
1115 );
1116 assert!(
1117 xff.contains("10.0.0.1"),
1118 "original chain should survive: {xff}"
1119 );
1120 }
1121
1122 #[test]
1123 fn untrusted_peer_cf_connecting_ip_is_ignored() {
1124 let peer: SocketAddr = "8.8.8.8:443".parse().unwrap();
1127 let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
1128 let svc = build_svc(peer, trusted);
1129
1130 let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
1131 svc.add_forwarding_headers(&mut parts);
1132
1133 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "8.8.8.8");
1134 let xff = parts
1135 .headers
1136 .get("x-forwarded-for")
1137 .unwrap()
1138 .to_str()
1139 .unwrap();
1140 assert!(
1142 xff.ends_with("8.8.8.8"),
1143 "XFF for untrusted peer should end with peer IP, got {xff}"
1144 );
1145 }
1146
1147 #[test]
1148 fn no_headers_uses_peer_ip() {
1149 let peer: SocketAddr = "198.51.100.250:443".parse().unwrap();
1152 let trusted = TrustedProxyList::localhost_only();
1153 let svc = build_svc(peer, trusted);
1154
1155 let mut parts = parts_with_headers(&[]);
1156 svc.add_forwarding_headers(&mut parts);
1157
1158 assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.250");
1159 assert_eq!(
1160 parts.headers.get("x-forwarded-for").unwrap(),
1161 "198.51.100.250"
1162 );
1163 }
1164
1165 #[tokio::test]
1168 async fn rps_registry_counts_recorded_requests() {
1169 let reg = RpsRegistry::new();
1170 assert!((reg.rps("svc").await - 0.0).abs() < f64::EPSILON);
1172
1173 let n = 30;
1175 for _ in 0..n {
1176 reg.record("svc").await;
1177 }
1178 let expected = f64::from(n) / RPS_WINDOW.as_secs_f64();
1179 let got = reg.rps("svc").await;
1180 assert!(
1181 (got - expected).abs() < 1e-9,
1182 "expected {expected}, got {got}"
1183 );
1184 }
1185
1186 #[tokio::test]
1187 async fn rps_registry_isolates_services() {
1188 let reg = RpsRegistry::new();
1189 reg.record("a").await;
1190 reg.record("a").await;
1191 reg.record("b").await;
1192
1193 let snap = reg.snapshot().await;
1194 let a = snap.get("a").copied().unwrap_or_default();
1195 let b = snap.get("b").copied().unwrap_or_default();
1196 assert!(
1197 a > b,
1198 "service a ({a}) should have a higher rate than b ({b})"
1199 );
1200 assert!((b - 1.0 / RPS_WINDOW.as_secs_f64()).abs() < 1e-9);
1202 }
1203
1204 #[tokio::test]
1205 async fn rps_registry_prunes_old_timestamps() {
1206 let reg = RpsRegistry::new();
1207 {
1210 let mut map = reg.services.lock().await;
1211 let ring = map.entry("svc".to_string()).or_default();
1212 let stale = Instant::now()
1213 .checked_sub(RPS_WINDOW + Duration::from_secs(5))
1214 .expect("instant underflow in test");
1215 ring.push_back(stale);
1216 }
1217 assert!((reg.rps("svc").await - 0.0).abs() < f64::EPSILON);
1219 }
1220
1221 use crate::lb::LbStrategy;
1224
1225 struct TestActivator {
1228 lb: Arc<LoadBalancer>,
1229 called: std::sync::atomic::AtomicBool,
1230 }
1231
1232 #[async_trait::async_trait]
1233 impl Activator for TestActivator {
1234 async fn activate(&self, service: &str) -> std::result::Result<(), String> {
1235 self.called.store(true, std::sync::atomic::Ordering::SeqCst);
1236 self.lb.register(
1238 service,
1239 vec!["127.0.0.1:9".parse().unwrap()],
1240 LbStrategy::RoundRobin,
1241 );
1242 Ok(())
1243 }
1244 }
1245
1246 fn build_svc_with_lb(lb: Arc<LoadBalancer>) -> ReverseProxyService {
1247 let registry = Arc::new(ServiceRegistry::new());
1248 let config = Arc::new(ProxyConfig::default());
1249 ReverseProxyService::new(registry, lb, config)
1250 }
1251
1252 #[tokio::test]
1253 async fn select_or_activate_returns_none_without_activator() {
1254 let lb = Arc::new(LoadBalancer::new());
1255 let svc = build_svc_with_lb(Arc::clone(&lb));
1256 assert!(svc.select_or_activate("svc").await.is_none());
1258 }
1259
1260 #[tokio::test]
1261 async fn select_or_activate_wakes_scaled_to_zero_service() {
1262 let lb = Arc::new(LoadBalancer::new());
1263 lb.register("svc", vec![], LbStrategy::RoundRobin);
1265
1266 let activator = Arc::new(TestActivator {
1267 lb: Arc::clone(&lb),
1268 called: std::sync::atomic::AtomicBool::new(false),
1269 });
1270 let svc = build_svc_with_lb(Arc::clone(&lb)).with_activator(activator.clone());
1271
1272 let backend = svc.select_or_activate("svc").await;
1273 assert!(
1274 backend.is_some(),
1275 "activator should have produced a backend"
1276 );
1277 assert!(
1278 activator.called.load(std::sync::atomic::Ordering::SeqCst),
1279 "activator must have been invoked"
1280 );
1281 }
1282
1283 #[tokio::test]
1284 async fn select_or_activate_returns_existing_backend_without_calling_activator() {
1285 let lb = Arc::new(LoadBalancer::new());
1286 lb.register(
1287 "svc",
1288 vec!["127.0.0.1:9".parse().unwrap()],
1289 LbStrategy::RoundRobin,
1290 );
1291 let activator = Arc::new(TestActivator {
1292 lb: Arc::clone(&lb),
1293 called: std::sync::atomic::AtomicBool::new(false),
1294 });
1295 let svc = build_svc_with_lb(Arc::clone(&lb)).with_activator(activator.clone());
1296
1297 assert!(svc.select_or_activate("svc").await.is_some());
1298 assert!(
1299 !activator.called.load(std::sync::atomic::Ordering::SeqCst),
1300 "activator must NOT be called when a backend already exists"
1301 );
1302 }
1303}