cubecl_hip/
runtime.rs

1use crate::{
2    HipWmmaCompiler,
3    compute::{HipServer, context::HipContext},
4    device::AmdDevice,
5};
6use cubecl_common::{
7    device::{Device, DeviceState},
8    profile::TimingMethod,
9};
10use cubecl_core::{
11    MemoryConfiguration, Runtime,
12    ir::{
13        ContiguousElements, DeviceProperties, HardwareProperties, LineSize, MatrixLayout,
14        MemoryDeviceProperties, MmaProperties, TargetProperties, features::Plane,
15    },
16    server::ServerUtilities,
17};
18use cubecl_cpp::{
19    hip::{HipDialect, arch::AMDArchitecture, mma::contiguous_elements_rdna3},
20    register_supported_types,
21    shared::{
22        Architecture, CompilationOptions, CppCompiler, CppSupportedFeatures, DialectWmmaCompiler,
23        register_mma_features, register_scaled_mma_features, register_wmma_features,
24    },
25};
26use cubecl_hip_sys::{HIP_SUCCESS, hipDeviceScheduleSpin, hipSetDeviceFlags};
27use cubecl_runtime::{client::ComputeClient, logging::ServerLogger};
28use cubecl_zspace::striding::has_pitched_row_major_strides;
29use std::{ffi::CStr, mem::MaybeUninit, sync::Arc};
30
31/// The values that control how a HIP Runtime will perform its calculations.
32#[derive(Default)]
33pub struct RuntimeOptions {
34    /// Configures the memory management.
35    pub memory_config: MemoryConfiguration,
36}
37
38#[derive(Debug)]
39pub struct HipRuntime;
40
41pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
42
43impl DeviceState for HipServer {
44    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
45        let device = AmdDevice::from_id(device_id);
46
47        #[allow(unused_assignments)]
48        let mut prop_warp_size = 0;
49        #[allow(unused_assignments)]
50        let mut prop_arch_name = "";
51        #[allow(unused_assignments)]
52        let mut prop_max_shared_memory_size = 0;
53        #[allow(unused_assignments)]
54        let mut max_cube_count = (1, 1, 1);
55        #[allow(unused_assignments)]
56        let mut prop_max_threads = 0;
57        let mut max_cube_dim = (1, 1, 1);
58        let mut mem_alignment = 32;
59        unsafe {
60            let mut ll_device_props = MaybeUninit::uninit();
61            let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
62                ll_device_props.as_mut_ptr(),
63                device.index as cubecl_hip_sys::hipDevice_t,
64            );
65            assert_eq!(status, HIP_SUCCESS, "Should get device properties");
66            let ll_device_props = ll_device_props.assume_init();
67            prop_warp_size = ll_device_props.warpSize;
68            prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
69                .to_str()
70                .unwrap();
71            prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
72            max_cube_count = (
73                ll_device_props.maxGridSize[0] as u32,
74                ll_device_props.maxGridSize[1] as u32,
75                ll_device_props.maxGridSize[2] as u32,
76            );
77            prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
78            max_cube_dim.0 = ll_device_props.maxThreadsDim[0] as u32;
79            max_cube_dim.1 = ll_device_props.maxThreadsDim[1] as u32;
80            max_cube_dim.2 = ll_device_props.maxThreadsDim[2] as u32;
81
82            // Just to be sure we check both.
83            mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment);
84            mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment);
85        };
86        let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
87        let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
88        assert_eq!(prop_warp_size as u32, arch.warp_size());
89
90        unsafe {
91            let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
92            hipSetDeviceFlags(hipDeviceScheduleSpin);
93
94            assert_eq!(
95                status, HIP_SUCCESS,
96                "Should set the default device for the current thread"
97            );
98        }
99
100        let max_memory = unsafe {
101            let free: usize = 0;
102            let total: usize = 0;
103            let status = cubecl_hip_sys::hipMemGetInfo(
104                &free as *const _ as *mut usize,
105                &total as *const _ as *mut usize,
106            );
107            assert_eq!(
108                status, HIP_SUCCESS,
109                "Should get the available memory of the device"
110            );
111            total
112        };
113        let mem_properties = MemoryDeviceProperties {
114            max_page_size: max_memory as u64 / 4,
115            alignment: mem_alignment as u64,
116        };
117
118        let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch);
119        let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch);
120        let supported_scaled_mma_combinations =
121            HipWmmaCompiler::supported_scaled_mma_combinations(&arch);
122
123        let topology = HardwareProperties {
124            load_width: 128,
125            plane_size_min: prop_warp_size as u32,
126            plane_size_max: prop_warp_size as u32,
127            max_bindings: crate::device::AMD_MAX_BINDINGS,
128            max_shared_memory_size: prop_max_shared_memory_size,
129            max_cube_count,
130            max_units_per_cube: prop_max_threads,
131            max_cube_dim,
132            num_streaming_multiprocessors: None,
133            num_tensor_cores: None,
134            min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
135                None
136            } else {
137                Some(16)
138            },
139            num_cpu_cores: None,
140        };
141
142        let mut device_props = DeviceProperties::new(
143            Default::default(),
144            mem_properties.clone(),
145            topology,
146            TimingMethod::System,
147        );
148        register_supported_types(&mut device_props);
149
150        // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120)
151        // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
152        // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
153
154        device_props.features.dynamic_line_size = true;
155        device_props.features.alignment = true;
156        device_props.features.plane.insert(Plane::Ops);
157        device_props
158            .features
159            .plane
160            .insert(Plane::NonUniformControlFlow);
161
162        register_wmma_features(supported_wmma_combinations, &mut device_props);
163        register_mma_features(supported_mma_combinations, &mut device_props);
164        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
165
166        let comp_opts = CompilationOptions {
167            warp_size: arch.warp_size(),
168            supports_features: CppSupportedFeatures {
169                fast_math: true,
170                ..Default::default()
171            },
172        };
173        let hip_ctx = HipContext::new(comp_opts);
174        let logger = Arc::new(ServerLogger::default());
175        let utilities = ServerUtilities::new(device_props, logger, ());
176        let options = RuntimeOptions::default();
177
178        HipServer::new(
179            hip_ctx,
180            mem_properties,
181            options.memory_config,
182            mem_alignment,
183            utilities,
184        )
185    }
186}
187
188impl Runtime for HipRuntime {
189    type Compiler = HipCompiler;
190    type Server = HipServer;
191    type Device = AmdDevice;
192
193    fn client(device: &Self::Device) -> ComputeClient<Self> {
194        ComputeClient::load(device)
195    }
196
197    fn name(_client: &ComputeClient<Self>) -> &'static str {
198        "hip"
199    }
200
201    fn require_array_lengths() -> bool {
202        true
203    }
204
205    fn supported_line_sizes() -> &'static [LineSize] {
206        &[16, 8, 4, 2, 1]
207    }
208
209    fn max_cube_count() -> (u32, u32, u32) {
210        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
211    }
212
213    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
214        if shape.is_empty() {
215            return true;
216        }
217        has_pitched_row_major_strides(shape, strides)
218    }
219
220    fn target_properties() -> TargetProperties {
221        TargetProperties {
222            mma: MmaProperties {
223                register_size_bits: 32,
224                const_plane_size: 32,
225                register_layout_a: MatrixLayout::RowMajor,
226                register_layout_b: MatrixLayout::ColMajor,
227                register_layout_acc: MatrixLayout::ColMajor,
228                register_duplication_a: 2,
229                register_duplication_b: 2,
230                register_duplication_acc: 1,
231                contiguous_elements: ContiguousElements::new(contiguous_elements_rdna3),
232            },
233        }
234    }
235}