Skip to main content

cubecl_cuda/
runtime.rs

1use crate::{
2    WmmaCompiler,
3    compute::{CudaServer, context::CudaContext},
4    device::CudaDevice,
5};
6use cubecl_common::{
7    device::{Device, DeviceState},
8    profile::TimingMethod,
9};
10use cubecl_core::{
11    MemoryConfiguration, Runtime,
12    ir::{
13        BarrierLevel, ContiguousElements, DeviceProperties, ElemType, FloatKind,
14        HardwareProperties, LineSize, MatrixLayout, MemoryDeviceProperties, MmaProperties,
15        OpaqueType, SemanticType, StorageType, TargetProperties,
16        features::{Plane, Tma, TypeUsage},
17    },
18    server::ServerUtilities,
19    zspace::striding::has_pitched_row_major_strides,
20};
21use cubecl_cpp::{
22    ComputeKernel, DialectWmmaCompiler,
23    cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
24    register_supported_types,
25    shared::{
26        CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
27        register_scaled_mma_features, register_wmma_features,
28    },
29};
30use cubecl_runtime::{client::ComputeClient, logging::ServerLogger};
31use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
32use std::{mem::MaybeUninit, sync::Arc};
33
34/// Options configuring the CUDA runtime.
35#[derive(Default)]
36pub struct RuntimeOptions {
37    /// Configures the memory management.
38    pub memory_config: MemoryConfiguration,
39}
40
41#[derive(Debug)]
42pub struct CudaRuntime;
43
44impl DeviceState for CudaServer {
45    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
46        let options = RuntimeOptions::default();
47        let device = CudaDevice::from_id(device_id);
48
49        // To get the supported WMMA features, and memory properties, we have to initialize the server immediately.
50        cudarc::driver::result::init().unwrap();
51        let device_id = device.index as i32;
52        let device_ptr = cudarc::driver::result::device::get(device_id).unwrap();
53        let arch_major;
54        let arch_version = unsafe {
55            arch_major = cudarc::driver::result::device::get_attribute(
56            device_ptr,
57            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
58        )
59        .unwrap();
60            let minor = cudarc::driver::result::device::get_attribute(
61            device_ptr,
62            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
63        )
64        .unwrap();
65            arch_major * 10 + minor
66        } as u32;
67
68        // This is the alignment returned by `cuMallocPitched`, so it's the one considered optimal
69        // for row alignment by CUDA. This hasn't changed since at least the GTX 700 series.
70        // Querying texture row align is a heuristic, but also not guaranteed to be the same.
71        let mem_alignment = 512;
72
73        // Ask the wmma compiler for its supported combinations
74        let arch = CudaArchitecture {
75            version: arch_version,
76        };
77        let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
78        let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
79        let supported_scaled_mma_combinations =
80            WmmaCompiler::supported_scaled_mma_combinations(&arch);
81
82        let ctx = unsafe {
83            let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
84            cudarc::driver::result::ctx::set_current(ctx).unwrap();
85            ctx
86        };
87
88        let max_memory = unsafe {
89            let mut bytes = MaybeUninit::uninit();
90            cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
91            bytes.assume_init() as u64
92        };
93        let mem_properties = MemoryDeviceProperties {
94            max_page_size: max_memory / 4,
95            alignment: mem_alignment as u64,
96        };
97
98        let mut comp_opts = CompilationOptions {
99            supports_features: CppSupportedFeatures {
100                fast_math: true,
101                ..Default::default()
102            },
103            ..Default::default()
104        };
105
106        let hardware_props = unsafe {
107            use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
108            let warp_size =
109                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
110            let max_shared = get_attribute(
111                device_ptr,
112                CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
113            )
114            .unwrap() as usize;
115            let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
116                .unwrap() as u32;
117            let block_dim_x =
118                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
119            let block_dim_y =
120                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
121            let block_dim_z =
122                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
123            let max_cube_dim = (block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
124
125            let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
126            let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
127            let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
128            let max_cube_count = (grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
129
130            let num_streaming_multiprocessors = Some(
131                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
132            );
133            let num_tensor_cores = tensor_cores_per_sm(arch_version);
134
135            comp_opts.warp_size = warp_size;
136
137            HardwareProperties {
138                load_width: 128,
139                plane_size_min: warp_size,
140                plane_size_max: warp_size,
141                max_bindings: crate::device::CUDA_MAX_BINDINGS,
142                max_shared_memory_size: max_shared,
143                max_cube_count,
144                max_units_per_cube: max_threads,
145                max_cube_dim,
146                num_streaming_multiprocessors,
147                num_tensor_cores,
148                min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
149                    None
150                } else {
151                    Some(8)
152                },
153                num_cpu_cores: None,
154                max_line_size: LineSize::MAX,
155            }
156        };
157
158        let mut device_props = DeviceProperties::new(
159            Default::default(),
160            mem_properties.clone(),
161            hardware_props,
162            TimingMethod::System,
163        );
164        register_supported_types(&mut device_props);
165        device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
166        if arch_version >= 60 {
167            device_props.register_type_usage(
168                StorageType::Atomic(ElemType::Float(FloatKind::F64)),
169                TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
170            );
171        }
172        if arch_version >= 70 {
173            device_props.register_type_usage(
174                StorageType::Atomic(ElemType::Float(FloatKind::F16)),
175                TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
176            );
177            device_props.register_semantic_type(SemanticType::Pipeline);
178            device_props
179                .register_type_usage(OpaqueType::Barrier(BarrierLevel::Unit), TypeUsage::Buffer);
180            device_props
181                .register_type_usage(OpaqueType::Barrier(BarrierLevel::Cube), TypeUsage::Buffer);
182            device_props.features.plane.insert(Plane::Sync);
183            comp_opts.supports_features.grid_constants = true;
184        }
185
186        if arch_version >= 75 {
187            device_props
188                .features
189                .ldmatrix
190                .insert(ElemType::Float(FloatKind::F16).into());
191            device_props
192                .features
193                .ldmatrix
194                .insert(ElemType::Float(FloatKind::BF16).into());
195            comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
196        }
197
198        if arch_version >= 80 {
199            device_props.features.copy_async = true;
200        }
201
202        // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16.
203        // if arch.get_version() >= 80 {
204        //     device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
205        // }
206
207        if arch_version >= 89 {
208            device_props.register_type_usage(
209                ElemType::Float(FloatKind::E4M3),
210                TypeUsage::Conversion | TypeUsage::Buffer,
211            );
212            device_props.register_type_usage(
213                ElemType::Float(FloatKind::E5M2),
214                TypeUsage::Conversion | TypeUsage::Buffer,
215            );
216        }
217        if arch_version >= 90 {
218            device_props.features.tma.insert(Tma::Base);
219            device_props.register_semantic_type(SemanticType::TensorMap);
220            device_props.features.cube_cluster = true;
221            comp_opts.supports_features.clusters = true;
222            comp_opts.supports_features.elect_sync = true;
223            device_props
224                .features
225                .stmatrix
226                .insert(ElemType::Float(FloatKind::F16).into());
227            device_props
228                .features
229                .stmatrix
230                .insert(ElemType::Float(FloatKind::BF16).into());
231        }
232
233        if arch_version >= 100 {
234            device_props.features.tma.insert(Tma::Im2colWide);
235            // Breaks swizzle so disable for now and fix in a PR specifically for this
236            // if CUDA_VERSION >= 12090 {
237            //     device_props.hardware.load_width = 256;
238            // }
239        }
240
241        // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a
242        // major version. Try to keep this up to date with new arch major revisions if they also
243        // implement it.
244        if arch_major == 10 || arch_major == 11 || arch_major == 12 {
245            device_props
246                .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
247            device_props.register_type_usage(
248                StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
249                TypeUsage::Conversion | TypeUsage::Buffer,
250            );
251            device_props.register_type_usage(
252                ElemType::Float(FloatKind::E2M3),
253                TypeUsage::Conversion | TypeUsage::Buffer,
254            );
255            device_props.register_type_usage(
256                ElemType::Float(FloatKind::E3M2),
257                TypeUsage::Conversion | TypeUsage::Buffer,
258            );
259            device_props.register_type_usage(
260                ElemType::Float(FloatKind::UE8M0),
261                TypeUsage::Conversion | TypeUsage::Buffer,
262            );
263
264            if CUDA_VERSION >= 12080 {
265                device_props.features.tma.insert(Tma::SwizzleAtomicity);
266            }
267        }
268
269        device_props.features.dynamic_line_size = true;
270        device_props.features.alignment = true;
271        device_props.features.plane.insert(Plane::Ops);
272        device_props
273            .features
274            .plane
275            .insert(Plane::NonUniformControlFlow);
276
277        register_wmma_features(supported_wmma_combinations, &mut device_props);
278        register_mma_features(supported_mma_combinations, &mut device_props);
279        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
280
281        let cuda_ctx = CudaContext::new(comp_opts, device_props.clone(), ctx, arch);
282        let logger = Arc::new(ServerLogger::default());
283        let utilities = ServerUtilities::new(device_props, logger, ());
284
285        CudaServer::new(
286            cuda_ctx,
287            mem_properties,
288            options.memory_config,
289            mem_alignment,
290            device_id,
291            utilities,
292        )
293    }
294}
295
296pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
297pub type CudaComputeKernel = ComputeKernel<CudaDialect<WmmaCompiler>>;
298
299fn tensor_cores_per_sm(version: u32) -> Option<u32> {
300    match version {
301        70 | 75 => Some(8),                           // Volta, Turing
302        80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), // Ampere, Hopper, Blackwell
303        _ => None,                                    // Unknown or unsupported architecture
304    }
305}
306
307impl Runtime for CudaRuntime {
308    type Compiler = CudaCompiler;
309    type Server = CudaServer;
310    type Device = CudaDevice;
311
312    fn client(device: &Self::Device) -> ComputeClient<Self> {
313        ComputeClient::load(device)
314    }
315
316    fn name(_client: &ComputeClient<Self>) -> &'static str {
317        "cuda"
318    }
319
320    fn require_array_lengths() -> bool {
321        true
322    }
323
324    fn max_cube_count() -> (u32, u32, u32) {
325        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
326    }
327
328    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
329        has_pitched_row_major_strides(shape, strides)
330    }
331
332    fn target_properties() -> TargetProperties {
333        TargetProperties {
334            mma: MmaProperties {
335                register_size_bits: 32,
336                const_plane_size: 32,
337                register_layout_a: MatrixLayout::RowMajor,
338                register_layout_b: MatrixLayout::ColMajor,
339                register_layout_acc: MatrixLayout::RowMajor,
340                register_duplication_a: 1,
341                register_duplication_b: 1,
342                register_duplication_acc: 1,
343                contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
344            },
345        }
346    }
347}