1use std::{ffi::CStr, mem::MaybeUninit, str::FromStr};
2
3use cubecl_cpp::{
4 hip::HipDialect,
5 register_supported_types,
6 shared::{register_wmma_features, Architecture, CompilationOptions, CppCompiler, WmmaCompiler},
7};
8
9use cubecl_core::{
10 ir::{Elem, FloatKind},
11 AtomicFeature, Feature, MemoryConfiguration, Runtime,
12};
13use cubecl_hip_sys::HIP_SUCCESS;
14use cubecl_runtime::{
15 channel::MutexComputeChannel,
16 client::ComputeClient,
17 memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
18 storage::ComputeStorage,
19 ComputeRuntime, DeviceProperties,
20};
21
22use crate::{
23 compute::{HipContext, HipServer, HipStorage},
24 device::HipDevice,
25 HipWmmaCompiler,
26};
27
28#[derive(Default)]
30pub struct RuntimeOptions {
31 pub memory_config: MemoryConfiguration,
33}
34
35#[derive(Debug)]
36pub struct HipRuntime;
37
38static RUNTIME: ComputeRuntime<HipDevice, Server, MutexComputeChannel<Server>> =
39 ComputeRuntime::new();
40
41pub type HipCompiler = CppCompiler<HipDialect<HipWmmaCompiler>>;
42
43type Server = HipServer;
44type Channel = MutexComputeChannel<Server>;
45
46fn create_client<M: WmmaCompiler<HipDialect<M>>>(
47 device: &HipDevice,
48 options: RuntimeOptions,
49) -> ComputeClient<Server, Channel> {
50 #[allow(unused_assignments)]
51 let mut prop_warp_size = 0;
52 #[allow(unused_assignments)]
53 let mut prop_arch_name = "";
54 unsafe {
55 let mut ll_device_props = MaybeUninit::uninit();
56 let status = cubecl_hip_sys::hipGetDevicePropertiesR0600(
57 ll_device_props.as_mut_ptr(),
58 device.index as cubecl_hip_sys::hipDevice_t,
59 );
60 assert_eq!(status, HIP_SUCCESS, "Should get device properties");
61 let ll_device_props = ll_device_props.assume_init();
62 prop_warp_size = ll_device_props.warpSize;
63 prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr())
64 .to_str()
65 .unwrap();
66 };
67 let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name);
68 let arch = M::Architecture::from_str(normalized_arch_name).unwrap();
69 assert_eq!(prop_warp_size as u32, arch.warp_size());
70
71 unsafe {
72 let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t);
73 assert_eq!(
74 status, HIP_SUCCESS,
75 "Should set the default device for the current thread"
76 );
77 }
78
79 let stream = unsafe {
80 let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
81 let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
82 assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
83 stream
84 };
85
86 let max_memory = unsafe {
87 let free: usize = 0;
88 let total: usize = 0;
89 let status = cubecl_hip_sys::hipMemGetInfo(
90 &free as *const _ as *mut usize,
91 &total as *const _ as *mut usize,
92 );
93 assert_eq!(
94 status, HIP_SUCCESS,
95 "Should get the available memory of the device"
96 );
97 total
98 };
99 let storage = HipStorage::new(stream);
100 let mem_properties = MemoryDeviceProperties {
101 max_page_size: max_memory as u64 / 4,
102 alignment: HipStorage::ALIGNMENT,
103 };
104 let topology = HardwareProperties {
105 plane_size_min: prop_warp_size as u32,
106 plane_size_max: prop_warp_size as u32,
107 max_bindings: 1024,
110 };
111 let memory_management = MemoryManagement::from_configuration(
112 storage,
113 mem_properties.clone(),
114 options.memory_config,
115 );
116 let mut device_props = DeviceProperties::new(&[Feature::Plane], mem_properties, topology);
117 register_supported_types(&mut device_props);
118 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F64)));
120 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
121 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
122
123 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
124 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
125
126 let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
127 register_wmma_features(supported_wmma_combinations, &mut device_props);
128
129 let comp_opts = CompilationOptions {
130 warp_size: arch.warp_size(),
131 };
132 let hip_ctx = HipContext::new(memory_management, comp_opts, stream);
133 let server = HipServer::new(hip_ctx);
134 ComputeClient::new(MutexComputeChannel::new(server), device_props)
135}
136
137impl Runtime for HipRuntime {
138 type Compiler = HipCompiler;
139 type Server = HipServer;
140 type Channel = MutexComputeChannel<HipServer>;
141 type Device = HipDevice;
142
143 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
144 RUNTIME.client(device, move || {
145 create_client::<HipWmmaCompiler>(device, RuntimeOptions::default())
146 })
147 }
148
149 fn name() -> &'static str {
150 "hip"
151 }
152
153 fn require_array_lengths() -> bool {
154 true
155 }
156
157 fn supported_line_sizes() -> &'static [u8] {
158 &[8, 4, 2, 1]
159 }
160
161 fn max_cube_count() -> (u32, u32, u32) {
162 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
163 }
164
165 fn extension() -> &'static str {
166 "hip"
167 }
168}