1use crate::{
2 WmmaCompiler,
3 compute::{CudaContext, CudaServer, CudaStorage},
4 device::CudaDevice,
5};
6use cubecl_common::profile::TimingMethod;
7use cubecl_core::{
8 AtomicFeature, CubeCount, CubeDim, Feature, MemoryConfiguration, Runtime, TmaFeature,
9 ir::{Elem, FloatKind, IntKind, UIntKind},
10};
11use cubecl_cpp::{
12 DialectWmmaCompiler,
13 cuda::{CudaDialect, arch::CudaArchitecture},
14 register_supported_types,
15 shared::{CompilationOptions, CppCompiler, register_wmma_features},
16};
17use cubecl_runtime::{
18 ComputeRuntime, DeviceProperties,
19 channel::MutexComputeChannel,
20 client::ComputeClient,
21 id::DeviceId,
22 memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
23};
24use cudarc::driver::sys::cuDeviceTotalMem_v2;
25use std::mem::MaybeUninit;
26
27#[derive(Default)]
29pub struct RuntimeOptions {
30 pub memory_config: MemoryConfiguration,
32}
33
34#[derive(Debug)]
35pub struct CudaRuntime;
36
37type Server = CudaServer;
38type Channel = MutexComputeChannel<Server>;
39
40static RUNTIME: ComputeRuntime<CudaDevice, Server, Channel> = ComputeRuntime::new();
41
42pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
43
44fn create_client<M: DialectWmmaCompiler<CudaDialect<M>>>(
45 device: &CudaDevice,
46 options: RuntimeOptions,
47) -> ComputeClient<Server, Channel> {
48 cudarc::driver::result::init().unwrap();
50 let device_ptr = cudarc::driver::result::device::get(device.index as i32).unwrap();
51 let arch_major;
52 let arch_version = unsafe {
53 arch_major = cudarc::driver::result::device::get_attribute(
54 device_ptr,
55 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
56 )
57 .unwrap();
58 let minor = cudarc::driver::result::device::get_attribute(
59 device_ptr,
60 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
61 )
62 .unwrap();
63 arch_major * 10 + minor
64 } as u32;
65 let mem_alignment = 32;
71
72 let arch = CudaArchitecture {
74 version: arch_version,
75 };
76 let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
77
78 let ctx = unsafe {
79 let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
80 cudarc::driver::result::ctx::set_current(ctx).unwrap();
81 ctx
82 };
83
84 let stream = cudarc::driver::result::stream::create(
85 cudarc::driver::result::stream::StreamKind::NonBlocking,
86 )
87 .unwrap();
88 let max_memory = unsafe {
89 let mut bytes = MaybeUninit::uninit();
90 cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
91 bytes.assume_init() as u64
92 };
93 let storage = CudaStorage::new(mem_alignment, stream);
94 let mem_properties = MemoryDeviceProperties {
95 max_page_size: max_memory / 4,
96 alignment: mem_alignment as u64,
97 };
98
99 let mut comp_opts = CompilationOptions::default();
100
101 let hardware_props = unsafe {
102 use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
103 let warp_size = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
104 let max_shared = get_attribute(
105 device_ptr,
106 CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
107 )
108 .unwrap() as usize;
109 let max_threads =
110 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK).unwrap() as u32;
111 let block_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
112 let block_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
113 let block_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
114 let max_cube_dim =
115 CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
116
117 let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
118 let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
119 let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
120 let max_cube_count =
121 CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
122
123 let num_streaming_multiprocessors = Some(
124 get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
125 );
126 let num_tensor_cores = tensor_cores_per_sm(arch_version);
127
128 comp_opts.warp_size = warp_size;
129
130 HardwareProperties {
131 plane_size_min: warp_size,
132 plane_size_max: warp_size,
133 max_bindings: crate::device::CUDA_MAX_BINDINGS,
134 max_shared_memory_size: max_shared,
135 max_cube_count,
136 max_units_per_cube: max_threads,
137 max_cube_dim,
138 num_streaming_multiprocessors,
139 num_tensor_cores,
140 min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
141 None
142 } else {
143 Some(8)
144 },
145 }
146 };
147
148 let memory_management =
149 MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
150
151 let mut device_props = DeviceProperties::new(
152 &[Feature::Plane],
153 mem_properties,
154 hardware_props,
155 TimingMethod::System,
156 );
157 register_supported_types(&mut device_props);
158 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::TF32)));
159 if arch_version >= 60 {
160 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F64)));
161 }
162 if arch_version >= 70 {
163 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16)));
164 device_props.register_feature(Feature::Pipeline);
165 device_props.register_feature(Feature::Barrier);
166 device_props.register_feature(Feature::SyncPlane);
167
168 comp_opts.grid_constants = true;
169 }
170
171 if arch_version >= 89 {
177 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E4M3)));
178 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E5M2)));
179 }
180 if arch_version >= 90 {
181 device_props.register_feature(Feature::Tma(TmaFeature::Base));
182 device_props.register_feature(Feature::CubeCluster);
183 comp_opts.supports_clusters = true;
184 }
185
186 if arch_version >= 100 {
187 device_props.register_feature(Feature::Tma(TmaFeature::Im2colWide));
188 }
189
190 if arch_major == 10 || arch_major == 12 {
194 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E2M1)));
195 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E2M3)));
196 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::E3M2)));
197 device_props.register_feature(Feature::Type(Elem::Float(FloatKind::UE8M0)));
198 }
199
200 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
201 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
202
203 device_props.register_feature(Feature::Type(Elem::AtomicInt(IntKind::I32)));
205 device_props.register_feature(Feature::Type(Elem::AtomicUInt(UIntKind::U32)));
206 device_props.register_feature(Feature::AtomicInt(AtomicFeature::LoadStore));
207 device_props.register_feature(Feature::AtomicInt(AtomicFeature::Add));
208 device_props.register_feature(Feature::AtomicUInt(AtomicFeature::LoadStore));
209 device_props.register_feature(Feature::AtomicUInt(AtomicFeature::Add));
210
211 device_props.register_feature(Feature::DynamicLineSize);
212
213 register_wmma_features(supported_wmma_combinations, &mut device_props);
214
215 let cuda_ctx = CudaContext::new(memory_management, comp_opts, stream, ctx, arch);
216 let server = CudaServer::new(mem_alignment, cuda_ctx);
217 ComputeClient::new(MutexComputeChannel::new(server), device_props, ())
218}
219
220fn tensor_cores_per_sm(version: u32) -> Option<u32> {
221 match version {
222 70 | 75 => Some(8), 80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), _ => None, }
226}
227
228impl Runtime for CudaRuntime {
229 type Compiler = CudaCompiler;
230 type Server = CudaServer;
231
232 type Channel = MutexComputeChannel<CudaServer>;
233 type Device = CudaDevice;
234
235 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
236 RUNTIME.client(device, move || {
237 create_client::<WmmaCompiler>(device, RuntimeOptions::default())
238 })
239 }
240
241 fn device_id(device: &Self::Device) -> DeviceId {
242 DeviceId::new(0, device.index as u32)
243 }
244
245 fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
246 "cuda"
247 }
248
249 fn require_array_lengths() -> bool {
250 true
251 }
252
253 fn supported_line_sizes() -> &'static [u8] {
254 &[8, 4, 2, 1]
255 }
256
257 fn max_cube_count() -> (u32, u32, u32) {
258 (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
259 }
260
261 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
262 let rank = shape.len();
263 if strides[rank - 1] != 1 {
264 return false;
265 }
266 if rank <= 1 {
267 return true;
268 }
269
270 let mut sorted = strides.to_vec();
271 sorted.sort();
272 sorted.reverse();
273
274 if sorted != strides {
275 return false;
276 }
277
278 for i in 0..rank - 2 {
279 if strides[i] != shape[i + 1] * strides[i + 1] {
280 return false;
281 }
282 }
283 true
284 }
285}