Skip to main content

ferrum_testkit/op_diff/
gemm.rs

1//! `gemm` op-diff harness — covers the basic fp16 matmul that backs
2//! `qkv_proj`, `o_proj`, `gate_up_proj`, `down_proj`, and the lm_head
3//! projection. Per nsys profile on Vast 4090 / M3, `Marlin<256,...>`
4//! Marlin matmul accounts for ~55% of GPU time at c=16; this op-diff
5//! validates the non-quantized fallback path against CPU.
6
7use super::{random_vec, OpUnderTest, Output};
8use ferrum_kernels::backend::Backend;
9
10/// `C[m, n] = A[m, k] · B[n, k]^T` (row-major, B already transposed
11/// to head-major). Matches the Backend::gemm signature used by Linear.
12pub struct GemmOp {
13    pub m: usize,
14    pub n: usize,
15    pub k: usize,
16}
17
18impl GemmOp {
19    fn input_a_len(&self) -> usize {
20        self.m * self.k
21    }
22    fn input_b_len(&self) -> usize {
23        self.n * self.k
24    }
25    fn output_len(&self) -> usize {
26        self.m * self.n
27    }
28
29    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>) {
30        let a = random_vec(self.input_a_len(), -1.0, 1.0, seed);
31        let b = random_vec(self.input_b_len(), -1.0, 1.0, seed.wrapping_add(1));
32        (a, b)
33    }
34}
35
36impl OpUnderTest for GemmOp {
37    fn name(&self) -> &str {
38        "gemm"
39    }
40
41    fn run_cpu(&self, seed: u64) -> Output {
42        use ferrum_kernels::backend::cpu::CpuBackend;
43        let (a, b) = self.build_input(seed);
44        let mut ctx = CpuBackend::new_context();
45        let a_buf = CpuBackend::from_slice(&a);
46        let b_buf = CpuBackend::from_slice(&b);
47        let mut out = CpuBackend::alloc(self.output_len());
48        CpuBackend::gemm(&mut ctx, &a_buf, &b_buf, &mut out, self.m, self.n, self.k);
49        CpuBackend::sync(&mut ctx);
50        CpuBackend::to_vec(&out, self.output_len())
51    }
52
53    #[cfg(all(target_os = "macos", feature = "metal"))]
54    fn run_metal(&self, seed: u64) -> Output {
55        use ferrum_kernels::backend::metal::MetalBackend;
56        let (a, b) = self.build_input(seed);
57        let mut ctx = MetalBackend::new_context();
58        let a_buf = MetalBackend::from_slice(&a);
59        let b_buf = MetalBackend::from_slice(&b);
60        let mut out = MetalBackend::alloc(self.output_len());
61        MetalBackend::gemm(&mut ctx, &a_buf, &b_buf, &mut out, self.m, self.n, self.k);
62        MetalBackend::sync(&mut ctx);
63        MetalBackend::to_vec(&out, self.output_len())
64    }
65
66    #[cfg(feature = "cuda")]
67    fn run_cuda(&self, seed: u64) -> Output {
68        use ferrum_kernels::backend::cuda::CudaBackend;
69        let (a, b) = self.build_input(seed);
70        let mut ctx = CudaBackend::new_context();
71        let a_buf = CudaBackend::from_slice(&a);
72        let b_buf = CudaBackend::from_slice(&b);
73        let mut out = CudaBackend::alloc(self.output_len());
74        CudaBackend::gemm(&mut ctx, &a_buf, &b_buf, &mut out, self.m, self.n, self.k);
75        CudaBackend::sync(&mut ctx);
76        CudaBackend::to_vec(&out, self.output_len())
77    }
78}