axum_reverse_proxy/
balanced_proxy.rs

1use axum::body::Body;
2#[cfg(all(feature = "tls", not(feature = "native-tls")))]
3use hyper_rustls::HttpsConnector;
4#[cfg(feature = "native-tls")]
5use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
6use hyper_util::client::legacy::{
7    Client,
8    connect::{Connect, HttpConnector},
9};
10use std::collections::HashMap;
11use std::convert::Infallible;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15
16use tower::discover::{Change, Discover};
17use tracing::{debug, error, trace, warn};
18
19use crate::proxy::ReverseProxy;
20
21// For custom P2C implementation
22use rand::Rng;
23
24/// Load balancing strategy for distributing requests across discovered services
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum LoadBalancingStrategy {
27    /// Simple round-robin distribution (default)
28    RoundRobin,
29    /// Power of Two Choices with pending request count as load metric
30    P2cPendingRequests,
31    /// Power of Two Choices with peak EWMA latency as load metric
32    P2cPeakEwma,
33}
34
35impl Default for LoadBalancingStrategy {
36    fn default() -> Self {
37        Self::RoundRobin
38    }
39}
40
41#[derive(Clone)]
42pub struct BalancedProxy<C: Connect + Clone + Send + Sync + 'static> {
43    path: String,
44    proxies: Vec<ReverseProxy<C>>,
45    counter: Arc<AtomicUsize>,
46}
47
48#[cfg(all(feature = "tls", not(feature = "native-tls")))]
49pub type StandardBalancedProxy = BalancedProxy<HttpsConnector<HttpConnector>>;
50#[cfg(feature = "native-tls")]
51pub type StandardBalancedProxy = BalancedProxy<NativeTlsHttpsConnector<HttpConnector>>;
52#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
53pub type StandardBalancedProxy = BalancedProxy<HttpConnector>;
54
55impl StandardBalancedProxy {
56    pub fn new<S>(path: S, targets: Vec<S>) -> Self
57    where
58        S: Into<String> + Clone,
59    {
60        let mut connector = HttpConnector::new();
61        connector.set_nodelay(true);
62        connector.enforce_http(false);
63        connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
64        connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
65        connector.set_reuse_address(true);
66
67        #[cfg(all(feature = "tls", not(feature = "native-tls")))]
68        let connector = {
69            use hyper_rustls::HttpsConnectorBuilder;
70            HttpsConnectorBuilder::new()
71                .with_native_roots()
72                .unwrap()
73                .https_or_http()
74                .enable_http1()
75                .wrap_connector(connector)
76        };
77
78        #[cfg(feature = "native-tls")]
79        let connector = NativeTlsHttpsConnector::new_with_connector(connector);
80
81        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
82            .pool_idle_timeout(std::time::Duration::from_secs(60))
83            .pool_max_idle_per_host(32)
84            .retry_canceled_requests(true)
85            .set_host(true)
86            .build(connector);
87
88        Self::new_with_client(path, targets, client)
89    }
90}
91
92impl<C> BalancedProxy<C>
93where
94    C: Connect + Clone + Send + Sync + 'static,
95{
96    pub fn new_with_client<S>(path: S, targets: Vec<S>, client: Client<C, Body>) -> Self
97    where
98        S: Into<String> + Clone,
99    {
100        let path = path.into();
101        let proxies = targets
102            .into_iter()
103            .map(|t| ReverseProxy::new_with_client(path.clone(), t.into(), client.clone()))
104            .collect();
105
106        Self {
107            path,
108            proxies,
109            counter: Arc::new(AtomicUsize::new(0)),
110        }
111    }
112
113    pub fn path(&self) -> &str {
114        &self.path
115    }
116
117    fn next_proxy(&self) -> Option<ReverseProxy<C>> {
118        if self.proxies.is_empty() {
119            None
120        } else {
121            let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.proxies.len();
122            Some(self.proxies[idx].clone())
123        }
124    }
125}
126
127use std::{
128    future::Future,
129    pin::Pin,
130    task::{Context, Poll},
131};
132use tower::Service;
133
134impl<C> Service<axum::http::Request<Body>> for BalancedProxy<C>
135where
136    C: Connect + Clone + Send + Sync + 'static,
137{
138    type Response = axum::http::Response<Body>;
139    type Error = Infallible;
140    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
141
142    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143        Poll::Ready(Ok(()))
144    }
145
146    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
147        if let Some(mut proxy) = self.next_proxy() {
148            trace!("balanced proxying via upstream {}", proxy.target());
149            Box::pin(async move { proxy.call(req).await })
150        } else {
151            warn!("No upstream services available");
152            Box::pin(async move {
153                Ok(axum::http::Response::builder()
154                    .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
155                    .body(Body::from("No upstream services available"))
156                    .unwrap())
157            })
158        }
159    }
160}
161
162/// A balanced proxy that supports dynamic service discovery.
163///
164/// This proxy uses the tower::discover trait to dynamically add and remove
165/// upstream services. Services are load-balanced using a configurable strategy.
166///
167/// Features:
168/// - High-performance request handling with minimal overhead
169/// - Atomic service updates that don't block ongoing requests
170/// - Efficient round-robin load balancing
171/// - Zero-downtime service discovery changes
172#[derive(Clone)]
173pub struct DiscoverableBalancedProxy<C, D>
174where
175    C: Connect + Clone + Send + Sync + 'static,
176    D: Discover + Clone + Send + Sync + 'static,
177    D::Service: Into<String> + Send,
178    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
179    D::Error: std::fmt::Debug + Send,
180{
181    path: String,
182    client: Client<C, Body>,
183    proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
184    proxy_keys: Arc<tokio::sync::RwLock<HashMap<D::Key, usize>>>, // key -> index mapping
185    counter: Arc<AtomicUsize>,
186    discover: D,
187    strategy: LoadBalancingStrategy,
188    // Custom P2C balancer for strategies that need it
189    p2c_balancer: Option<Arc<CustomP2cBalancer<C>>>,
190}
191
192#[cfg(all(feature = "tls", not(feature = "native-tls")))]
193pub type StandardDiscoverableBalancedProxy<D> =
194    DiscoverableBalancedProxy<HttpsConnector<HttpConnector>, D>;
195#[cfg(feature = "native-tls")]
196pub type StandardDiscoverableBalancedProxy<D> =
197    DiscoverableBalancedProxy<NativeTlsHttpsConnector<HttpConnector>, D>;
198#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
199pub type StandardDiscoverableBalancedProxy<D> = DiscoverableBalancedProxy<HttpConnector, D>;
200
201impl<C, D> DiscoverableBalancedProxy<C, D>
202where
203    C: Connect + Clone + Send + Sync + 'static,
204    D: Discover + Clone + Send + Sync + 'static,
205    D::Service: Into<String> + Send,
206    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
207    D::Error: std::fmt::Debug + Send,
208{
209    /// Creates a new discoverable balanced proxy with a custom client and discover implementation.
210    /// Uses round-robin load balancing by default.
211    pub fn new_with_client<S>(path: S, client: Client<C, Body>, discover: D) -> Self
212    where
213        S: Into<String>,
214    {
215        Self::new_with_client_and_strategy(path, client, discover, LoadBalancingStrategy::default())
216    }
217
218    /// Creates a new discoverable balanced proxy with a custom client, discover implementation, and load balancing strategy.
219    pub fn new_with_client_and_strategy<S>(
220        path: S,
221        client: Client<C, Body>,
222        discover: D,
223        strategy: LoadBalancingStrategy,
224    ) -> Self
225    where
226        S: Into<String>,
227    {
228        let path = path.into();
229        let proxies_snapshot = Arc::new(std::sync::RwLock::new(Arc::new(Vec::new())));
230
231        // Create P2C balancer if needed
232        let p2c_balancer = match strategy {
233            LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
234                Some(Arc::new(CustomP2cBalancer::new(
235                    strategy,
236                    Arc::clone(&proxies_snapshot),
237                )))
238            }
239            LoadBalancingStrategy::RoundRobin => None,
240        };
241
242        Self {
243            path,
244            client,
245            proxies_snapshot,
246            proxy_keys: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
247            counter: Arc::new(AtomicUsize::new(0)),
248            discover: discover.clone(),
249            strategy,
250            p2c_balancer,
251        }
252    }
253
254    /// Get the base path this proxy is configured to handle
255    pub fn path(&self) -> &str {
256        &self.path
257    }
258
259    /// Get the load balancing strategy being used
260    pub fn strategy(&self) -> LoadBalancingStrategy {
261        self.strategy
262    }
263
264    /// Start the discovery process in the background.
265    /// This should be called once to begin monitoring for service changes.
266    pub async fn start_discovery(&mut self) {
267        let discover = self.discover.clone();
268        let proxies_snapshot = Arc::clone(&self.proxies_snapshot);
269        let proxy_keys = Arc::clone(&self.proxy_keys);
270        let client = self.client.clone();
271        let path = self.path.clone();
272
273        tokio::spawn(async move {
274            use futures_util::future::poll_fn;
275
276            let mut discover = Box::pin(discover);
277
278            loop {
279                let change_result =
280                    poll_fn(|cx: &mut Context<'_>| discover.as_mut().poll_discover(cx)).await;
281
282                match change_result {
283                    Some(Ok(change)) => match change {
284                        Change::Insert(key, service) => {
285                            let target: String = service.into();
286                            debug!("Discovered new service: {:?} -> {}", key, target);
287
288                            let proxy =
289                                ReverseProxy::new_with_client(path.clone(), target, client.clone());
290
291                            {
292                                let mut keys_guard = proxy_keys.write().await;
293
294                                // Get current snapshot and create new one with added service
295                                let current_snapshot = {
296                                    let snapshot_guard = proxies_snapshot.read().unwrap();
297                                    Arc::clone(&*snapshot_guard)
298                                };
299
300                                let mut new_proxies = (*current_snapshot).clone();
301                                let index = new_proxies.len();
302                                new_proxies.push(proxy);
303                                keys_guard.insert(key, index);
304
305                                // Atomically update the snapshot
306                                {
307                                    let mut snapshot_guard = proxies_snapshot.write().unwrap();
308                                    *snapshot_guard = Arc::new(new_proxies);
309                                }
310                            }
311                        }
312                        Change::Remove(key) => {
313                            debug!("Removing service: {:?}", key);
314
315                            {
316                                let mut keys_guard = proxy_keys.write().await;
317
318                                if let Some(index) = keys_guard.remove(&key) {
319                                    // Get current snapshot and create new one with removed service
320                                    let current_snapshot = {
321                                        let snapshot_guard = proxies_snapshot.read().unwrap();
322                                        Arc::clone(&*snapshot_guard)
323                                    };
324
325                                    let mut new_proxies = (*current_snapshot).clone();
326                                    new_proxies.remove(index);
327
328                                    // Update indices for all keys after the removed index
329                                    for (_, idx) in keys_guard.iter_mut() {
330                                        if *idx > index {
331                                            *idx -= 1;
332                                        }
333                                    }
334
335                                    // Atomically update the snapshot
336                                    {
337                                        let mut snapshot_guard = proxies_snapshot.write().unwrap();
338                                        *snapshot_guard = Arc::new(new_proxies);
339                                    }
340                                }
341                            }
342                        }
343                    },
344                    Some(Err(e)) => {
345                        error!("Discovery error: {:?}", e);
346                    }
347                    None => {
348                        warn!("Discovery stream ended");
349                        break;
350                    }
351                }
352            }
353        });
354    }
355
356    /// Get the current number of discovered services
357    pub async fn service_count(&self) -> usize {
358        let snapshot = {
359            let guard = self.proxies_snapshot.read().unwrap();
360            Arc::clone(&*guard)
361        };
362        snapshot.len()
363    }
364}
365
366impl<C, D> Service<axum::http::Request<Body>> for DiscoverableBalancedProxy<C, D>
367where
368    C: Connect + Clone + Send + Sync + 'static,
369    D: Discover + Clone + Send + Sync + 'static,
370    D::Service: Into<String> + Send,
371    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
372    D::Error: std::fmt::Debug + Send,
373{
374    type Response = axum::http::Response<Body>;
375    type Error = Infallible;
376    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
377
378    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379        Poll::Ready(Ok(()))
380    }
381
382    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
383        // Get current proxy snapshot
384        let proxies_snapshot = {
385            let guard = self.proxies_snapshot.read().unwrap();
386            Arc::clone(&*guard)
387        };
388        let counter = Arc::clone(&self.counter);
389        let strategy = self.strategy;
390        let p2c_balancer = self.p2c_balancer.clone();
391
392        Box::pin(async move {
393            match strategy {
394                LoadBalancingStrategy::RoundRobin => {
395                    // Round-robin load balancing
396                    if proxies_snapshot.is_empty() {
397                        warn!("No upstream services available");
398                        Ok(axum::http::Response::builder()
399                            .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
400                            .body(Body::from("No upstream services available"))
401                            .unwrap())
402                    } else {
403                        let idx = counter.fetch_add(1, Ordering::Relaxed) % proxies_snapshot.len();
404                        let mut proxy = proxies_snapshot[idx].clone();
405                        proxy.call(req).await
406                    }
407                }
408                LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
409                    // Use the custom P2C balancer
410                    if let Some(balancer) = p2c_balancer {
411                        balancer.call_with_p2c(req).await
412                    } else {
413                        // Fallback to error if balancer is not available
414                        error!("P2C balancer not available for strategy {:?}", strategy);
415                        Ok(axum::http::Response::builder()
416                            .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
417                            .body(Body::from("P2C balancer not available"))
418                            .unwrap())
419                    }
420                }
421            }
422        })
423    }
424}
425
426/// Custom P2C load balancer that uses atomic operations for low-contention metrics tracking
427struct CustomP2cBalancer<C: Connect + Clone + Send + Sync + 'static> {
428    strategy: LoadBalancingStrategy,
429    proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
430    /// Metrics for each service, indexed by position in proxies_snapshot
431    /// We use Arc<Vec<Arc<ServiceMetrics>>> to allow concurrent access with minimal locking
432    metrics: Arc<std::sync::RwLock<Arc<Vec<Arc<ServiceMetrics>>>>>,
433}
434
435impl<C: Connect + Clone + Send + Sync + 'static> CustomP2cBalancer<C> {
436    fn new(
437        strategy: LoadBalancingStrategy,
438        proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
439    ) -> Self {
440        let initial_metrics = Arc::new(Vec::new());
441        Self {
442            strategy,
443            proxies_snapshot,
444            metrics: Arc::new(std::sync::RwLock::new(initial_metrics)),
445        }
446    }
447
448    async fn call_with_p2c(
449        &self,
450        req: axum::http::Request<Body>,
451    ) -> Result<axum::http::Response<Body>, Infallible> {
452        // Get current proxy snapshot
453        let proxies = {
454            let guard = self.proxies_snapshot.read().unwrap();
455            Arc::clone(&*guard)
456        };
457
458        if proxies.is_empty() {
459            return Ok(axum::http::Response::builder()
460                .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
461                .body(Body::from("No upstream services available"))
462                .unwrap());
463        }
464
465        // Ensure metrics vector is up to date
466        self.ensure_metrics_size(proxies.len());
467
468        // Get metrics snapshot
469        let metrics = {
470            let guard = self.metrics.read().unwrap();
471            Arc::clone(&*guard)
472        };
473
474        // P2C: Pick two random services and choose the one with lower load
475        let selected_idx = if proxies.len() == 1 {
476            0
477        } else {
478            let mut rng = rand::rng();
479            let idx1 = rng.random_range(0..proxies.len());
480            let idx2 = loop {
481                let i = rng.random_range(0..proxies.len());
482                if i != idx1 {
483                    break i;
484                }
485            };
486
487            // Compare load metrics based on strategy
488            let load1 = self.get_load(&metrics[idx1]);
489            let load2 = self.get_load(&metrics[idx2]);
490
491            if load1 <= load2 { idx1 } else { idx2 }
492        };
493
494        // Track request start for pending requests
495        let request_guard = if matches!(self.strategy, LoadBalancingStrategy::P2cPendingRequests) {
496            metrics[selected_idx]
497                .pending_requests
498                .fetch_add(1, Ordering::Relaxed);
499            Some(PendingRequestGuard {
500                metrics: Arc::clone(&metrics[selected_idx]),
501            })
502        } else {
503            None
504        };
505
506        // Record start time for latency tracking
507        let start = Instant::now();
508
509        // Make the actual request
510        let mut proxy = proxies[selected_idx].clone();
511        let result = proxy.call(req).await;
512
513        // Update latency metrics for EWMA strategy
514        if matches!(self.strategy, LoadBalancingStrategy::P2cPeakEwma) {
515            let latency = start.elapsed();
516            self.update_ewma(&metrics[selected_idx], latency);
517        }
518
519        // Request guard will decrement pending count when dropped
520        drop(request_guard);
521
522        result
523    }
524
525    fn ensure_metrics_size(&self, size: usize) {
526        let mut metrics_guard = self.metrics.write().unwrap();
527        let current_metrics = Arc::clone(&*metrics_guard);
528
529        if current_metrics.len() != size {
530            let mut new_metrics = Vec::with_capacity(size);
531
532            // Copy existing metrics
533            for (i, metric) in current_metrics.iter().enumerate() {
534                if i < size {
535                    new_metrics.push(Arc::clone(metric));
536                }
537            }
538
539            // Add new metrics if needed
540            while new_metrics.len() < size {
541                new_metrics.push(Arc::new(ServiceMetrics::new()));
542            }
543
544            *metrics_guard = Arc::new(new_metrics);
545        }
546    }
547
548    fn get_load(&self, metrics: &ServiceMetrics) -> u64 {
549        match self.strategy {
550            LoadBalancingStrategy::P2cPendingRequests => {
551                metrics.pending_requests.load(Ordering::Relaxed) as u64
552            }
553            LoadBalancingStrategy::P2cPeakEwma => {
554                // Apply decay based on time since last update
555                let last_update = *metrics.last_update.lock().unwrap();
556                let elapsed = last_update.elapsed();
557
558                // Simple exponential decay: reduce by ~50% every 5 seconds
559                let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
560                let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
561                (current as f64 * decay_factor) as u64
562            }
563            _ => unreachable!("CustomP2cBalancer should only be used with P2C strategies"),
564        }
565    }
566
567    fn update_ewma(&self, metrics: &ServiceMetrics, latency: Duration) {
568        let latency_micros = latency.as_micros() as u64;
569
570        // Update with exponential weighted moving average
571        // Using compare-and-swap loop for lock-free update
572        loop {
573            let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
574
575            // If this is the first measurement, just set it
576            if current == 0 {
577                if metrics
578                    .peak_ewma_micros
579                    .compare_exchange(0, latency_micros, Ordering::Relaxed, Ordering::Relaxed)
580                    .is_ok()
581                {
582                    *metrics.last_update.lock().unwrap() = Instant::now();
583                    break;
584                }
585                continue;
586            }
587
588            // Apply decay based on time since last update
589            let mut last_update_guard = metrics.last_update.lock().unwrap();
590            let elapsed = last_update_guard.elapsed();
591
592            // Decay factor: reduce by ~50% every 5 seconds
593            let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
594            let decayed_current = (current as f64 * decay_factor) as u64;
595
596            // Peak EWMA: take the maximum of the decayed value and the new measurement
597            let peak = decayed_current.max(latency_micros);
598
599            // EWMA with alpha = 0.25 (25% new value, 75% old value)
600            // This gives more weight to recent measurements
601            let ewma = ((peak as f64 * 0.25) + (decayed_current as f64 * 0.75)) as u64;
602
603            if metrics
604                .peak_ewma_micros
605                .compare_exchange(current, ewma, Ordering::Relaxed, Ordering::Relaxed)
606                .is_ok()
607            {
608                // Update last update time
609                *last_update_guard = Instant::now();
610                break;
611            }
612            drop(last_update_guard); // Release lock before retrying
613        }
614    }
615}
616
617/// RAII guard to decrement pending request count when request completes
618struct PendingRequestGuard {
619    metrics: Arc<ServiceMetrics>,
620}
621
622impl Drop for PendingRequestGuard {
623    fn drop(&mut self) {
624        self.metrics
625            .pending_requests
626            .fetch_sub(1, Ordering::Relaxed);
627    }
628}
629
630/// Metrics for a single service used in P2C load balancing
631#[derive(Debug)]
632struct ServiceMetrics {
633    /// Number of pending requests (for P2cPendingRequests strategy)
634    pending_requests: AtomicUsize,
635    /// Peak EWMA latency in microseconds (for P2cPeakEwma strategy)
636    peak_ewma_micros: AtomicU64,
637    /// Last update time for EWMA decay calculation
638    last_update: std::sync::Mutex<Instant>,
639}
640
641impl ServiceMetrics {
642    fn new() -> Self {
643        Self {
644            pending_requests: AtomicUsize::new(0),
645            peak_ewma_micros: AtomicU64::new(0),
646            last_update: std::sync::Mutex::new(Instant::now()),
647        }
648    }
649}