cubecl_cuda/
runtime.rs

1use crate::{
2    WmmaCompiler,
3    compute::{CudaContext, CudaServer, CudaStorage},
4    device::CudaDevice,
5};
6use cubecl_common::profile::TimingMethod;
7use cubecl_core::{
8    AtomicFeature, CubeCount, CubeDim, Feature, MemoryConfiguration, Runtime, TmaFeature,
9    ir::{Elem, FloatKind, IntKind, UIntKind},
10};
11use cubecl_cpp::{
12    DialectWmmaCompiler,
13    cuda::{CudaDialect, arch::CudaArchitecture},
14    register_supported_types,
15    shared::{CompilationOptions, CppCompiler, register_wmma_features},
16};
17use cubecl_runtime::{
18    ComputeRuntime, DeviceProperties,
19    channel::MutexComputeChannel,
20    client::ComputeClient,
21    id::DeviceId,
22    memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
23};
24use cudarc::driver::sys::cuDeviceTotalMem_v2;
25use std::mem::MaybeUninit;
26
27/// Options configuring the CUDA runtime.
28#[derive(Default)]
29pub struct RuntimeOptions {
30    /// Configures the memory management.
31    pub memory_config: MemoryConfiguration,
32}
33
34#[derive(Debug)]
35pub struct CudaRuntime;
36
37type Server = CudaServer;
38type Channel = MutexComputeChannel<Server>;
39
40static RUNTIME: ComputeRuntime<CudaDevice, Server, Channel> = ComputeRuntime::new();
41
42pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
43
44fn create_client<M: DialectWmmaCompiler<CudaDialect<M>>>(
45    device: &CudaDevice,
46    options: RuntimeOptions,
47) -> ComputeClient<Server, Channel> {
48    // To get the supported WMMA features, and memory properties, we have to initialize the server immediately.
49    cudarc::driver::result::init().unwrap();
50    let device_ptr = cudarc::driver::result::device::get(device.index as i32).unwrap();
51    let arch_major;
52    let arch_version = unsafe {
53        arch_major = cudarc::driver::result::device::get_attribute(
54            device_ptr,
55            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
56        )
57        .unwrap();
58        let minor = cudarc::driver::result::device::get_attribute(
59            device_ptr,
60            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
61        )
62        .unwrap();
63        arch_major * 10 + minor
64    } as u32;
65    // 32 bytes is enough to handle a double4 worth of alignment.
66    // NB: cudamalloc and co. actually align to _256_ bytes. Worth
67    // trying this in the future to see if it reduces memory coalescing.
68    //
69    // TODO: Find the correct value from the driver.
70    let mem_alignment = 32;
71
72    // Ask the wmma compiler for its supported combinations
73    let arch = CudaArchitecture {
74        version: arch_version,
75    };
76    let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
77
78    let ctx = unsafe {
79        let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
80        cudarc::driver::result::ctx::set_current(ctx).unwrap();
81        ctx
82    };
83
84    let stream = cudarc::driver::result::stream::create(
85        cudarc::driver::result::stream::StreamKind::NonBlocking,
86    )
87    .unwrap();
88    let max_memory = unsafe {
89        let mut bytes = MaybeUninit::uninit();
90        cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
91        bytes.assume_init() as u64
92    };
93    let storage = CudaStorage::new(mem_alignment, stream);
94    let mem_properties = MemoryDeviceProperties {
95        max_page_size: max_memory / 4,
96        alignment: mem_alignment as u64,
97    };
98
99    let mut comp_opts = CompilationOptions::default();
100
101    let hardware_props = unsafe {
102        use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
103        let warp_size = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
104        let max_shared = get_attribute(
105            device_ptr,
106            CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
107        )
108        .unwrap() as usize;
109        let max_threads =
110            get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK).unwrap() as u32;
111        let block_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
112        let block_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
113        let block_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
114        let max_cube_dim =
115            CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
116
117        let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
118        let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
119        let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
120        let max_cube_count =
121            CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
122
123        let num_streaming_multiprocessors = Some(
124            get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
125        );
126        let num_tensor_cores = tensor_cores_per_sm(arch_version);
127
128        comp_opts.warp_size = warp_size;
129
130        HardwareProperties {
131            plane_size_min: warp_size,
132            plane_size_max: warp_size,
133            max_bindings: crate::device::CUDA_MAX_BINDINGS,
134            max_shared_memory_size: max_shared,
135            max_cube_count,
136            max_units_per_cube: max_threads,
137            max_cube_dim,
138            num_streaming_multiprocessors,
139            num_tensor_cores,
140            min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
141                None
142            } else {
143                Some(8)
144            },
145        }
146    };
147
148    let memory_management =
149        MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
150
151    let mut device_props = DeviceProperties::new(
152        &[Feature::Plane],
153        mem_properties,
154        hardware_props,
155        TimingMethod::System,
156    );
157    register_supported_types(&mut device_props);
158    device_props.register_feature(Feature::Type(Elem::Float(FloatKind::TF32)));
159    if arch_version >= 60 {
160        device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F64)));
161    }
162    if arch_version >= 70 {
163        device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
164        device_props.register_feature(Feature::Pipeline);
165        device_props.register_feature(Feature::Barrier);
166        device_props.register_feature(Feature::SyncPlane);
167
168        comp_opts.grid_constants = true;
169    }
170
171    // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16.
172    // if arch.get_version() >= 80 {
173    //     device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
174    // }
175
176    if arch_version >= 89 {
177        device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E4M3)));
178        device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E5M2)));
179    }
180    if arch_version >= 90 {
181        device_props.register_feature(Feature::Tma(TmaFeature::Base));
182        device_props.register_feature(Feature::CubeCluster);
183        comp_opts.supports_clusters = true;
184    }
185
186    if arch_version >= 100 {
187        device_props.register_feature(Feature::Tma(TmaFeature::Im2colWide));
188    }
189
190    // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a
191    // major version. Try to keep this up to date with new arch major revisions if they also
192    // implement it.
193    if arch_major == 10 || arch_major == 12 {
194        device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E2M1)));
195        device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E2M3)));
196        device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E3M2)));
197        device_props.register_feature(Feature::Type(Elem::Float(FloatKind::UE8M0)));
198    }
199
200    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
201    device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
202
203    // Supported by all architectures
204    device_props.register_feature(Feature::Type(Elem::AtomicInt(IntKind::I32)));
205    device_props.register_feature(Feature::Type(Elem::AtomicUInt(UIntKind::U32)));
206    device_props.register_feature(Feature::AtomicInt(AtomicFeature::LoadStore));
207    device_props.register_feature(Feature::AtomicInt(AtomicFeature::Add));
208    device_props.register_feature(Feature::AtomicUInt(AtomicFeature::LoadStore));
209    device_props.register_feature(Feature::AtomicUInt(AtomicFeature::Add));
210
211    device_props.register_feature(Feature::DynamicLineSize);
212
213    register_wmma_features(supported_wmma_combinations, &mut device_props);
214
215    let cuda_ctx = CudaContext::new(memory_management, comp_opts, stream, ctx, arch);
216    let server = CudaServer::new(mem_alignment, cuda_ctx);
217    ComputeClient::new(MutexComputeChannel::new(server), device_props, ())
218}
219
220fn tensor_cores_per_sm(version: u32) -> Option<u32> {
221    match version {
222        70 | 75 => Some(8),                           // Volta, Turing
223        80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), // Ampere, Hopper, Blackwell
224        _ => None,                                    // Unknown or unsupported architecture
225    }
226}
227
228impl Runtime for CudaRuntime {
229    type Compiler = CudaCompiler;
230    type Server = CudaServer;
231
232    type Channel = MutexComputeChannel<CudaServer>;
233    type Device = CudaDevice;
234
235    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
236        RUNTIME.client(device, move || {
237            create_client::<WmmaCompiler>(device, RuntimeOptions::default())
238        })
239    }
240
241    fn device_id(device: &Self::Device) -> DeviceId {
242        DeviceId::new(0, device.index as u32)
243    }
244
245    fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
246        "cuda"
247    }
248
249    fn require_array_lengths() -> bool {
250        true
251    }
252
253    fn supported_line_sizes() -> &'static [u8] {
254        &[8, 4, 2, 1]
255    }
256
257    fn max_cube_count() -> (u32, u32, u32) {
258        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
259    }
260
261    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
262        let rank = shape.len();
263        if strides[rank - 1] != 1 {
264            return false;
265        }
266        if rank <= 1 {
267            return true;
268        }
269
270        let mut sorted = strides.to_vec();
271        sorted.sort();
272        sorted.reverse();
273
274        if sorted != strides {
275            return false;
276        }
277
278        for i in 0..rank - 2 {
279            if strides[i] != shape[i + 1] * strides[i + 1] {
280                return false;
281            }
282        }
283        true
284    }
285}