use crate::error::{NumRs2Error, Result};
use std::collections::VecDeque;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BalancingStrategy {
RoundRobin,
#[default]
LeastLoaded,
WeightedCapacity,
Adaptive,
WorkStealing,
NumaAware,
}
#[derive(Debug, Clone, Default)]
pub struct WorkloadMetrics {
pub active_tasks: u64,
pub total_throughput: f64,
pub avg_response_time: Duration,
pub cpu_utilization: Vec<f64>,
pub memory_usage: Vec<f64>,
pub queue_lengths: Vec<usize>,
pub load_imbalance: f64,
pub work_steals: u64,
pub cache_miss_rate: f64,
}
impl WorkloadMetrics {
pub fn load_distribution_cv(&self) -> f64 {
if self.queue_lengths.is_empty() {
return 0.0;
}
let mean =
self.queue_lengths.iter().sum::<usize>() as f64 / self.queue_lengths.len() as f64;
if mean == 0.0 {
return 0.0;
}
let variance = self
.queue_lengths
.iter()
.map(|&x| {
let diff = x as f64 - mean;
diff * diff
})
.sum::<f64>()
/ self.queue_lengths.len() as f64;
let std_dev = variance.sqrt();
std_dev / mean
}
pub fn is_balanced(&self, threshold: f64) -> bool {
self.load_imbalance < threshold
}
pub fn most_loaded_worker(&self) -> Option<usize> {
self.queue_lengths
.iter()
.enumerate()
.max_by_key(|(_, &len)| len)
.map(|(idx, _)| idx)
}
pub fn least_loaded_worker(&self) -> Option<usize> {
self.queue_lengths
.iter()
.enumerate()
.min_by_key(|(_, &len)| len)
.map(|(idx, _)| idx)
}
}
#[derive(Debug)]
struct WorkerState {
#[allow(dead_code)]
id: usize,
queue_length: usize,
cpu_utilization: f64,
memory_usage: f64,
tasks_completed: u64,
#[allow(dead_code)]
total_execution_time: Duration,
last_update: Instant,
capacity_weight: f64,
numa_node: Option<usize>,
}
impl WorkerState {
fn new(id: usize, numa_node: Option<usize>) -> Self {
Self {
id,
queue_length: 0,
cpu_utilization: 0.0,
memory_usage: 0.0,
tasks_completed: 0,
total_execution_time: Duration::ZERO,
last_update: Instant::now(),
capacity_weight: 1.0,
numa_node,
}
}
fn throughput(&self) -> f64 {
let elapsed = self.last_update.elapsed();
let elapsed_secs = elapsed.as_secs_f64();
if elapsed_secs < 0.001 {
0.0
} else {
self.tasks_completed as f64 / elapsed_secs
}
}
#[allow(dead_code)]
fn efficiency(&self) -> f64 {
if self.cpu_utilization == 0.0 {
0.0
} else {
self.throughput() / self.cpu_utilization
}
}
fn load_factor(&self) -> f64 {
let queue_factor = self.queue_length as f64 / 100.0; let cpu_factor = self.cpu_utilization;
let memory_factor = self.memory_usage;
(queue_factor * 0.4) + (cpu_factor * 0.4) + (memory_factor * 0.2)
}
}
pub struct LoadBalancer {
strategy: RwLock<BalancingStrategy>,
workers: Arc<RwLock<Vec<WorkerState>>>,
#[allow(dead_code)]
metrics_history: Mutex<VecDeque<WorkloadMetrics>>,
next_worker: Mutex<usize>, #[allow(dead_code)]
rebalance_threshold: f64,
adaptation_window: Duration,
last_strategy_change: Mutex<Instant>,
}
impl LoadBalancer {
pub fn new(strategy: BalancingStrategy, num_workers: usize) -> Result<Self> {
let mut workers = Vec::new();
for i in 0..num_workers {
let numa_node = Self::detect_numa_node(i);
workers.push(WorkerState::new(i, numa_node));
}
Ok(Self {
strategy: RwLock::new(strategy),
workers: Arc::new(RwLock::new(workers)),
metrics_history: Mutex::new(VecDeque::with_capacity(100)),
next_worker: Mutex::new(0),
rebalance_threshold: 0.3, adaptation_window: Duration::from_secs(10),
last_strategy_change: Mutex::new(Instant::now()),
})
}
pub fn select_worker(&self) -> Result<usize> {
let strategy = *self.strategy.read().expect("lock should not be poisoned");
match strategy {
BalancingStrategy::RoundRobin => self.round_robin_selection(),
BalancingStrategy::LeastLoaded => self.least_loaded_selection(),
BalancingStrategy::WeightedCapacity => self.weighted_capacity_selection(),
BalancingStrategy::Adaptive => self.adaptive_selection(),
BalancingStrategy::WorkStealing => self.work_stealing_selection(),
BalancingStrategy::NumaAware => self.numa_aware_selection(),
}
}
pub fn update_worker_metrics(
&self,
worker_id: usize,
queue_length: usize,
cpu_utilization: f64,
memory_usage: f64,
) -> Result<()> {
{
let mut workers = self.workers.write().expect("lock should not be poisoned");
if let Some(worker) = workers.get_mut(worker_id) {
worker.queue_length = queue_length;
worker.cpu_utilization = cpu_utilization;
worker.memory_usage = memory_usage;
worker.last_update = Instant::now();
} else {
return Err(NumRs2Error::IndexError(format!(
"Invalid worker ID: {}",
worker_id
)));
}
}
#[cfg(not(test))]
{
if self.should_rebalance()? {
self.rebalance_workload()?;
}
}
Ok(())
}
pub fn current_metrics(&self) -> WorkloadMetrics {
let workers = self.workers.read().expect("lock should not be poisoned");
let active_tasks = workers.iter().map(|w| w.queue_length as u64).sum();
let total_throughput = if cfg!(test) {
workers
.iter()
.map(|w| w.tasks_completed as f64)
.sum::<f64>()
/ 10.0
} else {
workers.iter().map(|w| w.throughput()).sum()
};
let queue_lengths: Vec<usize> = workers.iter().map(|w| w.queue_length).collect();
let cpu_utilization: Vec<f64> = workers.iter().map(|w| w.cpu_utilization).collect();
let memory_usage: Vec<f64> = workers.iter().map(|w| w.memory_usage).collect();
let load_imbalance = self.calculate_load_imbalance(&workers);
WorkloadMetrics {
active_tasks,
total_throughput,
avg_response_time: Duration::from_millis(100), cpu_utilization,
memory_usage,
queue_lengths,
load_imbalance,
work_steals: 0, cache_miss_rate: 0.05, }
}
pub fn num_workers(&self) -> usize {
self.workers
.read()
.expect("lock should not be poisoned")
.len()
}
pub fn set_strategy(&self, new_strategy: BalancingStrategy) {
let mut strategy = self.strategy.write().expect("lock should not be poisoned");
*strategy = new_strategy;
*self
.last_strategy_change
.lock()
.expect("lock should not be poisoned") = Instant::now();
}
pub fn current_strategy(&self) -> BalancingStrategy {
*self.strategy.read().expect("lock should not be poisoned")
}
fn round_robin_selection(&self) -> Result<usize> {
let mut next = self
.next_worker
.lock()
.expect("lock should not be poisoned");
let workers = self.workers.read().expect("lock should not be poisoned");
let worker_id = *next;
*next = (*next + 1) % workers.len();
Ok(worker_id)
}
fn least_loaded_selection(&self) -> Result<usize> {
let workers = self.workers.read().expect("lock should not be poisoned");
let worker_id = workers
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.load_factor()
.partial_cmp(&b.load_factor())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.ok_or_else(|| NumRs2Error::RuntimeError("No workers available".to_string()))?;
Ok(worker_id)
}
fn weighted_capacity_selection(&self) -> Result<usize> {
let workers = self.workers.read().expect("lock should not be poisoned");
let worker_id = workers
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let a_score = a.load_factor() / a.capacity_weight;
let b_score = b.load_factor() / b.capacity_weight;
a_score
.partial_cmp(&b_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.ok_or_else(|| NumRs2Error::RuntimeError("No workers available".to_string()))?;
Ok(worker_id)
}
fn adaptive_selection(&self) -> Result<usize> {
self.maybe_adapt_strategy()?;
let strategy = *self.strategy.read().expect("lock should not be poisoned");
match strategy {
BalancingStrategy::Adaptive => self.least_loaded_selection(), _ => self.select_worker(), }
}
fn work_stealing_selection(&self) -> Result<usize> {
self.least_loaded_selection()
}
fn numa_aware_selection(&self) -> Result<usize> {
let workers = self.workers.read().expect("lock should not be poisoned");
let current_numa = Self::get_current_numa_node();
let same_numa_workers: Vec<_> = workers
.iter()
.enumerate()
.filter(|(_, w)| w.numa_node == current_numa)
.collect();
if same_numa_workers.is_empty() {
return self.least_loaded_selection();
}
let worker_id = same_numa_workers
.iter()
.min_by(|(_, a), (_, b)| {
a.load_factor()
.partial_cmp(&b.load_factor())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| *idx)
.ok_or_else(|| NumRs2Error::RuntimeError("No NUMA workers available".to_string()))?;
Ok(worker_id)
}
#[allow(dead_code)]
fn should_rebalance(&self) -> Result<bool> {
let workers = self.workers.read().expect("lock should not be poisoned");
let imbalance = self.calculate_load_imbalance(&workers);
Ok(imbalance > self.rebalance_threshold)
}
fn calculate_load_imbalance(&self, workers: &[WorkerState]) -> f64 {
if workers.is_empty() {
return 0.0;
}
let loads: Vec<f64> = workers.iter().map(|w| w.load_factor()).collect();
let max_load = loads.iter().fold(0.0f64, |a, &b| a.max(b));
let min_load = loads.iter().fold(f64::INFINITY, |a, &b| a.min(b));
if max_load == 0.0 {
0.0
} else {
(max_load - min_load) / max_load
}
}
#[allow(dead_code)]
fn rebalance_workload(&self) -> Result<()> {
Ok(())
}
fn maybe_adapt_strategy(&self) -> Result<()> {
let last_change = *self
.last_strategy_change
.lock()
.expect("lock should not be poisoned");
if last_change.elapsed() < self.adaptation_window {
return Ok(()); }
let metrics = self.current_metrics();
let current_strategy = *self.strategy.read().expect("lock should not be poisoned");
let new_strategy = if metrics.load_imbalance > 0.4 {
BalancingStrategy::WorkStealing
} else if metrics.cache_miss_rate > 0.1 {
BalancingStrategy::NumaAware
} else if metrics.total_throughput < 10.0 {
BalancingStrategy::WeightedCapacity
} else {
BalancingStrategy::LeastLoaded
};
if new_strategy != current_strategy {
self.set_strategy(new_strategy);
}
Ok(())
}
fn detect_numa_node(_worker_id: usize) -> Option<usize> {
None
}
fn get_current_numa_node() -> Option<usize> {
None
}
}
pub struct LoadBalancingAdvisor {
metrics_history: VecDeque<WorkloadMetrics>,
#[allow(dead_code)]
analysis_window: Duration,
}
impl Default for LoadBalancingAdvisor {
fn default() -> Self {
Self::new()
}
}
impl LoadBalancingAdvisor {
pub fn new() -> Self {
Self {
metrics_history: VecDeque::with_capacity(1000),
analysis_window: Duration::from_secs(60),
}
}
pub fn record_metrics(&mut self, metrics: WorkloadMetrics) {
self.metrics_history.push_back(metrics);
while self.metrics_history.len() > 1000 {
self.metrics_history.pop_front();
}
}
pub fn recommend_strategy(&self) -> BalancingStrategy {
if self.metrics_history.is_empty() {
return BalancingStrategy::LeastLoaded;
}
let recent_metrics: Vec<_> = self.metrics_history.iter().rev().take(10).collect();
let avg_imbalance = recent_metrics.iter().map(|m| m.load_imbalance).sum::<f64>()
/ recent_metrics.len() as f64;
let avg_throughput = recent_metrics
.iter()
.map(|m| m.total_throughput)
.sum::<f64>()
/ recent_metrics.len() as f64;
let avg_cache_miss = recent_metrics
.iter()
.map(|m| m.cache_miss_rate)
.sum::<f64>()
/ recent_metrics.len() as f64;
if avg_imbalance > 0.3 {
BalancingStrategy::WorkStealing
} else if avg_cache_miss > 0.1 {
BalancingStrategy::NumaAware
} else if avg_throughput < 5.0 {
BalancingStrategy::WeightedCapacity
} else {
BalancingStrategy::Adaptive
}
}
pub fn analyze_trends(&self) -> LoadBalancingAnalysis {
if self.metrics_history.len() < 2 {
return LoadBalancingAnalysis::default();
}
let first = &self.metrics_history[0];
let last = &self.metrics_history[self.metrics_history.len() - 1];
let throughput_trend = last.total_throughput - first.total_throughput;
let imbalance_trend = last.load_imbalance - first.load_imbalance;
let response_time_trend =
last.avg_response_time.as_secs_f64() - first.avg_response_time.as_secs_f64();
LoadBalancingAnalysis {
throughput_trend,
imbalance_trend,
response_time_trend,
stability_score: self.calculate_stability_score(),
recommendation: self.recommend_strategy(),
}
}
fn calculate_stability_score(&self) -> f64 {
if self.metrics_history.len() < 10 {
return 0.5; }
let throughputs: Vec<f64> = self
.metrics_history
.iter()
.map(|m| m.total_throughput)
.collect();
let mean_throughput = throughputs.iter().sum::<f64>() / throughputs.len() as f64;
if mean_throughput == 0.0 {
return 0.0;
}
let variance = throughputs
.iter()
.map(|&x| (x - mean_throughput).powi(2))
.sum::<f64>()
/ throughputs.len() as f64;
let cv = variance.sqrt() / mean_throughput;
(1.0 / (1.0 + cv)).clamp(0.0, 1.0)
}
}
#[derive(Debug, Clone, Default)]
pub struct LoadBalancingAnalysis {
pub throughput_trend: f64,
pub imbalance_trend: f64,
pub response_time_trend: f64,
pub stability_score: f64,
pub recommendation: BalancingStrategy,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_balancer_creation() {
let balancer = LoadBalancer::new(BalancingStrategy::LeastLoaded, 4)
.expect("failed to create load balancer");
assert_eq!(balancer.num_workers(), 4);
assert_eq!(balancer.current_strategy(), BalancingStrategy::LeastLoaded);
}
#[test]
fn test_round_robin_selection() {
let balancer = LoadBalancer::new(BalancingStrategy::RoundRobin, 3)
.expect("failed to create load balancer");
let selections: Vec<usize> = (0..6)
.map(|_| balancer.select_worker().expect("failed to select worker"))
.collect();
assert_eq!(selections, vec![0, 1, 2, 0, 1, 2]);
}
#[test]
fn test_least_loaded_selection() {
let balancer = LoadBalancer::new(BalancingStrategy::LeastLoaded, 3)
.expect("failed to create load balancer");
balancer
.update_worker_metrics(1, 10, 0.5, 0.5)
.expect("failed to update worker metrics");
let selection = balancer.select_worker().expect("failed to select worker");
assert!(selection < 3); }
#[test]
fn test_strategy_switching() {
let balancer = LoadBalancer::new(BalancingStrategy::RoundRobin, 2)
.expect("failed to create load balancer");
assert_eq!(balancer.current_strategy(), BalancingStrategy::RoundRobin);
balancer.set_strategy(BalancingStrategy::LeastLoaded);
assert_eq!(balancer.current_strategy(), BalancingStrategy::LeastLoaded);
}
#[test]
fn test_workload_metrics() {
let balancer = LoadBalancer::new(BalancingStrategy::LeastLoaded, 3)
.expect("failed to create load balancer");
balancer
.update_worker_metrics(0, 5, 0.5, 0.4)
.expect("failed to update worker 0 metrics");
balancer
.update_worker_metrics(1, 3, 0.4, 0.3)
.expect("failed to update worker 1 metrics");
balancer
.update_worker_metrics(2, 7, 0.6, 0.5)
.expect("failed to update worker 2 metrics");
let metrics = balancer.current_metrics();
assert_eq!(metrics.active_tasks, 15);
assert_eq!(metrics.queue_lengths, vec![5, 3, 7]);
assert!(metrics.load_imbalance >= 0.0); }
#[test]
fn test_load_distribution_cv() {
let metrics = WorkloadMetrics {
queue_lengths: vec![5, 5, 5],
..Default::default()
};
assert_eq!(metrics.load_distribution_cv(), 0.0);
let metrics2 = WorkloadMetrics {
queue_lengths: vec![1, 5, 9],
..Default::default()
};
assert!(metrics2.load_distribution_cv() > 0.5);
}
#[test]
fn test_load_balancing_advisor() {
let mut advisor = LoadBalancingAdvisor::new();
let metrics = WorkloadMetrics {
load_imbalance: 0.5, ..Default::default()
};
advisor.record_metrics(metrics);
let recommendation = advisor.recommend_strategy();
assert_eq!(recommendation, BalancingStrategy::WorkStealing);
}
#[test]
fn test_workload_metrics_helpers() {
let mut metrics = WorkloadMetrics {
queue_lengths: vec![1, 5, 3, 7, 2],
..Default::default()
};
let max_load = 7.0; let min_load = 1.0; metrics.load_imbalance = (max_load - min_load) / max_load;
assert_eq!(metrics.most_loaded_worker(), Some(3));
assert_eq!(metrics.least_loaded_worker(), Some(0));
assert!(!metrics.is_balanced(0.3)); }
}