Skip to main content

cutlass_gemm_fp8/
cutlass_gemm_fp8.rs

1//! Example: render a Hopper fp8 GEMM template via CutlassActor.
2//!
3//! Runs entirely on the host — no CUDA toolkit required. Prints the
4//! generated `.cu` source and the lowered kernel name, plus the
5//! plan-cache hit rate after a second identical dispatch.
6
7use atomr_accel_cutlass::dtype::F8E4m3;
8use atomr_accel_cutlass::{CutlassActor, CutlassMsg, GemmEpilogue, GemmRequest, GemmShape, SmArch};
9
10fn main() {
11    let actor = CutlassActor::new(16);
12    let req = GemmRequest::<F8E4m3>::new(GemmShape::new(4096, 4096, 4096), SmArch::Sm90a)
13        .with_epilogue(GemmEpilogue::LinearReLU {
14            alpha: 1.0,
15            beta: 0.0,
16        });
17
18    println!("plan key: {:?}", req.plan_key());
19    let (src, name) = req.render_cu();
20    println!("kernel:   {name}");
21    println!("--- generated .cu ---");
22    println!("{src}");
23
24    actor.handle(CutlassMsg::Gemm(Box::new(req.clone())));
25    actor.handle(CutlassMsg::Gemm(Box::new(req)));
26
27    println!("dispatched: {}", actor.inner().dispatched());
28    println!("plan cache len: {}", actor.inner().plan_cache.len());
29}