1use crate::{
2 HipWmmaCompiler,
3 compute::{HipServer, context::HipContext},
4 device::AmdDevice,
5};
6use core::ffi::c_int;
7use cubecl_common::{
8 device::{Device, DeviceService},
9 profile::TimingMethod,
10};
11use cubecl_core::{
12 MemoryConfiguration, Runtime,
13 device::{DeviceId, ServerUtilitiesHandle},
14 ir::{
15 ContiguousElements, DeviceProperties, HardwareProperties, MatrixLayout,
16 MemoryDeviceProperties, MmaProperties, TargetProperties, VectorSize, features::Plane,
17 },
18 server::ServerUtilities,
19 zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
20};
21use cubecl_cpp::{
22 ComputeKernel,
23 hip::{HipDialect, arch::AMDArchitecture, mma::contiguous_elements_rdna3},
24 register_supported_types,
25 shared::{
26 Architecture, CompilationOptions, CppCompiler, CppSupportedFeatures, DialectWmmaCompiler,
27 register_mma_features, register_scaled_mma_features, register_wmma_features,
28 },
29};
30use cubecl_hip_sys::{HIP_SUCCESS, hipDeviceScheduleSpin, hipGetDeviceCount, hipSetDeviceFlags};
31use cubecl_runtime::{
32 allocator::PitchedMemoryLayoutPolicy, client::ComputeClient, logging::ServerLogger,
33};
34use std::{ffi::CStr, mem::MaybeUninit, sync::Arc};
35
36#[derive(Default)]
38pub struct RuntimeOptions {
39 pub memory_config: MemoryConfiguration,
41}
42
43#[derive(Debug, Clone)]
44pub struct HipRuntime;
45
46pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
47pub type HipComputeKernel = ComputeKernel<HipDialect<HipWmmaCompiler>>;
48
49impl DeviceService for HipServer {
50 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
51 let device = AmdDevice::from_id(device_id);
52
53 #[allow(unused_assignments)]
54 let mut prop_warp_size = 0;
55 #[allow(unused_assignments)]
56 let mut prop_arch_name = "";
57 #[allow(unused_assignments)]
58 let mut prop_max_shared_memory_size = 0;
59 #[allow(unused_assignments)]
60 let mut max_cube_count = (1, 1, 1);
61 #[allow(unused_assignments)]
62 let mut prop_max_threads = 0;
63 let mut max_cube_dim = (1, 1, 1);
64 let mut mem_alignment = 32;
65 unsafe {
69 let mut ll_device_props = MaybeUninit::uninit();
70 let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
71 ll_device_props.as_mut_ptr(),
72 device.index as cubecl_hip_sys::hipDevice_t,
73 );
74 assert_eq!(status, HIP_SUCCESS, "Should get device properties");
75 let ll_device_props = ll_device_props.assume_init();
76 prop_warp_size = ll_device_props.warpSize;
77 prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
78 .to_str()
79 .unwrap();
80 prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
81 max_cube_count = (
82 ll_device_props.maxGridSize[0] as u32,
83 ll_device_props.maxGridSize[1] as u32,
84 ll_device_props.maxGridSize[2] as u32,
85 );
86 prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
87 max_cube_dim.0 = ll_device_props.maxThreadsDim[0] as u32;
88 max_cube_dim.1 = ll_device_props.maxThreadsDim[1] as u32;
89 max_cube_dim.2 = ll_device_props.maxThreadsDim[2] as u32;
90
91 mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment);
93 mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment);
94 };
95 let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
96 let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
97 assert_eq!(prop_warp_size as u32, arch.warp_size());
98
99 unsafe {
103 let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
104 hipSetDeviceFlags(hipDeviceScheduleSpin);
105
106 assert_eq!(
107 status, HIP_SUCCESS,
108 "Should set the default device for the current thread"
109 );
110 }
111
112 let max_memory = unsafe {
116 let free: usize = 0;
117 let total: usize = 0;
118 let status = cubecl_hip_sys::hipMemGetInfo(
119 &free as *const _ as *mut usize,
120 &total as *const _ as *mut usize,
121 );
122 assert_eq!(
123 status, HIP_SUCCESS,
124 "Should get the available memory of the device"
125 );
126 total
127 };
128 let mem_properties = MemoryDeviceProperties {
129 max_page_size: max_memory as u64 / 4,
130 alignment: mem_alignment as u64,
131 };
132
133 let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch);
134 let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch);
135 let supported_scaled_mma_combinations =
136 HipWmmaCompiler::supported_scaled_mma_combinations(&arch);
137
138 let topology = HardwareProperties {
139 load_width: 128,
140 plane_size_min: prop_warp_size as u32,
141 plane_size_max: prop_warp_size as u32,
142 max_bindings: crate::device::AMD_MAX_BINDINGS,
143 max_shared_memory_size: prop_max_shared_memory_size,
144 max_cube_count,
145 max_units_per_cube: prop_max_threads,
146 max_cube_dim,
147 num_streaming_multiprocessors: None,
148 num_tensor_cores: None,
149 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
150 None
151 } else {
152 Some(16)
153 },
154 num_cpu_cores: None,
155 max_vector_size: VectorSize::MAX,
156 };
157
158 let mut device_props = DeviceProperties::new(
159 Default::default(),
160 mem_properties.clone(),
161 topology,
162 TimingMethod::System,
163 );
164 register_supported_types(&mut device_props);
165
166 device_props.features.memory_reinterpret = true;
171 device_props.features.alignment = true;
172 device_props.features.plane.insert(Plane::Ops);
173 device_props
174 .features
175 .plane
176 .insert(Plane::NonUniformControlFlow);
177
178 register_wmma_features(supported_wmma_combinations, &mut device_props);
179 register_mma_features(supported_mma_combinations, &mut device_props);
180 register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
181
182 let comp_opts = CompilationOptions {
183 warp_size: arch.warp_size(),
184 supports_features: CppSupportedFeatures {
185 fast_math: true,
186 ..Default::default()
187 },
188 };
189 let hip_ctx = HipContext::new(comp_opts, device_props.clone());
190 let logger = Arc::new(ServerLogger::default());
191 let policy = PitchedMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
192 let utilities = ServerUtilities::new(device_props, logger, (), policy);
193 let options = RuntimeOptions::default();
194
195 let is_integrated = unsafe { is_integrated_gpu(device_id.index_id as i32) };
197
198 HipServer::new(
199 hip_ctx,
200 mem_properties,
201 options.memory_config,
202 mem_alignment,
203 is_integrated,
204 utilities,
205 )
206 }
207
208 fn utilities(&self) -> ServerUtilitiesHandle {
209 self.utilities() as ServerUtilitiesHandle
210 }
211}
212
213impl Runtime for HipRuntime {
214 type Compiler = HipCompiler;
215 type Server = HipServer;
216 type Device = AmdDevice;
217
218 fn client(device: &Self::Device) -> ComputeClient<Self> {
219 ComputeClient::load(device)
220 }
221
222 fn name(_client: &ComputeClient<Self>) -> &'static str {
223 "hip"
224 }
225
226 fn require_array_lengths() -> bool {
227 true
228 }
229
230 fn max_cube_count() -> (u32, u32, u32) {
231 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
232 }
233
234 fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
235 if shape.is_empty() {
236 return true;
237 }
238 has_pitched_row_major_strides(shape, strides)
239 }
240
241 fn target_properties() -> TargetProperties {
242 TargetProperties {
243 mma: MmaProperties {
244 register_size_bits: 32,
245 const_plane_size: 32,
246 register_layout_a: MatrixLayout::RowMajor,
247 register_layout_b: MatrixLayout::ColMajor,
248 register_layout_acc: MatrixLayout::ColMajor,
249 register_duplication_a: 2,
250 register_duplication_b: 2,
251 register_duplication_acc: 1,
252 contiguous_elements: ContiguousElements::new(contiguous_elements_rdna3),
253 },
254 }
255 }
256
257 fn enumerate_devices(
258 _: u16,
259 _: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
260 ) -> Vec<cubecl_core::device::DeviceId> {
261 fn device_count() -> usize {
262 let mut device_count: c_int = 0;
263 let result;
264 unsafe {
267 result = hipGetDeviceCount(&mut device_count);
268 }
269 if result == HIP_SUCCESS {
270 device_count.try_into().unwrap_or(0)
271 } else {
272 0
273 }
274 }
275 (0..device_count())
276 .map(|i| DeviceId::new(0, i as u16))
277 .collect()
278 }
279}
280
281unsafe fn is_integrated_gpu(device_id: i32) -> bool {
287 let mut props = unsafe { std::mem::zeroed::<cubecl_hip_sys::hipDeviceProp_tR0600>() };
289 let status = unsafe { cubecl_hip_sys::hipGetDevicePropertiesR0600(&mut props, device_id) };
291 if status != HIP_SUCCESS {
292 return false; }
294 props.integrated != 0
295}