cubecl_hip/
runtime.rs

1use crate::{
2    HipWmmaCompiler,
3    compute::{HipServer, context::HipContext, contiguous_strides},
4    device::AmdDevice,
5};
6use cubecl_common::{
7    device::{Device, DeviceState},
8    profile::TimingMethod,
9};
10use cubecl_core::{
11    CubeCount, CubeDim, MemoryConfiguration, Runtime,
12    ir::{MatrixLayout, MmaProperties, TargetProperties},
13    server::ServerUtilities,
14};
15use cubecl_cpp::{
16    hip::{HipDialect, arch::AMDArchitecture},
17    register_supported_types,
18    shared::{
19        Architecture, CompilationOptions, CppCompiler, DialectWmmaCompiler, register_mma_features,
20        register_scaled_mma_features, register_wmma_features,
21    },
22};
23use cubecl_hip_sys::HIP_SUCCESS;
24use cubecl_runtime::{
25    DeviceProperties, Plane,
26    client::ComputeClient,
27    logging::ServerLogger,
28    memory_management::{HardwareProperties, MemoryDeviceProperties},
29};
30use std::{ffi::CStr, mem::MaybeUninit, sync::Arc};
31
32/// The values that control how a HIP Runtime will perform its calculations.
33#[derive(Default)]
34pub struct RuntimeOptions {
35    /// Configures the memory management.
36    pub memory_config: MemoryConfiguration,
37}
38
39#[derive(Debug)]
40pub struct HipRuntime;
41
42pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
43
44impl DeviceState for HipServer {
45    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
46        let device = AmdDevice::from_id(device_id);
47
48        #[allow(unused_assignments)]
49        let mut prop_warp_size = 0;
50        #[allow(unused_assignments)]
51        let mut prop_arch_name = "";
52        #[allow(unused_assignments)]
53        let mut prop_max_shared_memory_size = 0;
54        #[allow(unused_assignments)]
55        let mut max_cube_count = CubeCount::new_single();
56        #[allow(unused_assignments)]
57        let mut prop_max_threads = 0;
58        let mut max_cube_dim = CubeDim::new_single();
59        let mut mem_alignment = 32;
60        unsafe {
61            let mut ll_device_props = MaybeUninit::uninit();
62            let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
63                ll_device_props.as_mut_ptr(),
64                device.index as cubecl_hip_sys::hipDevice_t,
65            );
66            assert_eq!(status, HIP_SUCCESS, "Should get device properties");
67            let ll_device_props = ll_device_props.assume_init();
68            prop_warp_size = ll_device_props.warpSize;
69            prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
70                .to_str()
71                .unwrap();
72            prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
73            max_cube_count = CubeCount::new_3d(
74                ll_device_props.maxGridSize[0] as u32,
75                ll_device_props.maxGridSize[1] as u32,
76                ll_device_props.maxGridSize[2] as u32,
77            );
78            prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
79            max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32;
80            max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32;
81            max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32;
82
83            // Just to be sure we check both.
84            mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment);
85            mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment);
86        };
87        let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
88        let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
89        assert_eq!(prop_warp_size as u32, arch.warp_size());
90
91        unsafe {
92            let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
93            assert_eq!(
94                status, HIP_SUCCESS,
95                "Should set the default device for the current thread"
96            );
97        }
98
99        let max_memory = unsafe {
100            let free: usize = 0;
101            let total: usize = 0;
102            let status = cubecl_hip_sys::hipMemGetInfo(
103                &free as *const _ as *mut usize,
104                &total as *const _ as *mut usize,
105            );
106            assert_eq!(
107                status, HIP_SUCCESS,
108                "Should get the available memory of the device"
109            );
110            total
111        };
112        let mem_properties = MemoryDeviceProperties {
113            max_page_size: max_memory as u64 / 4,
114            alignment: mem_alignment as u64,
115        };
116
117        let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch);
118        let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch);
119        let supported_scaled_mma_combinations =
120            HipWmmaCompiler::supported_scaled_mma_combinations(&arch);
121
122        let topology = HardwareProperties {
123            plane_size_min: prop_warp_size as u32,
124            plane_size_max: prop_warp_size as u32,
125            max_bindings: crate::device::AMD_MAX_BINDINGS,
126            max_shared_memory_size: prop_max_shared_memory_size,
127            max_cube_count,
128            max_units_per_cube: prop_max_threads,
129            max_cube_dim,
130            num_streaming_multiprocessors: None,
131            num_tensor_cores: None,
132            min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
133                None
134            } else {
135                Some(16)
136            },
137        };
138
139        let mut device_props = DeviceProperties::new(
140            Default::default(),
141            mem_properties.clone(),
142            topology,
143            TimingMethod::System,
144        );
145        register_supported_types(&mut device_props);
146
147        // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120)
148        // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
149        // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
150
151        device_props.features.dynamic_line_size = true;
152        device_props.features.plane.insert(Plane::Ops);
153
154        register_wmma_features(supported_wmma_combinations, &mut device_props);
155        register_mma_features(supported_mma_combinations, &mut device_props);
156        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
157
158        let comp_opts = CompilationOptions {
159            warp_size: arch.warp_size(),
160            grid_constants: false,
161            supports_clusters: false,
162        };
163        let hip_ctx = HipContext::new(comp_opts);
164        let logger = Arc::new(ServerLogger::default());
165        let utilities = ServerUtilities::new(device_props, logger, ());
166        let options = RuntimeOptions::default();
167
168        HipServer::new(
169            hip_ctx,
170            mem_properties,
171            options.memory_config,
172            mem_alignment,
173            utilities,
174        )
175    }
176}
177
178impl Runtime for HipRuntime {
179    type Compiler = HipCompiler;
180    type Server = HipServer;
181    type Device = AmdDevice;
182
183    fn client(device: &Self::Device) -> ComputeClient<Self::Server> {
184        ComputeClient::load(device)
185    }
186
187    fn name(_client: &ComputeClient<Self::Server>) -> &'static str {
188        "hip"
189    }
190
191    fn require_array_lengths() -> bool {
192        true
193    }
194
195    fn supported_line_sizes() -> &'static [u8] {
196        &[16, 8, 4, 2, 1]
197    }
198
199    fn max_cube_count() -> (u32, u32, u32) {
200        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
201    }
202
203    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
204        if shape.is_empty() {
205            return true;
206        }
207
208        for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
209            if expected != stride {
210                return false;
211            }
212        }
213
214        true
215    }
216
217    fn target_properties() -> TargetProperties {
218        TargetProperties {
219            mma: MmaProperties {
220                register_size_bits: 32,
221                const_plane_size: 32,
222                register_layout_a: MatrixLayout::ColMajor,
223                register_layout_b: MatrixLayout::RowMajor,
224                register_layout_acc: MatrixLayout::RowMajor,
225                register_duplication_a: 2,
226                register_duplication_b: 2,
227                register_duplication_acc: 1,
228            },
229        }
230    }
231}