Skip to main content

trueno/blis/microkernels/
codegen.rs

1//! Generated GEMM microkernels via trueno-gemm-codegen proc macro.
2//!
3//! Contract: cgp-gemm-codegen-v1.yaml (C-CODEGEN-001 through C-CODEGEN-004)
4//! Sovereign: all code generated at compile time from trueno's own proc macro.
5//!
6//! These kernels are shape-specialized at compile time, producing fully-unrolled
7//! FMA code with optimal register allocation for each (MR, NR) combination.
8
9use trueno_gemm_codegen::{avx512_microkernel, avx512_microkernel_broadcast_b};
10
11// === Broadcast-A kernels (A scalar broadcast, B zmm vector load) ===
12
13// Generate the same 8x32 shape as the hand-written kernel for validation.
14// C-CODEGEN-001: must match hand-written output within 1e-5.
15// C-CODEGEN-002: must not be slower than hand-written (within 5%).
16avx512_microkernel!(mr = 8, nr = 32);
17
18// Generate 8x16 for small-N path validation.
19avx512_microkernel!(mr = 8, nr = 16);
20
21// New shapes not previously hand-written — explore register space.
22// 8x48: 24 accumulators (8*3 zmm) + 3 B loads = 27 zmm. Fits in 32.
23avx512_microkernel!(mr = 8, nr = 48);
24
25// === Broadcast-B kernels (faer-style: A zmm vector load, B scalar broadcast) ===
26// Advantage: small NR → tiny B panel → large KC → less packing overhead.
27// MR must be multiple of 16 (zmm width).
28
29// 32×6: 2 A-loads × 6 B-broadcasts = 12 accumulators + 2 A + 4 headroom = 18 zmm.
30avx512_microkernel_broadcast_b!(mr = 32, nr = 6);
31
32// 48×6: 3 A-loads × 6 B-broadcasts = 18 accumulators + 3 A + 4 headroom = 25 zmm.
33avx512_microkernel_broadcast_b!(mr = 48, nr = 6);
34
35// 64×6: 4 A-loads × 6 B-broadcasts = 24 accumulators + 4 A + 4 headroom = 32 zmm.
36// Matches faer's register utilization. Maximum tile that fits in zmm register file.
37avx512_microkernel_broadcast_b!(mr = 64, nr = 6);
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42
43    /// Helper: scalar reference GEMM for validation.
44    fn gemm_reference(m: usize, n: usize, k: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
45        for i in 0..m {
46            for j in 0..n {
47                for p in 0..k {
48                    c[i * n + j] += a[p * m + i] * b[p * n + j];
49                }
50            }
51        }
52    }
53
54    /// FALSIFY-CODEGEN-001: Generated 8x32 matches scalar reference.
55    #[test]
56    fn test_codegen_8x32_correctness() {
57        let mr = 8;
58        let nr = 32;
59        let k = 64;
60
61        let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
62        let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
63        let mut c_gen = vec![0.0f32; mr * nr];
64        let mut c_ref = vec![0.0f32; mr * nr];
65
66        // Reference
67        gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
68
69        // Generated kernel
70        // SAFETY: AVX-512 is available (test runs on x86_64 with avx512f)
71        if std::arch::is_x86_feature_detected!("avx512f") {
72            unsafe {
73                microkernel_8x32_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
74            }
75
76            let max_diff =
77                c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
78
79            assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x32");
80        }
81    }
82
83    /// FALSIFY-CODEGEN-001b: Generated 8x16 matches scalar reference.
84    #[test]
85    fn test_codegen_8x16_correctness() {
86        let mr = 8;
87        let nr = 16;
88        let k = 64;
89
90        let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
91        let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
92        let mut c_gen = vec![0.0f32; mr * nr];
93        let mut c_ref = vec![0.0f32; mr * nr];
94
95        gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
96
97        if std::arch::is_x86_feature_detected!("avx512f") {
98            unsafe {
99                microkernel_8x16_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
100            }
101
102            let max_diff =
103                c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
104
105            assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x16");
106        }
107    }
108
109    /// FALSIFY-CODEGEN-001c: Generated 8x48 (new shape) matches scalar reference.
110    #[test]
111    fn test_codegen_8x48_correctness() {
112        let mr = 8;
113        let nr = 48;
114        let k = 32;
115
116        let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
117        let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
118        let mut c_gen = vec![0.0f32; mr * nr];
119        let mut c_ref = vec![0.0f32; mr * nr];
120
121        gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
122
123        if std::arch::is_x86_feature_detected!("avx512f") {
124            unsafe {
125                microkernel_8x48_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
126            }
127
128            let max_diff =
129                c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
130
131            assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x48");
132        }
133    }
134
135    // === Broadcast-B kernel tests ===
136
137    /// FALSIFY-CODEGEN-002a: broadcast-B 32×6 matches scalar reference.
138    #[test]
139    fn test_codegen_bcast_b_32x6_correctness() {
140        let mr = 32;
141        let nr = 6;
142        let k = 64;
143
144        let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
145        let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
146        let mut c_gen = vec![0.0f32; mr * nr];
147        let mut c_ref = vec![0.0f32; mr * nr];
148
149        gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
150
151        if std::arch::is_x86_feature_detected!("avx512f") {
152            unsafe {
153                microkernel_32x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
154            }
155
156            let max_diff =
157                c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
158
159            assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002a: max diff {max_diff} >= 1e-2");
160        }
161    }
162
163    /// FALSIFY-CODEGEN-002b: broadcast-B 48×6 matches scalar reference.
164    #[test]
165    fn test_codegen_bcast_b_48x6_correctness() {
166        let mr = 48;
167        let nr = 6;
168        let k = 64;
169
170        let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
171        let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
172        let mut c_gen = vec![0.0f32; mr * nr];
173        let mut c_ref = vec![0.0f32; mr * nr];
174
175        gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
176
177        if std::arch::is_x86_feature_detected!("avx512f") {
178            unsafe {
179                microkernel_48x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
180            }
181
182            let max_diff =
183                c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
184
185            assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002b: max diff {max_diff} >= 1e-2");
186        }
187    }
188
189    /// FALSIFY-CODEGEN-002c: broadcast-B 64×6 matches scalar reference.
190    #[test]
191    fn test_codegen_bcast_b_64x6_correctness() {
192        let mr = 64;
193        let nr = 6;
194        let k = 32;
195
196        let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
197        let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
198        let mut c_gen = vec![0.0f32; mr * nr];
199        let mut c_ref = vec![0.0f32; mr * nr];
200
201        gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
202
203        if std::arch::is_x86_feature_detected!("avx512f") {
204            unsafe {
205                microkernel_64x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
206            }
207
208            let max_diff =
209                c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
210
211            assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002c: max diff {max_diff} >= 1e-2");
212        }
213    }
214}