cubecl_hip/
runtime.rs

1use std::{ffi::CStr, mem::MaybeUninit};
2
3use cubecl_cpp::{
4    hip::{HipDialect, arch::AMDArchitecture},
5    register_supported_types,
6    shared::{
7        Architecture, CompilationOptions, CppCompiler, DialectWmmaCompiler, register_wmma_features,
8    },
9};
10
11use cubecl_common::profile::TimingMethod;
12use cubecl_core::{
13    AtomicFeature, CubeCount, CubeDim, Feature, MemoryConfiguration, Runtime,
14    ir::{Elem, FloatKind, IntKind, UIntKind},
15};
16use cubecl_hip_sys::HIP_SUCCESS;
17use cubecl_runtime::id::DeviceId;
18use cubecl_runtime::{
19    ComputeRuntime, DeviceProperties,
20    channel::MutexComputeChannel,
21    client::ComputeClient,
22    memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
23};
24
25use crate::{
26    HipWmmaCompiler,
27    compute::{HipContext, HipServer, HipStorage, contiguous_strides},
28    device::AmdDevice,
29};
30
31/// The values that control how a HIP Runtime will perform its calculations.
32#[derive(Default)]
33pub struct RuntimeOptions {
34    /// Configures the memory management.
35    pub memory_config: MemoryConfiguration,
36}
37
38#[derive(Debug)]
39pub struct HipRuntime;
40
41static RUNTIME: ComputeRuntime<AmdDevice, Server, MutexComputeChannel<Server>> =
42    ComputeRuntime::new();
43
44pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
45
46type Server = HipServer;
47type Channel = MutexComputeChannel<Server>;
48
49fn create_client<M: DialectWmmaCompiler<HipDialect<M>>>(
50    device: &AmdDevice,
51    options: RuntimeOptions,
52) -> ComputeClient<Server, Channel> {
53    #[allow(unused_assignments)]
54    let mut prop_warp_size = 0;
55    #[allow(unused_assignments)]
56    let mut prop_arch_name = "";
57    #[allow(unused_assignments)]
58    let mut prop_max_shared_memory_size = 0;
59    #[allow(unused_assignments)]
60    let mut max_cube_count = CubeCount::new_single();
61    #[allow(unused_assignments)]
62    let mut prop_max_threads = 0;
63    let mut max_cube_dim = CubeDim::new_single();
64    let mut mem_aligment = 32;
65    unsafe {
66        let mut ll_device_props = MaybeUninit::uninit();
67        let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
68            ll_device_props.as_mut_ptr(),
69            device.index as cubecl_hip_sys::hipDevice_t,
70        );
71        assert_eq!(status, HIP_SUCCESS, "Should get device properties");
72        let ll_device_props = ll_device_props.assume_init();
73        prop_warp_size = ll_device_props.warpSize;
74        prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
75            .to_str()
76            .unwrap();
77        prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
78        max_cube_count = CubeCount::new_3d(
79            ll_device_props.maxGridSize[0] as u32,
80            ll_device_props.maxGridSize[1] as u32,
81            ll_device_props.maxGridSize[2] as u32,
82        );
83        prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
84        max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32;
85        max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32;
86        max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32;
87
88        // Just to be sure we check both.
89        mem_aligment = usize::max(mem_aligment, ll_device_props.textureAlignment);
90        mem_aligment = usize::max(mem_aligment, ll_device_props.surfaceAlignment);
91    };
92    let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
93    let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
94    assert_eq!(prop_warp_size as u32, arch.warp_size());
95
96    unsafe {
97        let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
98        assert_eq!(
99            status, HIP_SUCCESS,
100            "Should set the default device for the current thread"
101        );
102    }
103
104    let stream = unsafe {
105        let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
106        let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
107        assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
108        stream
109    };
110
111    let max_memory = unsafe {
112        let free: usize = 0;
113        let total: usize = 0;
114        let status = cubecl_hip_sys::hipMemGetInfo(
115            &free as *const _ as *mut usize,
116            &total as *const _ as *mut usize,
117        );
118        assert_eq!(
119            status, HIP_SUCCESS,
120            "Should get the available memory of the device"
121        );
122        total
123    };
124    let storage = HipStorage::new(mem_aligment, stream);
125    let mem_properties = MemoryDeviceProperties {
126        max_page_size: max_memory as u64 / 4,
127        alignment: mem_aligment as u64,
128    };
129    let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
130    let topology = HardwareProperties {
131        plane_size_min: prop_warp_size as u32,
132        plane_size_max: prop_warp_size as u32,
133        // This is a guess - not clear if ROCM has a limit on the number of bindings,
134        // but it's dubious it's more than this.
135        max_bindings: 1024,
136        max_shared_memory_size: prop_max_shared_memory_size,
137        max_cube_count,
138        max_units_per_cube: prop_max_threads,
139        max_cube_dim,
140        num_streaming_multiprocessors: None,
141        num_tensor_cores: None,
142        min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
143            None
144        } else {
145            Some(16)
146        },
147    };
148    let memory_management =
149        MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
150    let mut device_props = DeviceProperties::new(
151        &[Feature::Plane],
152        mem_properties,
153        topology,
154        TimingMethod::System,
155    );
156    register_supported_types(&mut device_props);
157    // Not sure if there's a good way to check for support on HIP
158    device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F32)));
159    // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120)
160    // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
161    // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
162
163    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
164    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
165
166    // Supported by all architectures
167    device_props.register_feature(Feature::Type(Elem::AtomicInt(IntKind::I32)));
168    device_props.register_feature(Feature::Type(Elem::AtomicUInt(UIntKind::U32)));
169    device_props.register_feature(Feature::AtomicInt(AtomicFeature::LoadStore));
170    device_props.register_feature(Feature::AtomicInt(AtomicFeature::Add));
171    device_props.register_feature(Feature::AtomicUInt(AtomicFeature::LoadStore));
172    device_props.register_feature(Feature::AtomicUInt(AtomicFeature::Add));
173
174    device_props.register_feature(Feature::DynamicLineSize);
175
176    register_wmma_features(supported_wmma_combinations, &mut device_props);
177
178    let comp_opts = CompilationOptions {
179        warp_size: arch.warp_size(),
180        grid_constants: false,
181        supports_clusters: false,
182    };
183    let hip_ctx = HipContext::new(memory_management, comp_opts, stream);
184    let server = HipServer::new(mem_aligment, hip_ctx);
185    ComputeClient::new(MutexComputeChannel::new(server), device_props, ())
186}
187
188impl Runtime for HipRuntime {
189    type Compiler = HipCompiler;
190    type Server = HipServer;
191    type Channel = MutexComputeChannel<HipServer>;
192    type Device = AmdDevice;
193
194    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
195        RUNTIME.client(device, move || {
196            create_client::<HipWmmaCompiler>(device, RuntimeOptions::default())
197        })
198    }
199
200    fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
201        "hip"
202    }
203
204    fn require_array_lengths() -> bool {
205        true
206    }
207
208    fn supported_line_sizes() -> &'static [u8] {
209        &[8, 4, 2, 1]
210    }
211
212    fn max_cube_count() -> (u32, u32, u32) {
213        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
214    }
215
216    fn device_id(device: &Self::Device) -> DeviceId {
217        DeviceId::new(0, device.index as u32)
218    }
219
220    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
221        if shape.is_empty() {
222            return true;
223        }
224
225        for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
226            if expected != stride {
227                return false;
228            }
229        }
230
231        true
232    }
233}