1use std::{ffi::CStr, mem::MaybeUninit};
2
3use cubecl_cpp::{
4 hip::{HipDialect, arch::AMDArchitecture},
5 register_supported_types,
6 shared::{
7 Architecture, CompilationOptions, CppCompiler, DialectWmmaCompiler, register_wmma_features,
8 },
9};
10
11use cubecl_common::profile::TimingMethod;
12use cubecl_core::{
13 AtomicFeature, CubeCount, CubeDim, Feature, MemoryConfiguration, Runtime,
14 ir::{Elem, FloatKind, IntKind, UIntKind},
15};
16use cubecl_hip_sys::HIP_SUCCESS;
17use cubecl_runtime::id::DeviceId;
18use cubecl_runtime::{
19 ComputeRuntime, DeviceProperties,
20 channel::MutexComputeChannel,
21 client::ComputeClient,
22 memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
23};
24
25use crate::{
26 HipWmmaCompiler,
27 compute::{HipContext, HipServer, HipStorage, contiguous_strides},
28 device::AmdDevice,
29};
30
31#[derive(Default)]
33pub struct RuntimeOptions {
34 pub memory_config: MemoryConfiguration,
36}
37
38#[derive(Debug)]
39pub struct HipRuntime;
40
41static RUNTIME: ComputeRuntime<AmdDevice, Server, MutexComputeChannel<Server>> =
42 ComputeRuntime::new();
43
44pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
45
46type Server = HipServer;
47type Channel = MutexComputeChannel<Server>;
48
49fn create_client<M: DialectWmmaCompiler<HipDialect<M>>>(
50 device: &AmdDevice,
51 options: RuntimeOptions,
52) -> ComputeClient<Server, Channel> {
53 #[allow(unused_assignments)]
54 let mut prop_warp_size = 0;
55 #[allow(unused_assignments)]
56 let mut prop_arch_name = "";
57 #[allow(unused_assignments)]
58 let mut prop_max_shared_memory_size = 0;
59 #[allow(unused_assignments)]
60 let mut max_cube_count = CubeCount::new_single();
61 #[allow(unused_assignments)]
62 let mut prop_max_threads = 0;
63 let mut max_cube_dim = CubeDim::new_single();
64 let mut mem_aligment = 32;
65 unsafe {
66 let mut ll_device_props = MaybeUninit::uninit();
67 let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
68 ll_device_props.as_mut_ptr(),
69 device.index as cubecl_hip_sys::hipDevice_t,
70 );
71 assert_eq!(status, HIP_SUCCESS, "Should get device properties");
72 let ll_device_props = ll_device_props.assume_init();
73 prop_warp_size = ll_device_props.warpSize;
74 prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
75 .to_str()
76 .unwrap();
77 prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
78 max_cube_count = CubeCount::new_3d(
79 ll_device_props.maxGridSize[0] as u32,
80 ll_device_props.maxGridSize[1] as u32,
81 ll_device_props.maxGridSize[2] as u32,
82 );
83 prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
84 max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32;
85 max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32;
86 max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32;
87
88 mem_aligment = usize::max(mem_aligment, ll_device_props.textureAlignment);
90 mem_aligment = usize::max(mem_aligment, ll_device_props.surfaceAlignment);
91 };
92 let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
93 let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
94 assert_eq!(prop_warp_size as u32, arch.warp_size());
95
96 unsafe {
97 let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
98 assert_eq!(
99 status, HIP_SUCCESS,
100 "Should set the default device for the current thread"
101 );
102 }
103
104 let stream = unsafe {
105 let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
106 let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
107 assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
108 stream
109 };
110
111 let max_memory = unsafe {
112 let free: usize = 0;
113 let total: usize = 0;
114 let status = cubecl_hip_sys::hipMemGetInfo(
115 &free as *const _ as *mut usize,
116 &total as *const _ as *mut usize,
117 );
118 assert_eq!(
119 status, HIP_SUCCESS,
120 "Should get the available memory of the device"
121 );
122 total
123 };
124 let storage = HipStorage::new(mem_aligment, stream);
125 let mem_properties = MemoryDeviceProperties {
126 max_page_size: max_memory as u64 / 4,
127 alignment: mem_aligment as u64,
128 };
129 let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
130 let topology = HardwareProperties {
131 plane_size_min: prop_warp_size as u32,
132 plane_size_max: prop_warp_size as u32,
133 max_bindings: 1024,
136 max_shared_memory_size: prop_max_shared_memory_size,
137 max_cube_count,
138 max_units_per_cube: prop_max_threads,
139 max_cube_dim,
140 num_streaming_multiprocessors: None,
141 num_tensor_cores: None,
142 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
143 None
144 } else {
145 Some(16)
146 },
147 };
148 let memory_management =
149 MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
150 let mut device_props = DeviceProperties::new(
151 &[Feature::Plane],
152 mem_properties,
153 topology,
154 TimingMethod::System,
155 );
156 register_supported_types(&mut device_props);
157 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F32)));
159 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
164 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
165
166 device_props.register_feature(Feature::Type(Elem::AtomicInt(IntKind::I32)));
168 device_props.register_feature(Feature::Type(Elem::AtomicUInt(UIntKind::U32)));
169 device_props.register_feature(Feature::AtomicInt(AtomicFeature::LoadStore));
170 device_props.register_feature(Feature::AtomicInt(AtomicFeature::Add));
171 device_props.register_feature(Feature::AtomicUInt(AtomicFeature::LoadStore));
172 device_props.register_feature(Feature::AtomicUInt(AtomicFeature::Add));
173
174 device_props.register_feature(Feature::DynamicLineSize);
175
176 register_wmma_features(supported_wmma_combinations, &mut device_props);
177
178 let comp_opts = CompilationOptions {
179 warp_size: arch.warp_size(),
180 grid_constants: false,
181 supports_clusters: false,
182 };
183 let hip_ctx = HipContext::new(memory_management, comp_opts, stream);
184 let server = HipServer::new(mem_aligment, hip_ctx);
185 ComputeClient::new(MutexComputeChannel::new(server), device_props, ())
186}
187
188impl Runtime for HipRuntime {
189 type Compiler = HipCompiler;
190 type Server = HipServer;
191 type Channel = MutexComputeChannel<HipServer>;
192 type Device = AmdDevice;
193
194 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
195 RUNTIME.client(device, move || {
196 create_client::<HipWmmaCompiler>(device, RuntimeOptions::default())
197 })
198 }
199
200 fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
201 "hip"
202 }
203
204 fn require_array_lengths() -> bool {
205 true
206 }
207
208 fn supported_line_sizes() -> &'static [u8] {
209 &[8, 4, 2, 1]
210 }
211
212 fn max_cube_count() -> (u32, u32, u32) {
213 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
214 }
215
216 fn device_id(device: &Self::Device) -> DeviceId {
217 DeviceId::new(0, device.index as u32)
218 }
219
220 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
221 if shape.is_empty() {
222 return true;
223 }
224
225 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
226 if expected != stride {
227 return false;
228 }
229 }
230
231 true
232 }
233}