1use crate::{
2 WmmaCompiler,
3 compute::{CudaServer, context::CudaContext},
4 device::CudaDevice,
5};
6use cubecl_common::{
7 device::{Device, DeviceState},
8 profile::TimingMethod,
9};
10use cubecl_core::{
11 MemoryConfiguration, Runtime,
12 ir::{
13 BarrierLevel, ContiguousElements, DeviceProperties, ElemType, FloatKind,
14 HardwareProperties, LineSize, MatrixLayout, MemoryDeviceProperties, MmaProperties,
15 OpaqueType, SemanticType, StorageType, TargetProperties,
16 features::{Plane, Tma, TypeUsage},
17 },
18 server::ServerUtilities,
19 zspace::striding::has_pitched_row_major_strides,
20};
21use cubecl_cpp::{
22 ComputeKernel, DialectWmmaCompiler,
23 cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
24 register_supported_types,
25 shared::{
26 CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
27 register_scaled_mma_features, register_wmma_features,
28 },
29};
30use cubecl_runtime::{client::ComputeClient, logging::ServerLogger};
31use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
32use std::{mem::MaybeUninit, sync::Arc};
33
34#[derive(Default)]
36pub struct RuntimeOptions {
37 pub memory_config: MemoryConfiguration,
39}
40
41#[derive(Debug)]
42pub struct CudaRuntime;
43
44impl DeviceState for CudaServer {
45 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
46 let options = RuntimeOptions::default();
47 let device = CudaDevice::from_id(device_id);
48
49 cudarc::driver::result::init().unwrap();
51 let device_id = device.index as i32;
52 let device_ptr = cudarc::driver::result::device::get(device_id).unwrap();
53 let arch_major;
54 let arch_version = unsafe {
55 arch_major = cudarc::driver::result::device::get_attribute(
56 device_ptr,
57 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
58 )
59 .unwrap();
60 let minor = cudarc::driver::result::device::get_attribute(
61 device_ptr,
62 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
63 )
64 .unwrap();
65 arch_major * 10 + minor
66 } as u32;
67
68 let mem_alignment = 512;
72
73 let arch = CudaArchitecture {
75 version: arch_version,
76 };
77 let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
78 let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
79 let supported_scaled_mma_combinations =
80 WmmaCompiler::supported_scaled_mma_combinations(&arch);
81
82 let ctx = unsafe {
83 let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
84 cudarc::driver::result::ctx::set_current(ctx).unwrap();
85 ctx
86 };
87
88 let max_memory = unsafe {
89 let mut bytes = MaybeUninit::uninit();
90 cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
91 bytes.assume_init() as u64
92 };
93 let mem_properties = MemoryDeviceProperties {
94 max_page_size: max_memory / 4,
95 alignment: mem_alignment as u64,
96 };
97
98 let mut comp_opts = CompilationOptions {
99 supports_features: CppSupportedFeatures {
100 fast_math: true,
101 ..Default::default()
102 },
103 ..Default::default()
104 };
105
106 let hardware_props = unsafe {
107 use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
108 let warp_size =
109 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
110 let max_shared = get_attribute(
111 device_ptr,
112 CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
113 )
114 .unwrap() as usize;
115 let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
116 .unwrap() as u32;
117 let block_dim_x =
118 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
119 let block_dim_y =
120 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
121 let block_dim_z =
122 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
123 let max_cube_dim = (block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
124
125 let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
126 let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
127 let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
128 let max_cube_count = (grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
129
130 let num_streaming_multiprocessors = Some(
131 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
132 );
133 let num_tensor_cores = tensor_cores_per_sm(arch_version);
134
135 comp_opts.warp_size = warp_size;
136
137 HardwareProperties {
138 load_width: 128,
139 plane_size_min: warp_size,
140 plane_size_max: warp_size,
141 max_bindings: crate::device::CUDA_MAX_BINDINGS,
142 max_shared_memory_size: max_shared,
143 max_cube_count,
144 max_units_per_cube: max_threads,
145 max_cube_dim,
146 num_streaming_multiprocessors,
147 num_tensor_cores,
148 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
149 None
150 } else {
151 Some(8)
152 },
153 num_cpu_cores: None,
154 max_line_size: LineSize::MAX,
155 }
156 };
157
158 let mut device_props = DeviceProperties::new(
159 Default::default(),
160 mem_properties.clone(),
161 hardware_props,
162 TimingMethod::System,
163 );
164 register_supported_types(&mut device_props);
165 device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
166 if arch_version >= 60 {
167 device_props.register_type_usage(
168 StorageType::Atomic(ElemType::Float(FloatKind::F64)),
169 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
170 );
171 }
172 if arch_version >= 70 {
173 device_props.register_type_usage(
174 StorageType::Atomic(ElemType::Float(FloatKind::F16)),
175 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
176 );
177 device_props.register_semantic_type(SemanticType::Pipeline);
178 device_props
179 .register_type_usage(OpaqueType::Barrier(BarrierLevel::Unit), TypeUsage::Buffer);
180 device_props
181 .register_type_usage(OpaqueType::Barrier(BarrierLevel::Cube), TypeUsage::Buffer);
182 device_props.features.plane.insert(Plane::Sync);
183 comp_opts.supports_features.grid_constants = true;
184 }
185
186 if arch_version >= 75 {
187 device_props
188 .features
189 .ldmatrix
190 .insert(ElemType::Float(FloatKind::F16).into());
191 device_props
192 .features
193 .ldmatrix
194 .insert(ElemType::Float(FloatKind::BF16).into());
195 comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
196 }
197
198 if arch_version >= 80 {
199 device_props.features.copy_async = true;
200 }
201
202 if arch_version >= 89 {
208 device_props.register_type_usage(
209 ElemType::Float(FloatKind::E4M3),
210 TypeUsage::Conversion | TypeUsage::Buffer,
211 );
212 device_props.register_type_usage(
213 ElemType::Float(FloatKind::E5M2),
214 TypeUsage::Conversion | TypeUsage::Buffer,
215 );
216 }
217 if arch_version >= 90 {
218 device_props.features.tma.insert(Tma::Base);
219 device_props.register_semantic_type(SemanticType::TensorMap);
220 device_props.features.cube_cluster = true;
221 comp_opts.supports_features.clusters = true;
222 comp_opts.supports_features.elect_sync = true;
223 device_props
224 .features
225 .stmatrix
226 .insert(ElemType::Float(FloatKind::F16).into());
227 device_props
228 .features
229 .stmatrix
230 .insert(ElemType::Float(FloatKind::BF16).into());
231 }
232
233 if arch_version >= 100 {
234 device_props.features.tma.insert(Tma::Im2colWide);
235 }
240
241 if arch_major == 10 || arch_major == 11 || arch_major == 12 {
245 device_props
246 .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
247 device_props.register_type_usage(
248 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
249 TypeUsage::Conversion | TypeUsage::Buffer,
250 );
251 device_props.register_type_usage(
252 ElemType::Float(FloatKind::E2M3),
253 TypeUsage::Conversion | TypeUsage::Buffer,
254 );
255 device_props.register_type_usage(
256 ElemType::Float(FloatKind::E3M2),
257 TypeUsage::Conversion | TypeUsage::Buffer,
258 );
259 device_props.register_type_usage(
260 ElemType::Float(FloatKind::UE8M0),
261 TypeUsage::Conversion | TypeUsage::Buffer,
262 );
263
264 if CUDA_VERSION >= 12080 {
265 device_props.features.tma.insert(Tma::SwizzleAtomicity);
266 }
267 }
268
269 device_props.features.dynamic_line_size = true;
270 device_props.features.alignment = true;
271 device_props.features.plane.insert(Plane::Ops);
272 device_props
273 .features
274 .plane
275 .insert(Plane::NonUniformControlFlow);
276
277 register_wmma_features(supported_wmma_combinations, &mut device_props);
278 register_mma_features(supported_mma_combinations, &mut device_props);
279 register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
280
281 let cuda_ctx = CudaContext::new(comp_opts, device_props.clone(), ctx, arch);
282 let logger = Arc::new(ServerLogger::default());
283 let utilities = ServerUtilities::new(device_props, logger, ());
284
285 CudaServer::new(
286 cuda_ctx,
287 mem_properties,
288 options.memory_config,
289 mem_alignment,
290 device_id,
291 utilities,
292 )
293 }
294}
295
296pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
297pub type CudaComputeKernel = ComputeKernel<CudaDialect<WmmaCompiler>>;
298
299fn tensor_cores_per_sm(version: u32) -> Option<u32> {
300 match version {
301 70 | 75 => Some(8), 80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), _ => None, }
305}
306
307impl Runtime for CudaRuntime {
308 type Compiler = CudaCompiler;
309 type Server = CudaServer;
310 type Device = CudaDevice;
311
312 fn client(device: &Self::Device) -> ComputeClient<Self> {
313 ComputeClient::load(device)
314 }
315
316 fn name(_client: &ComputeClient<Self>) -> &'static str {
317 "cuda"
318 }
319
320 fn require_array_lengths() -> bool {
321 true
322 }
323
324 fn max_cube_count() -> (u32, u32, u32) {
325 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
326 }
327
328 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
329 has_pitched_row_major_strides(shape, strides)
330 }
331
332 fn target_properties() -> TargetProperties {
333 TargetProperties {
334 mma: MmaProperties {
335 register_size_bits: 32,
336 const_plane_size: 32,
337 register_layout_a: MatrixLayout::RowMajor,
338 register_layout_b: MatrixLayout::ColMajor,
339 register_layout_acc: MatrixLayout::RowMajor,
340 register_duplication_a: 1,
341 register_duplication_b: 1,
342 register_duplication_acc: 1,
343 contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
344 },
345 }
346 }
347}