use crate::{
WmmaCompiler,
compute::{CudaServer, context::CudaContext},
device::CudaDevice,
};
use cubecl_common::{
device::{Device, DeviceService},
profile::TimingMethod,
};
use cubecl_core::{
MemoryConfiguration, Runtime,
device::{DeviceId, ServerUtilitiesHandle},
ir::{
BarrierLevel, ContiguousElements, DeviceProperties, ElemType, FloatKind,
HardwareProperties, MatrixLayout, MemoryDeviceProperties, MmaProperties, OpaqueType,
SemanticType, StorageType, TargetProperties, Type, VectorSize,
features::{AtomicUsage, Plane, Tma, TypeUsage},
},
server::ServerUtilities,
zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
};
use cubecl_cpp::{
ComputeKernel, DialectWmmaCompiler,
cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
register_supported_types,
shared::{
CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
register_scaled_mma_features, register_wmma_features,
},
};
use cubecl_runtime::{
allocator::PitchedMemoryLayoutPolicy, client::ComputeClient, logging::ServerLogger,
};
use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
use std::{mem::MaybeUninit, sync::Arc};
#[derive(Default)]
pub struct RuntimeOptions {
pub memory_config: MemoryConfiguration,
}
#[derive(Debug, Clone)]
pub struct CudaRuntime;
impl DeviceService for CudaServer {
fn init(device_id: cubecl_common::device::DeviceId) -> Self {
let options = RuntimeOptions::default();
let device = CudaDevice::from_id(device_id);
cudarc::driver::result::init().unwrap();
let device_index = device.index as i32;
let device_ptr = cudarc::driver::result::device::get(device_index).unwrap();
let arch_major;
let arch_version = unsafe {
arch_major = cudarc::driver::result::device::get_attribute(
device_ptr,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
)
.unwrap();
let minor = cudarc::driver::result::device::get_attribute(
device_ptr,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
)
.unwrap();
arch_major * 10 + minor
} as u32;
let mem_alignment = 512;
let arch = CudaArchitecture {
version: arch_version,
};
let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
let supported_scaled_mma_combinations =
WmmaCompiler::supported_scaled_mma_combinations(&arch);
let ctx = unsafe {
let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
cudarc::driver::result::ctx::set_current(ctx).unwrap();
ctx
};
let max_memory = unsafe {
let mut bytes = MaybeUninit::uninit();
cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
bytes.assume_init() as u64
};
let mem_properties = MemoryDeviceProperties {
max_page_size: max_memory / 4,
alignment: mem_alignment as u64,
};
let mut comp_opts = CompilationOptions {
supports_features: CppSupportedFeatures {
fast_math: true,
..Default::default()
},
..Default::default()
};
let hardware_props = unsafe {
use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
let warp_size =
get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
let max_shared = get_attribute(
device_ptr,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
)
.unwrap() as usize;
let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
.unwrap() as u32;
let block_dim_x =
get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
let block_dim_y =
get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
let block_dim_z =
get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
let max_cube_dim = (block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);
let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
let max_cube_count = (grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);
let num_streaming_multiprocessors = Some(
get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
);
let num_tensor_cores = tensor_cores_per_sm(arch_version);
comp_opts.warp_size = warp_size;
HardwareProperties {
load_width: 128,
plane_size_min: warp_size,
plane_size_max: warp_size,
max_bindings: crate::device::CUDA_MAX_BINDINGS,
max_shared_memory_size: max_shared,
max_cube_count,
max_units_per_cube: max_threads,
max_cube_dim,
num_streaming_multiprocessors,
num_tensor_cores,
min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
None
} else {
Some(8)
},
num_cpu_cores: None,
max_vector_size: VectorSize::MAX,
}
};
let mut device_props = DeviceProperties::new(
Default::default(),
mem_properties.clone(),
hardware_props,
TimingMethod::System,
);
register_supported_types(&mut device_props);
device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
if arch_version >= 60 {
device_props.register_atomic_type_usage(
Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F64))),
AtomicUsage::Add | AtomicUsage::LoadStore,
);
}
if arch_version >= 70 {
device_props.register_atomic_type_usage(
Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))),
AtomicUsage::Add | AtomicUsage::LoadStore,
);
device_props.register_atomic_type_usage(
Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))).with_vector_size(2),
AtomicUsage::Add | AtomicUsage::LoadStore,
);
device_props.register_semantic_type(SemanticType::Pipeline);
device_props
.register_type_usage(OpaqueType::Barrier(BarrierLevel::Unit), TypeUsage::Buffer);
device_props
.register_type_usage(OpaqueType::Barrier(BarrierLevel::Cube), TypeUsage::Buffer);
device_props.features.plane.insert(Plane::Sync);
comp_opts.supports_features.grid_constants = true;
}
if arch_version >= 75 {
device_props
.features
.matmul
.ldmatrix
.insert(ElemType::Float(FloatKind::F16).into());
device_props
.features
.matmul
.ldmatrix
.insert(ElemType::Float(FloatKind::BF16).into());
comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
}
if arch_version >= 80 {
device_props.features.copy_async = true;
}
if arch_version >= 89 {
device_props.register_type_usage(
ElemType::Float(FloatKind::E4M3),
TypeUsage::Conversion | TypeUsage::Buffer,
);
device_props.register_type_usage(
ElemType::Float(FloatKind::E5M2),
TypeUsage::Conversion | TypeUsage::Buffer,
);
}
if arch_version >= 90 {
device_props.features.tma.insert(Tma::Base);
device_props.register_semantic_type(SemanticType::TensorMap);
device_props.features.cube_cluster = true;
comp_opts.supports_features.clusters = true;
comp_opts.supports_features.elect_sync = true;
device_props
.features
.matmul
.stmatrix
.insert(ElemType::Float(FloatKind::F16).into());
device_props
.features
.matmul
.stmatrix
.insert(ElemType::Float(FloatKind::BF16).into());
device_props.register_atomic_type_usage(
Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(2),
AtomicUsage::LoadStore | AtomicUsage::Add,
);
device_props.register_atomic_type_usage(
Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(4),
AtomicUsage::LoadStore | AtomicUsage::Add,
);
}
if arch_version >= 100 {
device_props.features.tma.insert(Tma::Im2colWide);
}
if arch_major == 10 || arch_major == 11 || arch_major == 12 {
device_props
.register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
device_props.register_type_usage(
StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
TypeUsage::Conversion | TypeUsage::Buffer,
);
device_props.register_type_usage(
ElemType::Float(FloatKind::E2M3),
TypeUsage::Conversion | TypeUsage::Buffer,
);
device_props.register_type_usage(
ElemType::Float(FloatKind::E3M2),
TypeUsage::Conversion | TypeUsage::Buffer,
);
device_props.register_type_usage(
ElemType::Float(FloatKind::UE8M0),
TypeUsage::Conversion | TypeUsage::Buffer,
);
if CUDA_VERSION >= 12080 {
device_props.features.tma.insert(Tma::SwizzleAtomicity);
}
}
device_props.features.memory_reinterpret = true;
device_props.features.alignment = true;
device_props.features.plane.insert(Plane::Ops);
device_props
.features
.plane
.insert(Plane::NonUniformControlFlow);
register_wmma_features(supported_wmma_combinations, &mut device_props);
register_mma_features(supported_mma_combinations, &mut device_props);
register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);
let cuda_ctx = CudaContext::new(comp_opts, device_props.clone(), ctx, arch);
let logger = Arc::new(ServerLogger::default());
let policy = PitchedMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
let utilities = ServerUtilities::new(device_props, logger, (), policy);
CudaServer::new(
cuda_ctx,
mem_properties,
options.memory_config,
mem_alignment,
device_id,
utilities,
)
}
fn utilities(&self) -> ServerUtilitiesHandle {
self.utilities() as ServerUtilitiesHandle
}
}
pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
pub type CudaComputeKernel = ComputeKernel<CudaDialect<WmmaCompiler>>;
fn tensor_cores_per_sm(version: u32) -> Option<u32> {
match version {
70 | 75 => Some(8), 80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), _ => None, }
}
impl Runtime for CudaRuntime {
type Compiler = CudaCompiler;
type Server = CudaServer;
type Device = CudaDevice;
fn client(device: &Self::Device) -> ComputeClient<Self> {
ComputeClient::load(device)
}
fn name(_client: &ComputeClient<Self>) -> &'static str {
"cuda"
}
fn require_array_lengths() -> bool {
true
}
fn max_cube_count() -> (u32, u32, u32) {
(i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
}
fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
has_pitched_row_major_strides(shape, strides)
}
fn target_properties() -> TargetProperties {
TargetProperties {
mma: MmaProperties {
register_size_bits: 32,
const_plane_size: 32,
register_layout_a: MatrixLayout::RowMajor,
register_layout_b: MatrixLayout::ColMajor,
register_layout_acc: MatrixLayout::RowMajor,
register_duplication_a: 1,
register_duplication_b: 1,
register_duplication_acc: 1,
contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
},
}
}
fn enumerate_devices(
_: u16,
_: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
) -> Vec<cubecl_core::device::DeviceId> {
let count = cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize;
(0..count)
.map(|i| DeviceId {
type_id: 0,
index_id: i as u16,
})
.collect()
}
}