1use axum::body::Body;
2#[cfg(all(feature = "tls", not(feature = "native-tls")))]
3use hyper_rustls::HttpsConnector;
4#[cfg(feature = "native-tls")]
5use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
6use hyper_util::client::legacy::{
7 Client,
8 connect::{Connect, HttpConnector},
9};
10use std::collections::HashMap;
11use std::convert::Infallible;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15
16use tower::discover::{Change, Discover};
17use tracing::{debug, error, trace, warn};
18
19use crate::proxy::ReverseProxy;
20
21use rand::Rng;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum LoadBalancingStrategy {
27 RoundRobin,
29 P2cPendingRequests,
31 P2cPeakEwma,
33}
34
35impl Default for LoadBalancingStrategy {
36 fn default() -> Self {
37 Self::RoundRobin
38 }
39}
40
41#[derive(Clone)]
42pub struct BalancedProxy<C: Connect + Clone + Send + Sync + 'static> {
43 path: String,
44 proxies: Vec<ReverseProxy<C>>,
45 counter: Arc<AtomicUsize>,
46}
47
48#[cfg(all(feature = "tls", not(feature = "native-tls")))]
49pub type StandardBalancedProxy = BalancedProxy<HttpsConnector<HttpConnector>>;
50#[cfg(feature = "native-tls")]
51pub type StandardBalancedProxy = BalancedProxy<NativeTlsHttpsConnector<HttpConnector>>;
52#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
53pub type StandardBalancedProxy = BalancedProxy<HttpConnector>;
54
55impl StandardBalancedProxy {
56 pub fn new<S>(path: S, targets: Vec<S>) -> Self
57 where
58 S: Into<String> + Clone,
59 {
60 let mut connector = HttpConnector::new();
61 connector.set_nodelay(true);
62 connector.enforce_http(false);
63 connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
64 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
65 connector.set_reuse_address(true);
66
67 #[cfg(all(feature = "tls", not(feature = "native-tls")))]
68 let connector = {
69 use hyper_rustls::HttpsConnectorBuilder;
70 HttpsConnectorBuilder::new()
71 .with_native_roots()
72 .unwrap()
73 .https_or_http()
74 .enable_http1()
75 .wrap_connector(connector)
76 };
77
78 #[cfg(feature = "native-tls")]
79 let connector = NativeTlsHttpsConnector::new_with_connector(connector);
80
81 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
82 .pool_idle_timeout(std::time::Duration::from_secs(60))
83 .pool_max_idle_per_host(32)
84 .retry_canceled_requests(true)
85 .set_host(true)
86 .build(connector);
87
88 Self::new_with_client(path, targets, client)
89 }
90}
91
92impl<C> BalancedProxy<C>
93where
94 C: Connect + Clone + Send + Sync + 'static,
95{
96 pub fn new_with_client<S>(path: S, targets: Vec<S>, client: Client<C, Body>) -> Self
97 where
98 S: Into<String> + Clone,
99 {
100 let path = path.into();
101 let proxies = targets
102 .into_iter()
103 .map(|t| ReverseProxy::new_with_client(path.clone(), t.into(), client.clone()))
104 .collect();
105
106 Self {
107 path,
108 proxies,
109 counter: Arc::new(AtomicUsize::new(0)),
110 }
111 }
112
113 pub fn path(&self) -> &str {
114 &self.path
115 }
116
117 fn next_proxy(&self) -> Option<ReverseProxy<C>> {
118 if self.proxies.is_empty() {
119 None
120 } else {
121 let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.proxies.len();
122 Some(self.proxies[idx].clone())
123 }
124 }
125}
126
127use std::{
128 future::Future,
129 pin::Pin,
130 task::{Context, Poll},
131};
132use tower::Service;
133
134impl<C> Service<axum::http::Request<Body>> for BalancedProxy<C>
135where
136 C: Connect + Clone + Send + Sync + 'static,
137{
138 type Response = axum::http::Response<Body>;
139 type Error = Infallible;
140 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
141
142 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143 Poll::Ready(Ok(()))
144 }
145
146 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
147 if let Some(mut proxy) = self.next_proxy() {
148 trace!("balanced proxying via upstream {}", proxy.target());
149 Box::pin(async move { proxy.call(req).await })
150 } else {
151 warn!("No upstream services available");
152 Box::pin(async move {
153 Ok(axum::http::Response::builder()
154 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
155 .body(Body::from("No upstream services available"))
156 .unwrap())
157 })
158 }
159 }
160}
161
162#[derive(Clone)]
173pub struct DiscoverableBalancedProxy<C, D>
174where
175 C: Connect + Clone + Send + Sync + 'static,
176 D: Discover + Clone + Send + Sync + 'static,
177 D::Service: Into<String> + Send,
178 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
179 D::Error: std::fmt::Debug + Send,
180{
181 path: String,
182 client: Client<C, Body>,
183 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
184 proxy_keys: Arc<tokio::sync::RwLock<HashMap<D::Key, usize>>>, counter: Arc<AtomicUsize>,
186 discover: D,
187 strategy: LoadBalancingStrategy,
188 p2c_balancer: Option<Arc<CustomP2cBalancer<C>>>,
190}
191
192#[cfg(all(feature = "tls", not(feature = "native-tls")))]
193pub type StandardDiscoverableBalancedProxy<D> =
194 DiscoverableBalancedProxy<HttpsConnector<HttpConnector>, D>;
195#[cfg(feature = "native-tls")]
196pub type StandardDiscoverableBalancedProxy<D> =
197 DiscoverableBalancedProxy<NativeTlsHttpsConnector<HttpConnector>, D>;
198#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
199pub type StandardDiscoverableBalancedProxy<D> = DiscoverableBalancedProxy<HttpConnector, D>;
200
201impl<C, D> DiscoverableBalancedProxy<C, D>
202where
203 C: Connect + Clone + Send + Sync + 'static,
204 D: Discover + Clone + Send + Sync + 'static,
205 D::Service: Into<String> + Send,
206 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
207 D::Error: std::fmt::Debug + Send,
208{
209 pub fn new_with_client<S>(path: S, client: Client<C, Body>, discover: D) -> Self
212 where
213 S: Into<String>,
214 {
215 Self::new_with_client_and_strategy(path, client, discover, LoadBalancingStrategy::default())
216 }
217
218 pub fn new_with_client_and_strategy<S>(
220 path: S,
221 client: Client<C, Body>,
222 discover: D,
223 strategy: LoadBalancingStrategy,
224 ) -> Self
225 where
226 S: Into<String>,
227 {
228 let path = path.into();
229 let proxies_snapshot = Arc::new(std::sync::RwLock::new(Arc::new(Vec::new())));
230
231 let p2c_balancer = match strategy {
233 LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
234 Some(Arc::new(CustomP2cBalancer::new(
235 strategy,
236 Arc::clone(&proxies_snapshot),
237 )))
238 }
239 LoadBalancingStrategy::RoundRobin => None,
240 };
241
242 Self {
243 path,
244 client,
245 proxies_snapshot,
246 proxy_keys: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
247 counter: Arc::new(AtomicUsize::new(0)),
248 discover: discover.clone(),
249 strategy,
250 p2c_balancer,
251 }
252 }
253
254 pub fn path(&self) -> &str {
256 &self.path
257 }
258
259 pub fn strategy(&self) -> LoadBalancingStrategy {
261 self.strategy
262 }
263
264 pub async fn start_discovery(&mut self) {
267 let discover = self.discover.clone();
268 let proxies_snapshot = Arc::clone(&self.proxies_snapshot);
269 let proxy_keys = Arc::clone(&self.proxy_keys);
270 let client = self.client.clone();
271 let path = self.path.clone();
272
273 tokio::spawn(async move {
274 use futures_util::future::poll_fn;
275
276 let mut discover = Box::pin(discover);
277
278 loop {
279 let change_result =
280 poll_fn(|cx: &mut Context<'_>| discover.as_mut().poll_discover(cx)).await;
281
282 match change_result {
283 Some(Ok(change)) => match change {
284 Change::Insert(key, service) => {
285 let target: String = service.into();
286 debug!("Discovered new service: {:?} -> {}", key, target);
287
288 let proxy =
289 ReverseProxy::new_with_client(path.clone(), target, client.clone());
290
291 {
292 let mut keys_guard = proxy_keys.write().await;
293
294 let current_snapshot = {
296 let snapshot_guard = proxies_snapshot.read().unwrap();
297 Arc::clone(&*snapshot_guard)
298 };
299
300 let mut new_proxies = (*current_snapshot).clone();
301 let index = new_proxies.len();
302 new_proxies.push(proxy);
303 keys_guard.insert(key, index);
304
305 {
307 let mut snapshot_guard = proxies_snapshot.write().unwrap();
308 *snapshot_guard = Arc::new(new_proxies);
309 }
310 }
311 }
312 Change::Remove(key) => {
313 debug!("Removing service: {:?}", key);
314
315 {
316 let mut keys_guard = proxy_keys.write().await;
317
318 if let Some(index) = keys_guard.remove(&key) {
319 let current_snapshot = {
321 let snapshot_guard = proxies_snapshot.read().unwrap();
322 Arc::clone(&*snapshot_guard)
323 };
324
325 let mut new_proxies = (*current_snapshot).clone();
326 new_proxies.remove(index);
327
328 for (_, idx) in keys_guard.iter_mut() {
330 if *idx > index {
331 *idx -= 1;
332 }
333 }
334
335 {
337 let mut snapshot_guard = proxies_snapshot.write().unwrap();
338 *snapshot_guard = Arc::new(new_proxies);
339 }
340 }
341 }
342 }
343 },
344 Some(Err(e)) => {
345 error!("Discovery error: {:?}", e);
346 }
347 None => {
348 warn!("Discovery stream ended");
349 break;
350 }
351 }
352 }
353 });
354 }
355
356 pub async fn service_count(&self) -> usize {
358 let snapshot = {
359 let guard = self.proxies_snapshot.read().unwrap();
360 Arc::clone(&*guard)
361 };
362 snapshot.len()
363 }
364}
365
366impl<C, D> Service<axum::http::Request<Body>> for DiscoverableBalancedProxy<C, D>
367where
368 C: Connect + Clone + Send + Sync + 'static,
369 D: Discover + Clone + Send + Sync + 'static,
370 D::Service: Into<String> + Send,
371 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
372 D::Error: std::fmt::Debug + Send,
373{
374 type Response = axum::http::Response<Body>;
375 type Error = Infallible;
376 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
377
378 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379 Poll::Ready(Ok(()))
380 }
381
382 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
383 let proxies_snapshot = {
385 let guard = self.proxies_snapshot.read().unwrap();
386 Arc::clone(&*guard)
387 };
388 let counter = Arc::clone(&self.counter);
389 let strategy = self.strategy;
390 let p2c_balancer = self.p2c_balancer.clone();
391
392 Box::pin(async move {
393 match strategy {
394 LoadBalancingStrategy::RoundRobin => {
395 if proxies_snapshot.is_empty() {
397 warn!("No upstream services available");
398 Ok(axum::http::Response::builder()
399 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
400 .body(Body::from("No upstream services available"))
401 .unwrap())
402 } else {
403 let idx = counter.fetch_add(1, Ordering::Relaxed) % proxies_snapshot.len();
404 let mut proxy = proxies_snapshot[idx].clone();
405 proxy.call(req).await
406 }
407 }
408 LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
409 if let Some(balancer) = p2c_balancer {
411 balancer.call_with_p2c(req).await
412 } else {
413 error!("P2C balancer not available for strategy {:?}", strategy);
415 Ok(axum::http::Response::builder()
416 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
417 .body(Body::from("P2C balancer not available"))
418 .unwrap())
419 }
420 }
421 }
422 })
423 }
424}
425
426struct CustomP2cBalancer<C: Connect + Clone + Send + Sync + 'static> {
428 strategy: LoadBalancingStrategy,
429 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
430 metrics: Arc<std::sync::RwLock<Arc<Vec<Arc<ServiceMetrics>>>>>,
433}
434
435impl<C: Connect + Clone + Send + Sync + 'static> CustomP2cBalancer<C> {
436 fn new(
437 strategy: LoadBalancingStrategy,
438 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
439 ) -> Self {
440 let initial_metrics = Arc::new(Vec::new());
441 Self {
442 strategy,
443 proxies_snapshot,
444 metrics: Arc::new(std::sync::RwLock::new(initial_metrics)),
445 }
446 }
447
448 async fn call_with_p2c(
449 &self,
450 req: axum::http::Request<Body>,
451 ) -> Result<axum::http::Response<Body>, Infallible> {
452 let proxies = {
454 let guard = self.proxies_snapshot.read().unwrap();
455 Arc::clone(&*guard)
456 };
457
458 if proxies.is_empty() {
459 return Ok(axum::http::Response::builder()
460 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
461 .body(Body::from("No upstream services available"))
462 .unwrap());
463 }
464
465 self.ensure_metrics_size(proxies.len());
467
468 let metrics = {
470 let guard = self.metrics.read().unwrap();
471 Arc::clone(&*guard)
472 };
473
474 let selected_idx = if proxies.len() == 1 {
476 0
477 } else {
478 let mut rng = rand::rng();
479 let idx1 = rng.random_range(0..proxies.len());
480 let idx2 = loop {
481 let i = rng.random_range(0..proxies.len());
482 if i != idx1 {
483 break i;
484 }
485 };
486
487 let load1 = self.get_load(&metrics[idx1]);
489 let load2 = self.get_load(&metrics[idx2]);
490
491 if load1 <= load2 { idx1 } else { idx2 }
492 };
493
494 let request_guard = if matches!(self.strategy, LoadBalancingStrategy::P2cPendingRequests) {
496 metrics[selected_idx]
497 .pending_requests
498 .fetch_add(1, Ordering::Relaxed);
499 Some(PendingRequestGuard {
500 metrics: Arc::clone(&metrics[selected_idx]),
501 })
502 } else {
503 None
504 };
505
506 let start = Instant::now();
508
509 let mut proxy = proxies[selected_idx].clone();
511 let result = proxy.call(req).await;
512
513 if matches!(self.strategy, LoadBalancingStrategy::P2cPeakEwma) {
515 let latency = start.elapsed();
516 self.update_ewma(&metrics[selected_idx], latency);
517 }
518
519 drop(request_guard);
521
522 result
523 }
524
525 fn ensure_metrics_size(&self, size: usize) {
526 let mut metrics_guard = self.metrics.write().unwrap();
527 let current_metrics = Arc::clone(&*metrics_guard);
528
529 if current_metrics.len() != size {
530 let mut new_metrics = Vec::with_capacity(size);
531
532 for (i, metric) in current_metrics.iter().enumerate() {
534 if i < size {
535 new_metrics.push(Arc::clone(metric));
536 }
537 }
538
539 while new_metrics.len() < size {
541 new_metrics.push(Arc::new(ServiceMetrics::new()));
542 }
543
544 *metrics_guard = Arc::new(new_metrics);
545 }
546 }
547
548 fn get_load(&self, metrics: &ServiceMetrics) -> u64 {
549 match self.strategy {
550 LoadBalancingStrategy::P2cPendingRequests => {
551 metrics.pending_requests.load(Ordering::Relaxed) as u64
552 }
553 LoadBalancingStrategy::P2cPeakEwma => {
554 let last_update = *metrics.last_update.lock().unwrap();
556 let elapsed = last_update.elapsed();
557
558 let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
560 let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
561 (current as f64 * decay_factor) as u64
562 }
563 _ => unreachable!("CustomP2cBalancer should only be used with P2C strategies"),
564 }
565 }
566
567 fn update_ewma(&self, metrics: &ServiceMetrics, latency: Duration) {
568 let latency_micros = latency.as_micros() as u64;
569
570 loop {
573 let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
574
575 if current == 0 {
577 if metrics
578 .peak_ewma_micros
579 .compare_exchange(0, latency_micros, Ordering::Relaxed, Ordering::Relaxed)
580 .is_ok()
581 {
582 *metrics.last_update.lock().unwrap() = Instant::now();
583 break;
584 }
585 continue;
586 }
587
588 let mut last_update_guard = metrics.last_update.lock().unwrap();
590 let elapsed = last_update_guard.elapsed();
591
592 let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
594 let decayed_current = (current as f64 * decay_factor) as u64;
595
596 let peak = decayed_current.max(latency_micros);
598
599 let ewma = ((peak as f64 * 0.25) + (decayed_current as f64 * 0.75)) as u64;
602
603 if metrics
604 .peak_ewma_micros
605 .compare_exchange(current, ewma, Ordering::Relaxed, Ordering::Relaxed)
606 .is_ok()
607 {
608 *last_update_guard = Instant::now();
610 break;
611 }
612 drop(last_update_guard); }
614 }
615}
616
617struct PendingRequestGuard {
619 metrics: Arc<ServiceMetrics>,
620}
621
622impl Drop for PendingRequestGuard {
623 fn drop(&mut self) {
624 self.metrics
625 .pending_requests
626 .fetch_sub(1, Ordering::Relaxed);
627 }
628}
629
630#[derive(Debug)]
632struct ServiceMetrics {
633 pending_requests: AtomicUsize,
635 peak_ewma_micros: AtomicU64,
637 last_update: std::sync::Mutex<Instant>,
639}
640
641impl ServiceMetrics {
642 fn new() -> Self {
643 Self {
644 pending_requests: AtomicUsize::new(0),
645 peak_ewma_micros: AtomicU64::new(0),
646 last_update: std::sync::Mutex::new(Instant::now()),
647 }
648 }
649}