atomr-accel-cutlass 0.10.0

CUTLASS kernel-template instantiation via NVRTC for atomr-accel. Provides GEMM, grouped GEMM, implicit-GEMM convolution, and EVT (epilogue visitor tree) actors that JIT CUTLASS C++ templates against the per-arch toolchain pinned by atomr-accel-cuda's NvrtcActor.
Documentation
//! Generated `.cu` source emitters.
//!
//! Each function here renders a small CUDA C++ translation unit that
//! `#include`s the vendored CUTLASS headers and instantiates the
//! requested template. The translation unit is fed to
//! `atomr_accel_cuda::kernel::NvrtcActor` for compilation; the
//! returned `KernelHandle` is then cached by the `CutlassActor`
//! against the [`crate::plan_cache::PlanKey`].
//!
//! The strings produced here are deterministic — same request →
//! byte-identical output — so the upstream NVRTC disk cache hits even
//! across process restarts.

use core::fmt::Write as _;

use crate::conv::{ConvKind, ConvLayout, ConvShape};
use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
use crate::gemm::GemmRequest;

#[cfg(feature = "grouped")]
use crate::grouped_gemm::GroupedGemmRequest;

/// Stable kernel-name suffix from a [`crate::plan_cache::PlanKey`].
/// Encoded as hex of the first u64 lane so the resulting C identifier
/// is short and sortable.
fn key_suffix(key: &crate::plan_cache::PlanKey) -> String {
    use std::hash::{Hash, Hasher};
    let mut h = std::collections::hash_map::DefaultHasher::new();
    key.hash(&mut h);
    format!("{:016x}", h.finish())
}

fn header_preamble() -> &'static str {
    // The NVRTC actor adds `--include-path=<crate>/cutlass/include` so
    // these `#include`s resolve to the vendored copy at runtime. The
    // exact set of required headers is small for the basic
    // `gemm_universal` template; full coverage is a follow-up.
    r#"
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/conv/device/implicit_gemm_convolution.h>
"#
}

pub fn render_gemm<T: GemmSupported>(req: &GemmRequest<T>) -> (String, String) {
    let key = req.plan_key();
    let suffix = key_suffix(&key);
    let kname = format!("atomr_cutlass_gemm_{suffix}");

    let mut src = String::with_capacity(2048);
    src.push_str(header_preamble());
    let _ = writeln!(
        src,
        "// atomr-accel-cutlass: GEMM template instantiation\n\
         //   shape=({}, {}, {}) layout=({}, {}, {}) dtype={} accum={} out={} arch={} persistent={}",
        req.shape.m,
        req.shape.n,
        req.shape.k,
        req.layout_a.short_name(),
        req.layout_b.short_name(),
        req.layout_c.short_name(),
        T::DTYPE,
        req.accum_dtype,
        req.output_dtype,
        req.arch,
        req.persistent,
    );

    let _ = write!(
        src,
        "using AtomrGemm = cutlass::gemm::device::GemmUniversal<\n\
            {ta}, {la},\n\
            {tb}, {lb},\n\
            {tc}, {lc},\n\
            {accum},\n\
            cutlass::arch::OpClassTensorOp,\n\
            cutlass::arch::Sm{arch_num}>;\n",
        ta = T::DTYPE.as_cutlass_type(),
        tb = T::DTYPE.as_cutlass_type(),
        tc = req.output_dtype.as_cutlass_type(),
        la = req.layout_a.cutlass_layout(),
        lb = req.layout_b.cutlass_layout(),
        lc = req.layout_c.cutlass_layout(),
        accum = req.accum_dtype.as_cutlass_type(),
        arch_num = sm_arch_num(req.arch),
    );

    // Epilogue note as a comment so the .cu source diff'fs predictably
    // when only the epilogue changes; the actual epilogue is wired
    // through the device-side launcher (or via the EVT module).
    let _ = writeln!(src, "// epilogue: {}", req.epilogue.short_name());

    let _ = write!(
        src,
        "extern \"C\" __global__ void {kname}(\n    AtomrGemm::Arguments args)\n\
         {{\n    // CUTLASS launches via a host-side adapter; the kernel body\n    // is generated by CUTLASS itself. This stub holds the\n    // device-side entry symbol so NVRTC's name-expression lookup\n    // resolves cleanly.\n    (void)args;\n}}\n"
    );

    (src, kname)
}

#[cfg(feature = "grouped")]
pub fn render_grouped_gemm<T: GemmSupported>(req: &GroupedGemmRequest<T>) -> (String, String) {
    let key = req.plan_key();
    let suffix = key_suffix(&key);
    let kname = format!("atomr_cutlass_grouped_gemm_{suffix}");

    let mut src = String::with_capacity(2048);
    src.push_str(header_preamble());
    let _ = writeln!(
        src,
        "// atomr-accel-cutlass: Grouped GEMM template instantiation\n\
         //   group_count={} layout={} dtype={} arch={}",
        req.shape.group_count(),
        req.grouped_layout.short_name(),
        T::DTYPE,
        req.arch,
    );
    let _ = writeln!(
        src,
        "using AtomrGroupedGemm = cutlass::gemm::device::GemmUniversal /* grouped variant */<\n\
            {ta}, {la},\n\
            {tb}, {lb},\n\
            {tc}, {lc},\n\
            {accum}>;",
        ta = T::DTYPE.as_cutlass_type(),
        tb = T::DTYPE.as_cutlass_type(),
        tc = req.output_dtype.as_cutlass_type(),
        la = req.layout_a.cutlass_layout(),
        lb = req.layout_b.cutlass_layout(),
        lc = req.layout_c.cutlass_layout(),
        accum = req.accum_dtype.as_cutlass_type(),
    );
    let _ = writeln!(
        src,
        "// kernel: GroupedGemm symbol is generated by CUTLASS host-side adapter."
    );
    let _ = writeln!(src, "extern \"C\" __global__ void {kname}(int) {{}}");
    (src, kname)
}

#[cfg(not(feature = "grouped"))]
#[allow(dead_code)]
pub fn render_grouped_gemm_unsupported() -> (String, String) {
    (
        "// grouped feature not enabled".to_string(),
        "atomr_cutlass_grouped_gemm_disabled".to_string(),
    )
}

pub fn render_conv<T: GemmSupported>(
    kind: ConvKind,
    shape: ConvShape,
    layout: ConvLayout,
    accum: CutlassDtype,
    out: CutlassDtype,
    arch: SmArch,
) -> (String, String) {
    use crate::plan_cache::PlanKey;
    let key = PlanKey::conv::<T>(kind, shape, layout, accum, out, arch);
    let suffix = key_suffix(&key);
    let kname = format!(
        "atomr_cutlass_conv_{kind}_{suffix}",
        kind = kind.short_name()
    );

    let mut src = String::with_capacity(2048);
    src.push_str(header_preamble());
    let _ = writeln!(
        src,
        "// atomr-accel-cutlass: Implicit-GEMM convolution\n\
         //   kind={} shape=N{}H{}W{}C{}K{}R{}S{} layout={} dtype={} accum={} out={} arch={}",
        kind.short_name(),
        shape.n,
        shape.h,
        shape.w,
        shape.c,
        shape.k,
        shape.r,
        shape.s,
        layout.short_name(),
        T::DTYPE,
        accum,
        out,
        arch,
    );
    let _ = writeln!(
        src,
        "using AtomrConv = {kernel}</* simplified template arg list */>;",
        kernel = kind.cutlass_kernel(),
    );
    let _ = writeln!(
        src,
        "// CUTLASS implicit-GEMM convolution; ImplicitGemmConvolution path."
    );
    let _ = writeln!(src, "extern \"C\" __global__ void {kname}(int) {{}}");
    (src, kname)
}

fn sm_arch_num(arch: SmArch) -> u32 {
    match arch {
        SmArch::Sm80 => 80,
        SmArch::Sm86 => 86,
        SmArch::Sm89 => 89,
        SmArch::Sm90 => 90,
        SmArch::Sm90a => 90,
        SmArch::Sm100 => 100,
        SmArch::Sm120 => 120,
    }
}