boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! CUDA INT4 GEMM dispatch helpers (AWQ, GPTQ, Marlin)

use crate::error::{Error, Result};
use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::LaunchConfig;
use numr::runtime::Device;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;

use super::kernels::{self, INT4_GEMM_GPTQ_MODULE, INT4_GEMM_MODULE, MARLIN_GEMM_MODULE};

/// M threshold: use GEMV for M <= this, tiled GEMM for M > this
const GEMV_THRESHOLD: u32 = 4;

#[allow(clippy::too_many_arguments)]
pub fn launch_int4_gemm(
    client: &CudaClient,
    input: &Tensor<CudaRuntime>,
    qweight: &Tensor<CudaRuntime>,
    scales: &Tensor<CudaRuntime>,
    zeros: &Tensor<CudaRuntime>,
    output: &Tensor<CudaRuntime>,
    m: u32,
    k: u32,
    n: u32,
    group_size: u32,
) -> Result<()> {
    let device_index = input.device().id();
    let module = kernels::get_or_load_module(client.context(), device_index, INT4_GEMM_MODULE)?;

    if m <= GEMV_THRESHOLD {
        tracing::debug!(
            m,
            k,
            n,
            path = "awq_int4_gemv",
            "CUDA AWQ kernel: INT4 GEMV (optimized)"
        );
        // GEMV path: 4 warps/block (128 threads), each warp handles 8 cols (one packed u32)
        // 32 output columns per block. Shared memory caches input row.
        let func = kernels::get_kernel_function(&module, "int4_gemv_f32")?;
        let n_packed = n / 8;
        let cfg = LaunchConfig {
            grid_dim: (n_packed.div_ceil(4), m, 1),
            block_dim: (128, 1, 1),
            shared_mem_bytes: 0,
        };

        let input_ptr = input.ptr();
        let qweight_ptr = qweight.ptr();
        let scales_ptr = scales.ptr();
        let zeros_ptr = zeros.ptr();
        let output_ptr = output.ptr();

        unsafe {
            let mut builder = client.stream().launch_builder(&func);
            builder.arg(&input_ptr);
            builder.arg(&qweight_ptr);
            builder.arg(&scales_ptr);
            builder.arg(&zeros_ptr);
            builder.arg(&output_ptr);
            builder.arg(&m);
            builder.arg(&k);
            builder.arg(&n);
            builder.arg(&group_size);
            builder.launch(cfg).map_err(|e| Error::QuantError {
                reason: format!("CUDA int4_gemv launch failed: {:?}", e),
            })?;
        }
    } else {
        tracing::debug!(
            m,
            k,
            n,
            path = "awq_int4_gemm",
            "CUDA AWQ kernel: INT4 tiled GEMM (optimized)"
        );
        // Tiled GEMM path: BM=32, BN=32, BK=32
        // Block: (32, 4) = 128 threads, each thread handles 8 rows
        let func = kernels::get_kernel_function(&module, "int4_gemm_f32")?;
        let cfg = LaunchConfig {
            grid_dim: (n.div_ceil(32), m.div_ceil(32), 1),
            block_dim: (32, 4, 1),
            shared_mem_bytes: 0,
        };

        let input_ptr = input.ptr();
        let qweight_ptr = qweight.ptr();
        let scales_ptr = scales.ptr();
        let zeros_ptr = zeros.ptr();
        let output_ptr = output.ptr();

        unsafe {
            let mut builder = client.stream().launch_builder(&func);
            builder.arg(&input_ptr);
            builder.arg(&qweight_ptr);
            builder.arg(&scales_ptr);
            builder.arg(&zeros_ptr);
            builder.arg(&output_ptr);
            builder.arg(&m);
            builder.arg(&k);
            builder.arg(&n);
            builder.arg(&group_size);
            builder.launch(cfg).map_err(|e| Error::QuantError {
                reason: format!("CUDA int4_gemm launch failed: {:?}", e),
            })?;
        }
    }
    Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn launch_int4_gemm_gptq(
    client: &CudaClient,
    input: &Tensor<CudaRuntime>,
    qweight: &Tensor<CudaRuntime>,
    qzeros: &Tensor<CudaRuntime>,
    scales: &Tensor<CudaRuntime>,
    g_idx: &Tensor<CudaRuntime>,
    output: &Tensor<CudaRuntime>,
    m: u32,
    k: u32,
    n: u32,
) -> Result<()> {
    let device_index = input.device().id();
    let module =
        kernels::get_or_load_module(client.context(), device_index, INT4_GEMM_GPTQ_MODULE)?;

    let input_ptr = input.ptr();
    let qweight_ptr = qweight.ptr();
    let qzeros_ptr = qzeros.ptr();
    let scales_ptr = scales.ptr();
    let g_idx_ptr = g_idx.ptr();
    let output_ptr = output.ptr();

    if m <= GEMV_THRESHOLD {
        tracing::debug!(
            m,
            k,
            n,
            path = "gptq_int4_gemv",
            "CUDA GPTQ kernel: INT4 GEMV (optimized)"
        );
        // GEMV: 128 threads (one per output col), tiled over K in chunks of 128
        // Grid: (ceil(K/128), ceil(N/128), M). Uses atomicAdd → output must be zeroed.
        let func = kernels::get_kernel_function(&module, "int4_gemv_gptq_f32")?;
        let k_blocks = k.div_ceil(128);
        let cfg = LaunchConfig {
            grid_dim: (k_blocks, n.div_ceil(128), m),
            block_dim: (128, 1, 1),
            shared_mem_bytes: 0,
        };

        // Zero output before atomicAdd accumulation
        let output_bytes = (m * n) as usize * std::mem::size_of::<f32>();
        unsafe {
            cudarc::driver::result::memset_d8_async(
                output.ptr(),
                0,
                output_bytes,
                client.stream().cu_stream(),
            )
            .map_err(|e| Error::QuantError {
                reason: format!("CUDA memset failed: {:?}", e),
            })?;
        }

        unsafe {
            let mut builder = client.stream().launch_builder(&func);
            builder.arg(&input_ptr);
            builder.arg(&qweight_ptr);
            builder.arg(&qzeros_ptr);
            builder.arg(&scales_ptr);
            builder.arg(&g_idx_ptr);
            builder.arg(&output_ptr);
            builder.arg(&m);
            builder.arg(&k);
            builder.arg(&n);
            builder.launch(cfg).map_err(|e| Error::QuantError {
                reason: format!("CUDA int4_gemv_gptq launch failed: {:?}", e),
            })?;
        }
    } else {
        tracing::debug!(
            m,
            k,
            n,
            path = "gptq_int4_gemm",
            "CUDA GPTQ kernel: INT4 tiled GEMM (optimized)"
        );
        // Tiled GEMM: BM=32, BN=32, BK=32
        let func = kernels::get_kernel_function(&module, "int4_gemm_gptq_f32")?;
        let cfg = LaunchConfig {
            grid_dim: (n.div_ceil(32), m.div_ceil(32), 1),
            block_dim: (32, 4, 1),
            shared_mem_bytes: 0,
        };

        unsafe {
            let mut builder = client.stream().launch_builder(&func);
            builder.arg(&input_ptr);
            builder.arg(&qweight_ptr);
            builder.arg(&qzeros_ptr);
            builder.arg(&scales_ptr);
            builder.arg(&g_idx_ptr);
            builder.arg(&output_ptr);
            builder.arg(&m);
            builder.arg(&k);
            builder.arg(&n);
            builder.launch(cfg).map_err(|e| Error::QuantError {
                reason: format!("CUDA int4_gemm_gptq launch failed: {:?}", e),
            })?;
        }
    }
    Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn launch_marlin_gemm(
    client: &CudaClient,
    input: &Tensor<CudaRuntime>,
    weight: &Tensor<CudaRuntime>,
    scales: &Tensor<CudaRuntime>,
    zeros: &Tensor<CudaRuntime>,
    output: &Tensor<CudaRuntime>,
    m: u32,
    k: u32,
    n: u32,
    group_size: u32,
) -> Result<()> {
    let device_index = input.device().id();
    let module = kernels::get_or_load_module(client.context(), device_index, MARLIN_GEMM_MODULE)?;
    let func = kernels::get_kernel_function(&module, "marlin_gemm_f32")?;

    let cfg = LaunchConfig {
        grid_dim: (n.div_ceil(16), m.div_ceil(16), 1),
        block_dim: (16, 16, 1),
        shared_mem_bytes: 0,
    };

    let input_ptr = input.ptr();
    let weight_ptr = weight.ptr();
    let scales_ptr = scales.ptr();
    let zeros_ptr = zeros.ptr();
    let output_ptr = output.ptr();

    unsafe {
        let mut builder = client.stream().launch_builder(&func);
        builder.arg(&input_ptr);
        builder.arg(&weight_ptr);
        builder.arg(&scales_ptr);
        builder.arg(&zeros_ptr);
        builder.arg(&output_ptr);
        builder.arg(&m);
        builder.arg(&k);
        builder.arg(&n);
        builder.arg(&group_size);
        builder.launch(cfg).map_err(|e| Error::QuantError {
            reason: format!("CUDA marlin_gemm launch failed: {:?}", e),
        })?;
    }
    Ok(())
}