Skip to main content

cubecl_cuda/
runtime.rs

1use crate::{
2    WmmaCompiler,
3    compute::{CudaServer, context::CudaContext},
4    device::CudaDevice,
5};
6use cubecl_common::{
7    device::{Device, DeviceService},
8    profile::TimingMethod,
9};
10use cubecl_core::{
11    MemoryConfiguration, Runtime,
12    device::{DeviceId, ServerUtilitiesHandle},
13    ir::{
14        BarrierLevel, ContiguousElements, DeviceProperties, ElemType, FloatKind,
15        HardwareProperties, MatrixLayout, MemoryDeviceProperties, MmaProperties, OpaqueType,
16        SemanticType, StorageType, TargetProperties, Type, VectorSize,
17        features::{AtomicUsage, Plane, Tma, TypeUsage},
18    },
19    server::ServerUtilities,
20    zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
21};
22use cubecl_cpp::{
23    ComputeKernel, DialectWmmaCompiler,
24    cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
25    register_supported_types,
26    shared::{
27        CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
28        register_scaled_mma_features, register_wmma_features,
29    },
30};
31use cubecl_runtime::{
32    allocator::PitchedMemoryLayoutPolicy, client::ComputeClient, logging::ServerLogger,
33};
34use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
35use std::{mem::MaybeUninit, sync::Arc};
36
37/// Options configuring the CUDA runtime.
38#[derive(Default)]
39pub struct RuntimeOptions {
40    /// Configures the memory management.
41    pub memory_config: MemoryConfiguration,
42}
43
44#[derive(Debug, Clone)]
45pub struct CudaRuntime;
46
47impl DeviceService for CudaServer {
48    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
49        let options = RuntimeOptions::default();
50        let device = CudaDevice::from_id(device_id);
51
52        // To get the supported WMMA features, and memory properties, we have to initialize the server immediately.
53        cudarc::driver::result::init().unwrap();
54        let device_index = device.index as i32;
55        let device_ptr = cudarc::driver::result::device::get(device_index).unwrap();
56        let arch_major;
57        // SAFETY: Calling CUDA driver FFI to query compute capability attributes.
58        // `device_ptr` is a valid device handle obtained from `cudarc::driver::result::device::get`.
59        let arch_version = unsafe {
60            arch_major = cudarc::driver::result::device::get_attribute(
61            device_ptr,
62            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
63        )
64        .unwrap();
65            let minor = cudarc::driver::result::device::get_attribute(
66            device_ptr,
67            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
68        )
69        .unwrap();
70            arch_major * 10 + minor
71        } as u32;
72
73        // This is the alignment returned by `cuMallocPitched`, so it's the one considered optimal
74        // for row alignment by CUDA. This hasn't changed since at least the GTX 700 series.
75        // Querying texture row align is a heuristic, but also not guaranteed to be the same.
76        let mem_alignment = 512;
77
78        // Ask the wmma compiler for its supported combinations
79        let arch = CudaArchitecture {
80            version: arch_version,
81        };
82        let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
83        let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
84        let supported_scaled_mma_combinations =
85            WmmaCompiler::supported_scaled_mma_combinations(&arch);
86
87        // SAFETY: `device_ptr` is a valid CUDA device. `primary_ctx::retain` returns the
88        // primary context which is then set as current for the calling thread.
89        let ctx = unsafe {
90            let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
91            cudarc::driver::result::ctx::set_current(ctx).unwrap();
92            ctx
93        };
94
95        // SAFETY: `device_ptr` is valid. `cuDeviceTotalMem_v2` writes the total device memory
96        // into the `MaybeUninit`, making `assume_init()` valid on success.
97        let max_memory = unsafe {
98            let mut bytes = MaybeUninit::uninit();
99            cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
100            bytes.assume_init() as u64
101        };
102        let mem_properties = MemoryDeviceProperties {
103            max_page_size: max_memory / 4,
104            alignment: mem_alignment as u64,
105        };
106
107        let mut comp_opts = CompilationOptions {
108            supports_features: CppSupportedFeatures {
109                fast_math: true,
110                ..Default::default()
111            },
112            ..Default::default()
113        };
114
115        // SAFETY: `device_ptr` is a valid CUDA device. All `get_attribute` calls query
116        // read-only device properties via the CUDA driver API.
117        let hardware_props = unsafe {
118            use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
119            let warp_size =
120                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
121            let max_shared = get_attribute(
122                device_ptr,
123                CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
124            )
125            .unwrap() as usize;
126            let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
127                .unwrap() as u32;
128            let block_dim_x =
129                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
130            let block_dim_y =
131                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
132            let block_dim_z =
133                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
134            let max_cube_dim = (block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
135
136            let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
137            let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
138            let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
139            let max_cube_count = (grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
140
141            let num_streaming_multiprocessors = Some(
142                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
143            );
144            let num_tensor_cores = tensor_cores_per_sm(arch_version);
145
146            comp_opts.warp_size = warp_size;
147
148            HardwareProperties {
149                load_width: 128,
150                plane_size_min: warp_size,
151                plane_size_max: warp_size,
152                max_bindings: crate::device::CUDA_MAX_BINDINGS,
153                max_shared_memory_size: max_shared,
154                max_cube_count,
155                max_units_per_cube: max_threads,
156                max_cube_dim,
157                num_streaming_multiprocessors,
158                num_tensor_cores,
159                min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
160                    None
161                } else {
162                    Some(8)
163                },
164                num_cpu_cores: None,
165                max_vector_size: VectorSize::MAX,
166            }
167        };
168
169        let mut device_props = DeviceProperties::new(
170            Default::default(),
171            mem_properties.clone(),
172            hardware_props,
173            TimingMethod::System,
174        );
175        register_supported_types(&mut device_props);
176        device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
177        if arch_version >= 60 {
178            device_props.register_atomic_type_usage(
179                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F64))),
180                AtomicUsage::Add | AtomicUsage::LoadStore,
181            );
182        }
183        if arch_version >= 70 {
184            device_props.register_atomic_type_usage(
185                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))),
186                AtomicUsage::Add | AtomicUsage::LoadStore,
187            );
188            device_props.register_atomic_type_usage(
189                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))).with_vector_size(2),
190                AtomicUsage::Add | AtomicUsage::LoadStore,
191            );
192            device_props.register_semantic_type(SemanticType::Pipeline);
193            device_props
194                .register_type_usage(OpaqueType::Barrier(BarrierLevel::Unit), TypeUsage::Buffer);
195            device_props
196                .register_type_usage(OpaqueType::Barrier(BarrierLevel::Cube), TypeUsage::Buffer);
197            device_props.features.plane.insert(Plane::Sync);
198            comp_opts.supports_features.grid_constants = true;
199        }
200
201        if arch_version >= 75 {
202            device_props
203                .features
204                .matmul
205                .ldmatrix
206                .insert(ElemType::Float(FloatKind::F16).into());
207            device_props
208                .features
209                .matmul
210                .ldmatrix
211                .insert(ElemType::Float(FloatKind::BF16).into());
212            comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
213        }
214
215        if arch_version >= 80 {
216            device_props.features.copy_async = true;
217        }
218
219        // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16.
220        // if arch.get_version() >= 80 {
221        //     device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
222        // }
223
224        if arch_version >= 89 {
225            device_props.register_type_usage(
226                ElemType::Float(FloatKind::E4M3),
227                TypeUsage::Conversion | TypeUsage::Buffer,
228            );
229            device_props.register_type_usage(
230                ElemType::Float(FloatKind::E5M2),
231                TypeUsage::Conversion | TypeUsage::Buffer,
232            );
233        }
234        if arch_version >= 90 {
235            device_props.features.tma.insert(Tma::Base);
236            device_props.register_semantic_type(SemanticType::TensorMap);
237            device_props.features.cube_cluster = true;
238            comp_opts.supports_features.clusters = true;
239            comp_opts.supports_features.elect_sync = true;
240            device_props
241                .features
242                .matmul
243                .stmatrix
244                .insert(ElemType::Float(FloatKind::F16).into());
245            device_props
246                .features
247                .matmul
248                .stmatrix
249                .insert(ElemType::Float(FloatKind::BF16).into());
250            device_props.register_atomic_type_usage(
251                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(2),
252                AtomicUsage::LoadStore | AtomicUsage::Add,
253            );
254            device_props.register_atomic_type_usage(
255                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(4),
256                AtomicUsage::LoadStore | AtomicUsage::Add,
257            );
258        }
259
260        if arch_version >= 100 {
261            device_props.features.tma.insert(Tma::Im2colWide);
262            // Breaks swizzle so disable for now and fix in a PR specifically for this
263            // if CUDA_VERSION >= 12090 {
264            //     device_props.hardware.load_width = 256;
265            // }
266        }
267
268        // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a
269        // major version. Try to keep this up to date with new arch major revisions if they also
270        // implement it.
271        if arch_major == 10 || arch_major == 11 || arch_major == 12 {
272            device_props
273                .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
274            device_props.register_type_usage(
275                StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
276                TypeUsage::Conversion | TypeUsage::Buffer,
277            );
278            device_props.register_type_usage(
279                ElemType::Float(FloatKind::E2M3),
280                TypeUsage::Conversion | TypeUsage::Buffer,
281            );
282            device_props.register_type_usage(
283                ElemType::Float(FloatKind::E3M2),
284                TypeUsage::Conversion | TypeUsage::Buffer,
285            );
286            device_props.register_type_usage(
287                ElemType::Float(FloatKind::UE8M0),
288                TypeUsage::Conversion | TypeUsage::Buffer,
289            );
290
291            if CUDA_VERSION >= 12080 {
292                device_props.features.tma.insert(Tma::SwizzleAtomicity);
293            }
294        }
295
296        device_props.features.memory_reinterpret = true;
297        device_props.features.alignment = true;
298        device_props.features.plane.insert(Plane::Ops);
299        device_props
300            .features
301            .plane
302            .insert(Plane::NonUniformControlFlow);
303
304        register_wmma_features(supported_wmma_combinations, &mut device_props);
305        register_mma_features(supported_mma_combinations, &mut device_props);
306        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
307
308        let cuda_ctx = CudaContext::new(comp_opts, device_props.clone(), ctx, arch);
309        let logger = Arc::new(ServerLogger::default());
310        let policy = PitchedMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
311        let utilities = ServerUtilities::new(device_props, logger, (), policy);
312
313        CudaServer::new(
314            cuda_ctx,
315            mem_properties,
316            options.memory_config,
317            mem_alignment,
318            device_id,
319            utilities,
320        )
321    }
322
323    fn utilities(&self) -> ServerUtilitiesHandle {
324        self.utilities() as ServerUtilitiesHandle
325    }
326}
327
328pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
329pub type CudaComputeKernel = ComputeKernel<CudaDialect<WmmaCompiler>>;
330
331fn tensor_cores_per_sm(version: u32) -> Option<u32> {
332    match version {
333        70 | 75 => Some(8),                           // Volta, Turing
334        80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), // Ampere, Hopper, Blackwell
335        _ => None,                                    // Unknown or unsupported architecture
336    }
337}
338
339impl Runtime for CudaRuntime {
340    type Compiler = CudaCompiler;
341    type Server = CudaServer;
342    type Device = CudaDevice;
343
344    fn client(device: &Self::Device) -> ComputeClient<Self> {
345        ComputeClient::load(device)
346    }
347
348    fn name(_client: &ComputeClient<Self>) -> &'static str {
349        "cuda"
350    }
351
352    fn require_array_lengths() -> bool {
353        true
354    }
355
356    fn max_cube_count() -> (u32, u32, u32) {
357        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
358    }
359
360    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
361        has_pitched_row_major_strides(shape, strides)
362    }
363
364    fn target_properties() -> TargetProperties {
365        TargetProperties {
366            mma: MmaProperties {
367                register_size_bits: 32,
368                const_plane_size: 32,
369                register_layout_a: MatrixLayout::RowMajor,
370                register_layout_b: MatrixLayout::ColMajor,
371                register_layout_acc: MatrixLayout::RowMajor,
372                register_duplication_a: 1,
373                register_duplication_b: 1,
374                register_duplication_acc: 1,
375                contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
376            },
377        }
378    }
379
380    fn enumerate_devices(
381        _: u16,
382        _: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
383    ) -> Vec<cubecl_core::device::DeviceId> {
384        let count = cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize;
385        (0..count)
386            .map(|i| DeviceId {
387                type_id: 0,
388                index_id: i as u16,
389            })
390            .collect()
391    }
392}