Skip to main content

cubecl_hip/
runtime.rs

1use crate::{
2    HipWmmaCompiler,
3    compute::{HipServer, context::HipContext},
4    device::AmdDevice,
5};
6use core::ffi::c_int;
7use cubecl_common::{
8    device::{Device, DeviceService},
9    profile::TimingMethod,
10};
11use cubecl_core::{
12    MemoryConfiguration, Runtime,
13    device::{DeviceId, ServerUtilitiesHandle},
14    ir::{
15        ContiguousElements, DeviceProperties, HardwareProperties, MatrixLayout,
16        MemoryDeviceProperties, MmaProperties, TargetProperties, VectorSize, features::Plane,
17    },
18    server::ServerUtilities,
19    zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
20};
21use cubecl_cpp::{
22    ComputeKernel,
23    hip::{HipDialect, arch::AMDArchitecture, mma::contiguous_elements_rdna3},
24    register_supported_types,
25    shared::{
26        Architecture, CompilationOptions, CppCompiler, CppSupportedFeatures, DialectWmmaCompiler,
27        register_mma_features, register_scaled_mma_features, register_wmma_features,
28    },
29};
30use cubecl_hip_sys::{HIP_SUCCESS, hipDeviceScheduleSpin, hipGetDeviceCount, hipSetDeviceFlags};
31use cubecl_runtime::{
32    allocator::PitchedMemoryLayoutPolicy, client::ComputeClient, logging::ServerLogger,
33};
34use std::{ffi::CStr, mem::MaybeUninit, sync::Arc};
35
36/// The values that control how a HIP Runtime will perform its calculations.
37#[derive(Default)]
38pub struct RuntimeOptions {
39    /// Configures the memory management.
40    pub memory_config: MemoryConfiguration,
41}
42
43#[derive(Debug, Clone)]
44pub struct HipRuntime;
45
46pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
47pub type HipComputeKernel = ComputeKernel<HipDialect<HipWmmaCompiler>>;
48
49impl DeviceService for HipServer {
50    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
51        let device = AmdDevice::from_id(device_id);
52
53        #[allow(unused_assignments)]
54        let mut prop_warp_size = 0;
55        #[allow(unused_assignments)]
56        let mut prop_arch_name = "";
57        #[allow(unused_assignments)]
58        let mut prop_max_shared_memory_size = 0;
59        #[allow(unused_assignments)]
60        let mut max_cube_count = (1, 1, 1);
61        #[allow(unused_assignments)]
62        let mut prop_max_threads = 0;
63        let mut max_cube_dim = (1, 1, 1);
64        let mut mem_alignment = 32;
65        // SAFETY: Calling HIP FFI to query device properties. The `MaybeUninit` is
66        // initialized by `hipGetDevicePropertiesR0600` on success (asserted below), so
67        // `assume_init()` is valid. The device index is validated by the `AmdDevice` constructor.
68        unsafe {
69            let mut ll_device_props = MaybeUninit::uninit();
70            let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
71                ll_device_props.as_mut_ptr(),
72                device.index as cubecl_hip_sys::hipDevice_t,
73            );
74            assert_eq!(status, HIP_SUCCESS, "Should get device properties");
75            let ll_device_props = ll_device_props.assume_init();
76            prop_warp_size = ll_device_props.warpSize;
77            prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
78                .to_str()
79                .unwrap();
80            prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
81            max_cube_count = (
82                ll_device_props.maxGridSize[0] as u32,
83                ll_device_props.maxGridSize[1] as u32,
84                ll_device_props.maxGridSize[2] as u32,
85            );
86            prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
87            max_cube_dim.0 = ll_device_props.maxThreadsDim[0] as u32;
88            max_cube_dim.1 = ll_device_props.maxThreadsDim[1] as u32;
89            max_cube_dim.2 = ll_device_props.maxThreadsDim[2] as u32;
90
91            // Just to be sure we check both.
92            mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment);
93            mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment);
94        };
95        let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
96        let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
97        assert_eq!(prop_warp_size as u32, arch.warp_size());
98
99        // SAFETY: Calling HIP FFI to set the active device and configure spin-wait scheduling
100        // for the current thread. The device index has been validated above by a successful
101        // `hipGetDevicePropertiesR0600` call.
102        unsafe {
103            let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
104            hipSetDeviceFlags(hipDeviceScheduleSpin);
105
106            assert_eq!(
107                status, HIP_SUCCESS,
108                "Should set the default device for the current thread"
109            );
110        }
111
112        // SAFETY: Calling HIP FFI to query device memory info. The pointers to `free` and
113        // `total` are valid stack variables cast to mutable pointers; HIP writes the values
114        // through them on success (asserted below).
115        let max_memory = unsafe {
116            let free: usize = 0;
117            let total: usize = 0;
118            let status = cubecl_hip_sys::hipMemGetInfo(
119                &free as *const _ as *mut usize,
120                &total as *const _ as *mut usize,
121            );
122            assert_eq!(
123                status, HIP_SUCCESS,
124                "Should get the available memory of the device"
125            );
126            total
127        };
128        let mem_properties = MemoryDeviceProperties {
129            max_page_size: max_memory as u64 / 4,
130            alignment: mem_alignment as u64,
131        };
132
133        let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch);
134        let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch);
135        let supported_scaled_mma_combinations =
136            HipWmmaCompiler::supported_scaled_mma_combinations(&arch);
137
138        let topology = HardwareProperties {
139            load_width: 128,
140            plane_size_min: prop_warp_size as u32,
141            plane_size_max: prop_warp_size as u32,
142            max_bindings: crate::device::AMD_MAX_BINDINGS,
143            max_shared_memory_size: prop_max_shared_memory_size,
144            max_cube_count,
145            max_units_per_cube: prop_max_threads,
146            max_cube_dim,
147            num_streaming_multiprocessors: None,
148            num_tensor_cores: None,
149            min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
150                None
151            } else {
152                Some(16)
153            },
154            num_cpu_cores: None,
155            max_vector_size: VectorSize::MAX,
156        };
157
158        let mut device_props = DeviceProperties::new(
159            Default::default(),
160            mem_properties.clone(),
161            topology,
162            TimingMethod::System,
163        );
164        register_supported_types(&mut device_props);
165
166        // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120)
167        // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
168        // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
169
170        device_props.features.memory_reinterpret = true;
171        device_props.features.alignment = true;
172        device_props.features.plane.insert(Plane::Ops);
173        device_props
174            .features
175            .plane
176            .insert(Plane::NonUniformControlFlow);
177
178        register_wmma_features(supported_wmma_combinations, &mut device_props);
179        register_mma_features(supported_mma_combinations, &mut device_props);
180        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
181
182        let comp_opts = CompilationOptions {
183            warp_size: arch.warp_size(),
184            supports_features: CppSupportedFeatures {
185                fast_math: true,
186                ..Default::default()
187            },
188        };
189        let hip_ctx = HipContext::new(comp_opts, device_props.clone());
190        let logger = Arc::new(ServerLogger::default());
191        let policy = PitchedMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
192        let utilities = ServerUtilities::new(device_props, logger, (), policy);
193        let options = RuntimeOptions::default();
194
195        // SAFETY: `is_integrated_gpu` calls HIP FFI functions with a valid device index.
196        let is_integrated = unsafe { is_integrated_gpu(device_id.index_id as i32) };
197
198        HipServer::new(
199            hip_ctx,
200            mem_properties,
201            options.memory_config,
202            mem_alignment,
203            is_integrated,
204            utilities,
205        )
206    }
207
208    fn utilities(&self) -> ServerUtilitiesHandle {
209        self.utilities() as ServerUtilitiesHandle
210    }
211}
212
213impl Runtime for HipRuntime {
214    type Compiler = HipCompiler;
215    type Server = HipServer;
216    type Device = AmdDevice;
217
218    fn client(device: &Self::Device) -> ComputeClient<Self> {
219        ComputeClient::load(device)
220    }
221
222    fn name(_client: &ComputeClient<Self>) -> &'static str {
223        "hip"
224    }
225
226    fn require_array_lengths() -> bool {
227        true
228    }
229
230    fn max_cube_count() -> (u32, u32, u32) {
231        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
232    }
233
234    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
235        if shape.is_empty() {
236            return true;
237        }
238        has_pitched_row_major_strides(shape, strides)
239    }
240
241    fn target_properties() -> TargetProperties {
242        TargetProperties {
243            mma: MmaProperties {
244                register_size_bits: 32,
245                const_plane_size: 32,
246                register_layout_a: MatrixLayout::RowMajor,
247                register_layout_b: MatrixLayout::ColMajor,
248                register_layout_acc: MatrixLayout::ColMajor,
249                register_duplication_a: 2,
250                register_duplication_b: 2,
251                register_duplication_acc: 1,
252                contiguous_elements: ContiguousElements::new(contiguous_elements_rdna3),
253            },
254        }
255    }
256
257    fn enumerate_devices(
258        _: u16,
259        _: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
260    ) -> Vec<cubecl_core::device::DeviceId> {
261        fn device_count() -> usize {
262            let mut device_count: c_int = 0;
263            let result;
264            // SAFETY: Calling HIP FFI to get the number of available devices.
265            // `device_count` is a valid mutable pointer to a stack-allocated `c_int`.
266            unsafe {
267                result = hipGetDeviceCount(&mut device_count);
268            }
269            if result == HIP_SUCCESS {
270                device_count.try_into().unwrap_or(0)
271            } else {
272                0
273            }
274        }
275        (0..device_count())
276            .map(|i| DeviceId::new(0, i as u16))
277            .collect()
278    }
279}
280
281/// Checks whether the GPU with the given device ID is an integrated (APU) device.
282///
283/// # Safety
284///
285/// Calls HIP FFI functions. The caller must ensure `device_id` is a valid HIP device index.
286unsafe fn is_integrated_gpu(device_id: i32) -> bool {
287    // SAFETY: `hipDeviceProp_tR0600` is a plain-old-data struct; zeroing it is valid.
288    let mut props = unsafe { std::mem::zeroed::<cubecl_hip_sys::hipDeviceProp_tR0600>() };
289    // SAFETY: `props` is a valid mutable reference and `device_id` is assumed valid by the caller.
290    let status = unsafe { cubecl_hip_sys::hipGetDevicePropertiesR0600(&mut props, device_id) };
291    if status != HIP_SUCCESS {
292        return false; // assume discrete if we can't tell
293    }
294    props.integrated != 0
295}