use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Instant;
use oxicuda_driver::{CudaError, CudaResult};
pub enum DeviceSelectionPolicy {
RoundRobin,
LeastLoaded,
MostMemoryFree,
BestCompute,
WeightedRandom {
weights: Vec<f64>,
},
#[allow(clippy::type_complexity)]
Custom(Box<dyn Fn(&[DeviceStatus]) -> usize + Send + Sync>),
}
impl std::fmt::Debug for DeviceSelectionPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RoundRobin => write!(f, "RoundRobin"),
Self::LeastLoaded => write!(f, "LeastLoaded"),
Self::MostMemoryFree => write!(f, "MostMemoryFree"),
Self::BestCompute => write!(f, "BestCompute"),
Self::WeightedRandom { weights } => f
.debug_struct("WeightedRandom")
.field("weights", weights)
.finish(),
Self::Custom(_) => write!(f, "Custom(<closure>)"),
}
}
}
#[derive(Debug, Clone)]
pub struct DeviceStatus {
pub ordinal: i32,
pub name: String,
pub total_memory: usize,
pub active_tasks: u32,
pub compute_capability: (u32, u32),
pub sm_count: u32,
pub is_available: bool,
}
#[derive(Debug, Clone)]
pub struct GpuTask {
pub id: u64,
pub device_ordinal: i32,
pub description: String,
pub started_at: Instant,
}
pub struct GpuLease {
device_ordinal: i32,
task_id: u64,
description: String,
pool: Arc<RwLock<PoolInner>>,
}
impl GpuLease {
#[inline]
pub fn ordinal(&self) -> i32 {
self.device_ordinal
}
#[inline]
pub fn description(&self) -> &str {
&self.description
}
#[inline]
pub fn task_id(&self) -> u64 {
self.task_id
}
}
impl Drop for GpuLease {
fn drop(&mut self) {
if let Ok(mut inner) = self.pool.write() {
inner.decrement_tasks(self.device_ordinal);
}
}
}
impl std::fmt::Debug for GpuLease {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuLease")
.field("device_ordinal", &self.device_ordinal)
.field("task_id", &self.task_id)
.field("description", &self.description)
.finish()
}
}
struct PoolInner {
devices: Vec<DeviceStatus>,
policy: DeviceSelectionPolicy,
rr_cursor: usize,
}
impl PoolInner {
fn select(&mut self) -> CudaResult<usize> {
if self.devices.is_empty() {
return Err(CudaError::NoDevice);
}
let idx = match &self.policy {
DeviceSelectionPolicy::RoundRobin => {
let idx = self.rr_cursor % self.devices.len();
self.rr_cursor = self.rr_cursor.wrapping_add(1);
idx
}
DeviceSelectionPolicy::LeastLoaded => self
.devices
.iter()
.enumerate()
.filter(|(_, d)| d.is_available)
.min_by_key(|(_, d)| d.active_tasks)
.map(|(i, _)| i)
.unwrap_or(0),
DeviceSelectionPolicy::MostMemoryFree => self
.devices
.iter()
.enumerate()
.filter(|(_, d)| d.is_available)
.max_by_key(|(_, d)| d.total_memory)
.map(|(i, _)| i)
.unwrap_or(0),
DeviceSelectionPolicy::BestCompute => self
.devices
.iter()
.enumerate()
.filter(|(_, d)| d.is_available)
.max_by_key(|(_, d)| d.compute_capability)
.map(|(i, _)| i)
.unwrap_or(0),
DeviceSelectionPolicy::WeightedRandom { weights } => {
weighted_random_select(weights, self.devices.len())
}
DeviceSelectionPolicy::Custom(f) => {
let idx = f(&self.devices);
if idx >= self.devices.len() { 0 } else { idx }
}
};
Ok(idx)
}
fn increment_tasks(&mut self, ordinal: i32) {
if let Some(dev) = self.devices.iter_mut().find(|d| d.ordinal == ordinal) {
dev.active_tasks = dev.active_tasks.saturating_add(1);
}
}
fn decrement_tasks(&mut self, ordinal: i32) {
if let Some(dev) = self.devices.iter_mut().find(|d| d.ordinal == ordinal) {
dev.active_tasks = dev.active_tasks.saturating_sub(1);
}
}
}
fn weighted_random_select(weights: &[f64], device_count: usize) -> usize {
if weights.is_empty() || device_count == 0 {
return 0;
}
let total: f64 = weights.iter().take(device_count).sum();
if total <= 0.0 {
return 0;
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let r = (f64::from(nanos % 1_000_000) / 1_000_000.0) * total;
let mut cumulative = 0.0;
for (i, w) in weights.iter().take(device_count).enumerate() {
cumulative += w;
if r < cumulative {
return i;
}
}
device_count.saturating_sub(1)
}
pub struct MultiGpuPool {
inner: Arc<RwLock<PoolInner>>,
next_task_id: AtomicU64,
}
impl MultiGpuPool {
pub fn new(policy: DeviceSelectionPolicy) -> CudaResult<Self> {
let devices = Self::discover_devices();
let inner = PoolInner {
devices,
policy,
rr_cursor: 0,
};
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
next_task_id: AtomicU64::new(1),
})
}
pub fn with_devices(ordinals: Vec<i32>, policy: DeviceSelectionPolicy) -> CudaResult<Self> {
if ordinals.is_empty() {
return Err(CudaError::InvalidValue);
}
let devices: Vec<DeviceStatus> = ordinals
.iter()
.map(|&ord| Self::device_status_for(ord))
.collect();
let inner = PoolInner {
devices,
policy,
rr_cursor: 0,
};
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
next_task_id: AtomicU64::new(1),
})
}
pub fn acquire(&self) -> CudaResult<GpuLease> {
self.acquire_with_description(String::new())
}
pub fn acquire_with_description(&self, description: String) -> CudaResult<GpuLease> {
let mut inner = self.inner.write().map_err(|_| CudaError::InvalidValue)?;
let idx = inner.select()?;
let ordinal = inner.devices[idx].ordinal;
inner.increment_tasks(ordinal);
let task_id = self.next_task_id.fetch_add(1, Ordering::Relaxed);
Ok(GpuLease {
device_ordinal: ordinal,
task_id,
description,
pool: Arc::clone(&self.inner),
})
}
pub fn release(&self, lease: GpuLease) {
drop(lease);
}
pub fn device_count(&self) -> usize {
self.inner
.read()
.map(|inner| inner.devices.len())
.unwrap_or(0)
}
pub fn status(&self) -> Vec<DeviceStatus> {
self.inner
.read()
.map(|inner| inner.devices.clone())
.unwrap_or_default()
}
pub fn set_policy(&self, policy: DeviceSelectionPolicy) -> CudaResult<()> {
let mut inner = self.inner.write().map_err(|_| CudaError::InvalidValue)?;
inner.policy = policy;
Ok(())
}
fn discover_devices() -> Vec<DeviceStatus> {
if let Ok(devices) = Self::try_discover_real() {
if !devices.is_empty() {
return devices;
}
}
Self::synthetic_devices(2)
}
fn try_discover_real() -> CudaResult<Vec<DeviceStatus>> {
oxicuda_driver::init()?;
let count = oxicuda_driver::Device::count()?;
let mut out = Vec::with_capacity(count as usize);
for i in 0..count {
let dev = oxicuda_driver::Device::get(i)?;
let name = dev.name().unwrap_or_else(|_| format!("GPU-{i}"));
let total_memory = dev.total_memory().unwrap_or(0);
let cc = dev.compute_capability().unwrap_or((0, 0));
let sm = dev.multiprocessor_count().unwrap_or(0);
out.push(DeviceStatus {
ordinal: i,
name,
total_memory,
active_tasks: 0,
compute_capability: (cc.0 as u32, cc.1 as u32),
sm_count: sm as u32,
is_available: true,
});
}
Ok(out)
}
fn synthetic_devices(n: usize) -> Vec<DeviceStatus> {
(0..n)
.map(|i| DeviceStatus {
ordinal: i as i32,
name: format!("Synthetic GPU {i}"),
total_memory: if i == 0 {
16 * 1024 * 1024 * 1024 } else {
8 * 1024 * 1024 * 1024 },
active_tasks: 0,
compute_capability: (8, i as u32),
sm_count: (108 - (i as u32) * 24),
is_available: true,
})
.collect()
}
fn device_status_for(ordinal: i32) -> DeviceStatus {
if oxicuda_driver::init().is_ok() {
if let Ok(dev) = oxicuda_driver::Device::get(ordinal) {
let name = dev.name().unwrap_or_else(|_| format!("GPU-{ordinal}"));
let total_memory = dev.total_memory().unwrap_or(0);
let cc = dev.compute_capability().unwrap_or((0, 0));
let sm = dev.multiprocessor_count().unwrap_or(0);
return DeviceStatus {
ordinal,
name,
total_memory,
active_tasks: 0,
compute_capability: (cc.0 as u32, cc.1 as u32),
sm_count: sm as u32,
is_available: true,
};
}
}
DeviceStatus {
ordinal,
name: format!("Synthetic GPU {ordinal}"),
total_memory: if ordinal == 0 {
16 * 1024 * 1024 * 1024 } else {
8 * 1024 * 1024 * 1024 },
active_tasks: 0,
compute_capability: (8, ordinal as u32),
sm_count: (108_u32).saturating_sub(ordinal as u32 * 24),
is_available: true,
}
}
}
impl std::fmt::Debug for MultiGpuPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.device_count();
f.debug_struct("MultiGpuPool")
.field("device_count", &count)
.finish()
}
}
pub struct WorkloadBalancer {
ordinals: Vec<i32>,
}
impl WorkloadBalancer {
pub fn new(pool: &MultiGpuPool) -> Self {
let ordinals: Vec<i32> = pool.status().iter().map(|d| d.ordinal).collect();
Self { ordinals }
}
pub fn distribute_batch<T: Send>(&self, items: Vec<T>) -> Vec<(i32, T)> {
if self.ordinals.is_empty() {
return items.into_iter().map(|item| (-1, item)).collect();
}
items
.into_iter()
.enumerate()
.map(|(i, item)| {
let ord = self.ordinals[i % self.ordinals.len()];
(ord, item)
})
.collect()
}
pub fn parallel_map<T, R, F>(pool: &MultiGpuPool, items: Vec<T>, f: F) -> Vec<R>
where
T: Send + 'static,
R: Send + 'static,
F: Fn(i32, T) -> R + Send + Sync + 'static,
{
let device_count = pool.device_count();
if device_count == 0 || items.is_empty() {
return Vec::new();
}
let ordinals: Vec<i32> = pool.status().iter().map(|d| d.ordinal).collect();
let f = Arc::new(f);
let tagged: Vec<(usize, i32, T)> = items
.into_iter()
.enumerate()
.map(|(i, item)| {
let ord = ordinals[i % ordinals.len()];
(i, ord, item)
})
.collect();
let mut buckets: Vec<Vec<(usize, i32, T)>> =
(0..device_count).map(|_| Vec::new()).collect();
for entry in tagged {
let bucket_idx = ordinals.iter().position(|&o| o == entry.1).unwrap_or(0);
buckets[bucket_idx].push(entry);
}
let mut handles = Vec::with_capacity(device_count);
for bucket in buckets {
let f = Arc::clone(&f);
handles.push(std::thread::spawn(move || {
bucket
.into_iter()
.map(|(idx, ord, item)| (idx, f(ord, item)))
.collect::<Vec<(usize, R)>>()
}));
}
let mut results: Vec<(usize, R)> = Vec::new();
for handle in handles {
if let Ok(partial) = handle.join() {
results.extend(partial);
}
}
results.sort_by_key(|(idx, _)| *idx);
results.into_iter().map(|(_, r)| r).collect()
}
}
impl std::fmt::Debug for WorkloadBalancer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkloadBalancer")
.field("device_count", &self.ordinals.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_pool(policy: DeviceSelectionPolicy) -> MultiGpuPool {
let devices = MultiGpuPool::synthetic_devices(2);
let inner = PoolInner {
devices,
policy,
rr_cursor: 0,
};
MultiGpuPool {
inner: Arc::new(RwLock::new(inner)),
next_task_id: AtomicU64::new(1),
}
}
#[test]
fn pool_creation_with_new() {
let pool = MultiGpuPool::new(DeviceSelectionPolicy::RoundRobin);
assert!(pool.is_ok());
let pool = pool.expect("pool should be created");
assert!(pool.device_count() >= 1);
}
#[test]
fn pool_creation_with_specific_devices() {
let pool = MultiGpuPool::with_devices(vec![0, 1, 2], DeviceSelectionPolicy::LeastLoaded);
assert!(pool.is_ok());
let pool = pool.expect("pool should be created");
assert_eq!(pool.device_count(), 3);
}
#[test]
fn pool_empty_devices_returns_error() {
let result = MultiGpuPool::with_devices(vec![], DeviceSelectionPolicy::RoundRobin);
assert!(result.is_err());
}
#[test]
fn round_robin_cycles_correctly() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let l0 = pool.acquire().expect("acquire 0");
let l1 = pool.acquire().expect("acquire 1");
let l2 = pool.acquire().expect("acquire 2");
let l3 = pool.acquire().expect("acquire 3");
assert_eq!(l0.ordinal(), 0);
assert_eq!(l1.ordinal(), 1);
assert_eq!(l2.ordinal(), 0);
assert_eq!(l3.ordinal(), 1);
}
#[test]
fn least_loaded_selects_idle_device() {
let pool = synthetic_pool(DeviceSelectionPolicy::LeastLoaded);
let lease0 = pool.acquire().expect("first acquire");
let lease1 = pool.acquire().expect("second acquire");
assert_ne!(lease0.ordinal(), lease1.ordinal());
}
#[test]
fn acquire_release_task_counting() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let lease = pool.acquire().expect("acquire");
let ord = lease.ordinal();
let status_before: Vec<_> = pool
.status()
.into_iter()
.filter(|d| d.ordinal == ord)
.collect();
assert_eq!(status_before[0].active_tasks, 1);
pool.release(lease);
let status_after: Vec<_> = pool
.status()
.into_iter()
.filter(|d| d.ordinal == ord)
.collect();
assert_eq!(status_after[0].active_tasks, 0);
}
#[test]
fn gpu_lease_drop_auto_releases() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let ord;
{
let lease = pool.acquire().expect("acquire");
ord = lease.ordinal();
let tasks = pool
.status()
.iter()
.find(|d| d.ordinal == ord)
.map(|d| d.active_tasks)
.unwrap_or(0);
assert_eq!(tasks, 1);
} let tasks = pool
.status()
.iter()
.find(|d| d.ordinal == ord)
.map(|d| d.active_tasks)
.unwrap_or(0);
assert_eq!(tasks, 0);
}
#[test]
fn device_status_reporting() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let statuses = pool.status();
assert_eq!(statuses.len(), 2);
for s in &statuses {
assert!(s.is_available);
assert!(!s.name.is_empty());
assert!(s.total_memory > 0);
assert!(s.sm_count > 0);
}
}
#[test]
fn workload_balancer_distribution() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let balancer = WorkloadBalancer::new(&pool);
let items: Vec<i32> = (0..6).collect();
let distributed = balancer.distribute_batch(items);
assert_eq!(distributed.len(), 6);
assert_eq!(distributed[0].0, 0);
assert_eq!(distributed[1].0, 1);
assert_eq!(distributed[2].0, 0);
assert_eq!(distributed[3].0, 1);
}
#[test]
fn policy_switching_at_runtime() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let l0 = pool.acquire().expect("rr acquire");
assert_eq!(l0.ordinal(), 0);
pool.release(l0);
pool.set_policy(DeviceSelectionPolicy::MostMemoryFree)
.expect("set_policy");
let l1 = pool.acquire().expect("most-memory acquire");
assert_eq!(l1.ordinal(), 0); }
#[test]
fn single_device_pool() {
let pool = MultiGpuPool::with_devices(vec![0], DeviceSelectionPolicy::RoundRobin)
.expect("single-device pool");
assert_eq!(pool.device_count(), 1);
let l0 = pool.acquire().expect("acquire 0");
let l1 = pool.acquire().expect("acquire 1");
assert_eq!(l0.ordinal(), 0);
assert_eq!(l1.ordinal(), 0);
}
#[test]
fn best_compute_selects_highest() {
let pool = synthetic_pool(DeviceSelectionPolicy::BestCompute);
let lease = pool.acquire().expect("best compute acquire");
assert_eq!(lease.ordinal(), 1);
}
#[test]
fn custom_policy_selects_correctly() {
let policy = DeviceSelectionPolicy::Custom(Box::new(|statuses: &[DeviceStatus]| {
statuses.len().saturating_sub(1)
}));
let pool = synthetic_pool(policy);
let lease = pool.acquire().expect("custom acquire");
assert_eq!(lease.ordinal(), 1);
}
#[test]
fn parallel_map_preserves_order() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let items: Vec<i32> = (0..8).collect();
let results = WorkloadBalancer::parallel_map(&pool, items, |_device, x| x * 2);
assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14]);
}
#[test]
fn acquire_with_description() {
let pool = synthetic_pool(DeviceSelectionPolicy::RoundRobin);
let lease = pool
.acquire_with_description("matrix multiply".into())
.expect("acquire with desc");
assert_eq!(lease.description(), "matrix multiply");
assert!(lease.task_id() > 0);
}
}