cubecl-cuda 0.10.0

CUDA runtime for CubeCL
Documentation
use crate::{
    WmmaCompiler,
    compute::{CudaServer, context::CudaContext},
    device::CudaDevice,
};
use cubecl_common::{
    device::{Device, DeviceService},
    profile::TimingMethod,
};
use cubecl_core::{
    MemoryConfiguration, Runtime,
    device::{DeviceId, ServerUtilitiesHandle},
    ir::{
        BarrierLevel, ContiguousElements, DeviceProperties, ElemType, FloatKind,
        HardwareProperties, MatrixLayout, MemoryDeviceProperties, MmaProperties, OpaqueType,
        SemanticType, StorageType, TargetProperties, Type, VectorSize,
        features::{AtomicUsage, Plane, Tma, TypeUsage},
    },
    server::ServerUtilities,
    zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
};
use cubecl_cpp::{
    ComputeKernel, DialectWmmaCompiler,
    cuda::{CudaDialect, arch::CudaArchitecture, mma::contiguous_elements_cuda},
    register_supported_types,
    shared::{
        CompilationOptions, CppCompiler, CppSupportedFeatures, register_mma_features,
        register_scaled_mma_features, register_wmma_features,
    },
};
use cubecl_runtime::{
    allocator::PitchedMemoryLayoutPolicy, client::ComputeClient, logging::ServerLogger,
};
use cudarc::driver::sys::{CUDA_VERSION, cuDeviceTotalMem_v2};
use std::{mem::MaybeUninit, sync::Arc};

/// Options configuring the CUDA runtime.
#[derive(Default)]
pub struct RuntimeOptions {
    /// Configures the memory management.
    pub memory_config: MemoryConfiguration,
}

#[derive(Debug, Clone)]
pub struct CudaRuntime;

impl DeviceService for CudaServer {
    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
        let options = RuntimeOptions::default();
        let device = CudaDevice::from_id(device_id);

        // To get the supported WMMA features, and memory properties, we have to initialize the server immediately.
        cudarc::driver::result::init().unwrap();
        let device_index = device.index as i32;
        let device_ptr = cudarc::driver::result::device::get(device_index).unwrap();
        let arch_major;
        // SAFETY: Calling CUDA driver FFI to query compute capability attributes.
        // `device_ptr` is a valid device handle obtained from `cudarc::driver::result::device::get`.
        let arch_version = unsafe {
            arch_major = cudarc::driver::result::device::get_attribute(
            device_ptr,
            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
        )
        .unwrap();
            let minor = cudarc::driver::result::device::get_attribute(
            device_ptr,
            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
        )
        .unwrap();
            arch_major * 10 + minor
        } as u32;

        // This is the alignment returned by `cuMallocPitched`, so it's the one considered optimal
        // for row alignment by CUDA. This hasn't changed since at least the GTX 700 series.
        // Querying texture row align is a heuristic, but also not guaranteed to be the same.
        let mem_alignment = 512;

        // Ask the wmma compiler for its supported combinations
        let arch = CudaArchitecture {
            version: arch_version,
        };
        let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch);
        let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch);
        let supported_scaled_mma_combinations =
            WmmaCompiler::supported_scaled_mma_combinations(&arch);

        // SAFETY: `device_ptr` is a valid CUDA device. `primary_ctx::retain` returns the
        // primary context which is then set as current for the calling thread.
        let ctx = unsafe {
            let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
            cudarc::driver::result::ctx::set_current(ctx).unwrap();
            ctx
        };

        // SAFETY: `device_ptr` is valid. `cuDeviceTotalMem_v2` writes the total device memory
        // into the `MaybeUninit`, making `assume_init()` valid on success.
        let max_memory = unsafe {
            let mut bytes = MaybeUninit::uninit();
            cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
            bytes.assume_init() as u64
        };
        let mem_properties = MemoryDeviceProperties {
            max_page_size: max_memory / 4,
            alignment: mem_alignment as u64,
        };

        let mut comp_opts = CompilationOptions {
            supports_features: CppSupportedFeatures {
                fast_math: true,
                ..Default::default()
            },
            ..Default::default()
        };

        // SAFETY: `device_ptr` is a valid CUDA device. All `get_attribute` calls query
        // read-only device properties via the CUDA driver API.
        let hardware_props = unsafe {
            use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*};
            let warp_size =
                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32;
            let max_shared = get_attribute(
                device_ptr,
                CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
            )
            .unwrap() as usize;
            let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
                .unwrap() as u32;
            let block_dim_x =
                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap();
            let block_dim_y =
                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap();
            let block_dim_z =
                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap();
            let max_cube_dim = (block_dim_x as u32, block_dim_y as u32, block_dim_z as u32);

            let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap();
            let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap();
            let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap();
            let max_cube_count = (grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32);

            let num_streaming_multiprocessors = Some(
                get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32,
            );
            let num_tensor_cores = tensor_cores_per_sm(arch_version);

            comp_opts.warp_size = warp_size;

            HardwareProperties {
                load_width: 128,
                plane_size_min: warp_size,
                plane_size_max: warp_size,
                max_bindings: crate::device::CUDA_MAX_BINDINGS,
                max_shared_memory_size: max_shared,
                max_cube_count,
                max_units_per_cube: max_threads,
                max_cube_dim,
                num_streaming_multiprocessors,
                num_tensor_cores,
                min_tensor_cores_dim: if supported_wmma_combinations.is_empty() {
                    None
                } else {
                    Some(8)
                },
                num_cpu_cores: None,
                max_vector_size: VectorSize::MAX,
            }
        };

        let mut device_props = DeviceProperties::new(
            Default::default(),
            mem_properties.clone(),
            hardware_props,
            TimingMethod::System,
        );
        register_supported_types(&mut device_props);
        device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
        if arch_version >= 60 {
            device_props.register_atomic_type_usage(
                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F64))),
                AtomicUsage::Add | AtomicUsage::LoadStore,
            );
        }
        if arch_version >= 70 {
            device_props.register_atomic_type_usage(
                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))),
                AtomicUsage::Add | AtomicUsage::LoadStore,
            );
            device_props.register_atomic_type_usage(
                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F16))).with_vector_size(2),
                AtomicUsage::Add | AtomicUsage::LoadStore,
            );
            device_props.register_semantic_type(SemanticType::Pipeline);
            device_props
                .register_type_usage(OpaqueType::Barrier(BarrierLevel::Unit), TypeUsage::Buffer);
            device_props
                .register_type_usage(OpaqueType::Barrier(BarrierLevel::Cube), TypeUsage::Buffer);
            device_props.features.plane.insert(Plane::Sync);
            comp_opts.supports_features.grid_constants = true;
        }

        if arch_version >= 75 {
            device_props
                .features
                .matmul
                .ldmatrix
                .insert(ElemType::Float(FloatKind::F16).into());
            device_props
                .features
                .matmul
                .ldmatrix
                .insert(ElemType::Float(FloatKind::BF16).into());
            comp_opts.supports_features.fast_tanh = CUDA_VERSION >= 12080;
        }

        if arch_version >= 80 {
            device_props.features.copy_async = true;
        }

        // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16.
        // if arch.get_version() >= 80 {
        //     device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16)));
        // }

        if arch_version >= 89 {
            device_props.register_type_usage(
                ElemType::Float(FloatKind::E4M3),
                TypeUsage::Conversion | TypeUsage::Buffer,
            );
            device_props.register_type_usage(
                ElemType::Float(FloatKind::E5M2),
                TypeUsage::Conversion | TypeUsage::Buffer,
            );
        }
        if arch_version >= 90 {
            device_props.features.tma.insert(Tma::Base);
            device_props.register_semantic_type(SemanticType::TensorMap);
            device_props.features.cube_cluster = true;
            comp_opts.supports_features.clusters = true;
            comp_opts.supports_features.elect_sync = true;
            device_props
                .features
                .matmul
                .stmatrix
                .insert(ElemType::Float(FloatKind::F16).into());
            device_props
                .features
                .matmul
                .stmatrix
                .insert(ElemType::Float(FloatKind::BF16).into());
            device_props.register_atomic_type_usage(
                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(2),
                AtomicUsage::LoadStore | AtomicUsage::Add,
            );
            device_props.register_atomic_type_usage(
                Type::new(StorageType::Atomic(ElemType::Float(FloatKind::F32))).with_vector_size(4),
                AtomicUsage::LoadStore | AtomicUsage::Add,
            );
        }

        if arch_version >= 100 {
            device_props.features.tma.insert(Tma::Im2colWide);
            // Breaks swizzle so disable for now and fix in a PR specifically for this
            // if CUDA_VERSION >= 12090 {
            //     device_props.hardware.load_width = 256;
            // }
        }

        // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a
        // major version. Try to keep this up to date with new arch major revisions if they also
        // implement it.
        if arch_major == 10 || arch_major == 11 || arch_major == 12 {
            device_props
                .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion);
            device_props.register_type_usage(
                StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
                TypeUsage::Conversion | TypeUsage::Buffer,
            );
            device_props.register_type_usage(
                ElemType::Float(FloatKind::E2M3),
                TypeUsage::Conversion | TypeUsage::Buffer,
            );
            device_props.register_type_usage(
                ElemType::Float(FloatKind::E3M2),
                TypeUsage::Conversion | TypeUsage::Buffer,
            );
            device_props.register_type_usage(
                ElemType::Float(FloatKind::UE8M0),
                TypeUsage::Conversion | TypeUsage::Buffer,
            );

            if CUDA_VERSION >= 12080 {
                device_props.features.tma.insert(Tma::SwizzleAtomicity);
            }
        }

        device_props.features.memory_reinterpret = true;
        device_props.features.alignment = true;
        device_props.features.plane.insert(Plane::Ops);
        device_props
            .features
            .plane
            .insert(Plane::NonUniformControlFlow);

        register_wmma_features(supported_wmma_combinations, &mut device_props);
        register_mma_features(supported_mma_combinations, &mut device_props);
        register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props);

        let cuda_ctx = CudaContext::new(comp_opts, device_props.clone(), ctx, arch);
        let logger = Arc::new(ServerLogger::default());
        let policy = PitchedMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
        let utilities = ServerUtilities::new(device_props, logger, (), policy);

        CudaServer::new(
            cuda_ctx,
            mem_properties,
            options.memory_config,
            mem_alignment,
            device_id,
            utilities,
        )
    }

    fn utilities(&self) -> ServerUtilitiesHandle {
        self.utilities() as ServerUtilitiesHandle
    }
}

pub type CudaCompiler = CppCompiler<CudaDialect<WmmaCompiler>>;
pub type CudaComputeKernel = ComputeKernel<CudaDialect<WmmaCompiler>>;

fn tensor_cores_per_sm(version: u32) -> Option<u32> {
    match version {
        70 | 75 => Some(8),                           // Volta, Turing
        80 | 86 | 89 | 90 | 91 | 92 | 100 => Some(4), // Ampere, Hopper, Blackwell
        _ => None,                                    // Unknown or unsupported architecture
    }
}

impl Runtime for CudaRuntime {
    type Compiler = CudaCompiler;
    type Server = CudaServer;
    type Device = CudaDevice;

    fn client(device: &Self::Device) -> ComputeClient<Self> {
        ComputeClient::load(device)
    }

    fn name(_client: &ComputeClient<Self>) -> &'static str {
        "cuda"
    }

    fn require_array_lengths() -> bool {
        true
    }

    fn max_cube_count() -> (u32, u32, u32) {
        (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
    }

    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
        has_pitched_row_major_strides(shape, strides)
    }

    fn target_properties() -> TargetProperties {
        TargetProperties {
            mma: MmaProperties {
                register_size_bits: 32,
                const_plane_size: 32,
                register_layout_a: MatrixLayout::RowMajor,
                register_layout_b: MatrixLayout::ColMajor,
                register_layout_acc: MatrixLayout::RowMajor,
                register_duplication_a: 1,
                register_duplication_b: 1,
                register_duplication_acc: 1,
                contiguous_elements: ContiguousElements::new(contiguous_elements_cuda),
            },
        }
    }

    fn enumerate_devices(
        _: u16,
        _: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
    ) -> Vec<cubecl_core::device::DeviceId> {
        let count = cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize;
        (0..count)
            .map(|i| DeviceId {
                type_id: 0,
                index_id: i as u16,
            })
            .collect()
    }
}