use http::Request;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tower::{Layer, Service};
use super::{ServiceError, SimBody};
#[derive(Clone)]
pub struct DesLoadBalancerLayer {
strategy: DesLoadBalanceStrategy,
}
impl DesLoadBalancerLayer {
pub fn new(strategy: DesLoadBalanceStrategy) -> Self {
Self { strategy }
}
pub fn round_robin() -> Self {
Self::new(DesLoadBalanceStrategy::RoundRobin)
}
pub fn random() -> Self {
Self::new(DesLoadBalanceStrategy::Random)
}
}
impl<S> Layer<S> for DesLoadBalancerLayer
where
S: Clone,
{
type Service = DesLoadBalancer<S>;
fn layer(&self, inner: S) -> Self::Service {
DesLoadBalancer::new(vec![inner], self.strategy.clone())
}
}
pub struct DesLoadBalancer<S> {
services: Vec<S>,
strategy: DesLoadBalanceStrategy,
current_index: AtomicUsize,
rng: Arc<Mutex<ChaCha8Rng>>,
ready_index: Option<usize>,
}
impl<S: Clone> Clone for DesLoadBalancer<S> {
fn clone(&self) -> Self {
Self {
services: self.services.clone(),
strategy: self.strategy.clone(),
current_index: AtomicUsize::new(self.current_index.load(Ordering::Relaxed)),
rng: Arc::new(Mutex::new(self.rng.lock().unwrap().clone())),
ready_index: None,
}
}
}
#[derive(Debug, Clone)]
pub enum DesLoadBalanceStrategy {
RoundRobin,
Random,
LeastConnections,
}
impl<S> DesLoadBalancer<S> {
pub fn new(services: Vec<S>, strategy: DesLoadBalanceStrategy) -> Self {
Self::with_seed(services, strategy, 42) }
pub fn with_seed(services: Vec<S>, strategy: DesLoadBalanceStrategy, seed: u64) -> Self {
Self {
services,
strategy,
current_index: AtomicUsize::new(0),
rng: Arc::new(Mutex::new(ChaCha8Rng::seed_from_u64(seed))),
ready_index: None,
}
}
pub fn round_robin(services: Vec<S>) -> Self {
Self::new(services, DesLoadBalanceStrategy::RoundRobin)
}
pub fn random(services: Vec<S>) -> Self {
Self::new(services, DesLoadBalanceStrategy::Random)
}
pub fn random_with_seed(services: Vec<S>, seed: u64) -> Self {
Self::with_seed(services, DesLoadBalanceStrategy::Random, seed)
}
fn peek_start_index(&self) -> usize {
if self.services.is_empty() {
return 0;
}
match self.strategy {
DesLoadBalanceStrategy::RoundRobin | DesLoadBalanceStrategy::LeastConnections => {
self.current_index.load(Ordering::Relaxed) % self.services.len()
}
DesLoadBalanceStrategy::Random => {
let rng = self.rng.lock().unwrap();
let mut cloned = rng.clone();
cloned.gen_range(0..self.services.len())
}
}
}
fn commit_selected_index(&self, idx: usize) {
if self.services.is_empty() {
return;
}
match self.strategy {
DesLoadBalanceStrategy::RoundRobin | DesLoadBalanceStrategy::LeastConnections => {
self.current_index
.store((idx + 1) % self.services.len(), Ordering::Relaxed);
}
DesLoadBalanceStrategy::Random => {
let mut rng = self.rng.lock().unwrap();
let _ = rng.gen_range(0..self.services.len());
}
}
}
fn select_service(&mut self) -> usize {
if let Some(idx) = self.ready_index.take() {
self.commit_selected_index(idx);
return idx;
}
let start = self.peek_start_index();
self.commit_selected_index(start);
start
}
}
impl<S, ReqBody> Service<Request<ReqBody>> for DesLoadBalancer<S>
where
S: Service<Request<ReqBody>, Response = http::Response<SimBody>, Error = ServiceError>,
ReqBody: Clone,
{
type Response = S::Response;
type Error = ServiceError;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.services.is_empty() {
return Poll::Ready(Err(ServiceError::Internal(
"load balancer has no backends".to_string(),
)));
}
if let Some(idx) = self.ready_index {
match self.services[idx].poll_ready(cx) {
Poll::Ready(Ok(())) => return Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
self.ready_index = None;
}
}
}
let start = self.peek_start_index();
for offset in 0..self.services.len() {
let idx = (start + offset) % self.services.len();
match self.services[idx].poll_ready(cx) {
Poll::Ready(Ok(())) => {
self.ready_index = Some(idx);
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {}
}
}
Poll::Pending
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let index = self.select_service();
self.services[index].call(req)
}
}