trueno/blis/microkernels/
avx512.rs1#[cfg(target_arch = "x86_64")]
19#[target_feature(enable = "avx512f")]
20pub unsafe fn microkernel_16x8_avx512(
21 k: usize,
22 a: *const f32,
23 b: *const f32,
24 c: *mut f32,
25 ldc: usize,
26) {
27 unsafe {
28 use std::arch::x86_64::*;
29
30 let mut c0 = _mm512_loadu_ps(c);
32 let mut c1 = _mm512_loadu_ps(c.add(ldc));
33 let mut c2 = _mm512_loadu_ps(c.add(2 * ldc));
34 let mut c3 = _mm512_loadu_ps(c.add(3 * ldc));
35 let mut c4 = _mm512_loadu_ps(c.add(4 * ldc));
36 let mut c5 = _mm512_loadu_ps(c.add(5 * ldc));
37 let mut c6 = _mm512_loadu_ps(c.add(6 * ldc));
38 let mut c7 = _mm512_loadu_ps(c.add(7 * ldc));
39
40 let k4 = k / 4;
41 let k_rem = k % 4;
42
43 for p4 in 0..k4 {
44 let base = p4 * 4;
45
46 let a0 = _mm512_loadu_ps(a.add(base * 16));
48 let bp0 = b.add(base * 8);
49 c0 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0), c0);
50 c1 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(1)), c1);
51 c2 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(2)), c2);
52 c3 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(3)), c3);
53 c4 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(4)), c4);
54 c5 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(5)), c5);
55 c6 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(6)), c6);
56 c7 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(7)), c7);
57
58 let a1 = _mm512_loadu_ps(a.add((base + 1) * 16));
60 let bp1 = b.add((base + 1) * 8);
61 c0 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1), c0);
62 c1 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(1)), c1);
63 c2 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(2)), c2);
64 c3 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(3)), c3);
65 c4 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(4)), c4);
66 c5 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(5)), c5);
67 c6 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(6)), c6);
68 c7 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(7)), c7);
69
70 let a2 = _mm512_loadu_ps(a.add((base + 2) * 16));
72 let bp2 = b.add((base + 2) * 8);
73 c0 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2), c0);
74 c1 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(1)), c1);
75 c2 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(2)), c2);
76 c3 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(3)), c3);
77 c4 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(4)), c4);
78 c5 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(5)), c5);
79 c6 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(6)), c6);
80 c7 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(7)), c7);
81
82 let a3 = _mm512_loadu_ps(a.add((base + 3) * 16));
84 let bp3 = b.add((base + 3) * 8);
85 c0 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3), c0);
86 c1 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(1)), c1);
87 c2 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(2)), c2);
88 c3 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(3)), c3);
89 c4 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(4)), c4);
90 c5 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(5)), c5);
91 c6 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(6)), c6);
92 c7 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(7)), c7);
93 }
94
95 let base_rem = k4 * 4;
97 for p in 0..k_rem {
98 let pp = base_rem + p;
99 let a_col = _mm512_loadu_ps(a.add(pp * 16));
100 let bp = b.add(pp * 8);
101 c0 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp), c0);
102 c1 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(1)), c1);
103 c2 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(2)), c2);
104 c3 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(3)), c3);
105 c4 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(4)), c4);
106 c5 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(5)), c5);
107 c6 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(6)), c6);
108 c7 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(7)), c7);
109 }
110
111 _mm512_storeu_ps(c, c0);
113 _mm512_storeu_ps(c.add(ldc), c1);
114 _mm512_storeu_ps(c.add(2 * ldc), c2);
115 _mm512_storeu_ps(c.add(3 * ldc), c3);
116 _mm512_storeu_ps(c.add(4 * ldc), c4);
117 _mm512_storeu_ps(c.add(5 * ldc), c5);
118 _mm512_storeu_ps(c.add(6 * ldc), c6);
119 _mm512_storeu_ps(c.add(7 * ldc), c7);
120 }
121}
122
123#[cfg(target_arch = "x86_64")]
140#[target_feature(enable = "avx512f")]
141pub unsafe fn microkernel_32x6_avx512(
142 k: usize,
143 a: *const f32,
144 b: *const f32,
145 c: *mut f32,
146 ldc: usize,
147) {
148 unsafe {
149 use std::arch::x86_64::*;
150
151 let mut c00 = _mm512_loadu_ps(c);
153 let mut c01 = _mm512_loadu_ps(c.add(ldc));
154 let mut c02 = _mm512_loadu_ps(c.add(2 * ldc));
155 let mut c03 = _mm512_loadu_ps(c.add(3 * ldc));
156 let mut c04 = _mm512_loadu_ps(c.add(4 * ldc));
157 let mut c05 = _mm512_loadu_ps(c.add(5 * ldc));
158 let mut c10 = _mm512_loadu_ps(c.add(16));
159 let mut c11 = _mm512_loadu_ps(c.add(ldc + 16));
160 let mut c12 = _mm512_loadu_ps(c.add(2 * ldc + 16));
161 let mut c13 = _mm512_loadu_ps(c.add(3 * ldc + 16));
162 let mut c14 = _mm512_loadu_ps(c.add(4 * ldc + 16));
163 let mut c15 = _mm512_loadu_ps(c.add(5 * ldc + 16));
164
165 let k2 = k / 2;
166 let k_rem = k % 2;
167
168 for p2 in 0..k2 {
170 let base = p2 * 2;
171
172 let a0_lo = _mm512_loadu_ps(a.add(base * 32));
174 let a0_hi = _mm512_loadu_ps(a.add(base * 32 + 16));
175 let bp0 = b.add(base * 6);
176
177 let b0 = _mm512_set1_ps(*bp0);
179 c00 = _mm512_fmadd_ps(a0_lo, b0, c00);
180 c10 = _mm512_fmadd_ps(a0_hi, b0, c10);
181 let b1 = _mm512_set1_ps(*bp0.add(1));
182 c01 = _mm512_fmadd_ps(a0_lo, b1, c01);
183 c11 = _mm512_fmadd_ps(a0_hi, b1, c11);
184 let b2 = _mm512_set1_ps(*bp0.add(2));
185 c02 = _mm512_fmadd_ps(a0_lo, b2, c02);
186 c12 = _mm512_fmadd_ps(a0_hi, b2, c12);
187 let b3 = _mm512_set1_ps(*bp0.add(3));
188 c03 = _mm512_fmadd_ps(a0_lo, b3, c03);
189 c13 = _mm512_fmadd_ps(a0_hi, b3, c13);
190 let b4 = _mm512_set1_ps(*bp0.add(4));
191 c04 = _mm512_fmadd_ps(a0_lo, b4, c04);
192 c14 = _mm512_fmadd_ps(a0_hi, b4, c14);
193 let b5 = _mm512_set1_ps(*bp0.add(5));
194 c05 = _mm512_fmadd_ps(a0_lo, b5, c05);
195 c15 = _mm512_fmadd_ps(a0_hi, b5, c15);
196
197 let a1_lo = _mm512_loadu_ps(a.add((base + 1) * 32));
199 let a1_hi = _mm512_loadu_ps(a.add((base + 1) * 32 + 16));
200 let bp1 = b.add((base + 1) * 6);
201
202 let b0 = _mm512_set1_ps(*bp1);
203 c00 = _mm512_fmadd_ps(a1_lo, b0, c00);
204 c10 = _mm512_fmadd_ps(a1_hi, b0, c10);
205 let b1 = _mm512_set1_ps(*bp1.add(1));
206 c01 = _mm512_fmadd_ps(a1_lo, b1, c01);
207 c11 = _mm512_fmadd_ps(a1_hi, b1, c11);
208 let b2 = _mm512_set1_ps(*bp1.add(2));
209 c02 = _mm512_fmadd_ps(a1_lo, b2, c02);
210 c12 = _mm512_fmadd_ps(a1_hi, b2, c12);
211 let b3 = _mm512_set1_ps(*bp1.add(3));
212 c03 = _mm512_fmadd_ps(a1_lo, b3, c03);
213 c13 = _mm512_fmadd_ps(a1_hi, b3, c13);
214 let b4 = _mm512_set1_ps(*bp1.add(4));
215 c04 = _mm512_fmadd_ps(a1_lo, b4, c04);
216 c14 = _mm512_fmadd_ps(a1_hi, b4, c14);
217 let b5 = _mm512_set1_ps(*bp1.add(5));
218 c05 = _mm512_fmadd_ps(a1_lo, b5, c05);
219 c15 = _mm512_fmadd_ps(a1_hi, b5, c15);
220 }
221
222 let base_rem = k2 * 2;
224 for p in 0..k_rem {
225 let pp = base_rem + p;
226 let a_lo = _mm512_loadu_ps(a.add(pp * 32));
227 let a_hi = _mm512_loadu_ps(a.add(pp * 32 + 16));
228 let bp = b.add(pp * 6);
229 let b0 = _mm512_set1_ps(*bp);
230 c00 = _mm512_fmadd_ps(a_lo, b0, c00);
231 c10 = _mm512_fmadd_ps(a_hi, b0, c10);
232 let b1 = _mm512_set1_ps(*bp.add(1));
233 c01 = _mm512_fmadd_ps(a_lo, b1, c01);
234 c11 = _mm512_fmadd_ps(a_hi, b1, c11);
235 let b2 = _mm512_set1_ps(*bp.add(2));
236 c02 = _mm512_fmadd_ps(a_lo, b2, c02);
237 c12 = _mm512_fmadd_ps(a_hi, b2, c12);
238 let b3 = _mm512_set1_ps(*bp.add(3));
239 c03 = _mm512_fmadd_ps(a_lo, b3, c03);
240 c13 = _mm512_fmadd_ps(a_hi, b3, c13);
241 let b4 = _mm512_set1_ps(*bp.add(4));
242 c04 = _mm512_fmadd_ps(a_lo, b4, c04);
243 c14 = _mm512_fmadd_ps(a_hi, b4, c14);
244 let b5 = _mm512_set1_ps(*bp.add(5));
245 c05 = _mm512_fmadd_ps(a_lo, b5, c05);
246 c15 = _mm512_fmadd_ps(a_hi, b5, c15);
247 }
248
249 _mm512_storeu_ps(c, c00);
251 _mm512_storeu_ps(c.add(ldc), c01);
252 _mm512_storeu_ps(c.add(2 * ldc), c02);
253 _mm512_storeu_ps(c.add(3 * ldc), c03);
254 _mm512_storeu_ps(c.add(4 * ldc), c04);
255 _mm512_storeu_ps(c.add(5 * ldc), c05);
256 _mm512_storeu_ps(c.add(16), c10);
257 _mm512_storeu_ps(c.add(ldc + 16), c11);
258 _mm512_storeu_ps(c.add(2 * ldc + 16), c12);
259 _mm512_storeu_ps(c.add(3 * ldc + 16), c13);
260 _mm512_storeu_ps(c.add(4 * ldc + 16), c14);
261 _mm512_storeu_ps(c.add(5 * ldc + 16), c15);
262 }
263}