1use crate::{
2 HipWmmaCompiler,
3 compute::{HipServer, context::HipContext},
4 device::AmdDevice,
5};
6use cubecl_common::{
7 device::{Device, DeviceState},
8 profile::TimingMethod,
9};
10use cubecl_core::{
11 MemoryConfiguration, Runtime,
12 ir::{
13 ContiguousElements, DeviceProperties, HardwareProperties, LineSize, MatrixLayout,
14 MemoryDeviceProperties, MmaProperties, TargetProperties, features::Plane,
15 },
16 server::ServerUtilities,
17};
18use cubecl_cpp::{
19 hip::{HipDialect, arch::AMDArchitecture, mma::contiguous_elements_rdna3},
20 register_supported_types,
21 shared::{
22 Architecture, CompilationOptions, CppCompiler, CppSupportedFeatures, DialectWmmaCompiler,
23 register_mma_features, register_scaled_mma_features, register_wmma_features,
24 },
25};
26use cubecl_hip_sys::{HIP_SUCCESS, hipDeviceScheduleSpin, hipSetDeviceFlags};
27use cubecl_runtime::{client::ComputeClient, logging::ServerLogger};
28use cubecl_zspace::striding::has_pitched_row_major_strides;
29use std::{ffi::CStr, mem::MaybeUninit, sync::Arc};
30
31#[derive(Default)]
33pub struct RuntimeOptions {
34 pub memory_config: MemoryConfiguration,
36}
37
38#[derive(Debug)]
39pub struct HipRuntime;
40
41pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
42
43impl DeviceState for HipServer {
44 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
45 let device = AmdDevice::from_id(device_id);
46
47 #[allow(unused_assignments)]
48 let mut prop_warp_size = 0;
49 #[allow(unused_assignments)]
50 let mut prop_arch_name = "";
51 #[allow(unused_assignments)]
52 let mut prop_max_shared_memory_size = 0;
53 #[allow(unused_assignments)]
54 let mut max_cube_count = (1, 1, 1);
55 #[allow(unused_assignments)]
56 let mut prop_max_threads = 0;
57 let mut max_cube_dim = (1, 1, 1);
58 let mut mem_alignment = 32;
59 unsafe {
60 let mut ll_device_props = MaybeUninit::uninit();
61 let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
62 ll_device_props.as_mut_ptr(),
63 device.index as cubecl_hip_sys::hipDevice_t,
64 );
65 assert_eq!(status, HIP_SUCCESS, "Should get device properties");
66 let ll_device_props = ll_device_props.assume_init();
67 prop_warp_size = ll_device_props.warpSize;
68 prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
69 .to_str()
70 .unwrap();
71 prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
72 max_cube_count = (
73 ll_device_props.maxGridSize[0] as u32,
74 ll_device_props.maxGridSize[1] as u32,
75 ll_device_props.maxGridSize[2] as u32,
76 );
77 prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
78 max_cube_dim.0 = ll_device_props.maxThreadsDim[0] as u32;
79 max_cube_dim.1 = ll_device_props.maxThreadsDim[1] as u32;
80 max_cube_dim.2 = ll_device_props.maxThreadsDim[2] as u32;
81
82 mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment);
84 mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment);
85 };
86 let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
87 let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
88 assert_eq!(prop_warp_size as u32, arch.warp_size());
89
90 unsafe {
91 let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
92 hipSetDeviceFlags(hipDeviceScheduleSpin);
93
94 assert_eq!(
95 status, HIP_SUCCESS,
96 "Should set the default device for the current thread"
97 );
98 }
99
100 let max_memory = unsafe {
101 let free: usize = 0;
102 let total: usize = 0;
103 let status = cubecl_hip_sys::hipMemGetInfo(
104 &free as *const _ as *mut usize,
105 &total as *const _ as *mut usize,
106 );
107 assert_eq!(
108 status, HIP_SUCCESS,
109 "Should get the available memory of the device"
110 );
111 total
112 };
113 let mem_properties = MemoryDeviceProperties {
114 max_page_size: max_memory as u64 / 4,
115 alignment: mem_alignment as u64,
116 };
117
118 let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch);
119 let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch);
120 let supported_scaled_mma_combinations =
121 HipWmmaCompiler::supported_scaled_mma_combinations(&arch);
122
123 let topology = HardwareProperties {
124 load_width: 128,
125 plane_size_min: prop_warp_size as u32,
126 plane_size_max: prop_warp_size as u32,
127 max_bindings: crate::device::AMD_MAX_BINDINGS,
128 max_shared_memory_size: prop_max_shared_memory_size,
129 max_cube_count,
130 max_units_per_cube: prop_max_threads,
131 max_cube_dim,
132 num_streaming_multiprocessors: None,
133 num_tensor_cores: None,
134 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
135 None
136 } else {
137 Some(16)
138 },
139 num_cpu_cores: None,
140 };
141
142 let mut device_props = DeviceProperties::new(
143 Default::default(),
144 mem_properties.clone(),
145 topology,
146 TimingMethod::System,
147 );
148 register_supported_types(&mut device_props);
149
150 device_props.features.dynamic_line_size = true;
155 device_props.features.alignment = true;
156 device_props.features.plane.insert(Plane::Ops);
157 device_props
158 .features
159 .plane
160 .insert(Plane::NonUniformControlFlow);
161
162 register_wmma_features(supported_wmma_combinations, &mut device_props);
163 register_mma_features(supported_mma_combinations, &mut device_props);
164 register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
165
166 let comp_opts = CompilationOptions {
167 warp_size: arch.warp_size(),
168 supports_features: CppSupportedFeatures {
169 fast_math: true,
170 ..Default::default()
171 },
172 };
173 let hip_ctx = HipContext::new(comp_opts);
174 let logger = Arc::new(ServerLogger::default());
175 let utilities = ServerUtilities::new(device_props, logger, ());
176 let options = RuntimeOptions::default();
177
178 HipServer::new(
179 hip_ctx,
180 mem_properties,
181 options.memory_config,
182 mem_alignment,
183 utilities,
184 )
185 }
186}
187
188impl Runtime for HipRuntime {
189 type Compiler = HipCompiler;
190 type Server = HipServer;
191 type Device = AmdDevice;
192
193 fn client(device: &Self::Device) -> ComputeClient<Self> {
194 ComputeClient::load(device)
195 }
196
197 fn name(_client: &ComputeClient<Self>) -> &'static str {
198 "hip"
199 }
200
201 fn require_array_lengths() -> bool {
202 true
203 }
204
205 fn supported_line_sizes() -> &'static [LineSize] {
206 &[16, 8, 4, 2, 1]
207 }
208
209 fn max_cube_count() -> (u32, u32, u32) {
210 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
211 }
212
213 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
214 if shape.is_empty() {
215 return true;
216 }
217 has_pitched_row_major_strides(shape, strides)
218 }
219
220 fn target_properties() -> TargetProperties {
221 TargetProperties {
222 mma: MmaProperties {
223 register_size_bits: 32,
224 const_plane_size: 32,
225 register_layout_a: MatrixLayout::RowMajor,
226 register_layout_b: MatrixLayout::ColMajor,
227 register_layout_acc: MatrixLayout::ColMajor,
228 register_duplication_a: 2,
229 register_duplication_b: 2,
230 register_duplication_acc: 1,
231 contiguous_elements: ContiguousElements::new(contiguous_elements_rdna3),
232 },
233 }
234 }
235}