bench_matmul/
bench_matmul.rs1use std::time::Instant;
22
23use rlx_cuda::backend::CudaExecutable;
24use rlx_ir::{DType, Graph, Shape};
25
26fn bench(m: usize, k: usize, n: usize, warmup: usize, iters: usize) {
27 let mut g = Graph::new("mm");
28 let x = g.input("x", Shape::new(&[m, k], DType::F32));
29 let w = g.param("w", Shape::new(&[k, n], DType::F32));
30 let y = g.matmul(x, w, Shape::new(&[m, n], DType::F32));
31 g.set_outputs(vec![y]);
32
33 let mut exe = CudaExecutable::compile(g);
34 let wv: Vec<f32> = (0..k * n).map(|i| (i as f32) * 1e-3).collect();
35 exe.set_param("w", &wv);
36 let xv: Vec<f32> = (0..m * k).map(|i| (i as f32) * 1e-3).collect();
37
38 for _ in 0..warmup {
39 let _ = exe.run(&[("x", &xv)]);
40 }
41
42 let t0 = Instant::now();
43 for _ in 0..iters {
44 let _ = exe.run(&[("x", &xv)]);
45 }
46 let dt = t0.elapsed().as_secs_f64() / iters as f64;
47 let flops = 2.0 * (m * k * n) as f64;
48 let gflops = flops / dt / 1e9;
49 println!(
50 " M={:>5} K={:>5} N={:>5} {:>8.3} ms {:>8.1} GFLOP/s",
51 m,
52 k,
53 n,
54 dt * 1e3,
55 gflops
56 );
57}
58
59fn main() {
60 if !rlx_cuda::is_available() {
61 println!("CUDA not available on this host — exiting.");
62 return;
63 }
64 println!("rlx-cuda matmul bench");
65 println!("---------------------");
66 let cases: &[(usize, usize, usize)] = &[
67 (128, 128, 128),
68 (512, 512, 512),
69 (1024, 1024, 1024),
70 (2048, 2048, 2048),
71 (4096, 4096, 4096),
72 (8, 4096, 4096),
73 (1024, 4096, 4096),
74 ];
75 for &(m, k, n) in cases {
76 bench(m, k, n, 3, 20);
77 }
78}