Skip to main content

bench_matmul/
bench_matmul.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Minimal CUDA-only matmul micro-bench. No rlx-cpu / rlx-runtime —
17//! avoids the CPU-BLAS link tax on hosts without Accelerate / OpenBLAS.
18//!
19//! cargo run --release -p rlx-cuda --example bench_matmul
20
21use std::time::Instant;
22
23use rlx_cuda::backend::CudaExecutable;
24use rlx_ir::{DType, Graph, Shape};
25
26fn bench(m: usize, k: usize, n: usize, warmup: usize, iters: usize) {
27    let mut g = Graph::new("mm");
28    let x = g.input("x", Shape::new(&[m, k], DType::F32));
29    let w = g.param("w", Shape::new(&[k, n], DType::F32));
30    let y = g.matmul(x, w, Shape::new(&[m, n], DType::F32));
31    g.set_outputs(vec![y]);
32
33    let mut exe = CudaExecutable::compile(g);
34    let wv: Vec<f32> = (0..k * n).map(|i| (i as f32) * 1e-3).collect();
35    exe.set_param("w", &wv);
36    let xv: Vec<f32> = (0..m * k).map(|i| (i as f32) * 1e-3).collect();
37
38    for _ in 0..warmup {
39        let _ = exe.run(&[("x", &xv)]);
40    }
41
42    let t0 = Instant::now();
43    for _ in 0..iters {
44        let _ = exe.run(&[("x", &xv)]);
45    }
46    let dt = t0.elapsed().as_secs_f64() / iters as f64;
47    let flops = 2.0 * (m * k * n) as f64;
48    let gflops = flops / dt / 1e9;
49    println!(
50        "  M={:>5} K={:>5} N={:>5}   {:>8.3} ms   {:>8.1} GFLOP/s",
51        m,
52        k,
53        n,
54        dt * 1e3,
55        gflops
56    );
57}
58
59fn main() {
60    if !rlx_cuda::is_available() {
61        println!("CUDA not available on this host — exiting.");
62        return;
63    }
64    println!("rlx-cuda matmul bench");
65    println!("---------------------");
66    let cases: &[(usize, usize, usize)] = &[
67        (128, 128, 128),
68        (512, 512, 512),
69        (1024, 1024, 1024),
70        (2048, 2048, 2048),
71        (4096, 4096, 4096),
72        (8, 4096, 4096),
73        (1024, 4096, 4096),
74    ];
75    for &(m, k, n) in cases {
76        bench(m, k, n, /*warmup*/ 3, /*iters*/ 20);
77    }
78}