use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::context::Context;
use crate::device::Device;
use crate::error::{CudaError, CudaResult};
pub struct DevicePool {
entries: Vec<(Device, Arc<Context>)>,
round_robin: AtomicUsize,
}
unsafe impl Send for DevicePool {}
unsafe impl Sync for DevicePool {}
impl DevicePool {
pub fn new() -> CudaResult<Self> {
let devices = crate::device::list_devices()?;
if devices.is_empty() {
return Err(CudaError::NoDevice);
}
Self::with_devices(&devices)
}
pub fn with_devices(devices: &[Device]) -> CudaResult<Self> {
if devices.is_empty() {
return Err(CudaError::InvalidValue);
}
let mut entries = Vec::with_capacity(devices.len());
for dev in devices {
let ctx = Context::new(dev)?;
entries.push((*dev, Arc::new(ctx)));
}
Ok(Self {
entries,
round_robin: AtomicUsize::new(0),
})
}
pub fn context(&self, device_ordinal: i32) -> CudaResult<&Arc<Context>> {
self.entries
.iter()
.find(|(dev, _)| dev.ordinal() == device_ordinal)
.map(|(_, ctx)| ctx)
.ok_or(CudaError::InvalidDevice)
}
#[inline]
pub fn device_count(&self) -> usize {
self.entries.len()
}
pub fn best_available_device(&self) -> CudaResult<Device> {
let mut best_dev = self.entries[0].0;
let mut best_mem: usize = 0;
for (dev, _ctx) in &self.entries {
let mem = dev.total_memory()?;
if mem > best_mem {
best_mem = mem;
best_dev = *dev;
}
}
Ok(best_dev)
}
pub fn next_device(&self) -> CudaResult<Device> {
let idx = self.round_robin.fetch_add(1, Ordering::Relaxed) % self.entries.len();
Ok(self.entries[idx].0)
}
pub fn iter(&self) -> impl Iterator<Item = (&Device, &Arc<Context>)> {
self.entries.iter().map(|(dev, ctx)| (dev, ctx))
}
pub fn context_at(&self, index: usize) -> CudaResult<&Arc<Context>> {
self.entries
.get(index)
.map(|(_, ctx)| ctx)
.ok_or(CudaError::InvalidValue)
}
pub fn device_at(&self, index: usize) -> CudaResult<Device> {
self.entries
.get(index)
.map(|(dev, _)| *dev)
.ok_or(CudaError::InvalidValue)
}
}
impl std::fmt::Debug for DevicePool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DevicePool")
.field("device_count", &self.entries.len())
.field(
"devices",
&self
.entries
.iter()
.map(|(d, _)| d.ordinal())
.collect::<Vec<_>>(),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_with_empty_devices_returns_error() {
let result = DevicePool::with_devices(&[]);
assert!(result.is_err());
assert_eq!(result.err(), Some(CudaError::InvalidValue),);
}
#[test]
fn pool_new_returns_error_without_driver() {
let _result = DevicePool::new();
}
#[test]
fn device_pool_debug_format() {
let fmt = format!("{:?}", "DevicePool placeholder");
assert!(!fmt.is_empty());
}
#[test]
fn round_robin_counter_wraps() {
let counter = AtomicUsize::new(0);
let pool_size = 3;
for expected in [0, 1, 2, 0, 1, 2, 0] {
let idx = counter.fetch_add(1, Ordering::Relaxed) % pool_size;
assert_eq!(idx, expected);
}
}
#[test]
fn round_robin_single_device() {
let counter = AtomicUsize::new(0);
let pool_size = 1;
for _ in 0..10 {
let idx = counter.fetch_add(1, Ordering::Relaxed) % pool_size;
assert_eq!(idx, 0);
}
}
#[test]
fn context_at_out_of_bounds_returns_error() {
let err = CudaError::InvalidValue;
assert_eq!(err.as_raw(), 1);
}
#[cfg(feature = "gpu-tests")]
#[test]
fn pool_with_real_devices() {
crate::init().ok();
let result = DevicePool::new();
if let Ok(pool) = result {
assert!(pool.device_count() > 0);
let dev = pool.next_device().expect("next_device failed");
assert!(pool.context(dev.ordinal()).is_ok());
let best = pool.best_available_device().expect("best_available failed");
assert!(best.total_memory().is_ok());
for (d, _c) in pool.iter() {
assert!(d.name().is_ok());
}
}
}
}