use std::sync::Arc;
#[cfg(feature = "gpu")]
use cudarc::driver::CudaContext;
use crate::error::{DbxError, DbxResult};
#[cfg(feature = "gpu")]
pub struct MultiDeviceCoordinator {
devices: Vec<Arc<CudaContext>>,
}
#[cfg(feature = "gpu")]
impl MultiDeviceCoordinator {
pub fn new(devices: Vec<Arc<CudaContext>>) -> Self {
Self { devices }
}
pub fn device_count(&self) -> usize {
self.devices.len()
}
pub fn device(&self, index: usize) -> Option<Arc<CudaContext>> {
self.devices.get(index).cloned()
}
pub fn parallel_execute<F, R>(&self, f: F) -> DbxResult<Vec<R>>
where
F: Fn(usize, Arc<CudaContext>) -> DbxResult<R> + Send + Sync + 'static,
R: Send + 'static,
{
use std::thread;
let f = Arc::new(f);
let handles: Vec<_> = self
.devices
.iter()
.enumerate()
.map(|(idx, device)| {
let device = Arc::clone(device);
let f = Arc::clone(&f);
thread::spawn(move || f(idx, device))
})
.collect();
let mut results = Vec::new();
for handle in handles {
let result = handle
.join()
.map_err(|_| DbxError::Gpu("Thread join failed".to_string()))??;
results.push(result);
}
Ok(results)
}
pub fn synchronize_all(&self) -> DbxResult<()> {
for device in &self.devices {
device
.synchronize()
.map_err(|e| DbxError::Gpu(format!("Device sync failed: {:?}", e)))?;
}
Ok(())
}
}
#[cfg(not(feature = "gpu"))]
pub struct MultiDeviceCoordinator;
#[cfg(not(feature = "gpu"))]
impl MultiDeviceCoordinator {
pub fn new(_devices: Vec<()>) -> Self {
Self
}
}