use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone, Copy, Default)]
pub enum LoadBalancingStrategy {
#[default]
RoundRobin,
Random,
LeastLoaded,
}
#[derive(Debug)]
pub struct RoundRobinBalancer {
current: AtomicUsize,
}
impl RoundRobinBalancer {
pub fn new() -> Self {
Self {
current: AtomicUsize::new(0),
}
}
pub fn next(&self, worker_count: usize) -> usize {
if worker_count == 0 {
return 0;
}
self.current.fetch_add(1, Ordering::SeqCst) % worker_count
}
}
impl Default for RoundRobinBalancer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct RandomBalancer;
impl RandomBalancer {
pub fn new() -> Self {
Self
}
pub fn next(&self, worker_count: usize) -> usize {
if worker_count == 0 {
return 0;
}
(rand::random::<u64>() % worker_count as u64) as usize
}
}
#[derive(Debug)]
pub struct LeastLoadedBalancer {
worker_loads: std::sync::RwLock<Arc<Vec<AtomicUsize>>>,
}
impl LeastLoadedBalancer {
pub fn new(worker_count: usize) -> Self {
let loads = (0..worker_count).map(|_| AtomicUsize::new(0)).collect();
Self {
worker_loads: std::sync::RwLock::new(Arc::new(loads)),
}
}
pub fn add_worker(&self) {
let mut current = self.worker_loads.write().unwrap();
let mut new_loads = Vec::with_capacity(current.len() + 1);
for load in current.iter() {
let current_value = load.load(Ordering::Relaxed);
new_loads.push(AtomicUsize::new(current_value));
}
new_loads.push(AtomicUsize::new(0));
*current = Arc::new(new_loads);
}
pub fn next(&self) -> usize {
let loads = self.worker_loads.read().unwrap();
if loads.is_empty() {
return 0;
}
let mut min_load = usize::MAX;
let mut min_index = 0;
for (i, load) in loads.iter().enumerate() {
let current_load = load.load(Ordering::Relaxed);
if current_load < min_load {
min_load = current_load;
min_index = i;
}
}
min_index
}
pub fn increment_load(&self, worker_index: usize) {
let loads = self.worker_loads.read().unwrap();
if let Some(load) = loads.get(worker_index) {
load.fetch_add(1, Ordering::SeqCst);
}
}
pub fn decrement_load(&self, worker_index: usize) {
let loads = self.worker_loads.read().unwrap();
if let Some(load) = loads.get(worker_index) {
load.fetch_sub(1, Ordering::SeqCst);
}
}
pub fn get_load(&self, worker_index: usize) -> usize {
let loads = self.worker_loads.read().unwrap();
loads
.get(worker_index)
.map(|load| load.load(Ordering::Relaxed))
.unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_round_robin_distribution() {
let balancer = RoundRobinBalancer::new();
assert_eq!(balancer.next(3), 0);
assert_eq!(balancer.next(3), 1);
assert_eq!(balancer.next(3), 2);
assert_eq!(balancer.next(3), 0);
assert_eq!(balancer.next(3), 1);
}
#[test]
fn test_round_robin_single_worker() {
let balancer = RoundRobinBalancer::new();
assert_eq!(balancer.next(1), 0);
assert_eq!(balancer.next(1), 0);
assert_eq!(balancer.next(1), 0);
}
#[test]
fn test_round_robin_zero_workers() {
let balancer = RoundRobinBalancer::new();
assert_eq!(balancer.next(0), 0);
}
#[test]
fn test_random_balancer() {
let balancer = RandomBalancer::new();
let idx = balancer.next(5);
assert!(idx < 5);
let idx = balancer.next(1);
assert_eq!(idx, 0);
}
#[test]
fn test_least_loaded_initial() {
let balancer = LeastLoadedBalancer::new(3);
assert_eq!(balancer.next(), 0);
}
#[test]
fn test_least_loaded_after_increment() {
let balancer = LeastLoadedBalancer::new(3);
balancer.increment_load(0);
balancer.increment_load(0);
assert_eq!(balancer.next(), 1);
}
#[test]
fn test_least_loaded_load_tracking() {
let balancer = LeastLoadedBalancer::new(3);
balancer.increment_load(0);
balancer.increment_load(0);
balancer.increment_load(1);
assert_eq!(balancer.get_load(0), 2);
assert_eq!(balancer.get_load(1), 1);
assert_eq!(balancer.get_load(2), 0);
balancer.decrement_load(0);
assert_eq!(balancer.get_load(0), 1);
}
#[test]
fn test_least_loaded_empty() {
let balancer = LeastLoadedBalancer::new(0);
assert_eq!(balancer.next(), 0);
}
}