cubecl_hip/
runtime.rs

1use std::{ffi::CStr, mem::MaybeUninit, str::FromStr};
2
3use cubecl_cpp::{
4    hip::HipDialect,
5    register_supported_types,
6    shared::{register_wmma_features, Architecture, CompilationOptions, CppCompiler, WmmaCompiler},
7};
8
9use cubecl_core::{
10    ir::{Elem, FloatKind},
11    AtomicFeature, Feature, MemoryConfiguration, Runtime,
12};
13use cubecl_hip_sys::HIP_SUCCESS;
14use cubecl_runtime::{
15    channel::MutexComputeChannel,
16    client::ComputeClient,
17    memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
18    storage::ComputeStorage,
19    ComputeRuntime, DeviceProperties,
20};
21
22use crate::{
23    compute::{HipContext, HipServer, HipStorage},
24    device::HipDevice,
25    HipWmmaCompiler,
26};
27
28/// The values that control how a HIP Runtime will perform its calculations.
29#[derive(Default)]
30pub struct RuntimeOptions {
31    /// Configures the memory management.
32    pub memory_config: MemoryConfiguration,
33}
34
35#[derive(Debug)]
36pub struct HipRuntime;
37
38static RUNTIME: ComputeRuntime<HipDevice, Server, MutexComputeChannel<Server>> =
39    ComputeRuntime::new();
40
41pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
42
43type Server = HipServer;
44type Channel = MutexComputeChannel<Server>;
45
46fn create_client<M: WmmaCompiler<HipDialect<M>>>(
47    device: &HipDevice,
48    options: RuntimeOptions,
49) -> ComputeClient<Server, Channel> {
50    #[allow(unused_assignments)]
51    let mut prop_warp_size = 0;
52    #[allow(unused_assignments)]
53    let mut prop_arch_name = "";
54    unsafe {
55        let mut ll_device_props = MaybeUninit::uninit();
56        let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
57            ll_device_props.as_mut_ptr(),
58            device.index as cubecl_hip_sys::hipDevice_t,
59        );
60        assert_eq!(status, HIP_SUCCESS, "Should get device properties");
61        let ll_device_props = ll_device_props.assume_init();
62        prop_warp_size = ll_device_props.warpSize;
63        prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
64            .to_str()
65            .unwrap();
66    };
67    let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
68    let arch = M::Architecture::from_str(normalized_arch_name).unwrap();
69    assert_eq!(prop_warp_size as u32, arch.warp_size());
70
71    unsafe {
72        let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
73        assert_eq!(
74            status, HIP_SUCCESS,
75            "Should set the default device for the current thread"
76        );
77    }
78
79    let stream = unsafe {
80        let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
81        let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
82        assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
83        stream
84    };
85
86    let max_memory = unsafe {
87        let free: usize = 0;
88        let total: usize = 0;
89        let status = cubecl_hip_sys::hipMemGetInfo(
90            &free as *const _ as *mut usize,
91            &total as *const _ as *mut usize,
92        );
93        assert_eq!(
94            status, HIP_SUCCESS,
95            "Should get the available memory of the device"
96        );
97        total
98    };
99    let storage = HipStorage::new(stream);
100    let mem_properties = MemoryDeviceProperties {
101        max_page_size: max_memory as u64 / 4,
102        alignment: HipStorage::ALIGNMENT,
103    };
104    let topology = HardwareProperties {
105        plane_size_min: prop_warp_size as u32,
106        plane_size_max: prop_warp_size as u32,
107        // This is a guess - not clear if ROCM has a limit on the number of bindings,
108        // but it's dubious it's more than this.
109        max_bindings: 1024,
110    };
111    let memory_management = MemoryManagement::from_configuration(
112        storage,
113        mem_properties.clone(),
114        options.memory_config,
115    );
116    let mut device_props = DeviceProperties::new(&[Feature::Plane], mem_properties, topology);
117    register_supported_types(&mut device_props);
118    // Not sure if there's a good way to check for support on HIP
119    device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F64)));
120    device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
121    device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
122
123    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
124    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
125
126    let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
127    register_wmma_features(supported_wmma_combinations, &mut device_props);
128
129    let comp_opts = CompilationOptions {
130        warp_size: arch.warp_size(),
131    };
132    let hip_ctx = HipContext::new(memory_management, comp_opts, stream);
133    let server = HipServer::new(hip_ctx);
134    ComputeClient::new(MutexComputeChannel::new(server), device_props)
135}
136
137impl Runtime for HipRuntime {
138    type Compiler = HipCompiler;
139    type Server = HipServer;
140    type Channel = MutexComputeChannel<HipServer>;
141    type Device = HipDevice;
142
143    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
144        RUNTIME.client(device, move || {
145            create_client::<HipWmmaCompiler>(device, RuntimeOptions::default())
146        })
147    }
148
149    fn name() -> &'static str {
150        "hip"
151    }
152
153    fn require_array_lengths() -> bool {
154        true
155    }
156
157    fn supported_line_sizes() -> &'static [u8] {
158        &[8, 4, 2, 1]
159    }
160
161    fn max_cube_count() -> (u32, u32, u32) {
162        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
163    }
164
165    fn extension() -> &'static str {
166        "hip"
167    }
168}