1use std::{ffi::CStr, mem::MaybeUninit, str::FromStr};
2
3use cubecl_cpp::{
4 hip::HipDialect,
5 register_supported_types,
6 shared::{
7 Architecture, CompilationOptions, CppCompiler, DialectWmmaCompiler, register_wmma_features,
8 },
9};
10
11use cubecl_core::{
12 AtomicFeature, CubeDim, DeviceId, Feature, MemoryConfiguration, Runtime,
13 ir::{Elem, FloatKind},
14};
15use cubecl_hip_sys::HIP_SUCCESS;
16use cubecl_runtime::{
17 ComputeRuntime, DeviceProperties,
18 channel::MutexComputeChannel,
19 client::ComputeClient,
20 memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
21 storage::ComputeStorage,
22};
23
24use crate::{
25 HipWmmaCompiler,
26 compute::{HipContext, HipServer, HipStorage},
27 device::HipDevice,
28};
29
30#[derive(Default)]
32pub struct RuntimeOptions {
33 pub memory_config: MemoryConfiguration,
35}
36
37#[derive(Debug)]
38pub struct HipRuntime;
39
40static RUNTIME: ComputeRuntime<HipDevice, Server, MutexComputeChannel<Server>> =
41 ComputeRuntime::new();
42
43pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
44
45type Server = HipServer;
46type Channel = MutexComputeChannel<Server>;
47
48fn create_client<M: DialectWmmaCompiler<HipDialect<M>>>(
49 device: &HipDevice,
50 options: RuntimeOptions,
51) -> ComputeClient<Server, Channel> {
52 #[allow(unused_assignments)]
53 let mut prop_warp_size = 0;
54 #[allow(unused_assignments)]
55 let mut prop_arch_name = "";
56 #[allow(unused_assignments)]
57 let mut prop_max_shared_memory_size = 0;
58 let mut max_cube_count = CubeDim::new_single();
59 #[allow(unused_assignments)]
60 let mut prop_max_threads = 0;
61 let mut max_cube_dim = CubeDim::new_single();
62 unsafe {
63 let mut ll_device_props = MaybeUninit::uninit();
64 let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
65 ll_device_props.as_mut_ptr(),
66 device.index as cubecl_hip_sys::hipDevice_t,
67 );
68 assert_eq!(status, HIP_SUCCESS, "Should get device properties");
69 let ll_device_props = ll_device_props.assume_init();
70 prop_warp_size = ll_device_props.warpSize;
71 prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
72 .to_str()
73 .unwrap();
74 prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock;
75 max_cube_count.x = ll_device_props.maxGridSize[0] as u32;
76 max_cube_count.y = ll_device_props.maxGridSize[1] as u32;
77 max_cube_count.z = ll_device_props.maxGridSize[2] as u32;
78 prop_max_threads = ll_device_props.maxThreadsPerBlock as u32;
79 max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32;
80 max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32;
81 max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32;
82 };
83 let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
84 let arch = M::Architecture::from_str(normalized_arch_name).unwrap();
85 assert_eq!(prop_warp_size as u32, arch.warp_size());
86
87 unsafe {
88 let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
89 assert_eq!(
90 status, HIP_SUCCESS,
91 "Should set the default device for the current thread"
92 );
93 }
94
95 let stream = unsafe {
96 let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
97 let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
98 assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
99 stream
100 };
101
102 let max_memory = unsafe {
103 let free: usize = 0;
104 let total: usize = 0;
105 let status = cubecl_hip_sys::hipMemGetInfo(
106 &free as *const _ as *mut usize,
107 &total as *const _ as *mut usize,
108 );
109 assert_eq!(
110 status, HIP_SUCCESS,
111 "Should get the available memory of the device"
112 );
113 total
114 };
115 let storage = HipStorage::new(stream);
116 let mem_properties = MemoryDeviceProperties {
117 max_page_size: max_memory as u64 / 4,
118 alignment: HipStorage::ALIGNMENT,
119 };
120 let topology = HardwareProperties {
121 plane_size_min: prop_warp_size as u32,
122 plane_size_max: prop_warp_size as u32,
123 max_bindings: 1024,
126 max_shared_memory_size: prop_max_shared_memory_size,
127 max_cube_count,
128 max_units_per_cube: prop_max_threads,
129 max_cube_dim,
130 num_streaming_multiprocessors: None,
131 num_tensor_cores: None,
132 };
133 let memory_management =
134 MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
135 let mut device_props = DeviceProperties::new(
136 &[Feature::Plane],
137 mem_properties,
138 topology,
139 cubecl_runtime::TimeMeasurement::System,
140 );
141 register_supported_types(&mut device_props);
142 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F64)));
144 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
149 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
150
151 device_props.register_feature(Feature::DynamicLineSize);
152
153 let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
154 register_wmma_features(supported_wmma_combinations, &mut device_props);
155
156 let comp_opts = CompilationOptions {
157 warp_size: arch.warp_size(),
158 grid_constants: false,
159 supports_clusters: false,
160 };
161 let hip_ctx = HipContext::new(memory_management, comp_opts, stream);
162 let server = HipServer::new(hip_ctx);
163 ComputeClient::new(MutexComputeChannel::new(server), device_props, ())
164}
165
166impl Runtime for HipRuntime {
167 type Compiler = HipCompiler;
168 type Server = HipServer;
169 type Channel = MutexComputeChannel<HipServer>;
170 type Device = HipDevice;
171
172 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
173 RUNTIME.client(device, move || {
174 create_client::<HipWmmaCompiler>(device, RuntimeOptions::default())
175 })
176 }
177
178 fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
179 "hip"
180 }
181
182 fn require_array_lengths() -> bool {
183 true
184 }
185
186 fn supported_line_sizes() -> &'static [u8] {
187 &[8, 4, 2, 1]
188 }
189
190 fn max_cube_count() -> (u32, u32, u32) {
191 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
192 }
193
194 fn device_id(device: &Self::Device) -> cubecl_core::DeviceId {
195 DeviceId::new(0, device.index as u32)
196 }
197}