use anyhow::{anyhow, Result};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SimpleGpuDevice {
pub id: u32,
pub name: String,
pub memory_mb: u64,
pub compute_units: u32,
}
impl SimpleGpuDevice {
pub fn new(id: u32, name: impl Into<String>, memory_mb: u64, compute_units: u32) -> Self {
Self {
id,
name: name.into(),
memory_mb,
compute_units,
}
}
}
#[derive(Debug)]
struct DeviceState {
device: SimpleGpuDevice,
current_workload_mb: u64,
}
impl DeviceState {
fn new(device: SimpleGpuDevice) -> Self {
Self {
device,
current_workload_mb: 0,
}
}
fn utilization(&self) -> f64 {
if self.device.memory_mb == 0 {
return 0.0;
}
(self.current_workload_mb as f64 / self.device.memory_mb as f64).min(1.0)
}
}
#[derive(Debug, Clone)]
pub struct GpuLoadBalancer {
inner: Arc<Mutex<GpuLoadBalancerInner>>,
}
#[derive(Debug)]
struct GpuLoadBalancerInner {
device_order: Vec<u32>,
states: HashMap<u32, DeviceState>,
}
impl GpuLoadBalancerInner {
fn new() -> Self {
Self {
device_order: Vec::new(),
states: HashMap::new(),
}
}
}
impl Default for GpuLoadBalancer {
fn default() -> Self {
Self::new()
}
}
impl GpuLoadBalancer {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(GpuLoadBalancerInner::new())),
}
}
pub fn register_device(&self, device: SimpleGpuDevice) {
let mut g = self.inner.lock();
let id = device.id;
info!("Registering GPU device {} ({})", id, device.name);
if !g.device_order.contains(&id) {
g.device_order.push(id);
}
g.states.insert(id, DeviceState::new(device));
}
pub fn unregister_device(&self, device_id: u32) {
let mut g = self.inner.lock();
g.device_order.retain(|&x| x != device_id);
g.states.remove(&device_id);
debug!("Unregistered GPU device {}", device_id);
}
pub fn select_device(&self, workload_mb: u64) -> Option<u32> {
let g = self.inner.lock();
g.device_order
.iter()
.filter_map(|&id| g.states.get(&id).map(|s| (id, s)))
.filter(|(_, s)| {
s.device.memory_mb.saturating_sub(s.current_workload_mb) >= workload_mb
})
.min_by(|(_, a), (_, b)| {
a.utilization()
.partial_cmp(&b.utilization())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(id, _)| id)
}
pub fn record_workload(&self, device_id: u32, mb: u64) -> Result<()> {
let mut g = self.inner.lock();
let state = g
.states
.get_mut(&device_id)
.ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
state.current_workload_mb += mb;
debug!(
"Device {}: workload {} MB (util {:.1}%)",
device_id,
state.current_workload_mb,
state.utilization() * 100.0
);
Ok(())
}
pub fn release_workload(&self, device_id: u32, mb: u64) -> Result<()> {
let mut g = self.inner.lock();
let state = g
.states
.get_mut(&device_id)
.ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
state.current_workload_mb = state.current_workload_mb.saturating_sub(mb);
debug!(
"Device {}: released {} MB, now {} MB",
device_id, mb, state.current_workload_mb
);
Ok(())
}
pub fn utilization(&self, device_id: u32) -> Option<f64> {
let g = self.inner.lock();
g.states.get(&device_id).map(|s| s.utilization())
}
pub fn total_capacity_mb(&self) -> u64 {
let g = self.inner.lock();
g.states.values().map(|s| s.device.memory_mb).sum()
}
pub fn device_count(&self) -> usize {
self.inner.lock().states.len()
}
pub fn utilization_snapshot(&self) -> Vec<(u32, f64)> {
let g = self.inner.lock();
g.device_order
.iter()
.filter_map(|&id| g.states.get(&id).map(|s| (id, s.utilization())))
.collect()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct WorkloadChunk {
pub device_id: u32,
pub start_idx: usize,
pub end_idx: usize,
}
impl WorkloadChunk {
pub fn len(&self) -> usize {
self.end_idx.saturating_sub(self.start_idx)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, Default)]
pub struct WorkloadDistributor;
impl WorkloadDistributor {
pub fn new() -> Self {
Self
}
pub fn distribute(
&self,
total_vectors: usize,
devices: &[SimpleGpuDevice],
) -> Result<Vec<WorkloadChunk>> {
let eligible: Vec<&SimpleGpuDevice> = devices.iter().filter(|d| d.memory_mb > 0).collect();
if eligible.is_empty() {
return Err(anyhow!(
"No eligible GPU devices (all have zero memory or list is empty)"
));
}
let total_mem: u64 = eligible.iter().map(|d| d.memory_mb).sum();
let mut chunks: Vec<WorkloadChunk> = Vec::with_capacity(eligible.len());
let mut assigned = 0usize;
for (i, device) in eligible.iter().enumerate() {
let start_idx = assigned;
let end_idx = if i == eligible.len() - 1 {
total_vectors
} else {
let fraction = device.memory_mb as f64 / total_mem as f64;
let count = (total_vectors as f64 * fraction).round() as usize;
(assigned + count).min(total_vectors)
};
chunks.push(WorkloadChunk {
device_id: device.id,
start_idx,
end_idx,
});
assigned = end_idx;
if assigned >= total_vectors {
break;
}
}
Ok(chunks)
}
pub fn distribute_even(
&self,
total_vectors: usize,
devices: &[SimpleGpuDevice],
) -> Result<Vec<WorkloadChunk>> {
if devices.is_empty() {
return Err(anyhow!("Cannot distribute across zero devices"));
}
let n = devices.len();
let base = total_vectors / n;
let remainder = total_vectors % n;
let mut chunks = Vec::with_capacity(n);
let mut start = 0;
for (i, device) in devices.iter().enumerate() {
let extra = if i < remainder { 1 } else { 0 };
let end = start + base + extra;
chunks.push(WorkloadChunk {
device_id: device.id,
start_idx: start,
end_idx: end,
});
start = end;
}
Ok(chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
fn make_device(id: u32, mem_mb: u64) -> SimpleGpuDevice {
SimpleGpuDevice::new(id, format!("GPU-{}", id), mem_mb, 128)
}
#[test]
fn test_simple_gpu_device_fields() {
let d = SimpleGpuDevice::new(0, "TestGPU", 8192, 128);
assert_eq!(d.id, 0);
assert_eq!(d.name, "TestGPU");
assert_eq!(d.memory_mb, 8192);
assert_eq!(d.compute_units, 128);
}
#[test]
fn test_register_device_count() {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.register_device(make_device(1, 16384));
assert_eq!(lb.device_count(), 2);
}
#[test]
fn test_total_capacity_mb() {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 4096));
lb.register_device(make_device(1, 8192));
assert_eq!(lb.total_capacity_mb(), 12288);
}
#[test]
fn test_select_device_empty_returns_none() {
let lb = GpuLoadBalancer::new();
assert!(lb.select_device(100).is_none());
}
#[test]
fn test_select_device_single() {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
let sel = lb.select_device(100);
assert_eq!(sel, Some(0));
}
#[test]
fn test_select_device_insufficient_memory() {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 100)); assert!(lb.select_device(200).is_none());
}
#[test]
fn test_select_device_prefers_least_loaded() -> Result<()> {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.register_device(make_device(1, 8192));
lb.record_workload(0, 7000)?;
let sel = lb.select_device(500);
assert_eq!(sel, Some(1), "Should prefer the less-loaded device");
Ok(())
}
#[test]
fn test_record_and_release_workload() -> Result<()> {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.record_workload(0, 2048)?;
let u1 = lb.utilization(0).expect("utilization(0) was None");
assert!(
(u1 - 0.25).abs() < 1e-6,
"Expected 25% utilisation, got {}",
u1
);
lb.release_workload(0, 2048)?;
let u2 = lb.utilization(0).expect("utilization(0) was None");
assert!(u2 < 1e-9, "Expected 0% after release, got {}", u2);
Ok(())
}
#[test]
fn test_release_clamps_to_zero() -> Result<()> {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.record_workload(0, 100)?;
lb.release_workload(0, 9999)?;
let __val = lb.utilization(0).expect("utilization(0) was None");
assert_eq!(__val, 0.0);
Ok(())
}
#[test]
fn test_record_unknown_device_errors() {
let lb = GpuLoadBalancer::new();
assert!(lb.record_workload(99, 100).is_err());
}
#[test]
fn test_release_unknown_device_errors() {
let lb = GpuLoadBalancer::new();
assert!(lb.release_workload(99, 100).is_err());
}
#[test]
fn test_utilization_unknown_device_none() {
let lb = GpuLoadBalancer::new();
assert!(lb.utilization(42).is_none());
}
#[test]
fn test_utilization_snapshot() -> Result<()> {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.register_device(make_device(1, 4096));
lb.record_workload(0, 4096)?;
let snap = lb.utilization_snapshot();
assert_eq!(snap.len(), 2);
let u0 = snap
.iter()
.find(|(id, _)| *id == 0)
.map(|(_, u)| *u)
.expect("device 0 not in snapshot");
assert!((u0 - 0.5).abs() < 1e-6);
Ok(())
}
#[test]
fn test_unregister_device() {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.register_device(make_device(1, 8192));
lb.unregister_device(0);
assert_eq!(lb.device_count(), 1);
assert!(lb.utilization(0).is_none());
}
#[test]
fn test_reregister_device_resets_workload() -> Result<()> {
let lb = GpuLoadBalancer::new();
lb.register_device(make_device(0, 8192));
lb.record_workload(0, 4096)?;
lb.register_device(make_device(0, 8192));
let __val = lb.utilization(0).expect("utilization(0) should be present");
assert_eq!(__val, 0.0);
Ok(())
}
#[test]
fn test_workload_chunk_len() {
let chunk = WorkloadChunk {
device_id: 0,
start_idx: 10,
end_idx: 50,
};
assert_eq!(chunk.len(), 40);
}
#[test]
fn test_workload_chunk_is_empty() {
let chunk = WorkloadChunk {
device_id: 0,
start_idx: 5,
end_idx: 5,
};
assert!(chunk.is_empty());
}
#[test]
fn test_distribute_empty_devices_error() {
let dist = WorkloadDistributor::new();
assert!(dist.distribute(1000, &[]).is_err());
}
#[test]
fn test_distribute_single_device() -> Result<()> {
let dist = WorkloadDistributor::new();
let devices = vec![make_device(0, 8192)];
let chunks = dist.distribute(1000, &devices)?;
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].start_idx, 0);
assert_eq!(chunks[0].end_idx, 1000);
Ok(())
}
#[test]
fn test_distribute_covers_all_vectors() -> Result<()> {
let dist = WorkloadDistributor::new();
let devices = vec![make_device(0, 4096), make_device(1, 8192)];
let chunks = dist.distribute(900, &devices)?;
let covered: usize = chunks.iter().map(|c| c.len()).sum();
assert_eq!(covered, 900, "All vectors must be covered");
Ok(())
}
#[test]
fn test_distribute_proportional_to_memory() -> Result<()> {
let dist = WorkloadDistributor::new();
let devices = vec![make_device(0, 1024), make_device(1, 3072)];
let chunks = dist.distribute(1000, &devices)?;
assert_eq!(chunks.len(), 2);
let c0 = &chunks[0];
let c1 = &chunks[1];
assert!(
c0.len() <= 300,
"Device 0 should get ~25%, got {}",
c0.len()
);
assert!(
c1.len() >= 700,
"Device 1 should get ~75%, got {}",
c1.len()
);
assert_eq!(c0.start_idx, 0);
assert_eq!(c1.end_idx, 1000);
Ok(())
}
#[test]
fn test_distribute_skips_zero_memory_device() -> Result<()> {
let dist = WorkloadDistributor::new();
let devices = vec![make_device(0, 0), make_device(1, 8192)];
let chunks = dist.distribute(100, &devices)?;
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].device_id, 1);
Ok(())
}
#[test]
fn test_distribute_even_basic() -> Result<()> {
let dist = WorkloadDistributor::new();
let devices = vec![
make_device(0, 4096),
make_device(1, 4096),
make_device(2, 4096),
];
let chunks = dist.distribute_even(9, &devices)?;
assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 9);
for chunk in &chunks {
assert_eq!(chunk.len(), 3);
}
Ok(())
}
#[test]
fn test_distribute_even_with_remainder() -> Result<()> {
let dist = WorkloadDistributor::new();
let devices = vec![make_device(0, 4096), make_device(1, 4096)];
let chunks = dist.distribute_even(7, &devices)?;
assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 7);
assert_eq!(chunks[0].len(), 4);
assert_eq!(chunks[1].len(), 3);
Ok(())
}
#[test]
fn test_distribute_even_empty_devices_error() {
let dist = WorkloadDistributor::new();
assert!(dist.distribute_even(100, &[]).is_err());
}
}