1use crate::{
2 WmmaCompiler,
3 compute::{CudaServer, context::CudaContext},
4 device::CudaDevice,
5};
6use cubecl_common::{
7 device::{Device, DeviceService},
8 profile::TimingMethod,
9};
10use cubecl_core::{
11 MemoryConfiguration, Runtime,
12 device::{DeviceId, ServerUtilitiesHandle},
13 ir::{
14 BarrierLevel, ContiguousElements, DeviceProperties, ElemType, FloatKind,
15 HardwareProperties, MatrixLayout, MemoryDeviceProperties, MmaProperties, OpaqueType,
16 SemanticType, StorageType, TargetProperties, Type, VectorSize,
17 features::{AtomicUsage, Plane, Tma, TypeUsage},
18 },
19 server::ServerUtilities,
20 zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
21};
22use cubecl_cpp::{
23 ComputeKernel, DialectWmmaCompiler,
24 cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
25 register_supported_types,
26 shared::{
27 CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
28 register_scaled_mma_features, register_wmma_features,
29 },
30};
31use cubecl_runtime::{
32 allocator::PitchedMemoryLayoutPolicy, client::ComputeClient, logging::ServerLogger,
33};
34use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
35use std::{mem::MaybeUninit, sync::Arc};
36
37#[derive(Default)]
39pub struct RuntimeOptions {
40 pub memory_config: MemoryConfiguration,
42}
43
44#[derive(Debug, Clone)]
45pub struct CudaRuntime;
46
47impl DeviceService for CudaServer {
48 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
49 let options = RuntimeOptions::default();
50 let device = CudaDevice::from_id(device_id);
51
52 cudarc::driver::result::init().unwrap();
54 let device_index = device.index as i32;
55 let device_ptr = cudarc::driver::result::device::get(device_index).unwrap();
56 let arch_major;
57 let arch_version = unsafe {
60 arch_major = cudarc::driver::result::device::get_attribute(
61 device_ptr,
62 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
63 )
64 .unwrap();
65 let minor = cudarc::driver::result::device::get_attribute(
66 device_ptr,
67 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
68 )
69 .unwrap();
70 arch_major * 10 + minor
71 } as u32;
72
73 let mem_alignment = 512;
77
78 let arch = CudaArchitecture {
80 version: arch_version,
81 };
82 let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
83 let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
84 let supported_scaled_mma_combinations =
85 WmmaCompiler::supported_scaled_mma_combinations(&arch);
86
87 let ctx = unsafe {
90 let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
91 cudarc::driver::result::ctx::set_current(ctx).unwrap();
92 ctx
93 };
94
95 let max_memory = unsafe {
98 let mut bytes = MaybeUninit::uninit();
99 cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
100 bytes.assume_init() as u64
101 };
102 let mem_properties = MemoryDeviceProperties {
103 max_page_size: max_memory / 4,
104 alignment: mem_alignment as u64,
105 };
106
107 let mut comp_opts = CompilationOptions {
108 supports_features: CppSupportedFeatures {
109 fast_math: true,
110 ..Default::default()
111 },
112 ..Default::default()
113 };
114
115 let hardware_props = unsafe {
118 use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
119 let warp_size =
120 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
121 let max_shared = get_attribute(
122 device_ptr,
123 CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
124 )
125 .unwrap() as usize;
126 let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
127 .unwrap() as u32;
128 let block_dim_x =
129 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
130 let block_dim_y =
131 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
132 let block_dim_z =
133 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
134 let max_cube_dim = (block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
135
136 let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
137 let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
138 let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
139 let max_cube_count = (grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
140
141 let num_streaming_multiprocessors = Some(
142 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
143 );
144 let num_tensor_cores = tensor_cores_per_sm(arch_version);
145
146 comp_opts.warp_size = warp_size;
147
148 HardwareProperties {
149 load_width: 128,
150 plane_size_min: warp_size,
151 plane_size_max: warp_size,
152 max_bindings: crate::device::CUDA_MAX_BINDINGS,
153 max_shared_memory_size: max_shared,
154 max_cube_count,
155 max_units_per_cube: max_threads,
156 max_cube_dim,
157 num_streaming_multiprocessors,
158 num_tensor_cores,
159 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
160 None
161 } else {
162 Some(8)
163 },
164 num_cpu_cores: None,
165 max_vector_size: VectorSize::MAX,
166 }
167 };
168
169 let mut device_props = DeviceProperties::new(
170 Default::default(),
171 mem_properties.clone(),
172 hardware_props,
173 TimingMethod::System,
174 );
175 register_supported_types(&mut device_props);
176 device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
177 if arch_version >= 60 {
178 device_props.register_atomic_type_usage(
179 Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F64))),
180 AtomicUsage::Add | AtomicUsage::LoadStore,
181 );
182 }
183 if arch_version >= 70 {
184 device_props.register_atomic_type_usage(
185 Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))),
186 AtomicUsage::Add | AtomicUsage::LoadStore,
187 );
188 device_props.register_atomic_type_usage(
189 Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))).with_vector_size(2),
190 AtomicUsage::Add | AtomicUsage::LoadStore,
191 );
192 device_props.register_semantic_type(SemanticType::Pipeline);
193 device_props
194 .register_type_usage(OpaqueType::Barrier(BarrierLevel::Unit), TypeUsage::Buffer);
195 device_props
196 .register_type_usage(OpaqueType::Barrier(BarrierLevel::Cube), TypeUsage::Buffer);
197 device_props.features.plane.insert(Plane::Sync);
198 comp_opts.supports_features.grid_constants = true;
199 }
200
201 if arch_version >= 75 {
202 device_props
203 .features
204 .matmul
205 .ldmatrix
206 .insert(ElemType::Float(FloatKind::F16).into());
207 device_props
208 .features
209 .matmul
210 .ldmatrix
211 .insert(ElemType::Float(FloatKind::BF16).into());
212 comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
213 }
214
215 if arch_version >= 80 {
216 device_props.features.copy_async = true;
217 }
218
219 if arch_version >= 89 {
225 device_props.register_type_usage(
226 ElemType::Float(FloatKind::E4M3),
227 TypeUsage::Conversion | TypeUsage::Buffer,
228 );
229 device_props.register_type_usage(
230 ElemType::Float(FloatKind::E5M2),
231 TypeUsage::Conversion | TypeUsage::Buffer,
232 );
233 }
234 if arch_version >= 90 {
235 device_props.features.tma.insert(Tma::Base);
236 device_props.register_semantic_type(SemanticType::TensorMap);
237 device_props.features.cube_cluster = true;
238 comp_opts.supports_features.clusters = true;
239 comp_opts.supports_features.elect_sync = true;
240 device_props
241 .features
242 .matmul
243 .stmatrix
244 .insert(ElemType::Float(FloatKind::F16).into());
245 device_props
246 .features
247 .matmul
248 .stmatrix
249 .insert(ElemType::Float(FloatKind::BF16).into());
250 device_props.register_atomic_type_usage(
251 Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(2),
252 AtomicUsage::LoadStore | AtomicUsage::Add,
253 );
254 device_props.register_atomic_type_usage(
255 Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(4),
256 AtomicUsage::LoadStore | AtomicUsage::Add,
257 );
258 }
259
260 if arch_version >= 100 {
261 device_props.features.tma.insert(Tma::Im2colWide);
262 }
267
268 if arch_major == 10 || arch_major == 11 || arch_major == 12 {
272 device_props
273 .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
274 device_props.register_type_usage(
275 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
276 TypeUsage::Conversion | TypeUsage::Buffer,
277 );
278 device_props.register_type_usage(
279 ElemType::Float(FloatKind::E2M3),
280 TypeUsage::Conversion | TypeUsage::Buffer,
281 );
282 device_props.register_type_usage(
283 ElemType::Float(FloatKind::E3M2),
284 TypeUsage::Conversion | TypeUsage::Buffer,
285 );
286 device_props.register_type_usage(
287 ElemType::Float(FloatKind::UE8M0),
288 TypeUsage::Conversion | TypeUsage::Buffer,
289 );
290
291 if CUDA_VERSION >= 12080 {
292 device_props.features.tma.insert(Tma::SwizzleAtomicity);
293 }
294 }
295
296 device_props.features.memory_reinterpret = true;
297 device_props.features.alignment = true;
298 device_props.features.plane.insert(Plane::Ops);
299 device_props
300 .features
301 .plane
302 .insert(Plane::NonUniformControlFlow);
303
304 register_wmma_features(supported_wmma_combinations, &mut device_props);
305 register_mma_features(supported_mma_combinations, &mut device_props);
306 register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
307
308 let cuda_ctx = CudaContext::new(comp_opts, device_props.clone(), ctx, arch);
309 let logger = Arc::new(ServerLogger::default());
310 let policy = PitchedMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
311 let utilities = ServerUtilities::new(device_props, logger, (), policy);
312
313 CudaServer::new(
314 cuda_ctx,
315 mem_properties,
316 options.memory_config,
317 mem_alignment,
318 device_id,
319 utilities,
320 )
321 }
322
323 fn utilities(&self) -> ServerUtilitiesHandle {
324 self.utilities() as ServerUtilitiesHandle
325 }
326}
327
328pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
329pub type CudaComputeKernel = ComputeKernel<CudaDialect<WmmaCompiler>>;
330
331fn tensor_cores_per_sm(version: u32) -> Option<u32> {
332 match version {
333 70 | 75 => Some(8), 80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), _ => None, }
337}
338
339impl Runtime for CudaRuntime {
340 type Compiler = CudaCompiler;
341 type Server = CudaServer;
342 type Device = CudaDevice;
343
344 fn client(device: &Self::Device) -> ComputeClient<Self> {
345 ComputeClient::load(device)
346 }
347
348 fn name(_client: &ComputeClient<Self>) -> &'static str {
349 "cuda"
350 }
351
352 fn require_array_lengths() -> bool {
353 true
354 }
355
356 fn max_cube_count() -> (u32, u32, u32) {
357 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
358 }
359
360 fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
361 has_pitched_row_major_strides(shape, strides)
362 }
363
364 fn target_properties() -> TargetProperties {
365 TargetProperties {
366 mma: MmaProperties {
367 register_size_bits: 32,
368 const_plane_size: 32,
369 register_layout_a: MatrixLayout::RowMajor,
370 register_layout_b: MatrixLayout::ColMajor,
371 register_layout_acc: MatrixLayout::RowMajor,
372 register_duplication_a: 1,
373 register_duplication_b: 1,
374 register_duplication_acc: 1,
375 contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
376 },
377 }
378 }
379
380 fn enumerate_devices(
381 _: u16,
382 _: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
383 ) -> Vec<cubecl_core::device::DeviceId> {
384 let count = cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize;
385 (0..count)
386 .map(|i| DeviceId {
387 type_id: 0,
388 index_id: i as u16,
389 })
390 .collect()
391 }
392}