cubecl_hip/
runtime.rs

1use std::{ffi::CStr, mem::MaybeUninit, str::FromStr};
2
3use cubecl_cpp::{
4    hip::HipDialect,
5    register_supported_types,
6    shared::{
7        Architecture, CompilationOptions, CppCompiler, DialectWmmaCompiler, register_wmma_features,
8    },
9};
10
11use cubecl_core::{
12    AtomicFeature, CubeDim, DeviceId, Feature, MemoryConfiguration, Runtime,
13    ir::{Elem, FloatKind},
14};
15use cubecl_hip_sys::HIP_SUCCESS;
16use cubecl_runtime::{
17    ComputeRuntime, DeviceProperties,
18    channel::MutexComputeChannel,
19    client::ComputeClient,
20    memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
21    storage::ComputeStorage,
22};
23
24use crate::{
25    HipWmmaCompiler,
26    compute::{HipContext, HipServer, HipStorage},
27    device::HipDevice,
28};
29
30/// The values that control how a HIP Runtime will perform its calculations.
31#[derive(Default)]
32pub struct RuntimeOptions {
33    /// Configures the memory management.
34    pub memory_config: MemoryConfiguration,
35}
36
37#[derive(Debug)]
38pub struct HipRuntime;
39
40static RUNTIME: ComputeRuntime<HipDevice, Server, MutexComputeChannel<Server>> =
41    ComputeRuntime::new();
42
43pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
44
45type Server = HipServer;
46type Channel = MutexComputeChannel<Server>;
47
48fn create_client<M: DialectWmmaCompiler<HipDialect<M>>>(
49    device: &HipDevice,
50    options: RuntimeOptions,
51) -> ComputeClient<Server, Channel> {
52    #[allow(unused_assignments)]
53    let mut prop_warp_size = 0;
54    #[allow(unused_assignments)]
55    let mut prop_arch_name = "";
56    #[allow(unused_assignments)]
57    let mut prop_max_shared_memory_size = 0;
58    let mut max_cube_count = CubeDim::new_single();
59    #[allow(unused_assignments)]
60    let mut prop_max_threads = 0;
61    let mut max_cube_dim = CubeDim::new_single();
62    unsafe {
63        let mut ll_device_props = MaybeUninit::uninit();
64        let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
65            ll_device_props.as_mut_ptr(),
66            device.index as cubecl_hip_sys::hipDevice_t,
67        );
68        assert_eq!(status, HIP_SUCCESS, "Should get device properties");
69        let ll_device_props = ll_device_props.assume_init();
70        prop_warp_size = ll_device_props.warpSize;
71        prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
72            .to_str()
73            .unwrap();
74        prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
75        max_cube_count.x = ll_device_props.maxGridSize[0] as u32;
76        max_cube_count.y = ll_device_props.maxGridSize[1] as u32;
77        max_cube_count.z = ll_device_props.maxGridSize[2] as u32;
78        prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
79        max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32;
80        max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32;
81        max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32;
82    };
83    let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
84    let arch = M::Architecture::from_str(normalized_arch_name).unwrap();
85    assert_eq!(prop_warp_size as u32, arch.warp_size());
86
87    unsafe {
88        let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
89        assert_eq!(
90            status, HIP_SUCCESS,
91            "Should set the default device for the current thread"
92        );
93    }
94
95    let stream = unsafe {
96        let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
97        let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
98        assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
99        stream
100    };
101
102    let max_memory = unsafe {
103        let free: usize = 0;
104        let total: usize = 0;
105        let status = cubecl_hip_sys::hipMemGetInfo(
106            &free as *const _ as *mut usize,
107            &total as *const _ as *mut usize,
108        );
109        assert_eq!(
110            status, HIP_SUCCESS,
111            "Should get the available memory of the device"
112        );
113        total
114    };
115    let storage = HipStorage::new(stream);
116    let mem_properties = MemoryDeviceProperties {
117        max_page_size: max_memory as u64 / 4,
118        alignment: HipStorage::ALIGNMENT,
119    };
120    let topology = HardwareProperties {
121        plane_size_min: prop_warp_size as u32,
122        plane_size_max: prop_warp_size as u32,
123        // This is a guess - not clear if ROCM has a limit on the number of bindings,
124        // but it's dubious it's more than this.
125        max_bindings: 1024,
126        max_shared_memory_size: prop_max_shared_memory_size,
127        max_cube_count,
128        max_units_per_cube: prop_max_threads,
129        max_cube_dim,
130        num_streaming_multiprocessors: None,
131        num_tensor_cores: None,
132    };
133    let memory_management =
134        MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
135    let mut device_props = DeviceProperties::new(
136        &[Feature::Plane],
137        mem_properties,
138        topology,
139        cubecl_runtime::TimeMeasurement::System,
140    );
141    register_supported_types(&mut device_props);
142    // Not sure if there's a good way to check for support on HIP
143    device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F64)));
144    // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120)
145    // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
146    // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
147
148    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
149    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
150
151    device_props.register_feature(Feature::DynamicLineSize);
152
153    let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
154    register_wmma_features(supported_wmma_combinations, &mut device_props);
155
156    let comp_opts = CompilationOptions {
157        warp_size: arch.warp_size(),
158        grid_constants: false,
159        supports_clusters: false,
160    };
161    let hip_ctx = HipContext::new(memory_management, comp_opts, stream);
162    let server = HipServer::new(hip_ctx);
163    ComputeClient::new(MutexComputeChannel::new(server), device_props, ())
164}
165
166impl Runtime for HipRuntime {
167    type Compiler = HipCompiler;
168    type Server = HipServer;
169    type Channel = MutexComputeChannel<HipServer>;
170    type Device = HipDevice;
171
172    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
173        RUNTIME.client(device, move || {
174            create_client::<HipWmmaCompiler>(device, RuntimeOptions::default())
175        })
176    }
177
178    fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
179        "hip"
180    }
181
182    fn require_array_lengths() -> bool {
183        true
184    }
185
186    fn supported_line_sizes() -> &'static [u8] {
187        &[8, 4, 2, 1]
188    }
189
190    fn max_cube_count() -> (u32, u32, u32) {
191        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
192    }
193
194    fn device_id(device: &Self::Device) -> cubecl_core::DeviceId {
195        DeviceId::new(0, device.index as u32)
196    }
197}