1use crate::{
2 WmmaCompiler,
3 compute::{CudaServer, context::CudaContext, valid_strides},
4 device::CudaDevice,
5};
6use cubecl_common::{
7 device::{Device, DeviceState},
8 profile::TimingMethod,
9};
10use cubecl_core::{
11 CubeCount, CubeDim, MemoryConfiguration, Runtime,
12 ir::{
13 ContiguousElements, ElemType, FloatKind, MatrixLayout, MmaProperties, SemanticType,
14 StorageType, TargetProperties,
15 },
16 server::ServerUtilities,
17};
18use cubecl_cpp::{
19 DialectWmmaCompiler,
20 cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
21 register_supported_types,
22 shared::{
23 CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
24 register_scaled_mma_features, register_wmma_features,
25 },
26};
27use cubecl_runtime::{
28 DeviceProperties, Plane, Tma, TypeUsage,
29 client::ComputeClient,
30 logging::ServerLogger,
31 memory_management::{HardwareProperties, MemoryDeviceProperties},
32};
33use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
34use std::{mem::MaybeUninit, sync::Arc};
35
36#[derive(Default)]
38pub struct RuntimeOptions {
39 pub memory_config: MemoryConfiguration,
41}
42
43#[derive(Debug)]
44pub struct CudaRuntime;
45
46impl DeviceState for CudaServer {
47 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
48 let options = RuntimeOptions::default();
49 let device = CudaDevice::from_id(device_id);
50
51 cudarc::driver::result::init().unwrap();
53 let device_id = device.index as i32;
54 let device_ptr = cudarc::driver::result::device::get(device_id).unwrap();
55 let arch_major;
56 let arch_version = unsafe {
57 arch_major = cudarc::driver::result::device::get_attribute(
58 device_ptr,
59 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
60 )
61 .unwrap();
62 let minor = cudarc::driver::result::device::get_attribute(
63 device_ptr,
64 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
65 )
66 .unwrap();
67 arch_major * 10 + minor
68 } as u32;
69
70 let mem_alignment = 512;
74
75 let arch = CudaArchitecture {
77 version: arch_version,
78 };
79 let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
80 let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
81 let supported_scaled_mma_combinations =
82 WmmaCompiler::supported_scaled_mma_combinations(&arch);
83
84 let ctx = unsafe {
85 let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
86 cudarc::driver::result::ctx::set_current(ctx).unwrap();
87 ctx
88 };
89
90 let max_memory = unsafe {
91 let mut bytes = MaybeUninit::uninit();
92 cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
93 bytes.assume_init() as u64
94 };
95 let mem_properties = MemoryDeviceProperties {
96 max_page_size: max_memory / 4,
97 alignment: mem_alignment as u64,
98 };
99
100 let mut comp_opts = CompilationOptions {
101 supports_features: CppSupportedFeatures {
102 fast_math: true,
103 ..Default::default()
104 },
105 ..Default::default()
106 };
107
108 let hardware_props = unsafe {
109 use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
110 let warp_size =
111 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
112 let max_shared = get_attribute(
113 device_ptr,
114 CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
115 )
116 .unwrap() as usize;
117 let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
118 .unwrap() as u32;
119 let block_dim_x =
120 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
121 let block_dim_y =
122 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
123 let block_dim_z =
124 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
125 let max_cube_dim =
126 CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
127
128 let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
129 let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
130 let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
131 let max_cube_count =
132 CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
133
134 let num_streaming_multiprocessors = Some(
135 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
136 );
137 let num_tensor_cores = tensor_cores_per_sm(arch_version);
138
139 comp_opts.warp_size = warp_size;
140
141 HardwareProperties {
142 load_width: 128,
143 plane_size_min: warp_size,
144 plane_size_max: warp_size,
145 max_bindings: crate::device::CUDA_MAX_BINDINGS,
146 max_shared_memory_size: max_shared,
147 max_cube_count,
148 max_units_per_cube: max_threads,
149 max_cube_dim,
150 num_streaming_multiprocessors,
151 num_tensor_cores,
152 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
153 None
154 } else {
155 Some(8)
156 },
157 }
158 };
159
160 let mut device_props = DeviceProperties::new(
161 Default::default(),
162 mem_properties.clone(),
163 hardware_props,
164 TimingMethod::System,
165 );
166 register_supported_types(&mut device_props);
167 device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
168 if arch_version >= 60 {
169 device_props.register_type_usage(
170 StorageType::Atomic(ElemType::Float(FloatKind::F64)),
171 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
172 );
173 }
174 if arch_version >= 70 {
175 device_props.register_type_usage(
176 StorageType::Atomic(ElemType::Float(FloatKind::F16)),
177 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
178 );
179 device_props.register_semantic_type(SemanticType::Pipeline);
180 device_props.register_semantic_type(SemanticType::Barrier);
181 device_props.features.plane.insert(Plane::Sync);
182
183 comp_opts.supports_features.grid_constants = true;
184 }
185
186 if arch_version >= 75 {
187 device_props
188 .features
189 .ldmatrix
190 .insert(ElemType::Float(FloatKind::F16).into());
191 device_props
192 .features
193 .ldmatrix
194 .insert(ElemType::Float(FloatKind::BF16).into());
195 comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
196 }
197
198 if arch_version >= 89 {
204 device_props.register_type_usage(
205 ElemType::Float(FloatKind::E4M3),
206 TypeUsage::Conversion | TypeUsage::Buffer,
207 );
208 device_props.register_type_usage(
209 ElemType::Float(FloatKind::E5M2),
210 TypeUsage::Conversion | TypeUsage::Buffer,
211 );
212 }
213 if arch_version >= 90 {
214 device_props.features.tma.insert(Tma::Base);
215 device_props.register_semantic_type(SemanticType::TensorMap);
216 device_props.features.cube_cluster = true;
217 comp_opts.supports_features.clusters = true;
218 comp_opts.supports_features.elect_sync = true;
219 device_props
220 .features
221 .stmatrix
222 .insert(ElemType::Float(FloatKind::F16).into());
223 device_props
224 .features
225 .stmatrix
226 .insert(ElemType::Float(FloatKind::BF16).into());
227 }
228
229 if arch_version >= 100 {
230 device_props.features.tma.insert(Tma::Im2colWide);
231 }
236
237 if arch_major == 10 || arch_major == 11 || arch_major == 12 {
241 device_props
242 .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
243 device_props.register_type_usage(
244 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
245 TypeUsage::Conversion | TypeUsage::Buffer,
246 );
247 device_props.register_type_usage(
248 ElemType::Float(FloatKind::E2M3),
249 TypeUsage::Conversion | TypeUsage::Buffer,
250 );
251 device_props.register_type_usage(
252 ElemType::Float(FloatKind::E3M2),
253 TypeUsage::Conversion | TypeUsage::Buffer,
254 );
255 device_props.register_type_usage(
256 ElemType::Float(FloatKind::UE8M0),
257 TypeUsage::Conversion | TypeUsage::Buffer,
258 );
259
260 if CUDA_VERSION >= 12080 {
261 device_props.features.tma.insert(Tma::SwizzleAtomicity);
262 }
263 }
264
265 device_props.features.dynamic_line_size = true;
266 device_props.features.alignment = true;
267 device_props.features.plane.insert(Plane::Ops);
268
269 register_wmma_features(supported_wmma_combinations, &mut device_props);
270 register_mma_features(supported_mma_combinations, &mut device_props);
271 register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
272
273 let cuda_ctx = CudaContext::new(comp_opts, ctx, arch);
274 let logger = Arc::new(ServerLogger::default());
275 let utilities = ServerUtilities::new(device_props, logger, ());
276
277 CudaServer::new(
278 cuda_ctx,
279 mem_properties,
280 options.memory_config,
281 mem_alignment,
282 device_id,
283 utilities,
284 )
285 }
286}
287
288pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
289
290fn tensor_cores_per_sm(version: u32) -> Option<u32> {
291 match version {
292 70 | 75 => Some(8), 80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), _ => None, }
296}
297
298impl Runtime for CudaRuntime {
299 type Compiler = CudaCompiler;
300 type Server = CudaServer;
301 type Device = CudaDevice;
302
303 fn client(device: &Self::Device) -> ComputeClient<Self> {
304 ComputeClient::load(device)
305 }
306
307 fn name(_client: &ComputeClient<Self>) -> &'static str {
308 "cuda"
309 }
310
311 fn require_array_lengths() -> bool {
312 true
313 }
314
315 fn supported_line_sizes() -> &'static [u8] {
316 &[16, 8, 4, 2, 1]
317 }
318
319 fn max_cube_count() -> (u32, u32, u32) {
320 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
321 }
322
323 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
324 valid_strides(shape, strides)
325 }
326
327 fn target_properties() -> TargetProperties {
328 TargetProperties {
329 mma: MmaProperties {
330 register_size_bits: 32,
331 const_plane_size: 32,
332 register_layout_a: MatrixLayout::RowMajor,
333 register_layout_b: MatrixLayout::ColMajor,
334 register_layout_acc: MatrixLayout::RowMajor,
335 register_duplication_a: 1,
336 register_duplication_b: 1,
337 register_duplication_acc: 1,
338 contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
339 },
340 }
341 }
342}