trueno/blis/microkernels/
codegen.rs1use trueno_gemm_codegen::{avx512_microkernel, avx512_microkernel_broadcast_b};
10
11avx512_microkernel!(mr = 8, nr = 32);
17
18avx512_microkernel!(mr = 8, nr = 16);
20
21avx512_microkernel!(mr = 8, nr = 48);
24
25avx512_microkernel_broadcast_b!(mr = 32, nr = 6);
31
32avx512_microkernel_broadcast_b!(mr = 48, nr = 6);
34
35avx512_microkernel_broadcast_b!(mr = 64, nr = 6);
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42
43 fn gemm_reference(m: usize, n: usize, k: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
45 for i in 0..m {
46 for j in 0..n {
47 for p in 0..k {
48 c[i * n + j] += a[p * m + i] * b[p * n + j];
49 }
50 }
51 }
52 }
53
54 #[test]
56 fn test_codegen_8x32_correctness() {
57 let mr = 8;
58 let nr = 32;
59 let k = 64;
60
61 let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
62 let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
63 let mut c_gen = vec![0.0f32; mr * nr];
64 let mut c_ref = vec![0.0f32; mr * nr];
65
66 gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
68
69 if std::arch::is_x86_feature_detected!("avx512f") {
72 unsafe {
73 microkernel_8x32_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
74 }
75
76 let max_diff =
77 c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
78
79 assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x32");
80 }
81 }
82
83 #[test]
85 fn test_codegen_8x16_correctness() {
86 let mr = 8;
87 let nr = 16;
88 let k = 64;
89
90 let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
91 let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
92 let mut c_gen = vec![0.0f32; mr * nr];
93 let mut c_ref = vec![0.0f32; mr * nr];
94
95 gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
96
97 if std::arch::is_x86_feature_detected!("avx512f") {
98 unsafe {
99 microkernel_8x16_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
100 }
101
102 let max_diff =
103 c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
104
105 assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x16");
106 }
107 }
108
109 #[test]
111 fn test_codegen_8x48_correctness() {
112 let mr = 8;
113 let nr = 48;
114 let k = 32;
115
116 let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
117 let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
118 let mut c_gen = vec![0.0f32; mr * nr];
119 let mut c_ref = vec![0.0f32; mr * nr];
120
121 gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
122
123 if std::arch::is_x86_feature_detected!("avx512f") {
124 unsafe {
125 microkernel_8x48_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
126 }
127
128 let max_diff =
129 c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
130
131 assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x48");
132 }
133 }
134
135 #[test]
139 fn test_codegen_bcast_b_32x6_correctness() {
140 let mr = 32;
141 let nr = 6;
142 let k = 64;
143
144 let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
145 let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
146 let mut c_gen = vec![0.0f32; mr * nr];
147 let mut c_ref = vec![0.0f32; mr * nr];
148
149 gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
150
151 if std::arch::is_x86_feature_detected!("avx512f") {
152 unsafe {
153 microkernel_32x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
154 }
155
156 let max_diff =
157 c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
158
159 assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002a: max diff {max_diff} >= 1e-2");
160 }
161 }
162
163 #[test]
165 fn test_codegen_bcast_b_48x6_correctness() {
166 let mr = 48;
167 let nr = 6;
168 let k = 64;
169
170 let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
171 let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
172 let mut c_gen = vec![0.0f32; mr * nr];
173 let mut c_ref = vec![0.0f32; mr * nr];
174
175 gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
176
177 if std::arch::is_x86_feature_detected!("avx512f") {
178 unsafe {
179 microkernel_48x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
180 }
181
182 let max_diff =
183 c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
184
185 assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002b: max diff {max_diff} >= 1e-2");
186 }
187 }
188
189 #[test]
191 fn test_codegen_bcast_b_64x6_correctness() {
192 let mr = 64;
193 let nr = 6;
194 let k = 32;
195
196 let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
197 let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
198 let mut c_gen = vec![0.0f32; mr * nr];
199 let mut c_ref = vec![0.0f32; mr * nr];
200
201 gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
202
203 if std::arch::is_x86_feature_detected!("avx512f") {
204 unsafe {
205 microkernel_64x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
206 }
207
208 let max_diff =
209 c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
210
211 assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002c: max diff {max_diff} >= 1e-2");
212 }
213 }
214}