cubecl_cuda/
runtime.rs

1use crate::{
2    WmmaCompiler,
3    compute::{CudaServer, context::CudaContext, valid_strides},
4    device::CudaDevice,
5};
6use cubecl_common::{
7    device::{Device, DeviceState},
8    profile::TimingMethod,
9};
10use cubecl_core::{
11    CubeCount, CubeDim, MemoryConfiguration, Runtime,
12    ir::{
13        ContiguousElements, ElemType, FloatKind, MatrixLayout, MmaProperties, SemanticType,
14        StorageType, TargetProperties,
15    },
16    server::ServerUtilities,
17};
18use cubecl_cpp::{
19    DialectWmmaCompiler,
20    cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
21    register_supported_types,
22    shared::{
23        CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
24        register_scaled_mma_features, register_wmma_features,
25    },
26};
27use cubecl_runtime::{
28    DeviceProperties, Plane, Tma, TypeUsage,
29    client::ComputeClient,
30    logging::ServerLogger,
31    memory_management::{HardwareProperties, MemoryDeviceProperties},
32};
33use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
34use std::{mem::MaybeUninit, sync::Arc};
35
36/// Options configuring the CUDA runtime.
37#[derive(Default)]
38pub struct RuntimeOptions {
39    /// Configures the memory management.
40    pub memory_config: MemoryConfiguration,
41}
42
43#[derive(Debug)]
44pub struct CudaRuntime;
45
46impl DeviceState for CudaServer {
47    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
48        let options = RuntimeOptions::default();
49        let device = CudaDevice::from_id(device_id);
50
51        // To get the supported WMMA features, and memory properties, we have to initialize the server immediately.
52        cudarc::driver::result::init().unwrap();
53        let device_id = device.index as i32;
54        let device_ptr = cudarc::driver::result::device::get(device_id).unwrap();
55        let arch_major;
56        let arch_version = unsafe {
57            arch_major = cudarc::driver::result::device::get_attribute(
58            device_ptr,
59            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
60        )
61        .unwrap();
62            let minor = cudarc::driver::result::device::get_attribute(
63            device_ptr,
64            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
65        )
66        .unwrap();
67            arch_major * 10 + minor
68        } as u32;
69
70        // This is the alignment returned by `cuMallocPitched`, so it's the one considered optimal
71        // for row alignment by CUDA. This hasn't changed since at least the GTX 700 series.
72        // Querying texture row align is a heuristic, but also not guaranteed to be the same.
73        let mem_alignment = 512;
74
75        // Ask the wmma compiler for its supported combinations
76        let arch = CudaArchitecture {
77            version: arch_version,
78        };
79        let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
80        let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
81        let supported_scaled_mma_combinations =
82            WmmaCompiler::supported_scaled_mma_combinations(&arch);
83
84        let ctx = unsafe {
85            let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
86            cudarc::driver::result::ctx::set_current(ctx).unwrap();
87            ctx
88        };
89
90        let max_memory = unsafe {
91            let mut bytes = MaybeUninit::uninit();
92            cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
93            bytes.assume_init() as u64
94        };
95        let mem_properties = MemoryDeviceProperties {
96            max_page_size: max_memory / 4,
97            alignment: mem_alignment as u64,
98        };
99
100        let mut comp_opts = CompilationOptions {
101            supports_features: CppSupportedFeatures {
102                fast_math: true,
103                ..Default::default()
104            },
105            ..Default::default()
106        };
107
108        let hardware_props = unsafe {
109            use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
110            let warp_size =
111                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
112            let max_shared = get_attribute(
113                device_ptr,
114                CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
115            )
116            .unwrap() as usize;
117            let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
118                .unwrap() as u32;
119            let block_dim_x =
120                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
121            let block_dim_y =
122                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
123            let block_dim_z =
124                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
125            let max_cube_dim =
126                CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
127
128            let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
129            let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
130            let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
131            let max_cube_count =
132                CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
133
134            let num_streaming_multiprocessors = Some(
135                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
136            );
137            let num_tensor_cores = tensor_cores_per_sm(arch_version);
138
139            comp_opts.warp_size = warp_size;
140
141            HardwareProperties {
142                load_width: 128,
143                plane_size_min: warp_size,
144                plane_size_max: warp_size,
145                max_bindings: crate::device::CUDA_MAX_BINDINGS,
146                max_shared_memory_size: max_shared,
147                max_cube_count,
148                max_units_per_cube: max_threads,
149                max_cube_dim,
150                num_streaming_multiprocessors,
151                num_tensor_cores,
152                min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
153                    None
154                } else {
155                    Some(8)
156                },
157            }
158        };
159
160        let mut device_props = DeviceProperties::new(
161            Default::default(),
162            mem_properties.clone(),
163            hardware_props,
164            TimingMethod::System,
165        );
166        register_supported_types(&mut device_props);
167        device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
168        if arch_version >= 60 {
169            device_props.register_type_usage(
170                StorageType::Atomic(ElemType::Float(FloatKind::F64)),
171                TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
172            );
173        }
174        if arch_version >= 70 {
175            device_props.register_type_usage(
176                StorageType::Atomic(ElemType::Float(FloatKind::F16)),
177                TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
178            );
179            device_props.register_semantic_type(SemanticType::Pipeline);
180            device_props.register_semantic_type(SemanticType::Barrier);
181            device_props.features.plane.insert(Plane::Sync);
182
183            comp_opts.supports_features.grid_constants = true;
184        }
185
186        if arch_version >= 75 {
187            device_props
188                .features
189                .ldmatrix
190                .insert(ElemType::Float(FloatKind::F16).into());
191            device_props
192                .features
193                .ldmatrix
194                .insert(ElemType::Float(FloatKind::BF16).into());
195            comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
196        }
197
198        // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16.
199        // if arch.get_version() >= 80 {
200        //     device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
201        // }
202
203        if arch_version >= 89 {
204            device_props.register_type_usage(
205                ElemType::Float(FloatKind::E4M3),
206                TypeUsage::Conversion | TypeUsage::Buffer,
207            );
208            device_props.register_type_usage(
209                ElemType::Float(FloatKind::E5M2),
210                TypeUsage::Conversion | TypeUsage::Buffer,
211            );
212        }
213        if arch_version >= 90 {
214            device_props.features.tma.insert(Tma::Base);
215            device_props.register_semantic_type(SemanticType::TensorMap);
216            device_props.features.cube_cluster = true;
217            comp_opts.supports_features.clusters = true;
218            comp_opts.supports_features.elect_sync = true;
219            device_props
220                .features
221                .stmatrix
222                .insert(ElemType::Float(FloatKind::F16).into());
223            device_props
224                .features
225                .stmatrix
226                .insert(ElemType::Float(FloatKind::BF16).into());
227        }
228
229        if arch_version >= 100 {
230            device_props.features.tma.insert(Tma::Im2colWide);
231            // Breaks swizzle so disable for now and fix in a PR specifically for this
232            // if CUDA_VERSION >= 12090 {
233            //     device_props.hardware.load_width = 256;
234            // }
235        }
236
237        // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a
238        // major version. Try to keep this up to date with new arch major revisions if they also
239        // implement it.
240        if arch_major == 10 || arch_major == 11 || arch_major == 12 {
241            device_props
242                .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
243            device_props.register_type_usage(
244                StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
245                TypeUsage::Conversion | TypeUsage::Buffer,
246            );
247            device_props.register_type_usage(
248                ElemType::Float(FloatKind::E2M3),
249                TypeUsage::Conversion | TypeUsage::Buffer,
250            );
251            device_props.register_type_usage(
252                ElemType::Float(FloatKind::E3M2),
253                TypeUsage::Conversion | TypeUsage::Buffer,
254            );
255            device_props.register_type_usage(
256                ElemType::Float(FloatKind::UE8M0),
257                TypeUsage::Conversion | TypeUsage::Buffer,
258            );
259
260            if CUDA_VERSION >= 12080 {
261                device_props.features.tma.insert(Tma::SwizzleAtomicity);
262            }
263        }
264
265        device_props.features.dynamic_line_size = true;
266        device_props.features.alignment = true;
267        device_props.features.plane.insert(Plane::Ops);
268
269        register_wmma_features(supported_wmma_combinations, &mut device_props);
270        register_mma_features(supported_mma_combinations, &mut device_props);
271        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
272
273        let cuda_ctx = CudaContext::new(comp_opts, ctx, arch);
274        let logger = Arc::new(ServerLogger::default());
275        let utilities = ServerUtilities::new(device_props, logger, ());
276
277        CudaServer::new(
278            cuda_ctx,
279            mem_properties,
280            options.memory_config,
281            mem_alignment,
282            device_id,
283            utilities,
284        )
285    }
286}
287
288pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
289
290fn tensor_cores_per_sm(version: u32) -> Option<u32> {
291    match version {
292        70 | 75 => Some(8),                           // Volta, Turing
293        80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), // Ampere, Hopper, Blackwell
294        _ => None,                                    // Unknown or unsupported architecture
295    }
296}
297
298impl Runtime for CudaRuntime {
299    type Compiler = CudaCompiler;
300    type Server = CudaServer;
301    type Device = CudaDevice;
302
303    fn client(device: &Self::Device) -> ComputeClient<Self> {
304        ComputeClient::load(device)
305    }
306
307    fn name(_client: &ComputeClient<Self>) -> &'static str {
308        "cuda"
309    }
310
311    fn require_array_lengths() -> bool {
312        true
313    }
314
315    fn supported_line_sizes() -> &'static [u8] {
316        &[16, 8, 4, 2, 1]
317    }
318
319    fn max_cube_count() -> (u32, u32, u32) {
320        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
321    }
322
323    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
324        valid_strides(shape, strides)
325    }
326
327    fn target_properties() -> TargetProperties {
328        TargetProperties {
329            mma: MmaProperties {
330                register_size_bits: 32,
331                const_plane_size: 32,
332                register_layout_a: MatrixLayout::RowMajor,
333                register_layout_b: MatrixLayout::ColMajor,
334                register_layout_acc: MatrixLayout::RowMajor,
335                register_duplication_a: 1,
336                register_duplication_b: 1,
337                register_duplication_acc: 1,
338                contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
339            },
340        }
341    }
342}