1const GEMV_TILE_THRESHOLD: usize = 8192;
27
28#[cfg(target_arch = "x86_64")]
42#[target_feature(enable = "avx2", enable = "fma")]
43pub unsafe fn gemv_avx2(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
44 unsafe {
45 use std::arch::x86_64::*;
46
47 let n8 = n / 8 * 8;
48
49 let k4 = k / 4 * 4;
51 let mut ki = 0;
52 while ki < k4 {
53 let a0 = _mm256_set1_ps(*a.get_unchecked(ki));
54 let a1 = _mm256_set1_ps(*a.get_unchecked(ki + 1));
55 let a2 = _mm256_set1_ps(*a.get_unchecked(ki + 2));
56 let a3 = _mm256_set1_ps(*a.get_unchecked(ki + 3));
57 let b0_base = ki * n;
58 let b1_base = b0_base + n;
59 let b2_base = b1_base + n;
60 let b3_base = b2_base + n;
61
62 let mut j = 0;
63 let b_ptr = b.as_ptr();
64 let c_ptr = c.as_mut_ptr();
65 while j < n8 {
66 let cv = _mm256_loadu_ps(c_ptr.add(j));
67 let bv0 = _mm256_loadu_ps(b_ptr.add(b0_base + j));
68 let bv1 = _mm256_loadu_ps(b_ptr.add(b1_base + j));
69 let bv2 = _mm256_loadu_ps(b_ptr.add(b2_base + j));
70 let bv3 = _mm256_loadu_ps(b_ptr.add(b3_base + j));
71
72 let r = _mm256_fmadd_ps(a0, bv0, cv);
73 let r = _mm256_fmadd_ps(a1, bv1, r);
74 let r = _mm256_fmadd_ps(a2, bv2, r);
75 let r = _mm256_fmadd_ps(a3, bv3, r);
76
77 _mm256_storeu_ps(c_ptr.add(j), r);
78 j += 8;
79 }
80
81 while j < n {
83 *c.get_unchecked_mut(j) += *a.get_unchecked(ki) * *b.get_unchecked(b0_base + j)
84 + *a.get_unchecked(ki + 1) * *b.get_unchecked(b1_base + j)
85 + *a.get_unchecked(ki + 2) * *b.get_unchecked(b2_base + j)
86 + *a.get_unchecked(ki + 3) * *b.get_unchecked(b3_base + j);
87 j += 1;
88 }
89
90 ki += 4;
91 }
92
93 while ki < k {
95 let ak = *a.get_unchecked(ki);
96 let bk_base = ki * n;
97 let ak_v = _mm256_set1_ps(ak);
98
99 let mut j = 0;
100 let b_ptr = b.as_ptr();
101 let c_ptr = c.as_mut_ptr();
102 while j < n8 {
103 let cv = _mm256_loadu_ps(c_ptr.add(j));
104 let bv = _mm256_loadu_ps(b_ptr.add(bk_base + j));
105 let r = _mm256_fmadd_ps(ak_v, bv, cv);
106 _mm256_storeu_ps(c_ptr.add(j), r);
107 j += 8;
108 }
109 while j < n {
110 *c.get_unchecked_mut(j) += ak * *b.get_unchecked(bk_base + j);
111 j += 1;
112 }
113 ki += 1;
114 }
115 }
116}
117
118#[cfg(target_arch = "x86_64")]
132#[target_feature(enable = "avx2", enable = "fma")]
133unsafe fn gemv_tiled_avx2(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
134 unsafe {
135 use std::arch::x86_64::*;
136
137 const NT: usize = 64;
140
141 let k4 = k / 4 * 4;
142 let nt_end = n / NT * NT;
143
144 for j0 in (0..nt_end).step_by(NT) {
145 let mut acc0 = _mm256_setzero_ps();
147 let mut acc1 = _mm256_setzero_ps();
148 let mut acc2 = _mm256_setzero_ps();
149 let mut acc3 = _mm256_setzero_ps();
150 let mut acc4 = _mm256_setzero_ps();
151 let mut acc5 = _mm256_setzero_ps();
152 let mut acc6 = _mm256_setzero_ps();
153 let mut acc7 = _mm256_setzero_ps();
154
155 let mut ki = 0;
157 while ki < k4 {
158 let a0 = _mm256_set1_ps(*a.get_unchecked(ki));
159 let a1 = _mm256_set1_ps(*a.get_unchecked(ki + 1));
160 let a2 = _mm256_set1_ps(*a.get_unchecked(ki + 2));
161 let a3 = _mm256_set1_ps(*a.get_unchecked(ki + 3));
162
163 let b0 = ki * n + j0;
164 let b1 = b0 + n;
165 let b2 = b1 + n;
166 let b3 = b2 + n;
167
168 if ki + 8 < k {
170 let pf = (ki + 8) * n + j0;
171 _mm_prefetch(b.as_ptr().add(pf) as *const i8, _MM_HINT_T0);
172 _mm_prefetch(b.as_ptr().add(pf + 32) as *const i8, _MM_HINT_T0);
173 }
174
175 let bv = _mm256_loadu_ps(b.get_unchecked(b0));
177 acc0 = _mm256_fmadd_ps(a0, bv, acc0);
178 let bv = _mm256_loadu_ps(b.get_unchecked(b1));
179 acc0 = _mm256_fmadd_ps(a1, bv, acc0);
180 let bv = _mm256_loadu_ps(b.get_unchecked(b2));
181 acc0 = _mm256_fmadd_ps(a2, bv, acc0);
182 let bv = _mm256_loadu_ps(b.get_unchecked(b3));
183 acc0 = _mm256_fmadd_ps(a3, bv, acc0);
184
185 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 8));
186 acc1 = _mm256_fmadd_ps(a0, bv, acc1);
187 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 8));
188 acc1 = _mm256_fmadd_ps(a1, bv, acc1);
189 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 8));
190 acc1 = _mm256_fmadd_ps(a2, bv, acc1);
191 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 8));
192 acc1 = _mm256_fmadd_ps(a3, bv, acc1);
193
194 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 16));
195 acc2 = _mm256_fmadd_ps(a0, bv, acc2);
196 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 16));
197 acc2 = _mm256_fmadd_ps(a1, bv, acc2);
198 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 16));
199 acc2 = _mm256_fmadd_ps(a2, bv, acc2);
200 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 16));
201 acc2 = _mm256_fmadd_ps(a3, bv, acc2);
202
203 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 24));
204 acc3 = _mm256_fmadd_ps(a0, bv, acc3);
205 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 24));
206 acc3 = _mm256_fmadd_ps(a1, bv, acc3);
207 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 24));
208 acc3 = _mm256_fmadd_ps(a2, bv, acc3);
209 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 24));
210 acc3 = _mm256_fmadd_ps(a3, bv, acc3);
211
212 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 32));
213 acc4 = _mm256_fmadd_ps(a0, bv, acc4);
214 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 32));
215 acc4 = _mm256_fmadd_ps(a1, bv, acc4);
216 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 32));
217 acc4 = _mm256_fmadd_ps(a2, bv, acc4);
218 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 32));
219 acc4 = _mm256_fmadd_ps(a3, bv, acc4);
220
221 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 40));
222 acc5 = _mm256_fmadd_ps(a0, bv, acc5);
223 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 40));
224 acc5 = _mm256_fmadd_ps(a1, bv, acc5);
225 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 40));
226 acc5 = _mm256_fmadd_ps(a2, bv, acc5);
227 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 40));
228 acc5 = _mm256_fmadd_ps(a3, bv, acc5);
229
230 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 48));
231 acc6 = _mm256_fmadd_ps(a0, bv, acc6);
232 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 48));
233 acc6 = _mm256_fmadd_ps(a1, bv, acc6);
234 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 48));
235 acc6 = _mm256_fmadd_ps(a2, bv, acc6);
236 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 48));
237 acc6 = _mm256_fmadd_ps(a3, bv, acc6);
238
239 let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 56));
240 acc7 = _mm256_fmadd_ps(a0, bv, acc7);
241 let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 56));
242 acc7 = _mm256_fmadd_ps(a1, bv, acc7);
243 let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 56));
244 acc7 = _mm256_fmadd_ps(a2, bv, acc7);
245 let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 56));
246 acc7 = _mm256_fmadd_ps(a3, bv, acc7);
247
248 ki += 4;
249 }
250
251 while ki < k {
253 let av = _mm256_set1_ps(*a.get_unchecked(ki));
254 let base = ki * n + j0;
255
256 acc0 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base)), acc0);
257 acc1 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 8)), acc1);
258 acc2 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 16)), acc2);
259 acc3 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 24)), acc3);
260 acc4 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 32)), acc4);
261 acc5 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 40)), acc5);
262 acc6 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 48)), acc6);
263 acc7 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 56)), acc7);
264 ki += 1;
265 }
266
267 _mm256_storeu_ps(c.get_unchecked_mut(j0), acc0);
269 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 8), acc1);
270 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 16), acc2);
271 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 24), acc3);
272 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 32), acc4);
273 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 40), acc5);
274 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 48), acc6);
275 _mm256_storeu_ps(c.get_unchecked_mut(j0 + 56), acc7);
276 }
277
278 if nt_end < n {
280 let rem_n = n - nt_end;
281 let rem8 = rem_n / 8 * 8;
282 let k4 = k / 4 * 4;
283
284 let mut ki = 0;
285 while ki < k4 {
286 let a0 = _mm256_set1_ps(*a.get_unchecked(ki));
287 let a1 = _mm256_set1_ps(*a.get_unchecked(ki + 1));
288 let a2 = _mm256_set1_ps(*a.get_unchecked(ki + 2));
289 let a3 = _mm256_set1_ps(*a.get_unchecked(ki + 3));
290 let b0 = ki * n + nt_end;
291 let b1 = b0 + n;
292 let b2 = b1 + n;
293 let b3 = b2 + n;
294
295 let mut j = 0;
296 while j < rem8 {
297 let cv = _mm256_loadu_ps(c.get_unchecked(nt_end + j));
298 let r = _mm256_fmadd_ps(a0, _mm256_loadu_ps(b.get_unchecked(b0 + j)), cv);
299 let r = _mm256_fmadd_ps(a1, _mm256_loadu_ps(b.get_unchecked(b1 + j)), r);
300 let r = _mm256_fmadd_ps(a2, _mm256_loadu_ps(b.get_unchecked(b2 + j)), r);
301 let r = _mm256_fmadd_ps(a3, _mm256_loadu_ps(b.get_unchecked(b3 + j)), r);
302 _mm256_storeu_ps(c.get_unchecked_mut(nt_end + j), r);
303 j += 8;
304 }
305 while j < rem_n {
306 let idx = nt_end + j;
307 *c.get_unchecked_mut(idx) += *a.get_unchecked(ki) * *b.get_unchecked(b0 + j)
308 + *a.get_unchecked(ki + 1) * *b.get_unchecked(b1 + j)
309 + *a.get_unchecked(ki + 2) * *b.get_unchecked(b2 + j)
310 + *a.get_unchecked(ki + 3) * *b.get_unchecked(b3 + j);
311 j += 1;
312 }
313 ki += 4;
314 }
315
316 while ki < k {
317 let ak = *a.get_unchecked(ki);
318 let bk = ki * n + nt_end;
319 let ak_v = _mm256_set1_ps(ak);
320
321 let mut j = 0;
322 while j < rem8 {
323 let cv = _mm256_loadu_ps(c.get_unchecked(nt_end + j));
324 let bv = _mm256_loadu_ps(b.get_unchecked(bk + j));
325 _mm256_storeu_ps(
326 c.get_unchecked_mut(nt_end + j),
327 _mm256_fmadd_ps(ak_v, bv, cv),
328 );
329 j += 8;
330 }
331 while j < rem_n {
332 *c.get_unchecked_mut(nt_end + j) += ak * *b.get_unchecked(bk + j);
333 j += 1;
334 }
335 ki += 1;
336 }
337 }
338 }
339}
340
341#[cfg(target_arch = "x86_64")]
350#[target_feature(enable = "avx512f", enable = "fma")]
351#[allow(dead_code)] unsafe fn gemv_tiled_avx512(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
353 unsafe {
354 use std::arch::x86_64::*;
355
356 const NT: usize = 128;
359
360 let k4 = k / 4 * 4;
361 let nt_end = n / NT * NT;
362
363 for j0 in (0..nt_end).step_by(NT) {
364 let mut acc0 = _mm512_setzero_ps();
366 let mut acc1 = _mm512_setzero_ps();
367 let mut acc2 = _mm512_setzero_ps();
368 let mut acc3 = _mm512_setzero_ps();
369 let mut acc4 = _mm512_setzero_ps();
370 let mut acc5 = _mm512_setzero_ps();
371 let mut acc6 = _mm512_setzero_ps();
372 let mut acc7 = _mm512_setzero_ps();
373
374 let mut ki = 0;
376 while ki < k4 {
377 let a0 = _mm512_set1_ps(*a.get_unchecked(ki));
378 let a1 = _mm512_set1_ps(*a.get_unchecked(ki + 1));
379 let a2 = _mm512_set1_ps(*a.get_unchecked(ki + 2));
380 let a3 = _mm512_set1_ps(*a.get_unchecked(ki + 3));
381
382 let b0 = ki * n + j0;
383 let b1 = b0 + n;
384 let b2 = b1 + n;
385 let b3 = b2 + n;
386
387 if ki + 4 < k {
389 let pf = (ki + 4) * n + j0;
390 _mm_prefetch(b.as_ptr().add(pf) as *const i8, _MM_HINT_T0);
391 _mm_prefetch(b.as_ptr().add(pf + 64) as *const i8, _MM_HINT_T0);
392 }
393
394 let bv = _mm512_loadu_ps(b.get_unchecked(b0));
396 acc0 = _mm512_fmadd_ps(a0, bv, acc0);
397 let bv = _mm512_loadu_ps(b.get_unchecked(b1));
398 acc0 = _mm512_fmadd_ps(a1, bv, acc0);
399 let bv = _mm512_loadu_ps(b.get_unchecked(b2));
400 acc0 = _mm512_fmadd_ps(a2, bv, acc0);
401 let bv = _mm512_loadu_ps(b.get_unchecked(b3));
402 acc0 = _mm512_fmadd_ps(a3, bv, acc0);
403
404 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 16));
405 acc1 = _mm512_fmadd_ps(a0, bv, acc1);
406 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 16));
407 acc1 = _mm512_fmadd_ps(a1, bv, acc1);
408 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 16));
409 acc1 = _mm512_fmadd_ps(a2, bv, acc1);
410 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 16));
411 acc1 = _mm512_fmadd_ps(a3, bv, acc1);
412
413 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 32));
414 acc2 = _mm512_fmadd_ps(a0, bv, acc2);
415 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 32));
416 acc2 = _mm512_fmadd_ps(a1, bv, acc2);
417 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 32));
418 acc2 = _mm512_fmadd_ps(a2, bv, acc2);
419 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 32));
420 acc2 = _mm512_fmadd_ps(a3, bv, acc2);
421
422 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 48));
423 acc3 = _mm512_fmadd_ps(a0, bv, acc3);
424 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 48));
425 acc3 = _mm512_fmadd_ps(a1, bv, acc3);
426 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 48));
427 acc3 = _mm512_fmadd_ps(a2, bv, acc3);
428 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 48));
429 acc3 = _mm512_fmadd_ps(a3, bv, acc3);
430
431 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 64));
432 acc4 = _mm512_fmadd_ps(a0, bv, acc4);
433 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 64));
434 acc4 = _mm512_fmadd_ps(a1, bv, acc4);
435 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 64));
436 acc4 = _mm512_fmadd_ps(a2, bv, acc4);
437 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 64));
438 acc4 = _mm512_fmadd_ps(a3, bv, acc4);
439
440 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 80));
441 acc5 = _mm512_fmadd_ps(a0, bv, acc5);
442 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 80));
443 acc5 = _mm512_fmadd_ps(a1, bv, acc5);
444 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 80));
445 acc5 = _mm512_fmadd_ps(a2, bv, acc5);
446 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 80));
447 acc5 = _mm512_fmadd_ps(a3, bv, acc5);
448
449 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 96));
450 acc6 = _mm512_fmadd_ps(a0, bv, acc6);
451 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 96));
452 acc6 = _mm512_fmadd_ps(a1, bv, acc6);
453 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 96));
454 acc6 = _mm512_fmadd_ps(a2, bv, acc6);
455 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 96));
456 acc6 = _mm512_fmadd_ps(a3, bv, acc6);
457
458 let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 112));
459 acc7 = _mm512_fmadd_ps(a0, bv, acc7);
460 let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 112));
461 acc7 = _mm512_fmadd_ps(a1, bv, acc7);
462 let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 112));
463 acc7 = _mm512_fmadd_ps(a2, bv, acc7);
464 let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 112));
465 acc7 = _mm512_fmadd_ps(a3, bv, acc7);
466
467 ki += 4;
468 }
469
470 while ki < k {
472 let av = _mm512_set1_ps(*a.get_unchecked(ki));
473 let base = ki * n + j0;
474 acc0 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base)), acc0);
475 acc1 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 16)), acc1);
476 acc2 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 32)), acc2);
477 acc3 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 48)), acc3);
478 acc4 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 64)), acc4);
479 acc5 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 80)), acc5);
480 acc6 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 96)), acc6);
481 acc7 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 112)), acc7);
482 ki += 1;
483 }
484
485 let cp = c.as_mut_ptr().add(j0);
487 _mm512_storeu_ps(cp, acc0);
488 _mm512_storeu_ps(cp.add(16), acc1);
489 _mm512_storeu_ps(cp.add(32), acc2);
490 _mm512_storeu_ps(cp.add(48), acc3);
491 _mm512_storeu_ps(cp.add(64), acc4);
492 _mm512_storeu_ps(cp.add(80), acc5);
493 _mm512_storeu_ps(cp.add(96), acc6);
494 _mm512_storeu_ps(cp.add(112), acc7);
495 }
496
497 if nt_end < n {
500 let rem = n - nt_end;
501 let rem16 = rem / 16 * 16;
502
503 for j0 in (0..rem16).step_by(16) {
505 let j = nt_end + j0;
506 let mut acc = _mm512_setzero_ps();
507 for ki in 0..k {
508 let av = _mm512_set1_ps(*a.get_unchecked(ki));
509 let bv = _mm512_loadu_ps(b.get_unchecked(ki * n + j));
510 acc = _mm512_fmadd_ps(av, bv, acc);
511 }
512 _mm512_storeu_ps(c.as_mut_ptr().add(j), acc);
513 }
514
515 for j in (nt_end + rem16)..n {
517 let mut sum = 0.0f32;
518 for ki in 0..k {
519 sum += *a.get_unchecked(ki) * *b.get_unchecked(ki * n + j);
520 }
521 *c.get_unchecked_mut(j) = sum;
522 }
523 }
524 }
525}
526
527pub fn gemv_scalar(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
529 let k4 = k / 4 * 4;
531 for ki in (0..k4).step_by(4) {
532 let a0 = a[ki];
533 let a1 = a[ki + 1];
534 let a2 = a[ki + 2];
535 let a3 = a[ki + 3];
536 let b0 = ki * n;
537 let b1 = b0 + n;
538 let b2 = b1 + n;
539 let b3 = b2 + n;
540 for j in 0..n {
541 c[j] += a0 * b[b0 + j] + a1 * b[b1 + j] + a2 * b[b2 + j] + a3 * b[b3 + j];
542 }
543 }
544
545 for ki in k4..k {
547 let a_k = a[ki];
548 let b_start = ki * n;
549 for j in 0..n {
550 c[j] += a_k * b[b_start + j];
551 }
552 }
553}
554
555pub fn gemv(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
557 contract_pre_gemv!(a, b);
558 #[cfg(target_arch = "x86_64")]
559 {
560 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
568 unsafe {
571 if n > GEMV_TILE_THRESHOLD {
572 gemv_tiled_avx2(k, n, a, b, c);
573 } else {
574 gemv_avx2(k, n, a, b, c);
575 }
576 }
577 return;
578 }
579 }
580 gemv_scalar(k, n, a, b, c);
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_gemv_basic() {
589 let a = [1.0, 2.0, 3.0];
591 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
592 let mut c = [0.0f32; 4];
593
594 gemv(3, 4, &a, &b, &mut c);
595
596 assert!((c[0] - 38.0).abs() < 1e-5);
598 assert!((c[1] - 44.0).abs() < 1e-5);
599 assert!((c[2] - 50.0).abs() < 1e-5);
600 assert!((c[3] - 56.0).abs() < 1e-5);
601 }
602
603 #[test]
604 fn test_gemv_identity_row_select() {
605 let a = [0.0, 1.0, 0.0];
607 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
608 let mut c = [0.0f32; 3];
609
610 gemv(3, 3, &a, &b, &mut c);
611
612 assert!((c[0] - 4.0).abs() < 1e-5);
613 assert!((c[1] - 5.0).abs() < 1e-5);
614 assert!((c[2] - 6.0).abs() < 1e-5);
615 }
616
617 #[test]
618 fn test_gemv_large_n() {
619 let k = 2;
621 let n = 17;
622 let a = [1.0f32, 2.0];
623 let b: Vec<f32> = (0..k * n).map(|i| i as f32).collect();
624 let mut c = vec![0.0f32; n];
625
626 gemv(k, n, &a, &b, &mut c);
627
628 for j in 0..n {
630 let expected = a[0] * b[j] + a[1] * b[n + j];
631 assert!((c[j] - expected).abs() < 1e-4, "c[{j}] = {} expected {expected}", c[j]);
632 }
633 }
634
635 #[test]
636 fn test_gemv_zeros() {
637 let a = [0.0f32; 4];
638 let b = vec![1.0f32; 4 * 8];
639 let mut c = vec![0.0f32; 8];
640
641 gemv(4, 8, &a, &b, &mut c);
642
643 for j in 0..8 {
644 assert!((c[j]).abs() < 1e-10);
645 }
646 }
647
648 #[test]
650 fn test_gemv_tiled_large_n() {
651 let k = 64;
652 let n = 8192; let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
655 let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
656 let mut c_tiled = vec![0.0f32; n];
657 let mut c_scalar = vec![0.0f32; n];
658
659 gemv(k, n, &a, &b, &mut c_tiled);
660 gemv_scalar(k, n, &a, &b, &mut c_scalar);
661
662 for j in 0..n {
663 let diff = (c_tiled[j] - c_scalar[j]).abs();
664 assert!(diff < 1e-2, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
665 }
666 }
667
668 #[test]
670 fn test_gemv_tiled_llm_size() {
671 let k = 256; let n = 11008;
673
674 let a: Vec<f32> = (0..k).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
675 let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
676 let mut c_tiled = vec![0.0f32; n];
677 let mut c_scalar = vec![0.0f32; n];
678
679 gemv(k, n, &a, &b, &mut c_tiled);
680 gemv_scalar(k, n, &a, &b, &mut c_scalar);
681
682 for j in 0..n {
683 let diff = (c_tiled[j] - c_scalar[j]).abs();
684 assert!(diff < 1e-1, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
685 }
686 }
687
688 #[test]
690 fn test_gemv_tiled_remainder() {
691 let k = 32;
692 let n = 5000; let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
695 let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
696 let mut c_tiled = vec![0.0f32; n];
697 let mut c_scalar = vec![0.0f32; n];
698
699 gemv(k, n, &a, &b, &mut c_tiled);
700 gemv_scalar(k, n, &a, &b, &mut c_scalar);
701
702 for j in 0..n {
703 let diff = (c_tiled[j] - c_scalar[j]).abs();
704 assert!(diff < 1e-2, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
705 }
706 }
707
708 #[test]
710 fn test_gemv_tiled_k_remainder() {
711 let k = 67; let n = 8192;
713
714 let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
715 let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
716 let mut c_tiled = vec![0.0f32; n];
717 let mut c_scalar = vec![0.0f32; n];
718
719 gemv(k, n, &a, &b, &mut c_tiled);
720 gemv_scalar(k, n, &a, &b, &mut c_scalar);
721
722 for j in 0..n {
723 let diff = (c_tiled[j] - c_scalar[j]).abs();
724 assert!(diff < 1e-2, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
725 }
726 }
727
728 #[test]
731 fn test_gemv_avx512_attention_size() {
732 let k = 128;
733 let n = 512;
734
735 let a: Vec<f32> = (0..k).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
736 let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
737 let mut c_gemv = vec![0.0f32; n];
738 let mut c_scalar = vec![0.0f32; n];
739
740 gemv(k, n, &a, &b, &mut c_gemv);
741 gemv_scalar(k, n, &a, &b, &mut c_scalar);
742
743 let max_diff =
744 c_gemv.iter().zip(c_scalar.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
745 assert!(max_diff < 1e-2, "FALSIFY-AVX512-GEMV-001: max diff {max_diff}");
746 }
747
748 #[test]
751 fn test_gemv_avx512_remainder() {
752 let k = 128;
753 let n = 300; let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
756 let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
757 let mut c_gemv = vec![0.0f32; n];
758 let mut c_scalar = vec![0.0f32; n];
759
760 gemv(k, n, &a, &b, &mut c_gemv);
761 gemv_scalar(k, n, &a, &b, &mut c_scalar);
762
763 let max_diff =
764 c_gemv.iter().zip(c_scalar.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
765 assert!(max_diff < 1e-2, "FALSIFY-AVX512-GEMV-002: max diff {max_diff}");
766 }
767}