use core::fmt::Write as _;
use crate::conv::{ConvKind, ConvLayout, ConvShape};
use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
use crate::gemm::GemmRequest;
#[cfg(feature = "grouped")]
use crate::grouped_gemm::GroupedGemmRequest;
fn key_suffix(key: &crate::plan_cache::PlanKey) -> String {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut h);
format!("{:016x}", h.finish())
}
fn header_preamble() -> &'static str {
r#"
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/conv/device/implicit_gemm_convolution.h>
"#
}
pub fn render_gemm<T: GemmSupported>(req: &GemmRequest<T>) -> (String, String) {
let key = req.plan_key();
let suffix = key_suffix(&key);
let kname = format!("atomr_cutlass_gemm_{suffix}");
let mut src = String::with_capacity(2048);
src.push_str(header_preamble());
let _ = writeln!(
src,
"// atomr-accel-cutlass: GEMM template instantiation\n\
// shape=({}, {}, {}) layout=({}, {}, {}) dtype={} accum={} out={} arch={} persistent={}",
req.shape.m,
req.shape.n,
req.shape.k,
req.layout_a.short_name(),
req.layout_b.short_name(),
req.layout_c.short_name(),
T::DTYPE,
req.accum_dtype,
req.output_dtype,
req.arch,
req.persistent,
);
let _ = write!(
src,
"using AtomrGemm = cutlass::gemm::device::GemmUniversal<\n\
{ta}, {la},\n\
{tb}, {lb},\n\
{tc}, {lc},\n\
{accum},\n\
cutlass::arch::OpClassTensorOp,\n\
cutlass::arch::Sm{arch_num}>;\n",
ta = T::DTYPE.as_cutlass_type(),
tb = T::DTYPE.as_cutlass_type(),
tc = req.output_dtype.as_cutlass_type(),
la = req.layout_a.cutlass_layout(),
lb = req.layout_b.cutlass_layout(),
lc = req.layout_c.cutlass_layout(),
accum = req.accum_dtype.as_cutlass_type(),
arch_num = sm_arch_num(req.arch),
);
let _ = writeln!(src, "// epilogue: {}", req.epilogue.short_name());
let _ = write!(
src,
"extern \"C\" __global__ void {kname}(\n AtomrGemm::Arguments args)\n\
{{\n // CUTLASS launches via a host-side adapter; the kernel body\n // is generated by CUTLASS itself. This stub holds the\n // device-side entry symbol so NVRTC's name-expression lookup\n // resolves cleanly.\n (void)args;\n}}\n"
);
(src, kname)
}
#[cfg(feature = "grouped")]
pub fn render_grouped_gemm<T: GemmSupported>(req: &GroupedGemmRequest<T>) -> (String, String) {
let key = req.plan_key();
let suffix = key_suffix(&key);
let kname = format!("atomr_cutlass_grouped_gemm_{suffix}");
let mut src = String::with_capacity(2048);
src.push_str(header_preamble());
let _ = writeln!(
src,
"// atomr-accel-cutlass: Grouped GEMM template instantiation\n\
// group_count={} layout={} dtype={} arch={}",
req.shape.group_count(),
req.grouped_layout.short_name(),
T::DTYPE,
req.arch,
);
let _ = writeln!(
src,
"using AtomrGroupedGemm = cutlass::gemm::device::GemmUniversal /* grouped variant */<\n\
{ta}, {la},\n\
{tb}, {lb},\n\
{tc}, {lc},\n\
{accum}>;",
ta = T::DTYPE.as_cutlass_type(),
tb = T::DTYPE.as_cutlass_type(),
tc = req.output_dtype.as_cutlass_type(),
la = req.layout_a.cutlass_layout(),
lb = req.layout_b.cutlass_layout(),
lc = req.layout_c.cutlass_layout(),
accum = req.accum_dtype.as_cutlass_type(),
);
let _ = writeln!(
src,
"// kernel: GroupedGemm symbol is generated by CUTLASS host-side adapter."
);
let _ = writeln!(src, "extern \"C\" __global__ void {kname}(int) {{}}");
(src, kname)
}
#[cfg(not(feature = "grouped"))]
#[allow(dead_code)]
pub fn render_grouped_gemm_unsupported() -> (String, String) {
(
"// grouped feature not enabled".to_string(),
"atomr_cutlass_grouped_gemm_disabled".to_string(),
)
}
pub fn render_conv<T: GemmSupported>(
kind: ConvKind,
shape: ConvShape,
layout: ConvLayout,
accum: CutlassDtype,
out: CutlassDtype,
arch: SmArch,
) -> (String, String) {
use crate::plan_cache::PlanKey;
let key = PlanKey::conv::<T>(kind, shape, layout, accum, out, arch);
let suffix = key_suffix(&key);
let kname = format!(
"atomr_cutlass_conv_{kind}_{suffix}",
kind = kind.short_name()
);
let mut src = String::with_capacity(2048);
src.push_str(header_preamble());
let _ = writeln!(
src,
"// atomr-accel-cutlass: Implicit-GEMM convolution\n\
// kind={} shape=N{}H{}W{}C{}K{}R{}S{} layout={} dtype={} accum={} out={} arch={}",
kind.short_name(),
shape.n,
shape.h,
shape.w,
shape.c,
shape.k,
shape.r,
shape.s,
layout.short_name(),
T::DTYPE,
accum,
out,
arch,
);
let _ = writeln!(
src,
"using AtomrConv = {kernel}</* simplified template arg list */>;",
kernel = kind.cutlass_kernel(),
);
let _ = writeln!(
src,
"// CUTLASS implicit-GEMM convolution; ImplicitGemmConvolution path."
);
let _ = writeln!(src, "extern \"C\" __global__ void {kname}(int) {{}}");
(src, kname)
}
fn sm_arch_num(arch: SmArch) -> u32 {
match arch {
SmArch::Sm80 => 80,
SmArch::Sm86 => 86,
SmArch::Sm89 => 89,
SmArch::Sm90 => 90,
SmArch::Sm90a => 90,
SmArch::Sm100 => 100,
SmArch::Sm120 => 120,
}
}