pub mod config;
pub mod error;
pub mod events;
pub mod selection;
pub use config::WeightedRouterBuilder;
pub use error::WeightedRouterError;
pub use events::RouterEvent;
pub use selection::SelectionStrategy;
use config::RouterConfig;
use selection::WeightedSelector;
use std::task::{Context, Poll};
use tower_service::Service;
pub struct WeightedRouter<S> {
backends: Vec<(S, u32)>,
selector: WeightedSelector,
config: RouterConfig,
}
impl<S> WeightedRouter<S> {
pub fn builder() -> WeightedRouterBuilder<S> {
WeightedRouterBuilder::new()
}
pub(crate) fn new(backends: Vec<(S, u32)>, config: RouterConfig) -> Self {
let weights: Vec<u32> = backends.iter().map(|(_, w)| *w).collect();
let selector = WeightedSelector::new(&weights, config.strategy);
Self {
backends,
selector,
config,
}
}
pub fn backend_count(&self) -> usize {
self.backends.len()
}
pub fn weights(&self) -> Vec<u32> {
self.backends.iter().map(|(_, w)| *w).collect()
}
pub fn name(&self) -> &str {
&self.config.name
}
}
impl<S: Clone> Clone for WeightedRouter<S> {
fn clone(&self) -> Self {
Self {
backends: self.backends.clone(),
selector: self.selector.clone(),
config: self.config.clone(),
}
}
}
impl<S, Request> Service<Request> for WeightedRouter<S>
where
S: Service<Request>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
for (svc, _) in &mut self.backends {
match svc.poll_ready(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request) -> Self::Future {
let idx = self.selector.select();
let (svc, weight) = &mut self.backends[idx];
#[cfg(feature = "metrics")]
{
let labels = [
("router", self.config.name.clone()),
("backend", idx.to_string()),
];
metrics::counter!("router_requests_routed_total", &labels).increment(1);
}
#[cfg(feature = "tracing")]
{
tracing::debug!(
router = %self.config.name,
backend_index = idx,
backend_weight = *weight,
"routing request to backend"
);
}
self.config
.event_listeners
.emit(&RouterEvent::RequestRouted {
pattern_name: self.config.name.clone(),
timestamp: std::time::Instant::now(),
backend_index: idx,
backend_weight: *weight,
});
svc.call(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tower::util::BoxService;
use tower::ServiceExt;
type BoxSvc = BoxService<(), &'static str, TestError>;
#[derive(Clone, Debug)]
struct TestError;
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "test error")
}
}
impl std::error::Error for TestError {}
fn counting_svc(counter: Arc<AtomicUsize>, label: &'static str) -> BoxSvc {
BoxService::new(tower::service_fn(move |_: ()| {
let c = Arc::clone(&counter);
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok::<_, TestError>(label)
}
}))
}
#[tokio::test]
async fn routes_by_weight_deterministic() {
let count_a = Arc::new(AtomicUsize::new(0));
let count_b = Arc::new(AtomicUsize::new(0));
let mut router = WeightedRouter::builder()
.route(counting_svc(Arc::clone(&count_a), "a"), 80)
.route(counting_svc(Arc::clone(&count_b), "b"), 20)
.build();
for _ in 0..100 {
let _ = router.ready().await.unwrap().call(()).await;
}
assert_eq!(count_a.load(Ordering::SeqCst), 80);
assert_eq!(count_b.load(Ordering::SeqCst), 20);
}
#[tokio::test]
async fn single_backend_gets_all_traffic() {
let count = Arc::new(AtomicUsize::new(0));
let mut router = WeightedRouter::builder()
.route(counting_svc(Arc::clone(&count), "ok"), 1)
.build();
for _ in 0..50 {
let _ = router.ready().await.unwrap().call(()).await;
}
assert_eq!(count.load(Ordering::SeqCst), 50);
}
#[tokio::test]
async fn three_backends() {
let counts: Vec<Arc<AtomicUsize>> = (0..3).map(|_| Arc::new(AtomicUsize::new(0))).collect();
let mut router = WeightedRouter::builder()
.route(counting_svc(Arc::clone(&counts[0]), "0"), 50)
.route(counting_svc(Arc::clone(&counts[1]), "1"), 30)
.route(counting_svc(Arc::clone(&counts[2]), "2"), 20)
.build();
for _ in 0..100 {
let _ = router.ready().await.unwrap().call(()).await;
}
assert_eq!(counts[0].load(Ordering::SeqCst), 50);
assert_eq!(counts[1].load(Ordering::SeqCst), 30);
assert_eq!(counts[2].load(Ordering::SeqCst), 20);
}
#[tokio::test]
async fn error_propagates_from_backend() {
let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
Err::<&str, _>(TestError)
}));
let mut router = WeightedRouter::builder().route(svc, 1).build();
let result = router.ready().await.unwrap().call(()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn event_listener_fires() {
let routed_count = Arc::new(AtomicUsize::new(0));
let rc = Arc::clone(&routed_count);
let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
Ok::<_, TestError>("ok")
}));
let mut router = WeightedRouter::builder()
.route(svc, 1)
.on_request_routed(move |_idx, _weight| {
rc.fetch_add(1, Ordering::SeqCst);
})
.build();
for _ in 0..5 {
let _ = router.ready().await.unwrap().call(()).await;
}
assert_eq!(routed_count.load(Ordering::SeqCst), 5);
}
#[tokio::test]
async fn builder_accessors() {
let router = WeightedRouter::builder()
.name("canary")
.route(counting_svc(Arc::new(AtomicUsize::new(0)), "a"), 90)
.route(counting_svc(Arc::new(AtomicUsize::new(0)), "b"), 10)
.build();
assert_eq!(router.backend_count(), 2);
assert_eq!(router.weights(), vec![90, 10]);
assert_eq!(router.name(), "canary");
}
#[test]
#[should_panic(expected = "at least one backend is required")]
fn panics_on_no_backends() {
let _router: WeightedRouter<BoxSvc> = WeightedRouter::builder().build();
}
#[test]
#[should_panic(expected = "weight 0")]
fn panics_on_zero_weight() {
let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
Ok::<_, TestError>("ok")
}));
let _router = WeightedRouter::builder().route(svc, 0).build();
}
#[tokio::test]
async fn random_strategy_converges() {
let count_a = Arc::new(AtomicUsize::new(0));
let count_b = Arc::new(AtomicUsize::new(0));
let mut router = WeightedRouter::builder()
.route(counting_svc(Arc::clone(&count_a), "a"), 80)
.route(counting_svc(Arc::clone(&count_b), "b"), 20)
.random()
.build();
let total = 10_000;
for _ in 0..total {
let _ = router.ready().await.unwrap().call(()).await;
}
let a = count_a.load(Ordering::SeqCst);
let ratio = a as f64 / total as f64;
assert!(
(0.75..=0.85).contains(&ratio),
"expected ~80%, got {:.1}%",
ratio * 100.0
);
}
}