axum_reverse_proxy/
balanced_proxy.rs

1use axum::body::Body;
2use hyper_util::client::legacy::{Client, connect::Connect};
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::time::{Duration, Instant};
8
9use tower::discover::{Change, Discover};
10use tracing::{debug, error, trace, warn};
11
12use crate::forward::{ProxyConnector, create_http_connector};
13use crate::proxy::ReverseProxy;
14
15// For custom P2C implementation
16use rand::Rng;
17
18/// Load balancing strategy for distributing requests across discovered services
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum LoadBalancingStrategy {
21    /// Simple round-robin distribution (default)
22    #[default]
23    RoundRobin,
24    /// Power of Two Choices with pending request count as load metric
25    P2cPendingRequests,
26    /// Power of Two Choices with peak EWMA latency as load metric
27    P2cPeakEwma,
28}
29
30#[derive(Clone)]
31pub struct BalancedProxy<C: Connect + Clone + Send + Sync + 'static> {
32    path: String,
33    proxies: Vec<ReverseProxy<C>>,
34    counter: Arc<AtomicUsize>,
35}
36
37pub type StandardBalancedProxy = BalancedProxy<ProxyConnector>;
38
39impl StandardBalancedProxy {
40    pub fn new<S>(path: S, targets: Vec<S>) -> Self
41    where
42        S: Into<String> + Clone,
43    {
44        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
45            .pool_idle_timeout(std::time::Duration::from_secs(60))
46            .pool_max_idle_per_host(32)
47            .retry_canceled_requests(true)
48            .set_host(true)
49            .build(create_http_connector());
50
51        Self::new_with_client(path, targets, client)
52    }
53}
54
55impl<C> BalancedProxy<C>
56where
57    C: Connect + Clone + Send + Sync + 'static,
58{
59    pub fn new_with_client<S>(path: S, targets: Vec<S>, client: Client<C, Body>) -> Self
60    where
61        S: Into<String> + Clone,
62    {
63        let path = path.into();
64        let proxies = targets
65            .into_iter()
66            .map(|t| ReverseProxy::new_with_client(path.clone(), t.into(), client.clone()))
67            .collect();
68
69        Self {
70            path,
71            proxies,
72            counter: Arc::new(AtomicUsize::new(0)),
73        }
74    }
75
76    pub fn path(&self) -> &str {
77        &self.path
78    }
79
80    fn next_proxy(&self) -> Option<ReverseProxy<C>> {
81        if self.proxies.is_empty() {
82            None
83        } else {
84            let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.proxies.len();
85            Some(self.proxies[idx].clone())
86        }
87    }
88}
89
90use std::{
91    future::Future,
92    pin::Pin,
93    task::{Context, Poll},
94};
95use tower::Service;
96
97impl<C> Service<axum::http::Request<Body>> for BalancedProxy<C>
98where
99    C: Connect + Clone + Send + Sync + 'static,
100{
101    type Response = axum::http::Response<Body>;
102    type Error = Infallible;
103    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
104
105    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        Poll::Ready(Ok(()))
107    }
108
109    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
110        if let Some(mut proxy) = self.next_proxy() {
111            trace!("balanced proxying via upstream {}", proxy.target());
112            Box::pin(async move { proxy.call(req).await })
113        } else {
114            warn!("No upstream services available");
115            Box::pin(async move {
116                Ok(axum::http::Response::builder()
117                    .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
118                    .body(Body::from("No upstream services available"))
119                    .unwrap())
120            })
121        }
122    }
123}
124
125/// A balanced proxy that supports dynamic service discovery.
126///
127/// This proxy uses the tower::discover trait to dynamically add and remove
128/// upstream services. Services are load-balanced using a configurable strategy.
129///
130/// Features:
131/// - High-performance request handling with minimal overhead
132/// - Atomic service updates that don't block ongoing requests
133/// - Efficient round-robin load balancing
134/// - Zero-downtime service discovery changes
135#[derive(Clone)]
136pub struct DiscoverableBalancedProxy<C, D>
137where
138    C: Connect + Clone + Send + Sync + 'static,
139    D: Discover + Clone + Send + Sync + 'static,
140    D::Service: Into<String> + Send,
141    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
142    D::Error: std::fmt::Debug + Send,
143{
144    path: String,
145    client: Client<C, Body>,
146    proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
147    proxy_keys: Arc<tokio::sync::RwLock<HashMap<D::Key, usize>>>, // key -> index mapping
148    counter: Arc<AtomicUsize>,
149    discover: D,
150    strategy: LoadBalancingStrategy,
151    // Custom P2C balancer for strategies that need it
152    p2c_balancer: Option<Arc<CustomP2cBalancer<C>>>,
153}
154
155pub type StandardDiscoverableBalancedProxy<D> = DiscoverableBalancedProxy<ProxyConnector, D>;
156
157impl<C, D> DiscoverableBalancedProxy<C, D>
158where
159    C: Connect + Clone + Send + Sync + 'static,
160    D: Discover + Clone + Send + Sync + 'static,
161    D::Service: Into<String> + Send,
162    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
163    D::Error: std::fmt::Debug + Send,
164{
165    /// Creates a new discoverable balanced proxy with a custom client and discover implementation.
166    /// Uses round-robin load balancing by default.
167    pub fn new_with_client<S>(path: S, client: Client<C, Body>, discover: D) -> Self
168    where
169        S: Into<String>,
170    {
171        Self::new_with_client_and_strategy(path, client, discover, LoadBalancingStrategy::default())
172    }
173
174    /// Creates a new discoverable balanced proxy with a custom client, discover implementation, and load balancing strategy.
175    pub fn new_with_client_and_strategy<S>(
176        path: S,
177        client: Client<C, Body>,
178        discover: D,
179        strategy: LoadBalancingStrategy,
180    ) -> Self
181    where
182        S: Into<String>,
183    {
184        let path = path.into();
185        let proxies_snapshot = Arc::new(std::sync::RwLock::new(Arc::new(Vec::new())));
186
187        // Create P2C balancer if needed
188        let p2c_balancer = match strategy {
189            LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
190                Some(Arc::new(CustomP2cBalancer::new(
191                    strategy,
192                    Arc::clone(&proxies_snapshot),
193                )))
194            }
195            LoadBalancingStrategy::RoundRobin => None,
196        };
197
198        Self {
199            path,
200            client,
201            proxies_snapshot,
202            proxy_keys: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
203            counter: Arc::new(AtomicUsize::new(0)),
204            discover: discover.clone(),
205            strategy,
206            p2c_balancer,
207        }
208    }
209
210    /// Get the base path this proxy is configured to handle
211    pub fn path(&self) -> &str {
212        &self.path
213    }
214
215    /// Get the load balancing strategy being used
216    pub fn strategy(&self) -> LoadBalancingStrategy {
217        self.strategy
218    }
219
220    /// Start the discovery process in the background.
221    /// This should be called once to begin monitoring for service changes.
222    pub async fn start_discovery(&mut self) {
223        let discover = self.discover.clone();
224        let proxies_snapshot = Arc::clone(&self.proxies_snapshot);
225        let proxy_keys = Arc::clone(&self.proxy_keys);
226        let client = self.client.clone();
227        let path = self.path.clone();
228
229        tokio::spawn(async move {
230            use futures_util::future::poll_fn;
231
232            let mut discover = Box::pin(discover);
233
234            loop {
235                let change_result =
236                    poll_fn(|cx: &mut Context<'_>| discover.as_mut().poll_discover(cx)).await;
237
238                match change_result {
239                    Some(Ok(change)) => match change {
240                        Change::Insert(key, service) => {
241                            let target: String = service.into();
242                            debug!("Discovered new service: {:?} -> {}", key, target);
243
244                            let proxy =
245                                ReverseProxy::new_with_client(path.clone(), target, client.clone());
246
247                            {
248                                let mut keys_guard = proxy_keys.write().await;
249
250                                // Get current snapshot and create new one with added service
251                                let current_snapshot = {
252                                    let snapshot_guard = proxies_snapshot.read().unwrap();
253                                    Arc::clone(&*snapshot_guard)
254                                };
255
256                                let mut new_proxies = (*current_snapshot).clone();
257                                let index = new_proxies.len();
258                                new_proxies.push(proxy);
259                                keys_guard.insert(key, index);
260
261                                // Atomically update the snapshot
262                                {
263                                    let mut snapshot_guard = proxies_snapshot.write().unwrap();
264                                    *snapshot_guard = Arc::new(new_proxies);
265                                }
266                            }
267                        }
268                        Change::Remove(key) => {
269                            debug!("Removing service: {:?}", key);
270
271                            {
272                                let mut keys_guard = proxy_keys.write().await;
273
274                                if let Some(index) = keys_guard.remove(&key) {
275                                    // Get current snapshot and create new one with removed service
276                                    let current_snapshot = {
277                                        let snapshot_guard = proxies_snapshot.read().unwrap();
278                                        Arc::clone(&*snapshot_guard)
279                                    };
280
281                                    let mut new_proxies = (*current_snapshot).clone();
282                                    new_proxies.remove(index);
283
284                                    // Update indices for all keys after the removed index
285                                    for (_, idx) in keys_guard.iter_mut() {
286                                        if *idx > index {
287                                            *idx -= 1;
288                                        }
289                                    }
290
291                                    // Atomically update the snapshot
292                                    {
293                                        let mut snapshot_guard = proxies_snapshot.write().unwrap();
294                                        *snapshot_guard = Arc::new(new_proxies);
295                                    }
296                                }
297                            }
298                        }
299                    },
300                    Some(Err(e)) => {
301                        error!("Discovery error: {:?}", e);
302                    }
303                    None => {
304                        warn!("Discovery stream ended");
305                        break;
306                    }
307                }
308            }
309        });
310    }
311
312    /// Get the current number of discovered services
313    pub async fn service_count(&self) -> usize {
314        let snapshot = {
315            let guard = self.proxies_snapshot.read().unwrap();
316            Arc::clone(&*guard)
317        };
318        snapshot.len()
319    }
320}
321
322impl<C, D> Service<axum::http::Request<Body>> for DiscoverableBalancedProxy<C, D>
323where
324    C: Connect + Clone + Send + Sync + 'static,
325    D: Discover + Clone + Send + Sync + 'static,
326    D::Service: Into<String> + Send,
327    D::Key: Clone + std::fmt::Debug + Send + Sync + std::hash::Hash,
328    D::Error: std::fmt::Debug + Send,
329{
330    type Response = axum::http::Response<Body>;
331    type Error = Infallible;
332    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
333
334    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
335        Poll::Ready(Ok(()))
336    }
337
338    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
339        // Get current proxy snapshot
340        let proxies_snapshot = {
341            let guard = self.proxies_snapshot.read().unwrap();
342            Arc::clone(&*guard)
343        };
344        let counter = Arc::clone(&self.counter);
345        let strategy = self.strategy;
346        let p2c_balancer = self.p2c_balancer.clone();
347
348        Box::pin(async move {
349            match strategy {
350                LoadBalancingStrategy::RoundRobin => {
351                    // Round-robin load balancing
352                    if proxies_snapshot.is_empty() {
353                        warn!("No upstream services available");
354                        Ok(axum::http::Response::builder()
355                            .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
356                            .body(Body::from("No upstream services available"))
357                            .unwrap())
358                    } else {
359                        let idx = counter.fetch_add(1, Ordering::Relaxed) % proxies_snapshot.len();
360                        let mut proxy = proxies_snapshot[idx].clone();
361                        proxy.call(req).await
362                    }
363                }
364                LoadBalancingStrategy::P2cPendingRequests | LoadBalancingStrategy::P2cPeakEwma => {
365                    // Use the custom P2C balancer
366                    if let Some(balancer) = p2c_balancer {
367                        balancer.call_with_p2c(req).await
368                    } else {
369                        // Fallback to error if balancer is not available
370                        error!("P2C balancer not available for strategy {:?}", strategy);
371                        Ok(axum::http::Response::builder()
372                            .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
373                            .body(Body::from("P2C balancer not available"))
374                            .unwrap())
375                    }
376                }
377            }
378        })
379    }
380}
381
382/// Custom P2C load balancer that uses atomic operations for low-contention metrics tracking
383struct CustomP2cBalancer<C: Connect + Clone + Send + Sync + 'static> {
384    strategy: LoadBalancingStrategy,
385    proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
386    /// Metrics for each service, indexed by position in proxies_snapshot
387    /// We use Arc<Vec<Arc<ServiceMetrics>>> to allow concurrent access with minimal locking
388    metrics: Arc<std::sync::RwLock<Arc<Vec<Arc<ServiceMetrics>>>>>,
389}
390
391impl<C: Connect + Clone + Send + Sync + 'static> CustomP2cBalancer<C> {
392    fn new(
393        strategy: LoadBalancingStrategy,
394        proxies_snapshot: Arc<std::sync::RwLock<Arc<Vec<ReverseProxy<C>>>>>,
395    ) -> Self {
396        let initial_metrics = Arc::new(Vec::new());
397        Self {
398            strategy,
399            proxies_snapshot,
400            metrics: Arc::new(std::sync::RwLock::new(initial_metrics)),
401        }
402    }
403
404    async fn call_with_p2c(
405        &self,
406        req: axum::http::Request<Body>,
407    ) -> Result<axum::http::Response<Body>, Infallible> {
408        // Get current proxy snapshot
409        let proxies = {
410            let guard = self.proxies_snapshot.read().unwrap();
411            Arc::clone(&*guard)
412        };
413
414        if proxies.is_empty() {
415            return Ok(axum::http::Response::builder()
416                .status(axum::http::StatusCode::SERVICE_UNAVAILABLE)
417                .body(Body::from("No upstream services available"))
418                .unwrap());
419        }
420
421        // Ensure metrics vector is up to date
422        self.ensure_metrics_size(proxies.len());
423
424        // Get metrics snapshot
425        let metrics = {
426            let guard = self.metrics.read().unwrap();
427            Arc::clone(&*guard)
428        };
429
430        // P2C: Pick two random services and choose the one with lower load
431        let selected_idx = if proxies.len() == 1 {
432            0
433        } else {
434            let mut rng = rand::rng();
435            let idx1 = rng.random_range(0..proxies.len());
436            let idx2 = loop {
437                let i = rng.random_range(0..proxies.len());
438                if i != idx1 {
439                    break i;
440                }
441            };
442
443            // Compare load metrics based on strategy
444            let load1 = self.get_load(&metrics[idx1]);
445            let load2 = self.get_load(&metrics[idx2]);
446
447            if load1 <= load2 { idx1 } else { idx2 }
448        };
449
450        // Track request start for pending requests
451        let request_guard = if matches!(self.strategy, LoadBalancingStrategy::P2cPendingRequests) {
452            metrics[selected_idx]
453                .pending_requests
454                .fetch_add(1, Ordering::Relaxed);
455            Some(PendingRequestGuard {
456                metrics: Arc::clone(&metrics[selected_idx]),
457            })
458        } else {
459            None
460        };
461
462        // Record start time for latency tracking
463        let start = Instant::now();
464
465        // Make the actual request
466        let mut proxy = proxies[selected_idx].clone();
467        let result = proxy.call(req).await;
468
469        // Update latency metrics for EWMA strategy
470        if matches!(self.strategy, LoadBalancingStrategy::P2cPeakEwma) {
471            let latency = start.elapsed();
472            self.update_ewma(&metrics[selected_idx], latency);
473        }
474
475        // Request guard will decrement pending count when dropped
476        drop(request_guard);
477
478        result
479    }
480
481    fn ensure_metrics_size(&self, size: usize) {
482        let mut metrics_guard = self.metrics.write().unwrap();
483        let current_metrics = Arc::clone(&*metrics_guard);
484
485        if current_metrics.len() != size {
486            let mut new_metrics = Vec::with_capacity(size);
487
488            // Copy existing metrics
489            for (i, metric) in current_metrics.iter().enumerate() {
490                if i < size {
491                    new_metrics.push(Arc::clone(metric));
492                }
493            }
494
495            // Add new metrics if needed
496            while new_metrics.len() < size {
497                new_metrics.push(Arc::new(ServiceMetrics::new()));
498            }
499
500            *metrics_guard = Arc::new(new_metrics);
501        }
502    }
503
504    fn get_load(&self, metrics: &ServiceMetrics) -> u64 {
505        match self.strategy {
506            LoadBalancingStrategy::P2cPendingRequests => {
507                metrics.pending_requests.load(Ordering::Relaxed) as u64
508            }
509            LoadBalancingStrategy::P2cPeakEwma => {
510                // Apply decay based on time since last update
511                let last_update = *metrics.last_update.lock().unwrap();
512                let elapsed = last_update.elapsed();
513
514                // Simple exponential decay: reduce by ~50% every 5 seconds
515                let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
516                let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
517                (current as f64 * decay_factor) as u64
518            }
519            _ => unreachable!("CustomP2cBalancer should only be used with P2C strategies"),
520        }
521    }
522
523    fn update_ewma(&self, metrics: &ServiceMetrics, latency: Duration) {
524        let latency_micros = latency.as_micros() as u64;
525
526        // Update with exponential weighted moving average
527        // Using compare-and-swap loop for lock-free update
528        loop {
529            let current = metrics.peak_ewma_micros.load(Ordering::Relaxed);
530
531            // If this is the first measurement, just set it
532            if current == 0 {
533                if metrics
534                    .peak_ewma_micros
535                    .compare_exchange(0, latency_micros, Ordering::Relaxed, Ordering::Relaxed)
536                    .is_ok()
537                {
538                    *metrics.last_update.lock().unwrap() = Instant::now();
539                    break;
540                }
541                continue;
542            }
543
544            // Apply decay based on time since last update
545            let mut last_update_guard = metrics.last_update.lock().unwrap();
546            let elapsed = last_update_guard.elapsed();
547
548            // Decay factor: reduce by ~50% every 5 seconds
549            let decay_factor = (-elapsed.as_secs_f64() / 5.0).exp();
550            let decayed_current = (current as f64 * decay_factor) as u64;
551
552            // Peak EWMA: take the maximum of the decayed value and the new measurement
553            let peak = decayed_current.max(latency_micros);
554
555            // EWMA with alpha = 0.25 (25% new value, 75% old value)
556            // This gives more weight to recent measurements
557            let ewma = ((peak as f64 * 0.25) + (decayed_current as f64 * 0.75)) as u64;
558
559            if metrics
560                .peak_ewma_micros
561                .compare_exchange(current, ewma, Ordering::Relaxed, Ordering::Relaxed)
562                .is_ok()
563            {
564                // Update last update time
565                *last_update_guard = Instant::now();
566                break;
567            }
568            drop(last_update_guard); // Release lock before retrying
569        }
570    }
571}
572
573/// RAII guard to decrement pending request count when request completes
574struct PendingRequestGuard {
575    metrics: Arc<ServiceMetrics>,
576}
577
578impl Drop for PendingRequestGuard {
579    fn drop(&mut self) {
580        self.metrics
581            .pending_requests
582            .fetch_sub(1, Ordering::Relaxed);
583    }
584}
585
586/// Metrics for a single service used in P2C load balancing
587#[derive(Debug)]
588struct ServiceMetrics {
589    /// Number of pending requests (for P2cPendingRequests strategy)
590    pending_requests: AtomicUsize,
591    /// Peak EWMA latency in microseconds (for P2cPeakEwma strategy)
592    peak_ewma_micros: AtomicU64,
593    /// Last update time for EWMA decay calculation
594    last_update: std::sync::Mutex<Instant>,
595}
596
597impl ServiceMetrics {
598    fn new() -> Self {
599        Self {
600            pending_requests: AtomicUsize::new(0),
601            peak_ewma_micros: AtomicU64::new(0),
602            last_update: std::sync::Mutex::new(Instant::now()),
603        }
604    }
605}