cubecl_cuda/
device.rs

1use cubecl_common::device::{Device, DeviceId};
2
3// It is not clear if CUDA has a limit on the number of bindings it can hold at
4// any given time, but it's highly unlikely that it's more than this. We can
5// also assume that we'll never have more than this many bindings in flight,
6// so it's 'safe' to store only this many bindings.
7pub const CUDA_MAX_BINDINGS: u32 = 1024;
8
9#[derive(new, Clone, PartialEq, Eq, Default, Hash)]
10pub struct CudaDevice {
11    pub index: usize,
12}
13
14impl core::fmt::Debug for CudaDevice {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "Cuda({})", self.index)
17    }
18}
19
20impl Device for CudaDevice {
21    fn from_id(device_id: DeviceId) -> Self {
22        Self {
23            index: device_id.index_id as usize,
24        }
25    }
26
27    fn to_id(&self) -> DeviceId {
28        DeviceId {
29            type_id: 0,
30            index_id: self.index as u32,
31        }
32    }
33
34    fn device_count(_type_id: u16) -> usize {
35        cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize
36    }
37}