1use crate::{
2 HipWmmaCompiler,
3 compute::{HipServer, context::HipContext, contiguous_strides},
4 device::AmdDevice,
5};
6use cubecl_common::{
7 device::{Device, DeviceState},
8 profile::TimingMethod,
9};
10use cubecl_core::{
11 CubeCount, CubeDim, MemoryConfiguration, Runtime,
12 ir::{MatrixLayout, MmaProperties, TargetProperties},
13 server::ServerUtilities,
14};
15use cubecl_cpp::{
16 hip::{HipDialect, arch::AMDArchitecture},
17 register_supported_types,
18 shared::{
19 Architecture, CompilationOptions, CppCompiler, DialectWmmaCompiler, register_mma_features,
20 register_scaled_mma_features, register_wmma_features,
21 },
22};
23use cubecl_hip_sys::HIP_SUCCESS;
24use cubecl_runtime::{
25 DeviceProperties, Plane,
26 client::ComputeClient,
27 logging::ServerLogger,
28 memory_management::{HardwareProperties, MemoryDeviceProperties},
29};
30use std::{ffi::CStr, mem::MaybeUninit, sync::Arc};
31
32#[derive(Default)]
34pub struct RuntimeOptions {
35 pub memory_config: MemoryConfiguration,
37}
38
39#[derive(Debug)]
40pub struct HipRuntime;
41
42pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
43
44impl DeviceState for HipServer {
45 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
46 let device = AmdDevice::from_id(device_id);
47
48 #[allow(unused_assignments)]
49 let mut prop_warp_size = 0;
50 #[allow(unused_assignments)]
51 let mut prop_arch_name = "";
52 #[allow(unused_assignments)]
53 let mut prop_max_shared_memory_size = 0;
54 #[allow(unused_assignments)]
55 let mut max_cube_count = CubeCount::new_single();
56 #[allow(unused_assignments)]
57 let mut prop_max_threads = 0;
58 let mut max_cube_dim = CubeDim::new_single();
59 let mut mem_alignment = 32;
60 unsafe {
61 let mut ll_device_props = MaybeUninit::uninit();
62 let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
63 ll_device_props.as_mut_ptr(),
64 device.index as cubecl_hip_sys::hipDevice_t,
65 );
66 assert_eq!(status, HIP_SUCCESS, "Should get device properties");
67 let ll_device_props = ll_device_props.assume_init();
68 prop_warp_size = ll_device_props.warpSize;
69 prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
70 .to_str()
71 .unwrap();
72 prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
73 max_cube_count = CubeCount::new_3d(
74 ll_device_props.maxGridSize[0] as u32,
75 ll_device_props.maxGridSize[1] as u32,
76 ll_device_props.maxGridSize[2] as u32,
77 );
78 prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
79 max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32;
80 max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32;
81 max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32;
82
83 mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment);
85 mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment);
86 };
87 let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
88 let arch = AMDArchitecture::parse(normalized_arch_name).unwrap();
89 assert_eq!(prop_warp_size as u32, arch.warp_size());
90
91 unsafe {
92 let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
93 assert_eq!(
94 status, HIP_SUCCESS,
95 "Should set the default device for the current thread"
96 );
97 }
98
99 let max_memory = unsafe {
100 let free: usize = 0;
101 let total: usize = 0;
102 let status = cubecl_hip_sys::hipMemGetInfo(
103 &free as *const _ as *mut usize,
104 &total as *const _ as *mut usize,
105 );
106 assert_eq!(
107 status, HIP_SUCCESS,
108 "Should get the available memory of the device"
109 );
110 total
111 };
112 let mem_properties = MemoryDeviceProperties {
113 max_page_size: max_memory as u64 / 4,
114 alignment: mem_alignment as u64,
115 };
116
117 let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch);
118 let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch);
119 let supported_scaled_mma_combinations =
120 HipWmmaCompiler::supported_scaled_mma_combinations(&arch);
121
122 let topology = HardwareProperties {
123 plane_size_min: prop_warp_size as u32,
124 plane_size_max: prop_warp_size as u32,
125 max_bindings: crate::device::AMD_MAX_BINDINGS,
126 max_shared_memory_size: prop_max_shared_memory_size,
127 max_cube_count,
128 max_units_per_cube: prop_max_threads,
129 max_cube_dim,
130 num_streaming_multiprocessors: None,
131 num_tensor_cores: None,
132 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
133 None
134 } else {
135 Some(16)
136 },
137 };
138
139 let mut device_props = DeviceProperties::new(
140 Default::default(),
141 mem_properties.clone(),
142 topology,
143 TimingMethod::System,
144 );
145 register_supported_types(&mut device_props);
146
147 device_props.features.dynamic_line_size = true;
152 device_props.features.plane.insert(Plane::Ops);
153
154 register_wmma_features(supported_wmma_combinations, &mut device_props);
155 register_mma_features(supported_mma_combinations, &mut device_props);
156 register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
157
158 let comp_opts = CompilationOptions {
159 warp_size: arch.warp_size(),
160 grid_constants: false,
161 supports_clusters: false,
162 };
163 let hip_ctx = HipContext::new(comp_opts);
164 let logger = Arc::new(ServerLogger::default());
165 let utilities = ServerUtilities::new(device_props, logger, ());
166 let options = RuntimeOptions::default();
167
168 HipServer::new(
169 hip_ctx,
170 mem_properties,
171 options.memory_config,
172 mem_alignment,
173 utilities,
174 )
175 }
176}
177
178impl Runtime for HipRuntime {
179 type Compiler = HipCompiler;
180 type Server = HipServer;
181 type Device = AmdDevice;
182
183 fn client(device: &Self::Device) -> ComputeClient<Self::Server> {
184 ComputeClient::load(device)
185 }
186
187 fn name(_client: &ComputeClient<Self::Server>) -> &'static str {
188 "hip"
189 }
190
191 fn require_array_lengths() -> bool {
192 true
193 }
194
195 fn supported_line_sizes() -> &'static [u8] {
196 &[16, 8, 4, 2, 1]
197 }
198
199 fn max_cube_count() -> (u32, u32, u32) {
200 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
201 }
202
203 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
204 if shape.is_empty() {
205 return true;
206 }
207
208 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
209 if expected != stride {
210 return false;
211 }
212 }
213
214 true
215 }
216
217 fn target_properties() -> TargetProperties {
218 TargetProperties {
219 mma: MmaProperties {
220 register_size_bits: 32,
221 const_plane_size: 32,
222 register_layout_a: MatrixLayout::ColMajor,
223 register_layout_b: MatrixLayout::RowMajor,
224 register_layout_acc: MatrixLayout::RowMajor,
225 register_duplication_a: 2,
226 register_duplication_b: 2,
227 register_duplication_acc: 1,
228 },
229 }
230 }
231}