1use axum::body::Body;
2use hyper_util::client::legacy::{Client, connect::Connect};
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::time::{Duration, Instant};
8
9use tower::discover::{Change, Discover};
10use tracing::{debug, error, trace, warn};
11
12use crate::forward::{ProxyConnector, create_http_connector};
13use crate::proxy::ReverseProxy;
14
15use rand::Rng;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum LoadBalancingStrategy {
21 #[default]
23 RoundRobin,
24 P2cPendingRequests,
26 P2cPeakEwma,
28}
29
30#[derive(Clone)]
31pub struct BalancedProxy<C: Connect + Clone + Send + Sync + 'static> {
32 path: String,
33 proxies: Vec<ReverseProxy<C>>,
34 counter: Arc<AtomicUsize>,
35}
36
37pub type StandardBalancedProxy = BalancedProxy<ProxyConnector>;
38
39impl StandardBalancedProxy {
40 pub fn new<S>(path: S, targets: Vec<S>) -> Self
41 where
42 S: Into<String> + Clone,
43 {
44 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
45 .pool_idle_timeout(std::time::Duration::from_secs(60))
46 .pool_max_idle_per_host(32)
47 .retry_canceled_requests(true)
48 .set_host(true)
49 .build(create_http_connector());
50
51 Self::new_with_client(path, targets, client)
52 }
53}
54
55impl<C> BalancedProxy<C>
56where
57 C: Connect + Clone + Send + Sync + 'static,
58{
59 pub fn new_with_client<S>(path: S, targets: Vec<S>, client: Client<C, Body>) -> Self
60 where
61 S: Into<String> + Clone,
62 {
63 let path = path.into();
64 let proxies = targets
65 .into_iter()
66 .map(|t| ReverseProxy::new_with_client(path.clone(), t.into(), client.clone()))
67 .collect();
68
69 Self {
70 path,
71 proxies,
72 counter: Arc::new(AtomicUsize::new(0)),
73 }
74 }
75
76 pub fn path(&self) -> &str {
77 &self.path
78 }
79
80 fn next_proxy(&self) -> Option<ReverseProxy<C>> {
81 if self.proxies.is_empty() {
82 None
83 } else {
84 let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.proxies.len();
85 Some(self.proxies[idx].clone())
86 }
87 }
88}
89
90use std::{
91 future::Future,
92 pin::Pin,
93 task::{Context, Poll},
94};
95use tower::Service;
96
97impl<C> Service<axum::http::Request<Body>> for BalancedProxy<C>
98where
99 C: Connect + Clone + Send + Sync + 'static,
100{
101 type Response = axum::http::Response<Body>;
102 type Error = Infallible;
103 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
104
105 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 Poll::Ready(Ok(()))
107 }
108
109 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
110 if let Some(mut proxy) = self.next_proxy() {
111 trace!("balanced proxying via upstream {}", proxy.target());
112 Box::pin(async move { proxy.call(req).await })
113 } else {
114 warn!("No upstream services available");
115 Box::pin(async move {
116 Ok(axum::http::Response::builder()
117 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
118 .body(Body::from("No upstream services available"))
119 .unwrap())
120 })
121 }
122 }
123}
124
125#[derive(Clone)]
136pub struct DiscoverableBalancedProxy<C, D>
137where
138 C: Connect + Clone + Send + Sync + 'static,
139 D: Discover + Clone + Send + Sync + 'static,
140 D::Service: Into<String> + Send,
141 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
142 D::Error: std::fmt::Debug + Send,
143{
144 path: String,
145 client: Client<C, Body>,
146 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
147 proxy_keys: Arc<tokio::sync::RwLock<HashMap<D::Key, usize>>>, counter: Arc<AtomicUsize>,
149 discover: D,
150 strategy: LoadBalancingStrategy,
151 p2c_balancer: Option<Arc<CustomP2cBalancer<C>>>,
153}
154
155pub type StandardDiscoverableBalancedProxy<D> = DiscoverableBalancedProxy<ProxyConnector, D>;
156
157impl<C, D> DiscoverableBalancedProxy<C, D>
158where
159 C: Connect + Clone + Send + Sync + 'static,
160 D: Discover + Clone + Send + Sync + 'static,
161 D::Service: Into<String> + Send,
162 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
163 D::Error: std::fmt::Debug + Send,
164{
165 pub fn new_with_client<S>(path: S, client: Client<C, Body>, discover: D) -> Self
168 where
169 S: Into<String>,
170 {
171 Self::new_with_client_and_strategy(path, client, discover, LoadBalancingStrategy::default())
172 }
173
174 pub fn new_with_client_and_strategy<S>(
176 path: S,
177 client: Client<C, Body>,
178 discover: D,
179 strategy: LoadBalancingStrategy,
180 ) -> Self
181 where
182 S: Into<String>,
183 {
184 let path = path.into();
185 let proxies_snapshot = Arc::new(std::sync::RwLock::new(Arc::new(Vec::new())));
186
187 let p2c_balancer = match strategy {
189 LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
190 Some(Arc::new(CustomP2cBalancer::new(
191 strategy,
192 Arc::clone(&proxies_snapshot),
193 )))
194 }
195 LoadBalancingStrategy::RoundRobin => None,
196 };
197
198 Self {
199 path,
200 client,
201 proxies_snapshot,
202 proxy_keys: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
203 counter: Arc::new(AtomicUsize::new(0)),
204 discover: discover.clone(),
205 strategy,
206 p2c_balancer,
207 }
208 }
209
210 pub fn path(&self) -> &str {
212 &self.path
213 }
214
215 pub fn strategy(&self) -> LoadBalancingStrategy {
217 self.strategy
218 }
219
220 pub async fn start_discovery(&mut self) {
223 let discover = self.discover.clone();
224 let proxies_snapshot = Arc::clone(&self.proxies_snapshot);
225 let proxy_keys = Arc::clone(&self.proxy_keys);
226 let client = self.client.clone();
227 let path = self.path.clone();
228
229 tokio::spawn(async move {
230 use futures_util::future::poll_fn;
231
232 let mut discover = Box::pin(discover);
233
234 loop {
235 let change_result =
236 poll_fn(|cx: &mut Context<'_>| discover.as_mut().poll_discover(cx)).await;
237
238 match change_result {
239 Some(Ok(change)) => match change {
240 Change::Insert(key, service) => {
241 let target: String = service.into();
242 debug!("Discovered new service: {:?} -> {}", key, target);
243
244 let proxy =
245 ReverseProxy::new_with_client(path.clone(), target, client.clone());
246
247 {
248 let mut keys_guard = proxy_keys.write().await;
249
250 let current_snapshot = {
252 let snapshot_guard = proxies_snapshot.read().unwrap();
253 Arc::clone(&*snapshot_guard)
254 };
255
256 let mut new_proxies = (*current_snapshot).clone();
257 let index = new_proxies.len();
258 new_proxies.push(proxy);
259 keys_guard.insert(key, index);
260
261 {
263 let mut snapshot_guard = proxies_snapshot.write().unwrap();
264 *snapshot_guard = Arc::new(new_proxies);
265 }
266 }
267 }
268 Change::Remove(key) => {
269 debug!("Removing service: {:?}", key);
270
271 {
272 let mut keys_guard = proxy_keys.write().await;
273
274 if let Some(index) = keys_guard.remove(&key) {
275 let current_snapshot = {
277 let snapshot_guard = proxies_snapshot.read().unwrap();
278 Arc::clone(&*snapshot_guard)
279 };
280
281 let mut new_proxies = (*current_snapshot).clone();
282 new_proxies.remove(index);
283
284 for (_, idx) in keys_guard.iter_mut() {
286 if *idx > index {
287 *idx -= 1;
288 }
289 }
290
291 {
293 let mut snapshot_guard = proxies_snapshot.write().unwrap();
294 *snapshot_guard = Arc::new(new_proxies);
295 }
296 }
297 }
298 }
299 },
300 Some(Err(e)) => {
301 error!("Discovery error: {:?}", e);
302 }
303 None => {
304 warn!("Discovery stream ended");
305 break;
306 }
307 }
308 }
309 });
310 }
311
312 pub async fn service_count(&self) -> usize {
314 let snapshot = {
315 let guard = self.proxies_snapshot.read().unwrap();
316 Arc::clone(&*guard)
317 };
318 snapshot.len()
319 }
320}
321
322impl<C, D> Service<axum::http::Request<Body>> for DiscoverableBalancedProxy<C, D>
323where
324 C: Connect + Clone + Send + Sync + 'static,
325 D: Discover + Clone + Send + Sync + 'static,
326 D::Service: Into<String> + Send,
327 D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
328 D::Error: std::fmt::Debug + Send,
329{
330 type Response = axum::http::Response<Body>;
331 type Error = Infallible;
332 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
333
334 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
335 Poll::Ready(Ok(()))
336 }
337
338 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
339 let proxies_snapshot = {
341 let guard = self.proxies_snapshot.read().unwrap();
342 Arc::clone(&*guard)
343 };
344 let counter = Arc::clone(&self.counter);
345 let strategy = self.strategy;
346 let p2c_balancer = self.p2c_balancer.clone();
347
348 Box::pin(async move {
349 match strategy {
350 LoadBalancingStrategy::RoundRobin => {
351 if proxies_snapshot.is_empty() {
353 warn!("No upstream services available");
354 Ok(axum::http::Response::builder()
355 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
356 .body(Body::from("No upstream services available"))
357 .unwrap())
358 } else {
359 let idx = counter.fetch_add(1, Ordering::Relaxed) % proxies_snapshot.len();
360 let mut proxy = proxies_snapshot[idx].clone();
361 proxy.call(req).await
362 }
363 }
364 LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
365 if let Some(balancer) = p2c_balancer {
367 balancer.call_with_p2c(req).await
368 } else {
369 error!("P2C balancer not available for strategy {:?}", strategy);
371 Ok(axum::http::Response::builder()
372 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
373 .body(Body::from("P2C balancer not available"))
374 .unwrap())
375 }
376 }
377 }
378 })
379 }
380}
381
382struct CustomP2cBalancer<C: Connect + Clone + Send + Sync + 'static> {
384 strategy: LoadBalancingStrategy,
385 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
386 metrics: Arc<std::sync::RwLock<Arc<Vec<Arc<ServiceMetrics>>>>>,
389}
390
391impl<C: Connect + Clone + Send + Sync + 'static> CustomP2cBalancer<C> {
392 fn new(
393 strategy: LoadBalancingStrategy,
394 proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
395 ) -> Self {
396 let initial_metrics = Arc::new(Vec::new());
397 Self {
398 strategy,
399 proxies_snapshot,
400 metrics: Arc::new(std::sync::RwLock::new(initial_metrics)),
401 }
402 }
403
404 async fn call_with_p2c(
405 &self,
406 req: axum::http::Request<Body>,
407 ) -> Result<axum::http::Response<Body>, Infallible> {
408 let proxies = {
410 let guard = self.proxies_snapshot.read().unwrap();
411 Arc::clone(&*guard)
412 };
413
414 if proxies.is_empty() {
415 return Ok(axum::http::Response::builder()
416 .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
417 .body(Body::from("No upstream services available"))
418 .unwrap());
419 }
420
421 self.ensure_metrics_size(proxies.len());
423
424 let metrics = {
426 let guard = self.metrics.read().unwrap();
427 Arc::clone(&*guard)
428 };
429
430 let selected_idx = if proxies.len() == 1 {
432 0
433 } else {
434 let mut rng = rand::rng();
435 let idx1 = rng.random_range(0..proxies.len());
436 let idx2 = loop {
437 let i = rng.random_range(0..proxies.len());
438 if i != idx1 {
439 break i;
440 }
441 };
442
443 let load1 = self.get_load(&metrics[idx1]);
445 let load2 = self.get_load(&metrics[idx2]);
446
447 if load1 <= load2 { idx1 } else { idx2 }
448 };
449
450 let request_guard = if matches!(self.strategy, LoadBalancingStrategy::P2cPendingRequests) {
452 metrics[selected_idx]
453 .pending_requests
454 .fetch_add(1, Ordering::Relaxed);
455 Some(PendingRequestGuard {
456 metrics: Arc::clone(&metrics[selected_idx]),
457 })
458 } else {
459 None
460 };
461
462 let start = Instant::now();
464
465 let mut proxy = proxies[selected_idx].clone();
467 let result = proxy.call(req).await;
468
469 if matches!(self.strategy, LoadBalancingStrategy::P2cPeakEwma) {
471 let latency = start.elapsed();
472 self.update_ewma(&metrics[selected_idx], latency);
473 }
474
475 drop(request_guard);
477
478 result
479 }
480
481 fn ensure_metrics_size(&self, size: usize) {
482 let mut metrics_guard = self.metrics.write().unwrap();
483 let current_metrics = Arc::clone(&*metrics_guard);
484
485 if current_metrics.len() != size {
486 let mut new_metrics = Vec::with_capacity(size);
487
488 for (i, metric) in current_metrics.iter().enumerate() {
490 if i < size {
491 new_metrics.push(Arc::clone(metric));
492 }
493 }
494
495 while new_metrics.len() < size {
497 new_metrics.push(Arc::new(ServiceMetrics::new()));
498 }
499
500 *metrics_guard = Arc::new(new_metrics);
501 }
502 }
503
504 fn get_load(&self, metrics: &ServiceMetrics) -> u64 {
505 match self.strategy {
506 LoadBalancingStrategy::P2cPendingRequests => {
507 metrics.pending_requests.load(Ordering::Relaxed) as u64
508 }
509 LoadBalancingStrategy::P2cPeakEwma => {
510 let last_update = *metrics.last_update.lock().unwrap();
512 let elapsed = last_update.elapsed();
513
514 let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
516 let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
517 (current as f64 * decay_factor) as u64
518 }
519 _ => unreachable!("CustomP2cBalancer should only be used with P2C strategies"),
520 }
521 }
522
523 fn update_ewma(&self, metrics: &ServiceMetrics, latency: Duration) {
524 let latency_micros = latency.as_micros() as u64;
525
526 loop {
529 let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
530
531 if current == 0 {
533 if metrics
534 .peak_ewma_micros
535 .compare_exchange(0, latency_micros, Ordering::Relaxed, Ordering::Relaxed)
536 .is_ok()
537 {
538 *metrics.last_update.lock().unwrap() = Instant::now();
539 break;
540 }
541 continue;
542 }
543
544 let mut last_update_guard = metrics.last_update.lock().unwrap();
546 let elapsed = last_update_guard.elapsed();
547
548 let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
550 let decayed_current = (current as f64 * decay_factor) as u64;
551
552 let peak = decayed_current.max(latency_micros);
554
555 let ewma = ((peak as f64 * 0.25) + (decayed_current as f64 * 0.75)) as u64;
558
559 if metrics
560 .peak_ewma_micros
561 .compare_exchange(current, ewma, Ordering::Relaxed, Ordering::Relaxed)
562 .is_ok()
563 {
564 *last_update_guard = Instant::now();
566 break;
567 }
568 drop(last_update_guard); }
570 }
571}
572
573struct PendingRequestGuard {
575 metrics: Arc<ServiceMetrics>,
576}
577
578impl Drop for PendingRequestGuard {
579 fn drop(&mut self) {
580 self.metrics
581 .pending_requests
582 .fetch_sub(1, Ordering::Relaxed);
583 }
584}
585
586#[derive(Debug)]
588struct ServiceMetrics {
589 pending_requests: AtomicUsize,
591 peak_ewma_micros: AtomicU64,
593 last_update: std::sync::Mutex<Instant>,
595}
596
597impl ServiceMetrics {
598 fn new() -> Self {
599 Self {
600 pending_requests: AtomicUsize::new(0),
601 peak_ewma_micros: AtomicU64::new(0),
602 last_update: std::sync::Mutex::new(Instant::now()),
603 }
604 }
605}