cubecl_cuda/
runtime.rs

1use crate::{
2    WmmaCompiler,
3    compute::{CudaServer, context::CudaContext, valid_strides},
4    device::CudaDevice,
5};
6use cubecl_common::{
7    device::{Device, DeviceState},
8    profile::TimingMethod,
9};
10use cubecl_core::{
11    CubeCount, CubeDim, MemoryConfiguration, Runtime,
12    ir::{
13        ElemType, FloatKind, MatrixLayout, MmaProperties, SemanticType, StorageType,
14        TargetProperties,
15    },
16    server::ServerUtilities,
17};
18use cubecl_cpp::{
19    DialectWmmaCompiler,
20    cuda::{CudaDialect, arch::CudaArchitecture},
21    register_supported_types,
22    shared::{
23        CompilationOptions, CppCompiler, register_mma_features, register_scaled_mma_features,
24        register_wmma_features,
25    },
26};
27use cubecl_runtime::{
28    DeviceProperties, Plane, Tma, TypeUsage,
29    client::ComputeClient,
30    logging::ServerLogger,
31    memory_management::{HardwareProperties, MemoryDeviceProperties},
32};
33use cudarc::driver::sys::cuDeviceTotalMem_v2;
34use std::{mem::MaybeUninit, sync::Arc};
35
36/// Options configuring the CUDA runtime.
37#[derive(Default)]
38pub struct RuntimeOptions {
39    /// Configures the memory management.
40    pub memory_config: MemoryConfiguration,
41}
42
43#[derive(Debug)]
44pub struct CudaRuntime;
45
46impl DeviceState for CudaServer {
47    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
48        let options = RuntimeOptions::default();
49        let device = CudaDevice::from_id(device_id);
50
51        // To get the supported WMMA features, and memory properties, we have to initialize the server immediately.
52        cudarc::driver::result::init().unwrap();
53        let device_id = device.index as i32;
54        let device_ptr = cudarc::driver::result::device::get(device_id).unwrap();
55        let arch_major;
56        let arch_version = unsafe {
57            arch_major = cudarc::driver::result::device::get_attribute(
58            device_ptr,
59            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
60        )
61        .unwrap();
62            let minor = cudarc::driver::result::device::get_attribute(
63            device_ptr,
64            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
65        )
66        .unwrap();
67            arch_major * 10 + minor
68        } as u32;
69
70        // cudamalloc and co. align to _256_ bytes.
71        //
72        // TODO: Find the correct value from the driver.
73        let mem_alignment = 256;
74
75        // Ask the wmma compiler for its supported combinations
76        let arch = CudaArchitecture {
77            version: arch_version,
78        };
79        let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
80        let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
81        let supported_scaled_mma_combinations =
82            WmmaCompiler::supported_scaled_mma_combinations(&arch);
83
84        let ctx = unsafe {
85            let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
86            cudarc::driver::result::ctx::set_current(ctx).unwrap();
87            ctx
88        };
89
90        let max_memory = unsafe {
91            let mut bytes = MaybeUninit::uninit();
92            cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
93            bytes.assume_init() as u64
94        };
95        let mem_properties = MemoryDeviceProperties {
96            max_page_size: max_memory / 4,
97            alignment: mem_alignment as u64,
98        };
99
100        let mut comp_opts = CompilationOptions::default();
101
102        let hardware_props = unsafe {
103            use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
104            let warp_size =
105                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
106            let max_shared = get_attribute(
107                device_ptr,
108                CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
109            )
110            .unwrap() as usize;
111            let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
112                .unwrap() as u32;
113            let block_dim_x =
114                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
115            let block_dim_y =
116                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
117            let block_dim_z =
118                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
119            let max_cube_dim =
120                CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
121
122            let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
123            let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
124            let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
125            let max_cube_count =
126                CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
127
128            let num_streaming_multiprocessors = Some(
129                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
130            );
131            let num_tensor_cores = tensor_cores_per_sm(arch_version);
132
133            comp_opts.warp_size = warp_size;
134
135            HardwareProperties {
136                plane_size_min: warp_size,
137                plane_size_max: warp_size,
138                max_bindings: crate::device::CUDA_MAX_BINDINGS,
139                max_shared_memory_size: max_shared,
140                max_cube_count,
141                max_units_per_cube: max_threads,
142                max_cube_dim,
143                num_streaming_multiprocessors,
144                num_tensor_cores,
145                min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
146                    None
147                } else {
148                    Some(8)
149                },
150            }
151        };
152
153        let mut device_props = DeviceProperties::new(
154            Default::default(),
155            mem_properties.clone(),
156            hardware_props,
157            TimingMethod::System,
158        );
159        register_supported_types(&mut device_props);
160        device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
161        if arch_version >= 60 {
162            device_props.register_type_usage(
163                StorageType::Atomic(ElemType::Float(FloatKind::F64)),
164                TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
165            );
166        }
167        if arch_version >= 70 {
168            device_props.register_type_usage(
169                StorageType::Atomic(ElemType::Float(FloatKind::F16)),
170                TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore,
171            );
172            device_props.register_semantic_type(SemanticType::Pipeline);
173            device_props.register_semantic_type(SemanticType::Barrier);
174            device_props.features.plane.insert(Plane::Sync);
175
176            comp_opts.grid_constants = true;
177        }
178
179        // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16.
180        // if arch.get_version() >= 80 {
181        //     device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
182        // }
183
184        if arch_version >= 89 {
185            device_props.register_type_usage(
186                ElemType::Float(FloatKind::E4M3),
187                TypeUsage::Conversion | TypeUsage::Buffer,
188            );
189            device_props.register_type_usage(
190                ElemType::Float(FloatKind::E5M2),
191                TypeUsage::Conversion | TypeUsage::Buffer,
192            );
193        }
194        if arch_version >= 90 {
195            device_props.features.tma.insert(Tma::Base);
196            device_props.register_semantic_type(SemanticType::TensorMap);
197            device_props.features.cube_cluster = true;
198            comp_opts.supports_clusters = true;
199        }
200
201        if arch_version >= 100 {
202            device_props.features.tma.insert(Tma::Im2colWide);
203        }
204
205        // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a
206        // major version. Try to keep this up to date with new arch major revisions if they also
207        // implement it.
208        if arch_major == 10 || arch_major == 12 {
209            device_props
210                .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
211            device_props.register_type_usage(
212                StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
213                TypeUsage::Conversion | TypeUsage::Buffer,
214            );
215            device_props.register_type_usage(
216                ElemType::Float(FloatKind::E2M3),
217                TypeUsage::Conversion | TypeUsage::Buffer,
218            );
219            device_props.register_type_usage(
220                ElemType::Float(FloatKind::E3M2),
221                TypeUsage::Conversion | TypeUsage::Buffer,
222            );
223            device_props.register_type_usage(
224                ElemType::Float(FloatKind::UE8M0),
225                TypeUsage::Conversion | TypeUsage::Buffer,
226            );
227        }
228
229        device_props.features.dynamic_line_size = true;
230        device_props.features.plane.insert(Plane::Ops);
231
232        register_wmma_features(supported_wmma_combinations, &mut device_props);
233        register_mma_features(supported_mma_combinations, &mut device_props);
234        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
235
236        let cuda_ctx = CudaContext::new(comp_opts, ctx, arch);
237        let logger = Arc::new(ServerLogger::default());
238        let utilities = ServerUtilities::new(device_props, logger, ());
239
240        CudaServer::new(
241            cuda_ctx,
242            mem_properties,
243            options.memory_config,
244            mem_alignment,
245            device_id,
246            utilities,
247        )
248    }
249}
250
251pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
252
253fn tensor_cores_per_sm(version: u32) -> Option<u32> {
254    match version {
255        70 | 75 => Some(8),                           // Volta, Turing
256        80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), // Ampere, Hopper, Blackwell
257        _ => None,                                    // Unknown or unsupported architecture
258    }
259}
260
261impl Runtime for CudaRuntime {
262    type Compiler = CudaCompiler;
263    type Server = CudaServer;
264    type Device = CudaDevice;
265
266    fn client(device: &Self::Device) -> ComputeClient<Self::Server> {
267        ComputeClient::load(device)
268    }
269
270    fn name(_client: &ComputeClient<Self::Server>) -> &'static str {
271        "cuda"
272    }
273
274    fn require_array_lengths() -> bool {
275        true
276    }
277
278    fn supported_line_sizes() -> &'static [u8] {
279        &[16, 8, 4, 2, 1]
280    }
281
282    fn max_cube_count() -> (u32, u32, u32) {
283        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
284    }
285
286    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
287        valid_strides(shape, strides)
288    }
289
290    fn target_properties() -> TargetProperties {
291        TargetProperties {
292            mma: MmaProperties {
293                register_size_bits: 32,
294                const_plane_size: 32,
295                register_layout_a: MatrixLayout::RowMajor,
296                register_layout_b: MatrixLayout::ColMajor,
297                register_layout_acc: MatrixLayout::RowMajor,
298                register_duplication_a: 1,
299                register_duplication_b: 1,
300                register_duplication_acc: 1,
301            },
302        }
303    }
304}