use super::{GpuDevice, SelectionStrategy};
use crate::error::{GpuAdvancedError, Result};
use parking_lot::RwLock;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering as AtomicOrdering};
use std::time::{Duration, Instant};
pub struct LoadBalancer {
devices: Vec<Arc<GpuDevice>>,
strategy: SelectionStrategy,
rr_counter: AtomicUsize,
stats: Arc<RwLock<LoadStats>>,
migration_config: Arc<RwLock<MigrationConfig>>,
migration_history: Arc<RwLock<MigrationHistory>>,
workload_tracker: Arc<RwLock<WorkloadTracker>>,
}
#[derive(Debug, Clone, Default)]
pub struct LoadStats {
pub tasks_per_device: Vec<usize>,
pub time_per_device: Vec<u64>,
pub active_tasks: Vec<usize>,
pub memory_per_device: Vec<u64>,
pub migrations_from: Vec<usize>,
pub migrations_to: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct MigrationConfig {
pub overload_threshold: f32,
pub underutilization_threshold: f32,
pub min_imbalance_threshold: f32,
pub transfer_cost_base: f64,
pub transfer_cost_per_byte: f64,
pub min_migration_size: u64,
pub max_pending_migrations: usize,
pub migration_cooldown_secs: u64,
pub enable_predictive_migration: bool,
pub history_window_size: usize,
pub memory_weight: f32,
pub compute_weight: f32,
}
impl Default for MigrationConfig {
fn default() -> Self {
Self {
overload_threshold: 0.8,
underutilization_threshold: 0.3,
min_imbalance_threshold: 0.2,
transfer_cost_base: 1.0,
transfer_cost_per_byte: 0.000001, min_migration_size: 1024, max_pending_migrations: 4,
migration_cooldown_secs: 5,
enable_predictive_migration: true,
history_window_size: 100,
memory_weight: 0.4,
compute_weight: 0.6,
}
}
}
#[derive(Debug, Clone)]
pub struct MigratableWorkload {
pub id: u64,
pub source_device: usize,
pub memory_size: u64,
pub compute_intensity: f32,
pub priority: u32,
pub created_at: Instant,
pub migrating: bool,
pub dependencies: Vec<u64>,
}
impl MigratableWorkload {
pub fn new(
id: u64,
source_device: usize,
memory_size: u64,
compute_intensity: f32,
priority: u32,
) -> Self {
Self {
id,
source_device,
memory_size,
compute_intensity,
priority,
created_at: Instant::now(),
migrating: false,
dependencies: Vec::new(),
}
}
pub fn with_dependency(mut self, dep_id: u64) -> Self {
self.dependencies.push(dep_id);
self
}
pub fn calculate_migration_cost(&self, config: &MigrationConfig) -> f64 {
config.transfer_cost_base
+ (self.memory_size as f64 * config.transfer_cost_per_byte)
+ (self.compute_intensity as f64 * 0.1) }
}
#[derive(Debug, Clone)]
pub struct MigrationPlan {
pub workload: MigratableWorkload,
pub source_device: usize,
pub target_device: usize,
pub estimated_cost: f64,
pub expected_benefit: f64,
pub net_benefit: f64,
pub created_at: Instant,
pub approved: bool,
}
impl MigrationPlan {
pub fn new(
workload: MigratableWorkload,
target_device: usize,
config: &MigrationConfig,
) -> Self {
let source_device = workload.source_device;
let estimated_cost = workload.calculate_migration_cost(config);
let expected_benefit = workload.compute_intensity as f64 * 10.0; let net_benefit = expected_benefit - estimated_cost;
Self {
workload,
source_device,
target_device,
estimated_cost,
expected_benefit,
net_benefit,
created_at: Instant::now(),
approved: net_benefit > 0.0,
}
}
pub fn should_migrate(&self) -> bool {
self.approved && self.net_benefit > 0.0
}
}
#[derive(Debug, Clone)]
pub struct MigrationResult {
pub success: bool,
pub source_device: usize,
pub target_device: usize,
pub workload_id: u64,
pub transfer_time: Duration,
pub bytes_transferred: u64,
pub error_message: Option<String>,
}
#[derive(Debug, Default)]
pub struct MigrationHistory {
entries: VecDeque<MigrationHistoryEntry>,
max_size: usize,
total_successful: usize,
total_failed: usize,
}
#[derive(Debug, Clone)]
pub struct MigrationHistoryEntry {
pub timestamp: Instant,
pub source_device: usize,
pub target_device: usize,
pub success: bool,
pub transfer_time: Duration,
pub bytes_transferred: u64,
}
impl MigrationHistory {
pub fn new(max_size: usize) -> Self {
Self {
entries: VecDeque::with_capacity(max_size),
max_size,
total_successful: 0,
total_failed: 0,
}
}
pub fn add_entry(&mut self, entry: MigrationHistoryEntry) {
if entry.success {
self.total_successful += 1;
} else {
self.total_failed += 1;
}
if self.entries.len() >= self.max_size {
self.entries.pop_front();
}
self.entries.push_back(entry);
}
pub fn success_rate(&self, source: usize, target: usize) -> f64 {
let filtered: Vec<_> = self
.entries
.iter()
.filter(|e| e.source_device == source && e.target_device == target)
.collect();
if filtered.is_empty() {
return 1.0; }
let successful = filtered.iter().filter(|e| e.success).count();
successful as f64 / filtered.len() as f64
}
pub fn average_transfer_time(&self, source: usize, target: usize) -> Option<Duration> {
let filtered: Vec<_> = self
.entries
.iter()
.filter(|e| e.source_device == source && e.target_device == target && e.success)
.collect();
if filtered.is_empty() {
return None;
}
let total: Duration = filtered.iter().map(|e| e.transfer_time).sum();
Some(total / filtered.len() as u32)
}
pub fn total_bytes_transferred(&self) -> u64 {
self.entries.iter().map(|e| e.bytes_transferred).sum()
}
pub fn overall_success_rate(&self) -> f64 {
let total = self.total_successful + self.total_failed;
if total == 0 {
return 1.0;
}
self.total_successful as f64 / total as f64
}
}
#[derive(Debug)]
pub struct WorkloadTracker {
utilization_samples: Vec<VecDeque<UtilizationSample>>,
pending_workloads: Vec<Vec<MigratableWorkload>>,
next_workload_id: AtomicU64,
last_rebalance: Vec<Option<Instant>>,
}
#[derive(Debug, Clone)]
pub struct UtilizationSample {
pub timestamp: Instant,
pub compute: f32,
pub memory: f32,
pub active_tasks: usize,
}
impl WorkloadTracker {
pub fn new(device_count: usize, history_size: usize) -> Self {
let mut utilization_samples = Vec::with_capacity(device_count);
let mut pending_workloads = Vec::with_capacity(device_count);
let mut last_rebalance = Vec::with_capacity(device_count);
for _ in 0..device_count {
utilization_samples.push(VecDeque::with_capacity(history_size));
pending_workloads.push(Vec::new());
last_rebalance.push(None);
}
Self {
utilization_samples,
pending_workloads,
next_workload_id: AtomicU64::new(0),
last_rebalance,
}
}
pub fn next_workload_id(&self) -> u64 {
self.next_workload_id.fetch_add(1, AtomicOrdering::Relaxed)
}
pub fn record_sample(&mut self, device_index: usize, sample: UtilizationSample) {
if let Some(samples) = self.utilization_samples.get_mut(device_index) {
if samples.len() >= samples.capacity() {
samples.pop_front();
}
samples.push_back(sample);
}
}
pub fn average_utilization(&self, device_index: usize, window: usize) -> Option<(f32, f32)> {
let samples = self.utilization_samples.get(device_index)?;
if samples.is_empty() {
return None;
}
let take_count = window.min(samples.len());
let recent: Vec<_> = samples.iter().rev().take(take_count).collect();
let avg_compute = recent.iter().map(|s| s.compute).sum::<f32>() / take_count as f32;
let avg_memory = recent.iter().map(|s| s.memory).sum::<f32>() / take_count as f32;
Some((avg_compute, avg_memory))
}
pub fn utilization_trend(&self, device_index: usize, window: usize) -> Option<f32> {
let samples = self.utilization_samples.get(device_index)?;
if samples.len() < 2 {
return None;
}
let take_count = window.min(samples.len());
let skip_count = samples.len().saturating_sub(take_count);
let recent: Vec<_> = samples.iter().skip(skip_count).collect();
if recent.len() < 2 {
return None;
}
let n = recent.len() as f32;
let mut sum_x = 0.0f32;
let mut sum_y = 0.0f32;
let mut sum_xy = 0.0f32;
let mut sum_xx = 0.0f32;
for (i, sample) in recent.iter().enumerate() {
let x = i as f32;
let y = sample.compute;
sum_x += x;
sum_y += y;
sum_xy += x * y;
sum_xx += x * x;
}
let denominator = n * sum_xx - sum_x * sum_x;
if denominator.abs() < f32::EPSILON {
return Some(0.0);
}
Some((n * sum_xy - sum_x * sum_y) / denominator)
}
pub fn add_workload(&mut self, device_index: usize, workload: MigratableWorkload) {
if let Some(workloads) = self.pending_workloads.get_mut(device_index) {
workloads.push(workload);
}
}
pub fn remove_workload(
&mut self,
device_index: usize,
workload_id: u64,
) -> Option<MigratableWorkload> {
if let Some(workloads) = self.pending_workloads.get_mut(device_index) {
if let Some(pos) = workloads.iter().position(|w| w.id == workload_id) {
return Some(workloads.remove(pos));
}
}
None
}
pub fn get_migratable_workloads(&self, device_index: usize) -> Vec<&MigratableWorkload> {
self.pending_workloads
.get(device_index)
.map(|workloads| {
workloads
.iter()
.filter(|w| !w.migrating && w.dependencies.is_empty())
.collect()
})
.unwrap_or_default()
}
pub fn update_rebalance_time(&mut self, device_index: usize) {
if let Some(time) = self.last_rebalance.get_mut(device_index) {
*time = Some(Instant::now());
}
}
pub fn is_in_cooldown(&self, device_index: usize, cooldown_secs: u64) -> bool {
self.last_rebalance
.get(device_index)
.and_then(|opt| opt.as_ref())
.map(|t| t.elapsed().as_secs() < cooldown_secs)
.unwrap_or(false)
}
pub fn pending_count(&self, device_index: usize) -> usize {
self.pending_workloads
.get(device_index)
.map(|w| w.len())
.unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct DeviceLoad {
pub device_index: usize,
pub compute_utilization: f32,
pub memory_utilization: f32,
pub combined_load: f32,
pub active_tasks: usize,
pub pending_workloads: usize,
pub score: f32,
pub trend: f32,
}
impl DeviceLoad {
pub fn is_overloaded(&self, config: &MigrationConfig) -> bool {
self.combined_load > config.overload_threshold
}
pub fn is_underutilized(&self, config: &MigrationConfig) -> bool {
self.combined_load < config.underutilization_threshold
}
}
impl LoadBalancer {
pub fn new(devices: Vec<Arc<GpuDevice>>, strategy: SelectionStrategy) -> Self {
let device_count = devices.len();
let stats = LoadStats {
tasks_per_device: vec![0; device_count],
time_per_device: vec![0; device_count],
active_tasks: vec![0; device_count],
memory_per_device: vec![0; device_count],
migrations_from: vec![0; device_count],
migrations_to: vec![0; device_count],
};
let config = MigrationConfig::default();
let tracker = WorkloadTracker::new(device_count, config.history_window_size);
let history = MigrationHistory::new(config.history_window_size);
Self {
devices,
strategy,
rr_counter: AtomicUsize::new(0),
stats: Arc::new(RwLock::new(stats)),
migration_config: Arc::new(RwLock::new(config)),
migration_history: Arc::new(RwLock::new(history)),
workload_tracker: Arc::new(RwLock::new(tracker)),
}
}
pub fn migration_config(&self) -> MigrationConfig {
self.migration_config.read().clone()
}
pub fn set_migration_config(&self, config: MigrationConfig) {
*self.migration_config.write() = config;
}
pub fn select_device(&self) -> Result<Arc<GpuDevice>> {
if self.devices.is_empty() {
return Err(GpuAdvancedError::GpuNotFound(
"No devices available".to_string(),
));
}
match self.strategy {
SelectionStrategy::RoundRobin => self.select_round_robin(),
SelectionStrategy::LeastLoaded => self.select_least_loaded(),
SelectionStrategy::BestScore => self.select_best_score(),
SelectionStrategy::Affinity => self.select_affinity(),
}
}
fn select_round_robin(&self) -> Result<Arc<GpuDevice>> {
let index = self.rr_counter.fetch_add(1, AtomicOrdering::Relaxed) % self.devices.len();
self.devices
.get(index)
.cloned()
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index,
total: self.devices.len(),
})
}
fn select_least_loaded(&self) -> Result<Arc<GpuDevice>> {
let stats = self.stats.read();
let (index, _) = self
.devices
.iter()
.enumerate()
.map(|(i, device)| {
let active_tasks = stats.active_tasks.get(i).copied().unwrap_or(0);
let workload = device.get_workload();
let load = (active_tasks as f32) + workload;
(i, load)
})
.min_by(|(_, load_a), (_, load_b)| {
load_a.partial_cmp(load_b).unwrap_or(Ordering::Equal)
})
.ok_or_else(|| {
GpuAdvancedError::LoadBalancingError("No device available".to_string())
})?;
self.devices
.get(index)
.cloned()
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index,
total: self.devices.len(),
})
}
fn select_best_score(&self) -> Result<Arc<GpuDevice>> {
let (index, _) = self
.devices
.iter()
.enumerate()
.map(|(i, device)| (i, device.get_score()))
.max_by(|(_, score_a), (_, score_b)| {
score_a.partial_cmp(score_b).unwrap_or(Ordering::Equal)
})
.ok_or_else(|| {
GpuAdvancedError::LoadBalancingError("No device available".to_string())
})?;
self.devices
.get(index)
.cloned()
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index,
total: self.devices.len(),
})
}
fn select_affinity(&self) -> Result<Arc<GpuDevice>> {
let thread_id = std::thread::current().id();
let hash = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
thread_id.hash(&mut hasher);
hasher.finish()
};
let index = (hash as usize) % self.devices.len();
self.devices
.get(index)
.cloned()
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index,
total: self.devices.len(),
})
}
pub fn select_weighted(&self) -> Result<Arc<GpuDevice>> {
if self.devices.is_empty() {
return Err(GpuAdvancedError::GpuNotFound(
"No devices available".to_string(),
));
}
let config = self.migration_config.read();
let mut best_index = 0;
let mut best_score = f32::MIN;
for (i, device) in self.devices.iter().enumerate() {
let compute_util = device.get_workload();
let memory_usage = device.get_memory_usage();
let max_memory = device.info.max_buffer_size;
let memory_util = if max_memory > 0 {
memory_usage as f32 / max_memory as f32
} else {
0.0
};
let availability =
1.0 - (compute_util * config.compute_weight + memory_util * config.memory_weight);
let type_bonus = device.get_score();
let score = availability * type_bonus;
if score > best_score {
best_score = score;
best_index = i;
}
}
self.devices
.get(best_index)
.cloned()
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index: best_index,
total: self.devices.len(),
})
}
pub fn get_device_loads(&self) -> Vec<DeviceLoad> {
let config = self.migration_config.read();
let tracker = self.workload_tracker.read();
let stats = self.stats.read();
self.devices
.iter()
.enumerate()
.map(|(i, device)| {
let compute_utilization = device.get_workload();
let memory_usage = device.get_memory_usage();
let max_memory = device.info.max_buffer_size;
let memory_utilization = if max_memory > 0 {
memory_usage as f32 / max_memory as f32
} else {
0.0
};
let combined_load = compute_utilization * config.compute_weight
+ memory_utilization * config.memory_weight;
let trend = tracker.utilization_trend(i, 10).unwrap_or(0.0);
DeviceLoad {
device_index: i,
compute_utilization,
memory_utilization,
combined_load,
active_tasks: stats.active_tasks.get(i).copied().unwrap_or(0),
pending_workloads: tracker.pending_count(i),
score: device.get_score(),
trend,
}
})
.collect()
}
pub fn identify_overloaded_devices(&self) -> Vec<DeviceLoad> {
let config = self.migration_config.read();
self.get_device_loads()
.into_iter()
.filter(|load| load.is_overloaded(&config))
.collect()
}
pub fn identify_underutilized_devices(&self) -> Vec<DeviceLoad> {
let config = self.migration_config.read();
self.get_device_loads()
.into_iter()
.filter(|load| load.is_underutilized(&config))
.collect()
}
pub fn is_imbalanced(&self) -> bool {
let loads = self.get_device_loads();
if loads.len() < 2 {
return false;
}
let config = self.migration_config.read();
let max_load = loads
.iter()
.map(|l| l.combined_load)
.fold(f32::MIN, f32::max);
let min_load = loads
.iter()
.map(|l| l.combined_load)
.fold(f32::MAX, f32::min);
(max_load - min_load) > config.min_imbalance_threshold
}
pub fn calculate_transfer_cost(
&self,
source_device: usize,
target_device: usize,
data_size: u64,
) -> Result<f64> {
if source_device >= self.devices.len() || target_device >= self.devices.len() {
return Err(GpuAdvancedError::InvalidGpuIndex {
index: source_device.max(target_device),
total: self.devices.len(),
});
}
let config = self.migration_config.read();
let history = self.migration_history.read();
let mut cost =
config.transfer_cost_base + (data_size as f64 * config.transfer_cost_per_byte);
if let Some(avg_time) = history.average_transfer_time(source_device, target_device) {
let time_factor = avg_time.as_secs_f64();
cost *= 1.0 + time_factor;
}
let success_rate = history.success_rate(source_device, target_device);
if success_rate < 1.0 {
cost *= 1.0 + (1.0 - success_rate) * 0.5;
}
Ok(cost)
}
pub fn create_migration_plan(
&self,
workload: MigratableWorkload,
target_device: usize,
) -> Result<MigrationPlan> {
if target_device >= self.devices.len() {
return Err(GpuAdvancedError::InvalidGpuIndex {
index: target_device,
total: self.devices.len(),
});
}
let config = self.migration_config.read();
let plan = MigrationPlan::new(workload, target_device, &config);
Ok(plan)
}
pub fn find_migration_target(&self, source_device: usize) -> Result<Option<usize>> {
let loads = self.get_device_loads();
let config = self.migration_config.read();
let tracker = self.workload_tracker.read();
let source_load = loads
.iter()
.find(|l| l.device_index == source_device)
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index: source_device,
total: self.devices.len(),
})?;
let mut candidates: Vec<_> = loads
.iter()
.filter(|l| {
l.device_index != source_device
&& l.is_underutilized(&config)
&& !tracker.is_in_cooldown(l.device_index, config.migration_cooldown_secs)
})
.collect();
if candidates.is_empty() {
return Ok(None);
}
candidates.sort_by(|a, b| match a.combined_load.partial_cmp(&b.combined_load) {
Some(Ordering::Equal) | None => {
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
}
Some(ordering) => ordering,
});
if let Some(best) = candidates.first() {
let load_diff = source_load.combined_load - best.combined_load;
if load_diff > config.min_imbalance_threshold {
return Ok(Some(best.device_index));
}
}
Ok(None)
}
pub fn select_workload_for_migration(
&self,
source_device: usize,
) -> Option<MigratableWorkload> {
let config = self.migration_config.read();
let tracker = self.workload_tracker.read();
let migratable = tracker.get_migratable_workloads(source_device);
let mut candidates: Vec<_> = migratable
.into_iter()
.filter(|w| w.memory_size >= config.min_migration_size)
.collect();
candidates.sort_by(|a, b| {
match b.priority.cmp(&a.priority) {
Ordering::Equal => b
.compute_intensity
.partial_cmp(&a.compute_intensity)
.unwrap_or(Ordering::Equal),
other => other,
}
});
candidates.first().map(|w| (*w).clone())
}
pub fn execute_migration(&self, plan: &MigrationPlan) -> Result<MigrationResult> {
if !plan.should_migrate() {
return Ok(MigrationResult {
success: false,
source_device: plan.source_device,
target_device: plan.target_device,
workload_id: plan.workload.id,
transfer_time: Duration::ZERO,
bytes_transferred: 0,
error_message: Some("Migration not approved".to_string()),
});
}
let start = Instant::now();
{
let mut tracker = self.workload_tracker.write();
if tracker
.remove_workload(plan.source_device, plan.workload.id)
.is_none()
{
return Ok(MigrationResult {
success: false,
source_device: plan.source_device,
target_device: plan.target_device,
workload_id: plan.workload.id,
transfer_time: Duration::ZERO,
bytes_transferred: 0,
error_message: Some("Workload not found on source device".to_string()),
});
}
let mut migrated = plan.workload.clone();
migrated.source_device = plan.target_device;
tracker.add_workload(plan.target_device, migrated);
tracker.update_rebalance_time(plan.source_device);
tracker.update_rebalance_time(plan.target_device);
}
{
let mut stats = self.stats.write();
if let Some(from) = stats.migrations_from.get_mut(plan.source_device) {
*from = from.saturating_add(1);
}
if let Some(to) = stats.migrations_to.get_mut(plan.target_device) {
*to = to.saturating_add(1);
}
}
let transfer_time = start.elapsed();
{
let mut history = self.migration_history.write();
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: plan.source_device,
target_device: plan.target_device,
success: true,
transfer_time,
bytes_transferred: plan.workload.memory_size,
});
}
Ok(MigrationResult {
success: true,
source_device: plan.source_device,
target_device: plan.target_device,
workload_id: plan.workload.id,
transfer_time,
bytes_transferred: plan.workload.memory_size,
error_message: None,
})
}
pub fn rebalance(&self) -> Result<Vec<MigrationResult>> {
if !self.is_imbalanced() {
return Ok(Vec::new());
}
let mut results = Vec::new();
self.sample_utilization();
let overloaded = self.identify_overloaded_devices();
if overloaded.is_empty() {
return Ok(results);
}
for source_load in overloaded {
if results.len() >= self.migration_config.read().max_pending_migrations {
break;
}
let target = match self.find_migration_target(source_load.device_index)? {
Some(t) => t,
None => continue,
};
let workload = match self.select_workload_for_migration(source_load.device_index) {
Some(w) => w,
None => continue,
};
let plan = self.create_migration_plan(workload, target)?;
if plan.should_migrate() {
let result = self.execute_migration(&plan)?;
results.push(result);
}
}
let config = self.migration_config.read();
if config.enable_predictive_migration {
drop(config);
self.handle_predictive_migrations(&mut results)?;
}
Ok(results)
}
fn sample_utilization(&self) {
let stats = self.stats.read();
let mut tracker = self.workload_tracker.write();
for (i, device) in self.devices.iter().enumerate() {
let compute = device.get_workload();
let memory_usage = device.get_memory_usage();
let max_memory = device.info.max_buffer_size;
let memory = if max_memory > 0 {
memory_usage as f32 / max_memory as f32
} else {
0.0
};
tracker.record_sample(
i,
UtilizationSample {
timestamp: Instant::now(),
compute,
memory,
active_tasks: stats.active_tasks.get(i).copied().unwrap_or(0),
},
);
}
}
fn handle_predictive_migrations(&self, results: &mut Vec<MigrationResult>) -> Result<()> {
let candidates: Vec<(usize, f32, f32)> = {
let config = self.migration_config.read();
let tracker = self.workload_tracker.read();
let device_loads = self.get_device_loads();
self.devices
.iter()
.enumerate()
.filter_map(|(i, _device)| {
let trend = tracker.utilization_trend(i, 20)?;
let load = device_loads.iter().find(|l| l.device_index == i)?;
if trend > 0.05
&& load.combined_load > 0.5
&& load.combined_load < config.overload_threshold
{
Some((i, trend, load.combined_load))
} else {
None
}
})
.collect()
};
let max_migrations = self.migration_config.read().max_pending_migrations;
for (device_index, _trend, _combined_load) in candidates {
if results.len() >= max_migrations {
break;
}
if let Some(target) = self.find_migration_target(device_index)? {
if let Some(workload) = self.select_workload_for_migration(device_index) {
let plan = self.create_migration_plan(workload, target)?;
if plan.should_migrate() {
let result = self.execute_migration(&plan)?;
results.push(result);
}
}
}
}
Ok(())
}
pub fn register_workload(
&self,
device_index: usize,
memory_size: u64,
compute_intensity: f32,
priority: u32,
) -> Result<u64> {
if device_index >= self.devices.len() {
return Err(GpuAdvancedError::InvalidGpuIndex {
index: device_index,
total: self.devices.len(),
});
}
let mut tracker = self.workload_tracker.write();
let workload_id = tracker.next_workload_id();
let workload = MigratableWorkload::new(
workload_id,
device_index,
memory_size,
compute_intensity,
priority,
);
tracker.add_workload(device_index, workload);
Ok(workload_id)
}
pub fn unregister_workload(&self, device_index: usize, workload_id: u64) -> Result<()> {
if device_index >= self.devices.len() {
return Err(GpuAdvancedError::InvalidGpuIndex {
index: device_index,
total: self.devices.len(),
});
}
let mut tracker = self.workload_tracker.write();
tracker.remove_workload(device_index, workload_id);
Ok(())
}
pub fn task_started(&self, device_index: usize) {
let mut stats = self.stats.write();
if let Some(count) = stats.tasks_per_device.get_mut(device_index) {
*count = count.saturating_add(1);
}
if let Some(active) = stats.active_tasks.get_mut(device_index) {
*active = active.saturating_add(1);
}
}
pub fn task_completed(&self, device_index: usize, duration_us: u64) {
let mut stats = self.stats.write();
if let Some(active) = stats.active_tasks.get_mut(device_index) {
*active = active.saturating_sub(1);
}
if let Some(time) = stats.time_per_device.get_mut(device_index) {
*time = time.saturating_add(duration_us);
}
}
pub fn get_stats(&self) -> LoadStats {
self.stats.read().clone()
}
pub fn print_stats(&self) {
let stats = self.stats.read();
println!("\nLoad Balancer Statistics:");
println!(" Strategy: {:?}", self.strategy);
for (i, device) in self.devices.iter().enumerate() {
let tasks = stats.tasks_per_device.get(i).copied().unwrap_or(0);
let time_us = stats.time_per_device.get(i).copied().unwrap_or(0);
let active = stats.active_tasks.get(i).copied().unwrap_or(0);
let avg_time_us = if tasks > 0 {
time_us / (tasks as u64)
} else {
0
};
let migrations_from = stats.migrations_from.get(i).copied().unwrap_or(0);
let migrations_to = stats.migrations_to.get(i).copied().unwrap_or(0);
println!("\n GPU {}: {}", i, device.info.name);
println!(" Total tasks: {}", tasks);
println!(" Active tasks: {}", active);
println!(" Total time: {} ms", time_us / 1000);
println!(" Avg task time: {} us", avg_time_us);
println!(
" Current workload: {:.1}%",
device.get_workload() * 100.0
);
println!(" Migrations from: {}", migrations_from);
println!(" Migrations to: {}", migrations_to);
}
}
pub fn reset_stats(&self) {
let mut stats = self.stats.write();
let device_count = self.devices.len();
stats.tasks_per_device = vec![0; device_count];
stats.time_per_device = vec![0; device_count];
stats.active_tasks = vec![0; device_count];
stats.memory_per_device = vec![0; device_count];
stats.migrations_from = vec![0; device_count];
stats.migrations_to = vec![0; device_count];
}
pub fn get_device_utilization(&self, device_index: usize) -> f32 {
self.devices
.get(device_index)
.map(|device| device.get_workload())
.unwrap_or(0.0)
}
pub fn get_cluster_utilization(&self) -> f32 {
if self.devices.is_empty() {
return 0.0;
}
let total_utilization: f32 = self
.devices
.iter()
.map(|device| device.get_workload())
.sum();
total_utilization / (self.devices.len() as f32)
}
pub fn suggest_device(&self, estimated_memory: u64) -> Result<Arc<GpuDevice>> {
let candidates: Vec<_> = self
.devices
.iter()
.filter(|device| {
let memory_usage = device.get_memory_usage();
let max_memory = device.info.max_buffer_size;
(max_memory - memory_usage) >= estimated_memory
})
.collect();
if candidates.is_empty() {
return Err(GpuAdvancedError::GpuNotFound(
"No device with enough memory".to_string(),
));
}
self.select_device()
}
pub fn get_migration_stats(&self) -> (usize, usize, f64) {
let history = self.migration_history.read();
(
history.total_successful,
history.total_failed,
history.overall_success_rate(),
)
}
pub fn device_count(&self) -> usize {
self.devices.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_stats() {
let stats = LoadStats::default();
assert_eq!(stats.tasks_per_device.len(), 0);
}
#[test]
fn test_selection_strategy() {
let strategy = SelectionStrategy::RoundRobin;
let _strategy2 = strategy;
}
#[test]
fn test_migration_config_default() {
let config = MigrationConfig::default();
assert!(config.overload_threshold > 0.0);
assert!(config.overload_threshold <= 1.0);
assert!(config.underutilization_threshold >= 0.0);
assert!(config.underutilization_threshold < config.overload_threshold);
}
#[test]
fn test_migratable_workload() {
let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.5, 10);
assert_eq!(workload.id, 1);
assert_eq!(workload.source_device, 0);
assert_eq!(workload.memory_size, 1024 * 1024);
assert!(!workload.migrating);
let workload_with_dep = workload.with_dependency(0);
assert_eq!(workload_with_dep.dependencies.len(), 1);
}
#[test]
fn test_migration_cost_calculation() {
let config = MigrationConfig::default();
let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.5, 10);
let cost = workload.calculate_migration_cost(&config);
assert!(cost > config.transfer_cost_base);
}
#[test]
fn test_migration_plan() {
let config = MigrationConfig::default();
let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.8, 10);
let plan = MigrationPlan::new(workload, 1, &config);
assert_eq!(plan.source_device, 0);
assert_eq!(plan.target_device, 1);
assert!(plan.estimated_cost > 0.0);
}
#[test]
fn test_migration_history() {
let mut history = MigrationHistory::new(10);
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: 0,
target_device: 1,
success: true,
transfer_time: Duration::from_millis(10),
bytes_transferred: 1024,
});
assert_eq!(history.total_successful, 1);
assert_eq!(history.total_failed, 0);
assert!((history.overall_success_rate() - 1.0).abs() < f64::EPSILON);
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: 0,
target_device: 1,
success: false,
transfer_time: Duration::from_millis(5),
bytes_transferred: 0,
});
assert_eq!(history.total_failed, 1);
assert!((history.overall_success_rate() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_workload_tracker() {
let mut tracker = WorkloadTracker::new(2, 100);
let id1 = tracker.next_workload_id();
let id2 = tracker.next_workload_id();
assert_ne!(id1, id2);
let workload = MigratableWorkload::new(id1, 0, 1024, 0.5, 10);
tracker.add_workload(0, workload);
assert_eq!(tracker.pending_count(0), 1);
let removed = tracker.remove_workload(0, id1);
assert!(removed.is_some());
assert_eq!(tracker.pending_count(0), 0);
}
#[test]
fn test_utilization_sample() {
let mut tracker = WorkloadTracker::new(2, 100);
for i in 0..10 {
tracker.record_sample(
0,
UtilizationSample {
timestamp: Instant::now(),
compute: 0.1 * (i as f32),
memory: 0.05 * (i as f32),
active_tasks: i,
},
);
}
let (avg_compute, avg_memory) = tracker
.average_utilization(0, 5)
.expect("Should have samples");
assert!(avg_compute > 0.0);
assert!(avg_memory > 0.0);
let trend = tracker.utilization_trend(0, 10).expect("Should have trend");
assert!(trend > 0.0); }
#[test]
fn test_device_load() {
let config = MigrationConfig::default();
let load = DeviceLoad {
device_index: 0,
compute_utilization: 0.9,
memory_utilization: 0.5,
combined_load: 0.85,
active_tasks: 5,
pending_workloads: 3,
score: 0.7,
trend: 0.1,
};
assert!(load.is_overloaded(&config));
assert!(!load.is_underutilized(&config));
let underutilized_load = DeviceLoad {
device_index: 1,
compute_utilization: 0.1,
memory_utilization: 0.1,
combined_load: 0.1,
active_tasks: 0,
pending_workloads: 0,
score: 0.9,
trend: -0.05,
};
assert!(!underutilized_load.is_overloaded(&config));
assert!(underutilized_load.is_underutilized(&config));
}
#[test]
fn test_migration_history_average_time() {
let mut history = MigrationHistory::new(10);
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: 0,
target_device: 1,
success: true,
transfer_time: Duration::from_millis(10),
bytes_transferred: 1024,
});
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: 0,
target_device: 1,
success: true,
transfer_time: Duration::from_millis(20),
bytes_transferred: 2048,
});
let avg = history
.average_transfer_time(0, 1)
.expect("Should have average");
assert_eq!(avg, Duration::from_millis(15));
assert!(history.average_transfer_time(1, 0).is_none());
}
#[test]
fn test_migration_history_success_rate() {
let mut history = MigrationHistory::new(10);
assert!((history.success_rate(0, 1) - 1.0).abs() < f64::EPSILON);
for _ in 0..3 {
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: 0,
target_device: 1,
success: true,
transfer_time: Duration::from_millis(10),
bytes_transferred: 1024,
});
}
history.add_entry(MigrationHistoryEntry {
timestamp: Instant::now(),
source_device: 0,
target_device: 1,
success: false,
transfer_time: Duration::from_millis(5),
bytes_transferred: 0,
});
let rate = history.success_rate(0, 1);
assert!((rate - 0.75).abs() < f64::EPSILON);
}
#[test]
fn test_workload_tracker_cooldown() {
let mut tracker = WorkloadTracker::new(2, 100);
assert!(!tracker.is_in_cooldown(0, 1));
tracker.update_rebalance_time(0);
assert!(tracker.is_in_cooldown(0, 1));
assert!(!tracker.is_in_cooldown(0, 0));
}
#[test]
fn test_workload_tracker_migratable() {
let mut tracker = WorkloadTracker::new(2, 100);
let workload1 = MigratableWorkload::new(0, 0, 1024, 0.5, 10);
let mut workload2 = MigratableWorkload::new(1, 0, 2048, 0.7, 5);
workload2.migrating = true;
let workload3 = MigratableWorkload::new(2, 0, 4096, 0.3, 15).with_dependency(0);
tracker.add_workload(0, workload1);
tracker.add_workload(0, workload2);
tracker.add_workload(0, workload3);
let migratable = tracker.get_migratable_workloads(0);
assert_eq!(migratable.len(), 1);
assert_eq!(migratable[0].id, 0);
}
#[test]
fn test_utilization_trend_calculation() {
let mut tracker = WorkloadTracker::new(1, 100);
for i in 0..20 {
tracker.record_sample(
0,
UtilizationSample {
timestamp: Instant::now(),
compute: 0.05 * (i as f32),
memory: 0.02 * (i as f32),
active_tasks: i,
},
);
}
let trend = tracker
.utilization_trend(0, 20)
.expect("Should compute trend");
assert!(
trend > 0.0,
"Trend should be positive for increasing samples"
);
let mut tracker2 = WorkloadTracker::new(1, 100);
for i in 0..20 {
tracker2.record_sample(
0,
UtilizationSample {
timestamp: Instant::now(),
compute: 1.0 - 0.05 * (i as f32),
memory: 0.5 - 0.02 * (i as f32),
active_tasks: 20 - i,
},
);
}
let trend2 = tracker2
.utilization_trend(0, 20)
.expect("Should compute trend");
assert!(
trend2 < 0.0,
"Trend should be negative for decreasing samples"
);
}
}