1use axum::body::Body;
2#[cfg(all(feature = "tls", not(feature = "native-tls")))]
3use hyper_rustls::HttpsConnector;
4#[cfg(feature = "native-tls")]
5use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
6use hyper_util::client::legacy::{
7 Client,
8 connect::{Connect, HttpConnector},
9};
10use std::collections::HashMap;
11use std::convert::Infallible;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15
16use tower::discover::{Change, Discover};
17use tracing::{debug, error, trace, warn};
18
19use crate::proxy::ReverseProxy;
20
21use rand::Rng;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum LoadBalancingStrategy {
27 #[default]
29 RoundRobin,
30 P2cPendingRequests,
32 P2cPeakEwma,
34}
35
36#[derive(Clone)]
37pub struct BalancedProxy<C: Connect + Clone + Send + Sync + 'static> {
38 path: String,
39 proxies: Vec<ReverseProxy<C>>,
40 counter: Arc<AtomicUsize>,
41}
42
43#[cfg(all(feature = "tls", not(feature = "native-tls")))]
44pub type StandardBalancedProxy = BalancedProxy<HttpsConnector<HttpConnector>>;
45#[cfg(feature = "native-tls")]
46pub type StandardBalancedProxy = BalancedProxy<NativeTlsHttpsConnector<HttpConnector>>;
47#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
48pub type StandardBalancedProxy = BalancedProxy<HttpConnector>;
49
50impl StandardBalancedProxy {
51 pub fn new<S>(path: S, targets: Vec<S>) -> Self
52 where
53 S: Into<String> + Clone,
54 {
55 let mut connector = HttpConnector::new();
56 connector.set_nodelay(true);
57 connector.enforce_http(false);
58 connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
59 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
60 connector.set_reuse_address(true);
61
62 #[cfg(all(feature = "tls", not(feature = "native-tls")))]
63 let connector = {
64 use hyper_rustls::HttpsConnectorBuilder;
65 HttpsConnectorBuilder::new()
66 .with_native_roots()
67 .unwrap()
68 .https_or_http()
69 .enable_http1()
70 .wrap_connector(connector)
71 };
72
73 #[cfg(feature = "native-tls")]
74 let connector = NativeTlsHttpsConnector::new_with_connector(connector);
75
76 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
77 .pool_idle_timeout(std::time::Duration::from_secs(60))
78 .pool_max_idle_per_host(32)
79 .retry_canceled_requests(true)
80 .set_host(true)
81 .build(connector);
82
83 Self::new_with_client(path, targets, client)
84 }
85}
86
87impl<C> BalancedProxy<C>
88where
89 C: Connect + Clone + Send + Sync + 'static,
90{
91 pub fn new_with_client<S>(path: S, targets: Vec<S>, client: Client<C, Body>) -> Self
92 where
93 S: Into<String> + Clone,
94 {
95 let path = path.into();
96 let proxies = targets
97 .into_iter()
98 .map(|t| ReverseProxy::new_with_client(path.clone(), t.into(), client.clone()))
99 .collect();
100
101 Self {
102 path,
103 proxies,
104 counter: Arc::new(AtomicUsize::new(0)),
105 }
106 }
107
108 pub fn path(&self) -> &str {
109 &self.path
110 }
111
112 fn next_proxy(&self) -> Option<ReverseProxy<C>> {
113 if self.proxies.is_empty() {
114 None
115 } else {
116 let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.proxies.len();
117 Some(self.proxies[idx].clone())
118 }
119 }
120}
121
122use std::{
123 future::Future,
124 pin::Pin,
125 task::{Context, Poll},
126};
127use tower::Service;
128
129impl<C> Service<axum::http::Request<Body>> for BalancedProxy<C>
130where
131 C: Connect + Clone + Send + Sync + 'static,
132{
133 type Response = axum::http::Response<Body>;
134 type Error = Infallible;
135 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
136
137 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138 Poll::Ready(Ok(()))
139 }
140
141 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
142 if let Some(mut proxy) = self.next_proxy() {
143 trace!("balanced proxying via upstream {}", proxy.target());
144 Box::pin(async move { proxy.call(req).await })
145 } else {
146 warn!("No upstream services available");
147 Box::pin(async move {
148 Ok(axum::http::Response::builder()
149 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
150 .body(Body::from("No upstream services available"))
151 .unwrap())
152 })
153 }
154 }
155}
156
157#[derive(Clone)]
168pub struct DiscoverableBalancedProxy<C, D>
169where
170 C: Connect + Clone + Send + Sync + 'static,
171 D: Discover + Clone + Send + Sync + 'static,
172 D::Service: Into<String> + Send,
173 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
174 D::Error: std::fmt::Debug + Send,
175{
176 path: String,
177 client: Client<C, Body>,
178 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
179 proxy_keys: Arc<tokio::sync::RwLock<HashMap<D::Key, usize>>>, counter: Arc<AtomicUsize>,
181 discover: D,
182 strategy: LoadBalancingStrategy,
183 p2c_balancer: Option<Arc<CustomP2cBalancer<C>>>,
185}
186
187#[cfg(all(feature = "tls", not(feature = "native-tls")))]
188pub type StandardDiscoverableBalancedProxy<D> =
189 DiscoverableBalancedProxy<HttpsConnector<HttpConnector>, D>;
190#[cfg(feature = "native-tls")]
191pub type StandardDiscoverableBalancedProxy<D> =
192 DiscoverableBalancedProxy<NativeTlsHttpsConnector<HttpConnector>, D>;
193#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
194pub type StandardDiscoverableBalancedProxy<D> = DiscoverableBalancedProxy<HttpConnector, D>;
195
196impl<C, D> DiscoverableBalancedProxy<C, D>
197where
198 C: Connect + Clone + Send + Sync + 'static,
199 D: Discover + Clone + Send + Sync + 'static,
200 D::Service: Into<String> + Send,
201 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
202 D::Error: std::fmt::Debug + Send,
203{
204 pub fn new_with_client<S>(path: S, client: Client<C, Body>, discover: D) -> Self
207 where
208 S: Into<String>,
209 {
210 Self::new_with_client_and_strategy(path, client, discover, LoadBalancingStrategy::default())
211 }
212
213 pub fn new_with_client_and_strategy<S>(
215 path: S,
216 client: Client<C, Body>,
217 discover: D,
218 strategy: LoadBalancingStrategy,
219 ) -> Self
220 where
221 S: Into<String>,
222 {
223 let path = path.into();
224 let proxies_snapshot = Arc::new(std::sync::RwLock::new(Arc::new(Vec::new())));
225
226 let p2c_balancer = match strategy {
228 LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
229 Some(Arc::new(CustomP2cBalancer::new(
230 strategy,
231 Arc::clone(&proxies_snapshot),
232 )))
233 }
234 LoadBalancingStrategy::RoundRobin => None,
235 };
236
237 Self {
238 path,
239 client,
240 proxies_snapshot,
241 proxy_keys: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
242 counter: Arc::new(AtomicUsize::new(0)),
243 discover: discover.clone(),
244 strategy,
245 p2c_balancer,
246 }
247 }
248
249 pub fn path(&self) -> &str {
251 &self.path
252 }
253
254 pub fn strategy(&self) -> LoadBalancingStrategy {
256 self.strategy
257 }
258
259 pub async fn start_discovery(&mut self) {
262 let discover = self.discover.clone();
263 let proxies_snapshot = Arc::clone(&self.proxies_snapshot);
264 let proxy_keys = Arc::clone(&self.proxy_keys);
265 let client = self.client.clone();
266 let path = self.path.clone();
267
268 tokio::spawn(async move {
269 use futures_util::future::poll_fn;
270
271 let mut discover = Box::pin(discover);
272
273 loop {
274 let change_result =
275 poll_fn(|cx: &mut Context<'_>| discover.as_mut().poll_discover(cx)).await;
276
277 match change_result {
278 Some(Ok(change)) => match change {
279 Change::Insert(key, service) => {
280 let target: String = service.into();
281 debug!("Discovered new service: {:?} -> {}", key, target);
282
283 let proxy =
284 ReverseProxy::new_with_client(path.clone(), target, client.clone());
285
286 {
287 let mut keys_guard = proxy_keys.write().await;
288
289 let current_snapshot = {
291 let snapshot_guard = proxies_snapshot.read().unwrap();
292 Arc::clone(&*snapshot_guard)
293 };
294
295 let mut new_proxies = (*current_snapshot).clone();
296 let index = new_proxies.len();
297 new_proxies.push(proxy);
298 keys_guard.insert(key, index);
299
300 {
302 let mut snapshot_guard = proxies_snapshot.write().unwrap();
303 *snapshot_guard = Arc::new(new_proxies);
304 }
305 }
306 }
307 Change::Remove(key) => {
308 debug!("Removing service: {:?}", key);
309
310 {
311 let mut keys_guard = proxy_keys.write().await;
312
313 if let Some(index) = keys_guard.remove(&key) {
314 let current_snapshot = {
316 let snapshot_guard = proxies_snapshot.read().unwrap();
317 Arc::clone(&*snapshot_guard)
318 };
319
320 let mut new_proxies = (*current_snapshot).clone();
321 new_proxies.remove(index);
322
323 for (_, idx) in keys_guard.iter_mut() {
325 if *idx > index {
326 *idx -= 1;
327 }
328 }
329
330 {
332 let mut snapshot_guard = proxies_snapshot.write().unwrap();
333 *snapshot_guard = Arc::new(new_proxies);
334 }
335 }
336 }
337 }
338 },
339 Some(Err(e)) => {
340 error!("Discovery error: {:?}", e);
341 }
342 None => {
343 warn!("Discovery stream ended");
344 break;
345 }
346 }
347 }
348 });
349 }
350
351 pub async fn service_count(&self) -> usize {
353 let snapshot = {
354 let guard = self.proxies_snapshot.read().unwrap();
355 Arc::clone(&*guard)
356 };
357 snapshot.len()
358 }
359}
360
361impl<C, D> Service<axum::http::Request<Body>> for DiscoverableBalancedProxy<C, D>
362where
363 C: Connect + Clone + Send + Sync + 'static,
364 D: Discover + Clone + Send + Sync + 'static,
365 D::Service: Into<String> + Send,
366 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
367 D::Error: std::fmt::Debug + Send,
368{
369 type Response = axum::http::Response<Body>;
370 type Error = Infallible;
371 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
372
373 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
374 Poll::Ready(Ok(()))
375 }
376
377 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
378 let proxies_snapshot = {
380 let guard = self.proxies_snapshot.read().unwrap();
381 Arc::clone(&*guard)
382 };
383 let counter = Arc::clone(&self.counter);
384 let strategy = self.strategy;
385 let p2c_balancer = self.p2c_balancer.clone();
386
387 Box::pin(async move {
388 match strategy {
389 LoadBalancingStrategy::RoundRobin => {
390 if proxies_snapshot.is_empty() {
392 warn!("No upstream services available");
393 Ok(axum::http::Response::builder()
394 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
395 .body(Body::from("No upstream services available"))
396 .unwrap())
397 } else {
398 let idx = counter.fetch_add(1, Ordering::Relaxed) % proxies_snapshot.len();
399 let mut proxy = proxies_snapshot[idx].clone();
400 proxy.call(req).await
401 }
402 }
403 LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
404 if let Some(balancer) = p2c_balancer {
406 balancer.call_with_p2c(req).await
407 } else {
408 error!("P2C balancer not available for strategy {:?}", strategy);
410 Ok(axum::http::Response::builder()
411 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
412 .body(Body::from("P2C balancer not available"))
413 .unwrap())
414 }
415 }
416 }
417 })
418 }
419}
420
421struct CustomP2cBalancer<C: Connect + Clone + Send + Sync + 'static> {
423 strategy: LoadBalancingStrategy,
424 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
425 metrics: Arc<std::sync::RwLock<Arc<Vec<Arc<ServiceMetrics>>>>>,
428}
429
430impl<C: Connect + Clone + Send + Sync + 'static> CustomP2cBalancer<C> {
431 fn new(
432 strategy: LoadBalancingStrategy,
433 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
434 ) -> Self {
435 let initial_metrics = Arc::new(Vec::new());
436 Self {
437 strategy,
438 proxies_snapshot,
439 metrics: Arc::new(std::sync::RwLock::new(initial_metrics)),
440 }
441 }
442
443 async fn call_with_p2c(
444 &self,
445 req: axum::http::Request<Body>,
446 ) -> Result<axum::http::Response<Body>, Infallible> {
447 let proxies = {
449 let guard = self.proxies_snapshot.read().unwrap();
450 Arc::clone(&*guard)
451 };
452
453 if proxies.is_empty() {
454 return Ok(axum::http::Response::builder()
455 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
456 .body(Body::from("No upstream services available"))
457 .unwrap());
458 }
459
460 self.ensure_metrics_size(proxies.len());
462
463 let metrics = {
465 let guard = self.metrics.read().unwrap();
466 Arc::clone(&*guard)
467 };
468
469 let selected_idx = if proxies.len() == 1 {
471 0
472 } else {
473 let mut rng = rand::rng();
474 let idx1 = rng.random_range(0..proxies.len());
475 let idx2 = loop {
476 let i = rng.random_range(0..proxies.len());
477 if i != idx1 {
478 break i;
479 }
480 };
481
482 let load1 = self.get_load(&metrics[idx1]);
484 let load2 = self.get_load(&metrics[idx2]);
485
486 if load1 <= load2 { idx1 } else { idx2 }
487 };
488
489 let request_guard = if matches!(self.strategy, LoadBalancingStrategy::P2cPendingRequests) {
491 metrics[selected_idx]
492 .pending_requests
493 .fetch_add(1, Ordering::Relaxed);
494 Some(PendingRequestGuard {
495 metrics: Arc::clone(&metrics[selected_idx]),
496 })
497 } else {
498 None
499 };
500
501 let start = Instant::now();
503
504 let mut proxy = proxies[selected_idx].clone();
506 let result = proxy.call(req).await;
507
508 if matches!(self.strategy, LoadBalancingStrategy::P2cPeakEwma) {
510 let latency = start.elapsed();
511 self.update_ewma(&metrics[selected_idx], latency);
512 }
513
514 drop(request_guard);
516
517 result
518 }
519
520 fn ensure_metrics_size(&self, size: usize) {
521 let mut metrics_guard = self.metrics.write().unwrap();
522 let current_metrics = Arc::clone(&*metrics_guard);
523
524 if current_metrics.len() != size {
525 let mut new_metrics = Vec::with_capacity(size);
526
527 for (i, metric) in current_metrics.iter().enumerate() {
529 if i < size {
530 new_metrics.push(Arc::clone(metric));
531 }
532 }
533
534 while new_metrics.len() < size {
536 new_metrics.push(Arc::new(ServiceMetrics::new()));
537 }
538
539 *metrics_guard = Arc::new(new_metrics);
540 }
541 }
542
543 fn get_load(&self, metrics: &ServiceMetrics) -> u64 {
544 match self.strategy {
545 LoadBalancingStrategy::P2cPendingRequests => {
546 metrics.pending_requests.load(Ordering::Relaxed) as u64
547 }
548 LoadBalancingStrategy::P2cPeakEwma => {
549 let last_update = *metrics.last_update.lock().unwrap();
551 let elapsed = last_update.elapsed();
552
553 let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
555 let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
556 (current as f64 * decay_factor) as u64
557 }
558 _ => unreachable!("CustomP2cBalancer should only be used with P2C strategies"),
559 }
560 }
561
562 fn update_ewma(&self, metrics: &ServiceMetrics, latency: Duration) {
563 let latency_micros = latency.as_micros() as u64;
564
565 loop {
568 let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
569
570 if current == 0 {
572 if metrics
573 .peak_ewma_micros
574 .compare_exchange(0, latency_micros, Ordering::Relaxed, Ordering::Relaxed)
575 .is_ok()
576 {
577 *metrics.last_update.lock().unwrap() = Instant::now();
578 break;
579 }
580 continue;
581 }
582
583 let mut last_update_guard = metrics.last_update.lock().unwrap();
585 let elapsed = last_update_guard.elapsed();
586
587 let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
589 let decayed_current = (current as f64 * decay_factor) as u64;
590
591 let peak = decayed_current.max(latency_micros);
593
594 let ewma = ((peak as f64 * 0.25) + (decayed_current as f64 * 0.75)) as u64;
597
598 if metrics
599 .peak_ewma_micros
600 .compare_exchange(current, ewma, Ordering::Relaxed, Ordering::Relaxed)
601 .is_ok()
602 {
603 *last_update_guard = Instant::now();
605 break;
606 }
607 drop(last_update_guard); }
609 }
610}
611
612struct PendingRequestGuard {
614 metrics: Arc<ServiceMetrics>,
615}
616
617impl Drop for PendingRequestGuard {
618 fn drop(&mut self) {
619 self.metrics
620 .pending_requests
621 .fetch_sub(1, Ordering::Relaxed);
622 }
623}
624
625#[derive(Debug)]
627struct ServiceMetrics {
628 pending_requests: AtomicUsize,
630 peak_ewma_micros: AtomicU64,
632 last_update: std::sync::Mutex<Instant>,
634}
635
636impl ServiceMetrics {
637 fn new() -> Self {
638 Self {
639 pending_requests: AtomicUsize::new(0),
640 peak_ewma_micros: AtomicU64::new(0),
641 last_update: std::sync::Mutex::new(Instant::now()),
642 }
643 }
644}