cubecl_hip/
device.rs

1use cubecl_common::device::{Device, DeviceId};
2use cubecl_hip_sys::{HIP_SUCCESS, hipGetDeviceCount};
3use std::ffi::c_int;
4
5// It is not clear if AMD has a limit on the number of bindings it can hold at
6// any given time, but it's highly unlikely that it's more than this. We can
7// also assume that we'll never have more than this many bindings in flight,
8// so it's 'safe' to store only this many bindings.
9pub const AMD_MAX_BINDINGS: u32 = 1024;
10
11#[derive(new, Clone, PartialEq, Eq, Default, Hash)]
12pub struct AmdDevice {
13    pub index: usize,
14}
15
16impl core::fmt::Debug for AmdDevice {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.write_fmt(format_args!("AmdDevice({})", self.index))
19    }
20}
21
22impl Device for AmdDevice {
23    fn from_id(device_id: DeviceId) -> Self {
24        Self {
25            index: device_id.index_id as usize,
26        }
27    }
28
29    fn to_id(&self) -> DeviceId {
30        DeviceId {
31            type_id: 0,
32            index_id: self.index as u32,
33        }
34    }
35
36    fn device_count(_type_id: u16) -> usize {
37        let mut device_count: c_int = 0;
38        let result;
39        unsafe {
40            result = hipGetDeviceCount(&mut device_count);
41        }
42        if result == HIP_SUCCESS {
43            device_count.try_into().unwrap_or(0)
44        } else {
45            0
46        }
47    }
48}