axum_reverse_proxy/
balanced_proxy.rs

1use axum::body::Body;
2#[cfg(all(feature = "tls", not(feature = "native-tls")))]
3use hyper_rustls::HttpsConnector;
4#[cfg(feature = "native-tls")]
5use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
6use hyper_util::client::legacy::{
7    Client,
8    connect::{Connect, HttpConnector},
9};
10use std::collections::HashMap;
11use std::convert::Infallible;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15
16use tower::discover::{Change, Discover};
17use tracing::{debug, error, trace, warn};
18
19use crate::proxy::ReverseProxy;
20
21// For custom P2C implementation
22use rand::Rng;
23
24/// Load balancing strategy for distributing requests across discovered services
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum LoadBalancingStrategy {
27    /// Simple round-robin distribution (default)
28    #[default]
29    RoundRobin,
30    /// Power of Two Choices with pending request count as load metric
31    P2cPendingRequests,
32    /// Power of Two Choices with peak EWMA latency as load metric
33    P2cPeakEwma,
34}
35
36#[derive(Clone)]
37pub struct BalancedProxy<C: Connect + Clone + Send + Sync + 'static> {
38    path: String,
39    proxies: Vec<ReverseProxy<C>>,
40    counter: Arc<AtomicUsize>,
41}
42
43#[cfg(all(feature = "tls", not(feature = "native-tls")))]
44pub type StandardBalancedProxy = BalancedProxy<HttpsConnector<HttpConnector>>;
45#[cfg(feature = "native-tls")]
46pub type StandardBalancedProxy = BalancedProxy<NativeTlsHttpsConnector<HttpConnector>>;
47#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
48pub type StandardBalancedProxy = BalancedProxy<HttpConnector>;
49
50impl StandardBalancedProxy {
51    pub fn new<S>(path: S, targets: Vec<S>) -> Self
52    where
53        S: Into<String> + Clone,
54    {
55        let mut connector = HttpConnector::new();
56        connector.set_nodelay(true);
57        connector.enforce_http(false);
58        connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
59        connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
60        connector.set_reuse_address(true);
61
62        #[cfg(all(feature = "tls", not(feature = "native-tls")))]
63        let connector = {
64            use hyper_rustls::HttpsConnectorBuilder;
65            HttpsConnectorBuilder::new()
66                .with_native_roots()
67                .unwrap()
68                .https_or_http()
69                .enable_http1()
70                .wrap_connector(connector)
71        };
72
73        #[cfg(feature = "native-tls")]
74        let connector = NativeTlsHttpsConnector::new_with_connector(connector);
75
76        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
77            .pool_idle_timeout(std::time::Duration::from_secs(60))
78            .pool_max_idle_per_host(32)
79            .retry_canceled_requests(true)
80            .set_host(true)
81            .build(connector);
82
83        Self::new_with_client(path, targets, client)
84    }
85}
86
87impl<C> BalancedProxy<C>
88where
89    C: Connect + Clone + Send + Sync + 'static,
90{
91    pub fn new_with_client<S>(path: S, targets: Vec<S>, client: Client<C, Body>) -> Self
92    where
93        S: Into<String> + Clone,
94    {
95        let path = path.into();
96        let proxies = targets
97            .into_iter()
98            .map(|t| ReverseProxy::new_with_client(path.clone(), t.into(), client.clone()))
99            .collect();
100
101        Self {
102            path,
103            proxies,
104            counter: Arc::new(AtomicUsize::new(0)),
105        }
106    }
107
108    pub fn path(&self) -> &str {
109        &self.path
110    }
111
112    fn next_proxy(&self) -> Option<ReverseProxy<C>> {
113        if self.proxies.is_empty() {
114            None
115        } else {
116            let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.proxies.len();
117            Some(self.proxies[idx].clone())
118        }
119    }
120}
121
122use std::{
123    future::Future,
124    pin::Pin,
125    task::{Context, Poll},
126};
127use tower::Service;
128
129impl<C> Service<axum::http::Request<Body>> for BalancedProxy<C>
130where
131    C: Connect + Clone + Send + Sync + 'static,
132{
133    type Response = axum::http::Response<Body>;
134    type Error = Infallible;
135    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
136
137    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138        Poll::Ready(Ok(()))
139    }
140
141    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
142        if let Some(mut proxy) = self.next_proxy() {
143            trace!("balanced proxying via upstream {}", proxy.target());
144            Box::pin(async move { proxy.call(req).await })
145        } else {
146            warn!("No upstream services available");
147            Box::pin(async move {
148                Ok(axum::http::Response::builder()
149                    .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
150                    .body(Body::from("No upstream services available"))
151                    .unwrap())
152            })
153        }
154    }
155}
156
157/// A balanced proxy that supports dynamic service discovery.
158///
159/// This proxy uses the tower::discover trait to dynamically add and remove
160/// upstream services. Services are load-balanced using a configurable strategy.
161///
162/// Features:
163/// - High-performance request handling with minimal overhead
164/// - Atomic service updates that don't block ongoing requests
165/// - Efficient round-robin load balancing
166/// - Zero-downtime service discovery changes
167#[derive(Clone)]
168pub struct DiscoverableBalancedProxy<C, D>
169where
170    C: Connect + Clone + Send + Sync + 'static,
171    D: Discover + Clone + Send + Sync + 'static,
172    D::Service: Into<String> + Send,
173    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
174    D::Error: std::fmt::Debug + Send,
175{
176    path: String,
177    client: Client<C, Body>,
178    proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
179    proxy_keys: Arc<tokio::sync::RwLock<HashMap<D::Key, usize>>>, // key -> index mapping
180    counter: Arc<AtomicUsize>,
181    discover: D,
182    strategy: LoadBalancingStrategy,
183    // Custom P2C balancer for strategies that need it
184    p2c_balancer: Option<Arc<CustomP2cBalancer<C>>>,
185}
186
187#[cfg(all(feature = "tls", not(feature = "native-tls")))]
188pub type StandardDiscoverableBalancedProxy<D> =
189    DiscoverableBalancedProxy<HttpsConnector<HttpConnector>, D>;
190#[cfg(feature = "native-tls")]
191pub type StandardDiscoverableBalancedProxy<D> =
192    DiscoverableBalancedProxy<NativeTlsHttpsConnector<HttpConnector>, D>;
193#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
194pub type StandardDiscoverableBalancedProxy<D> = DiscoverableBalancedProxy<HttpConnector, D>;
195
196impl<C, D> DiscoverableBalancedProxy<C, D>
197where
198    C: Connect + Clone + Send + Sync + 'static,
199    D: Discover + Clone + Send + Sync + 'static,
200    D::Service: Into<String> + Send,
201    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
202    D::Error: std::fmt::Debug + Send,
203{
204    /// Creates a new discoverable balanced proxy with a custom client and discover implementation.
205    /// Uses round-robin load balancing by default.
206    pub fn new_with_client<S>(path: S, client: Client<C, Body>, discover: D) -> Self
207    where
208        S: Into<String>,
209    {
210        Self::new_with_client_and_strategy(path, client, discover, LoadBalancingStrategy::default())
211    }
212
213    /// Creates a new discoverable balanced proxy with a custom client, discover implementation, and load balancing strategy.
214    pub fn new_with_client_and_strategy<S>(
215        path: S,
216        client: Client<C, Body>,
217        discover: D,
218        strategy: LoadBalancingStrategy,
219    ) -> Self
220    where
221        S: Into<String>,
222    {
223        let path = path.into();
224        let proxies_snapshot = Arc::new(std::sync::RwLock::new(Arc::new(Vec::new())));
225
226        // Create P2C balancer if needed
227        let p2c_balancer = match strategy {
228            LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
229                Some(Arc::new(CustomP2cBalancer::new(
230                    strategy,
231                    Arc::clone(&proxies_snapshot),
232                )))
233            }
234            LoadBalancingStrategy::RoundRobin => None,
235        };
236
237        Self {
238            path,
239            client,
240            proxies_snapshot,
241            proxy_keys: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
242            counter: Arc::new(AtomicUsize::new(0)),
243            discover: discover.clone(),
244            strategy,
245            p2c_balancer,
246        }
247    }
248
249    /// Get the base path this proxy is configured to handle
250    pub fn path(&self) -> &str {
251        &self.path
252    }
253
254    /// Get the load balancing strategy being used
255    pub fn strategy(&self) -> LoadBalancingStrategy {
256        self.strategy
257    }
258
259    /// Start the discovery process in the background.
260    /// This should be called once to begin monitoring for service changes.
261    pub async fn start_discovery(&mut self) {
262        let discover = self.discover.clone();
263        let proxies_snapshot = Arc::clone(&self.proxies_snapshot);
264        let proxy_keys = Arc::clone(&self.proxy_keys);
265        let client = self.client.clone();
266        let path = self.path.clone();
267
268        tokio::spawn(async move {
269            use futures_util::future::poll_fn;
270
271            let mut discover = Box::pin(discover);
272
273            loop {
274                let change_result =
275                    poll_fn(|cx: &mut Context<'_>| discover.as_mut().poll_discover(cx)).await;
276
277                match change_result {
278                    Some(Ok(change)) => match change {
279                        Change::Insert(key, service) => {
280                            let target: String = service.into();
281                            debug!("Discovered new service: {:?} -> {}", key, target);
282
283                            let proxy =
284                                ReverseProxy::new_with_client(path.clone(), target, client.clone());
285
286                            {
287                                let mut keys_guard = proxy_keys.write().await;
288
289                                // Get current snapshot and create new one with added service
290                                let current_snapshot = {
291                                    let snapshot_guard = proxies_snapshot.read().unwrap();
292                                    Arc::clone(&*snapshot_guard)
293                                };
294
295                                let mut new_proxies = (*current_snapshot).clone();
296                                let index = new_proxies.len();
297                                new_proxies.push(proxy);
298                                keys_guard.insert(key, index);
299
300                                // Atomically update the snapshot
301                                {
302                                    let mut snapshot_guard = proxies_snapshot.write().unwrap();
303                                    *snapshot_guard = Arc::new(new_proxies);
304                                }
305                            }
306                        }
307                        Change::Remove(key) => {
308                            debug!("Removing service: {:?}", key);
309
310                            {
311                                let mut keys_guard = proxy_keys.write().await;
312
313                                if let Some(index) = keys_guard.remove(&key) {
314                                    // Get current snapshot and create new one with removed service
315                                    let current_snapshot = {
316                                        let snapshot_guard = proxies_snapshot.read().unwrap();
317                                        Arc::clone(&*snapshot_guard)
318                                    };
319
320                                    let mut new_proxies = (*current_snapshot).clone();
321                                    new_proxies.remove(index);
322
323                                    // Update indices for all keys after the removed index
324                                    for (_, idx) in keys_guard.iter_mut() {
325                                        if *idx > index {
326                                            *idx -= 1;
327                                        }
328                                    }
329
330                                    // Atomically update the snapshot
331                                    {
332                                        let mut snapshot_guard = proxies_snapshot.write().unwrap();
333                                        *snapshot_guard = Arc::new(new_proxies);
334                                    }
335                                }
336                            }
337                        }
338                    },
339                    Some(Err(e)) => {
340                        error!("Discovery error: {:?}", e);
341                    }
342                    None => {
343                        warn!("Discovery stream ended");
344                        break;
345                    }
346                }
347            }
348        });
349    }
350
351    /// Get the current number of discovered services
352    pub async fn service_count(&self) -> usize {
353        let snapshot = {
354            let guard = self.proxies_snapshot.read().unwrap();
355            Arc::clone(&*guard)
356        };
357        snapshot.len()
358    }
359}
360
361impl<C, D> Service<axum::http::Request<Body>> for DiscoverableBalancedProxy<C, D>
362where
363    C: Connect + Clone + Send + Sync + 'static,
364    D: Discover + Clone + Send + Sync + 'static,
365    D::Service: Into<String> + Send,
366    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
367    D::Error: std::fmt::Debug + Send,
368{
369    type Response = axum::http::Response<Body>;
370    type Error = Infallible;
371    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
372
373    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
374        Poll::Ready(Ok(()))
375    }
376
377    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
378        // Get current proxy snapshot
379        let proxies_snapshot = {
380            let guard = self.proxies_snapshot.read().unwrap();
381            Arc::clone(&*guard)
382        };
383        let counter = Arc::clone(&self.counter);
384        let strategy = self.strategy;
385        let p2c_balancer = self.p2c_balancer.clone();
386
387        Box::pin(async move {
388            match strategy {
389                LoadBalancingStrategy::RoundRobin => {
390                    // Round-robin load balancing
391                    if proxies_snapshot.is_empty() {
392                        warn!("No upstream services available");
393                        Ok(axum::http::Response::builder()
394                            .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
395                            .body(Body::from("No upstream services available"))
396                            .unwrap())
397                    } else {
398                        let idx = counter.fetch_add(1, Ordering::Relaxed) % proxies_snapshot.len();
399                        let mut proxy = proxies_snapshot[idx].clone();
400                        proxy.call(req).await
401                    }
402                }
403                LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
404                    // Use the custom P2C balancer
405                    if let Some(balancer) = p2c_balancer {
406                        balancer.call_with_p2c(req).await
407                    } else {
408                        // Fallback to error if balancer is not available
409                        error!("P2C balancer not available for strategy {:?}", strategy);
410                        Ok(axum::http::Response::builder()
411                            .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
412                            .body(Body::from("P2C balancer not available"))
413                            .unwrap())
414                    }
415                }
416            }
417        })
418    }
419}
420
421/// Custom P2C load balancer that uses atomic operations for low-contention metrics tracking
422struct CustomP2cBalancer<C: Connect + Clone + Send + Sync + 'static> {
423    strategy: LoadBalancingStrategy,
424    proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
425    /// Metrics for each service, indexed by position in proxies_snapshot
426    /// We use Arc<Vec<Arc<ServiceMetrics>>> to allow concurrent access with minimal locking
427    metrics: Arc<std::sync::RwLock<Arc<Vec<Arc<ServiceMetrics>>>>>,
428}
429
430impl<C: Connect + Clone + Send + Sync + 'static> CustomP2cBalancer<C> {
431    fn new(
432        strategy: LoadBalancingStrategy,
433        proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
434    ) -> Self {
435        let initial_metrics = Arc::new(Vec::new());
436        Self {
437            strategy,
438            proxies_snapshot,
439            metrics: Arc::new(std::sync::RwLock::new(initial_metrics)),
440        }
441    }
442
443    async fn call_with_p2c(
444        &self,
445        req: axum::http::Request<Body>,
446    ) -> Result<axum::http::Response<Body>, Infallible> {
447        // Get current proxy snapshot
448        let proxies = {
449            let guard = self.proxies_snapshot.read().unwrap();
450            Arc::clone(&*guard)
451        };
452
453        if proxies.is_empty() {
454            return Ok(axum::http::Response::builder()
455                .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
456                .body(Body::from("No upstream services available"))
457                .unwrap());
458        }
459
460        // Ensure metrics vector is up to date
461        self.ensure_metrics_size(proxies.len());
462
463        // Get metrics snapshot
464        let metrics = {
465            let guard = self.metrics.read().unwrap();
466            Arc::clone(&*guard)
467        };
468
469        // P2C: Pick two random services and choose the one with lower load
470        let selected_idx = if proxies.len() == 1 {
471            0
472        } else {
473            let mut rng = rand::rng();
474            let idx1 = rng.random_range(0..proxies.len());
475            let idx2 = loop {
476                let i = rng.random_range(0..proxies.len());
477                if i != idx1 {
478                    break i;
479                }
480            };
481
482            // Compare load metrics based on strategy
483            let load1 = self.get_load(&metrics[idx1]);
484            let load2 = self.get_load(&metrics[idx2]);
485
486            if load1 <= load2 { idx1 } else { idx2 }
487        };
488
489        // Track request start for pending requests
490        let request_guard = if matches!(self.strategy, LoadBalancingStrategy::P2cPendingRequests) {
491            metrics[selected_idx]
492                .pending_requests
493                .fetch_add(1, Ordering::Relaxed);
494            Some(PendingRequestGuard {
495                metrics: Arc::clone(&metrics[selected_idx]),
496            })
497        } else {
498            None
499        };
500
501        // Record start time for latency tracking
502        let start = Instant::now();
503
504        // Make the actual request
505        let mut proxy = proxies[selected_idx].clone();
506        let result = proxy.call(req).await;
507
508        // Update latency metrics for EWMA strategy
509        if matches!(self.strategy, LoadBalancingStrategy::P2cPeakEwma) {
510            let latency = start.elapsed();
511            self.update_ewma(&metrics[selected_idx], latency);
512        }
513
514        // Request guard will decrement pending count when dropped
515        drop(request_guard);
516
517        result
518    }
519
520    fn ensure_metrics_size(&self, size: usize) {
521        let mut metrics_guard = self.metrics.write().unwrap();
522        let current_metrics = Arc::clone(&*metrics_guard);
523
524        if current_metrics.len() != size {
525            let mut new_metrics = Vec::with_capacity(size);
526
527            // Copy existing metrics
528            for (i, metric) in current_metrics.iter().enumerate() {
529                if i < size {
530                    new_metrics.push(Arc::clone(metric));
531                }
532            }
533
534            // Add new metrics if needed
535            while new_metrics.len() < size {
536                new_metrics.push(Arc::new(ServiceMetrics::new()));
537            }
538
539            *metrics_guard = Arc::new(new_metrics);
540        }
541    }
542
543    fn get_load(&self, metrics: &ServiceMetrics) -> u64 {
544        match self.strategy {
545            LoadBalancingStrategy::P2cPendingRequests => {
546                metrics.pending_requests.load(Ordering::Relaxed) as u64
547            }
548            LoadBalancingStrategy::P2cPeakEwma => {
549                // Apply decay based on time since last update
550                let last_update = *metrics.last_update.lock().unwrap();
551                let elapsed = last_update.elapsed();
552
553                // Simple exponential decay: reduce by ~50% every 5 seconds
554                let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
555                let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
556                (current as f64 * decay_factor) as u64
557            }
558            _ => unreachable!("CustomP2cBalancer should only be used with P2C strategies"),
559        }
560    }
561
562    fn update_ewma(&self, metrics: &ServiceMetrics, latency: Duration) {
563        let latency_micros = latency.as_micros() as u64;
564
565        // Update with exponential weighted moving average
566        // Using compare-and-swap loop for lock-free update
567        loop {
568            let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
569
570            // If this is the first measurement, just set it
571            if current == 0 {
572                if metrics
573                    .peak_ewma_micros
574                    .compare_exchange(0, latency_micros, Ordering::Relaxed, Ordering::Relaxed)
575                    .is_ok()
576                {
577                    *metrics.last_update.lock().unwrap() = Instant::now();
578                    break;
579                }
580                continue;
581            }
582
583            // Apply decay based on time since last update
584            let mut last_update_guard = metrics.last_update.lock().unwrap();
585            let elapsed = last_update_guard.elapsed();
586
587            // Decay factor: reduce by ~50% every 5 seconds
588            let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
589            let decayed_current = (current as f64 * decay_factor) as u64;
590
591            // Peak EWMA: take the maximum of the decayed value and the new measurement
592            let peak = decayed_current.max(latency_micros);
593
594            // EWMA with alpha = 0.25 (25% new value, 75% old value)
595            // This gives more weight to recent measurements
596            let ewma = ((peak as f64 * 0.25) + (decayed_current as f64 * 0.75)) as u64;
597
598            if metrics
599                .peak_ewma_micros
600                .compare_exchange(current, ewma, Ordering::Relaxed, Ordering::Relaxed)
601                .is_ok()
602            {
603                // Update last update time
604                *last_update_guard = Instant::now();
605                break;
606            }
607            drop(last_update_guard); // Release lock before retrying
608        }
609    }
610}
611
612/// RAII guard to decrement pending request count when request completes
613struct PendingRequestGuard {
614    metrics: Arc<ServiceMetrics>,
615}
616
617impl Drop for PendingRequestGuard {
618    fn drop(&mut self) {
619        self.metrics
620            .pending_requests
621            .fetch_sub(1, Ordering::Relaxed);
622    }
623}
624
625/// Metrics for a single service used in P2C load balancing
626#[derive(Debug)]
627struct ServiceMetrics {
628    /// Number of pending requests (for P2cPendingRequests strategy)
629    pending_requests: AtomicUsize,
630    /// Peak EWMA latency in microseconds (for P2cPeakEwma strategy)
631    peak_ewma_micros: AtomicU64,
632    /// Last update time for EWMA decay calculation
633    last_update: std::sync::Mutex<Instant>,
634}
635
636impl ServiceMetrics {
637    fn new() -> Self {
638        Self {
639            pending_requests: AtomicUsize::new(0),
640            peak_ewma_micros: AtomicU64::new(0),
641            last_update: std::sync::Mutex::new(Instant::now()),
642        }
643    }
644}