1use cubecl_common::device::{Device, DeviceId};
2
3pub const CUDA_MAX_BINDINGS: u32 = 1024;
8
9#[derive(new, Clone, PartialEq, Eq, Default, Hash)]
10pub struct CudaDevice {
11 pub index: usize,
12}
13
14impl core::fmt::Debug for CudaDevice {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(f, "Cuda({})", self.index)
17 }
18}
19
20impl Device for CudaDevice {
21 fn from_id(device_id: DeviceId) -> Self {
22 Self {
23 index: device_id.index_id as usize,
24 }
25 }
26
27 fn to_id(&self) -> DeviceId {
28 DeviceId {
29 type_id: 0,
30 index_id: self.index as u32,
31 }
32 }
33
34 fn device_count(_type_id: u16) -> usize {
35 cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize
36 }
37}