use std::sync::Arc;
#[cfg(feature = "gpu")]
use cudarc::driver::CudaContext;
use crate::error::{DbxError, DbxResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinkType {
PCIe,
NVLink,
None,
}
#[cfg(feature = "gpu")]
pub struct DeviceTopology {
device_count: usize,
p2p_matrix: Vec<Vec<bool>>,
link_types: Vec<Vec<LinkType>>,
devices: Vec<Arc<CudaContext>>,
}
#[cfg(feature = "gpu")]
impl DeviceTopology {
pub fn detect() -> DbxResult<Self> {
let device_count = CudaContext::device_count()
.map_err(|e| DbxError::Gpu(format!("Failed to get device count: {:?}", e)))?;
if device_count == 0 {
return Err(DbxError::Gpu("No CUDA devices found".to_string()));
}
let mut devices = Vec::new();
for i in 0..device_count {
let device = CudaContext::new(i as usize).map_err(|e| {
DbxError::Gpu(format!("Failed to initialize device {}: {:?}", i, e))
})?;
devices.push(device);
}
let device_count_usize = device_count as usize;
let mut p2p_matrix = vec![vec![false; device_count_usize]; device_count_usize];
let mut link_types = vec![vec![LinkType::None; device_count_usize]; device_count_usize];
for i in 0..device_count_usize {
for j in 0..device_count_usize {
if i == j {
p2p_matrix[i][j] = true;
link_types[i][j] = LinkType::NVLink; continue;
}
p2p_matrix[i][j] = false;
link_types[i][j] = LinkType::PCIe;
}
}
Ok(Self {
device_count: device_count_usize,
p2p_matrix,
link_types,
devices,
})
}
pub fn device_count(&self) -> usize {
self.device_count
}
pub fn can_access_peer(&self, i: usize, j: usize) -> bool {
if i >= self.device_count || j >= self.device_count {
return false;
}
self.p2p_matrix[i][j]
}
pub fn link_type(&self, i: usize, j: usize) -> LinkType {
if i >= self.device_count || j >= self.device_count {
return LinkType::None;
}
self.link_types[i][j]
}
pub fn device(&self, i: usize) -> Option<Arc<CudaContext>> {
self.devices.get(i).cloned()
}
pub fn enable_peer_access(&self, i: usize, j: usize) -> DbxResult<()> {
if i >= self.device_count || j >= self.device_count {
return Err(DbxError::Gpu(format!(
"Invalid device indices: {} and {}",
i, j
)));
}
if i == j {
return Ok(()); }
if !self.p2p_matrix[i][j] {
return Err(DbxError::Gpu(format!(
"P2P access not supported between devices {} and {}",
i, j
)));
}
Err(DbxError::NotImplemented(
"P2P access requires cudarc peer access API (not yet available)".to_string(),
))
}
pub fn has_nvlink(&self) -> bool {
for i in 0..self.device_count {
for j in 0..self.device_count {
if i != j && self.link_types[i][j] == LinkType::NVLink {
return true;
}
}
}
false
}
}
#[cfg(not(feature = "gpu"))]
pub struct DeviceTopology;
#[cfg(not(feature = "gpu"))]
impl DeviceTopology {
pub fn detect() -> DbxResult<Self> {
Err(DbxError::NotImplemented(
"GPU acceleration is not enabled".to_string(),
))
}
}