Skip to main content

trueno/blis/
gemv.rs

1//! SIMD-accelerated GEMV (General Matrix-Vector Multiply)
2//!
3//! Specialized kernel for M=1 matrix-vector product: c = a × B
4//! where a is 1×K and B is K×N, both row-major.
5//!
6//! This bypasses the BLIS 5-loop packing overhead which dominates for M=1.
7//! Instead, uses direct AVX2 VFMADD on unpacked row-major data.
8//!
9//! # Algorithm
10//!
11//! Two strategies based on N:
12//!
13//! - **Small N (≤ 4096)**: Axpy pattern — outer K, inner N. c[] fits in L1.
14//! - **Large N (> 4096)**: N-tiled — outer N-tiles (64), inner K. c[] stays
15//!   in YMM registers for all K iterations, eliminating L1 thrashing.
16//!
17//! # References
18//!
19//! - GH-380: matvec (M=1) performance gap vs ndarray
20
21/// Threshold: when N > this, switch to tiled GEMV kernel.
22/// Raised from 4096 → 8192 (2026-04-05): tiled kernel has strided B access
23/// (stride=N*4 bytes between rows) which is TLB-unfriendly at large N.
24/// Measured: vecmat 4096×4096: tiled 9.3 GFLOPS vs axpy predicts better.
25/// 4096 path benchmarks to use axpy. c[] still fits L1 at N=8192 (32KB).
26const GEMV_TILE_THRESHOLD: usize = 8192;
27
28/// AVX2 GEMV using axpy pattern: c += a[k] * B[k,:] for each k
29///
30/// Outer loop over K (4-way unrolled), inner loop over N with AVX2 VFMADD.
31/// This matches row-major B access: B[k,:] is contiguous → sequential reads.
32///
33/// Best for small N where c[] fits in L1 cache.
34///
35/// # Safety
36///
37/// Requires AVX2+FMA CPU features. Caller must ensure:
38/// - `a` has length >= `k`
39/// - `b` has length >= `k * n`
40/// - `c` has length >= `n`
41#[cfg(target_arch = "x86_64")]
42#[target_feature(enable = "avx2", enable = "fma")]
43pub unsafe fn gemv_avx2(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
44    unsafe {
45        use std::arch::x86_64::*;
46
47        let n8 = n / 8 * 8;
48
49        // 4-way K-unrolled axpy with AVX2 VFMADD on inner N loop
50        let k4 = k / 4 * 4;
51        let mut ki = 0;
52        while ki < k4 {
53            let a0 = _mm256_set1_ps(*a.get_unchecked(ki));
54            let a1 = _mm256_set1_ps(*a.get_unchecked(ki + 1));
55            let a2 = _mm256_set1_ps(*a.get_unchecked(ki + 2));
56            let a3 = _mm256_set1_ps(*a.get_unchecked(ki + 3));
57            let b0_base = ki * n;
58            let b1_base = b0_base + n;
59            let b2_base = b1_base + n;
60            let b3_base = b2_base + n;
61
62            let mut j = 0;
63            let b_ptr = b.as_ptr();
64            let c_ptr = c.as_mut_ptr();
65            while j < n8 {
66                let cv = _mm256_loadu_ps(c_ptr.add(j));
67                let bv0 = _mm256_loadu_ps(b_ptr.add(b0_base + j));
68                let bv1 = _mm256_loadu_ps(b_ptr.add(b1_base + j));
69                let bv2 = _mm256_loadu_ps(b_ptr.add(b2_base + j));
70                let bv3 = _mm256_loadu_ps(b_ptr.add(b3_base + j));
71
72                let r = _mm256_fmadd_ps(a0, bv0, cv);
73                let r = _mm256_fmadd_ps(a1, bv1, r);
74                let r = _mm256_fmadd_ps(a2, bv2, r);
75                let r = _mm256_fmadd_ps(a3, bv3, r);
76
77                _mm256_storeu_ps(c_ptr.add(j), r);
78                j += 8;
79            }
80
81            // Scalar remainder for N % 8
82            while j < n {
83                *c.get_unchecked_mut(j) += *a.get_unchecked(ki) * *b.get_unchecked(b0_base + j)
84                    + *a.get_unchecked(ki + 1) * *b.get_unchecked(b1_base + j)
85                    + *a.get_unchecked(ki + 2) * *b.get_unchecked(b2_base + j)
86                    + *a.get_unchecked(ki + 3) * *b.get_unchecked(b3_base + j);
87                j += 1;
88            }
89
90            ki += 4;
91        }
92
93        // Remainder K (scalar axpy)
94        while ki < k {
95            let ak = *a.get_unchecked(ki);
96            let bk_base = ki * n;
97            let ak_v = _mm256_set1_ps(ak);
98
99            let mut j = 0;
100            let b_ptr = b.as_ptr();
101            let c_ptr = c.as_mut_ptr();
102            while j < n8 {
103                let cv = _mm256_loadu_ps(c_ptr.add(j));
104                let bv = _mm256_loadu_ps(b_ptr.add(bk_base + j));
105                let r = _mm256_fmadd_ps(ak_v, bv, cv);
106                _mm256_storeu_ps(c_ptr.add(j), r);
107                j += 8;
108            }
109            while j < n {
110                *c.get_unchecked_mut(j) += ak * *b.get_unchecked(bk_base + j);
111                j += 1;
112            }
113            ki += 1;
114        }
115    }
116}
117
118/// AVX2 GEMV with N-dimension tiling for bandwidth-bound sizes.
119///
120/// Tiles the N dimension into strips of 64, keeping the c[] accumulator
121/// in 8 YMM registers for ALL K iterations. This eliminates the repeated
122/// L1 load/store of c[] that dominates the axpy pattern when N > L1.
123///
124/// For 4096×11008: original axpy does 1024 load-store sweeps of c[] (43KB).
125/// Tiled: each c[j0..j0+64] is loaded 0 times (initialized in registers)
126/// and stored once at the end. Saves ~88MB of c[] traffic.
127///
128/// # Safety
129///
130/// Requires AVX2+FMA CPU features.
131#[cfg(target_arch = "x86_64")]
132#[target_feature(enable = "avx2", enable = "fma")]
133unsafe fn gemv_tiled_avx2(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
134    unsafe {
135        use std::arch::x86_64::*;
136
137        // NT=64: 8 YMM accumulators × 8 f32 = 64 elements.
138        // 4 registers for broadcast a, 1-2 for B loads = ~14 registers total.
139        const NT: usize = 64;
140
141        let k4 = k / 4 * 4;
142        let nt_end = n / NT * NT;
143
144        for j0 in (0..nt_end).step_by(NT) {
145            // 8 YMM accumulators — stay in registers for ALL K iterations
146            let mut acc0 = _mm256_setzero_ps();
147            let mut acc1 = _mm256_setzero_ps();
148            let mut acc2 = _mm256_setzero_ps();
149            let mut acc3 = _mm256_setzero_ps();
150            let mut acc4 = _mm256_setzero_ps();
151            let mut acc5 = _mm256_setzero_ps();
152            let mut acc6 = _mm256_setzero_ps();
153            let mut acc7 = _mm256_setzero_ps();
154
155            // Process ALL K for this N-tile (4-way unrolled)
156            let mut ki = 0;
157            while ki < k4 {
158                let a0 = _mm256_set1_ps(*a.get_unchecked(ki));
159                let a1 = _mm256_set1_ps(*a.get_unchecked(ki + 1));
160                let a2 = _mm256_set1_ps(*a.get_unchecked(ki + 2));
161                let a3 = _mm256_set1_ps(*a.get_unchecked(ki + 3));
162
163                let b0 = ki * n + j0;
164                let b1 = b0 + n;
165                let b2 = b1 + n;
166                let b3 = b2 + n;
167
168                // Software prefetch: B rows 8 iterations ahead
169                if ki + 8 < k {
170                    let pf = (ki + 8) * n + j0;
171                    _mm_prefetch(b.as_ptr().add(pf) as *const i8, _MM_HINT_T0);
172                    _mm_prefetch(b.as_ptr().add(pf + 32) as *const i8, _MM_HINT_T0);
173                }
174
175                // 8 chunks × 4 K iterations = 32 FMAs
176                let bv = _mm256_loadu_ps(b.get_unchecked(b0));
177                acc0 = _mm256_fmadd_ps(a0, bv, acc0);
178                let bv = _mm256_loadu_ps(b.get_unchecked(b1));
179                acc0 = _mm256_fmadd_ps(a1, bv, acc0);
180                let bv = _mm256_loadu_ps(b.get_unchecked(b2));
181                acc0 = _mm256_fmadd_ps(a2, bv, acc0);
182                let bv = _mm256_loadu_ps(b.get_unchecked(b3));
183                acc0 = _mm256_fmadd_ps(a3, bv, acc0);
184
185                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 8));
186                acc1 = _mm256_fmadd_ps(a0, bv, acc1);
187                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 8));
188                acc1 = _mm256_fmadd_ps(a1, bv, acc1);
189                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 8));
190                acc1 = _mm256_fmadd_ps(a2, bv, acc1);
191                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 8));
192                acc1 = _mm256_fmadd_ps(a3, bv, acc1);
193
194                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 16));
195                acc2 = _mm256_fmadd_ps(a0, bv, acc2);
196                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 16));
197                acc2 = _mm256_fmadd_ps(a1, bv, acc2);
198                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 16));
199                acc2 = _mm256_fmadd_ps(a2, bv, acc2);
200                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 16));
201                acc2 = _mm256_fmadd_ps(a3, bv, acc2);
202
203                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 24));
204                acc3 = _mm256_fmadd_ps(a0, bv, acc3);
205                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 24));
206                acc3 = _mm256_fmadd_ps(a1, bv, acc3);
207                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 24));
208                acc3 = _mm256_fmadd_ps(a2, bv, acc3);
209                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 24));
210                acc3 = _mm256_fmadd_ps(a3, bv, acc3);
211
212                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 32));
213                acc4 = _mm256_fmadd_ps(a0, bv, acc4);
214                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 32));
215                acc4 = _mm256_fmadd_ps(a1, bv, acc4);
216                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 32));
217                acc4 = _mm256_fmadd_ps(a2, bv, acc4);
218                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 32));
219                acc4 = _mm256_fmadd_ps(a3, bv, acc4);
220
221                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 40));
222                acc5 = _mm256_fmadd_ps(a0, bv, acc5);
223                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 40));
224                acc5 = _mm256_fmadd_ps(a1, bv, acc5);
225                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 40));
226                acc5 = _mm256_fmadd_ps(a2, bv, acc5);
227                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 40));
228                acc5 = _mm256_fmadd_ps(a3, bv, acc5);
229
230                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 48));
231                acc6 = _mm256_fmadd_ps(a0, bv, acc6);
232                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 48));
233                acc6 = _mm256_fmadd_ps(a1, bv, acc6);
234                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 48));
235                acc6 = _mm256_fmadd_ps(a2, bv, acc6);
236                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 48));
237                acc6 = _mm256_fmadd_ps(a3, bv, acc6);
238
239                let bv = _mm256_loadu_ps(b.get_unchecked(b0 + 56));
240                acc7 = _mm256_fmadd_ps(a0, bv, acc7);
241                let bv = _mm256_loadu_ps(b.get_unchecked(b1 + 56));
242                acc7 = _mm256_fmadd_ps(a1, bv, acc7);
243                let bv = _mm256_loadu_ps(b.get_unchecked(b2 + 56));
244                acc7 = _mm256_fmadd_ps(a2, bv, acc7);
245                let bv = _mm256_loadu_ps(b.get_unchecked(b3 + 56));
246                acc7 = _mm256_fmadd_ps(a3, bv, acc7);
247
248                ki += 4;
249            }
250
251            // Remainder K (1 at a time)
252            while ki < k {
253                let av = _mm256_set1_ps(*a.get_unchecked(ki));
254                let base = ki * n + j0;
255
256                acc0 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base)), acc0);
257                acc1 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 8)), acc1);
258                acc2 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 16)), acc2);
259                acc3 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 24)), acc3);
260                acc4 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 32)), acc4);
261                acc5 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 40)), acc5);
262                acc6 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 48)), acc6);
263                acc7 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b.get_unchecked(base + 56)), acc7);
264                ki += 1;
265            }
266
267            // Store accumulators (one store per tile, not K/4 stores)
268            _mm256_storeu_ps(c.get_unchecked_mut(j0), acc0);
269            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 8), acc1);
270            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 16), acc2);
271            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 24), acc3);
272            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 32), acc4);
273            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 40), acc5);
274            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 48), acc6);
275            _mm256_storeu_ps(c.get_unchecked_mut(j0 + 56), acc7);
276        }
277
278        // Remainder N (< 64 elements) — axpy is fine since c fits in L1
279        if nt_end < n {
280            let rem_n = n - nt_end;
281            let rem8 = rem_n / 8 * 8;
282            let k4 = k / 4 * 4;
283
284            let mut ki = 0;
285            while ki < k4 {
286                let a0 = _mm256_set1_ps(*a.get_unchecked(ki));
287                let a1 = _mm256_set1_ps(*a.get_unchecked(ki + 1));
288                let a2 = _mm256_set1_ps(*a.get_unchecked(ki + 2));
289                let a3 = _mm256_set1_ps(*a.get_unchecked(ki + 3));
290                let b0 = ki * n + nt_end;
291                let b1 = b0 + n;
292                let b2 = b1 + n;
293                let b3 = b2 + n;
294
295                let mut j = 0;
296                while j < rem8 {
297                    let cv = _mm256_loadu_ps(c.get_unchecked(nt_end + j));
298                    let r = _mm256_fmadd_ps(a0, _mm256_loadu_ps(b.get_unchecked(b0 + j)), cv);
299                    let r = _mm256_fmadd_ps(a1, _mm256_loadu_ps(b.get_unchecked(b1 + j)), r);
300                    let r = _mm256_fmadd_ps(a2, _mm256_loadu_ps(b.get_unchecked(b2 + j)), r);
301                    let r = _mm256_fmadd_ps(a3, _mm256_loadu_ps(b.get_unchecked(b3 + j)), r);
302                    _mm256_storeu_ps(c.get_unchecked_mut(nt_end + j), r);
303                    j += 8;
304                }
305                while j < rem_n {
306                    let idx = nt_end + j;
307                    *c.get_unchecked_mut(idx) += *a.get_unchecked(ki) * *b.get_unchecked(b0 + j)
308                        + *a.get_unchecked(ki + 1) * *b.get_unchecked(b1 + j)
309                        + *a.get_unchecked(ki + 2) * *b.get_unchecked(b2 + j)
310                        + *a.get_unchecked(ki + 3) * *b.get_unchecked(b3 + j);
311                    j += 1;
312                }
313                ki += 4;
314            }
315
316            while ki < k {
317                let ak = *a.get_unchecked(ki);
318                let bk = ki * n + nt_end;
319                let ak_v = _mm256_set1_ps(ak);
320
321                let mut j = 0;
322                while j < rem8 {
323                    let cv = _mm256_loadu_ps(c.get_unchecked(nt_end + j));
324                    let bv = _mm256_loadu_ps(b.get_unchecked(bk + j));
325                    _mm256_storeu_ps(
326                        c.get_unchecked_mut(nt_end + j),
327                        _mm256_fmadd_ps(ak_v, bv, cv),
328                    );
329                    j += 8;
330                }
331                while j < rem_n {
332                    *c.get_unchecked_mut(nt_end + j) += ak * *b.get_unchecked(bk + j);
333                    j += 1;
334                }
335                ki += 1;
336            }
337        }
338    }
339}
340
341/// AVX-512 GEMV with N-dimension tiling — 2× throughput vs AVX2.
342///
343/// NT=128: 8 ZMM accumulators × 16 f32 = 128 elements per tile.
344/// 4-way K-unrolled: 32 FMAs per tile per iteration.
345/// Software prefetch: B rows 4 iterations ahead.
346///
347/// For attention scoring (Q @ K_cache^T): head_dim=128, seq_len varies.
348/// This is the hottest path in LLM inference (44.3% of compute).
349#[cfg(target_arch = "x86_64")]
350#[target_feature(enable = "avx512f", enable = "fma")]
351#[allow(dead_code)] // Retained for Intel SPR (no AVX-512 throttle). See negative result above.
352unsafe fn gemv_tiled_avx512(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
353    unsafe {
354        use std::arch::x86_64::*;
355
356        // NT=128: 8 ZMM accumulators × 16 f32 = 128 elements.
357        // Fits in 8 of 32 ZMM registers. 4 for A broadcasts + 1 for B load = 13 total.
358        const NT: usize = 128;
359
360        let k4 = k / 4 * 4;
361        let nt_end = n / NT * NT;
362
363        for j0 in (0..nt_end).step_by(NT) {
364            // 8 ZMM accumulators — stay in registers for ALL K iterations
365            let mut acc0 = _mm512_setzero_ps();
366            let mut acc1 = _mm512_setzero_ps();
367            let mut acc2 = _mm512_setzero_ps();
368            let mut acc3 = _mm512_setzero_ps();
369            let mut acc4 = _mm512_setzero_ps();
370            let mut acc5 = _mm512_setzero_ps();
371            let mut acc6 = _mm512_setzero_ps();
372            let mut acc7 = _mm512_setzero_ps();
373
374            // Process ALL K for this N-tile (4-way unrolled)
375            let mut ki = 0;
376            while ki < k4 {
377                let a0 = _mm512_set1_ps(*a.get_unchecked(ki));
378                let a1 = _mm512_set1_ps(*a.get_unchecked(ki + 1));
379                let a2 = _mm512_set1_ps(*a.get_unchecked(ki + 2));
380                let a3 = _mm512_set1_ps(*a.get_unchecked(ki + 3));
381
382                let b0 = ki * n + j0;
383                let b1 = b0 + n;
384                let b2 = b1 + n;
385                let b3 = b2 + n;
386
387                // Prefetch B rows 4 iterations ahead
388                if ki + 4 < k {
389                    let pf = (ki + 4) * n + j0;
390                    _mm_prefetch(b.as_ptr().add(pf) as *const i8, _MM_HINT_T0);
391                    _mm_prefetch(b.as_ptr().add(pf + 64) as *const i8, _MM_HINT_T0);
392                }
393
394                // 8 chunks × 4 K iterations = 32 FMAs
395                let bv = _mm512_loadu_ps(b.get_unchecked(b0));
396                acc0 = _mm512_fmadd_ps(a0, bv, acc0);
397                let bv = _mm512_loadu_ps(b.get_unchecked(b1));
398                acc0 = _mm512_fmadd_ps(a1, bv, acc0);
399                let bv = _mm512_loadu_ps(b.get_unchecked(b2));
400                acc0 = _mm512_fmadd_ps(a2, bv, acc0);
401                let bv = _mm512_loadu_ps(b.get_unchecked(b3));
402                acc0 = _mm512_fmadd_ps(a3, bv, acc0);
403
404                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 16));
405                acc1 = _mm512_fmadd_ps(a0, bv, acc1);
406                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 16));
407                acc1 = _mm512_fmadd_ps(a1, bv, acc1);
408                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 16));
409                acc1 = _mm512_fmadd_ps(a2, bv, acc1);
410                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 16));
411                acc1 = _mm512_fmadd_ps(a3, bv, acc1);
412
413                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 32));
414                acc2 = _mm512_fmadd_ps(a0, bv, acc2);
415                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 32));
416                acc2 = _mm512_fmadd_ps(a1, bv, acc2);
417                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 32));
418                acc2 = _mm512_fmadd_ps(a2, bv, acc2);
419                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 32));
420                acc2 = _mm512_fmadd_ps(a3, bv, acc2);
421
422                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 48));
423                acc3 = _mm512_fmadd_ps(a0, bv, acc3);
424                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 48));
425                acc3 = _mm512_fmadd_ps(a1, bv, acc3);
426                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 48));
427                acc3 = _mm512_fmadd_ps(a2, bv, acc3);
428                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 48));
429                acc3 = _mm512_fmadd_ps(a3, bv, acc3);
430
431                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 64));
432                acc4 = _mm512_fmadd_ps(a0, bv, acc4);
433                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 64));
434                acc4 = _mm512_fmadd_ps(a1, bv, acc4);
435                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 64));
436                acc4 = _mm512_fmadd_ps(a2, bv, acc4);
437                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 64));
438                acc4 = _mm512_fmadd_ps(a3, bv, acc4);
439
440                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 80));
441                acc5 = _mm512_fmadd_ps(a0, bv, acc5);
442                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 80));
443                acc5 = _mm512_fmadd_ps(a1, bv, acc5);
444                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 80));
445                acc5 = _mm512_fmadd_ps(a2, bv, acc5);
446                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 80));
447                acc5 = _mm512_fmadd_ps(a3, bv, acc5);
448
449                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 96));
450                acc6 = _mm512_fmadd_ps(a0, bv, acc6);
451                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 96));
452                acc6 = _mm512_fmadd_ps(a1, bv, acc6);
453                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 96));
454                acc6 = _mm512_fmadd_ps(a2, bv, acc6);
455                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 96));
456                acc6 = _mm512_fmadd_ps(a3, bv, acc6);
457
458                let bv = _mm512_loadu_ps(b.get_unchecked(b0 + 112));
459                acc7 = _mm512_fmadd_ps(a0, bv, acc7);
460                let bv = _mm512_loadu_ps(b.get_unchecked(b1 + 112));
461                acc7 = _mm512_fmadd_ps(a1, bv, acc7);
462                let bv = _mm512_loadu_ps(b.get_unchecked(b2 + 112));
463                acc7 = _mm512_fmadd_ps(a2, bv, acc7);
464                let bv = _mm512_loadu_ps(b.get_unchecked(b3 + 112));
465                acc7 = _mm512_fmadd_ps(a3, bv, acc7);
466
467                ki += 4;
468            }
469
470            // Remainder K (no unroll)
471            while ki < k {
472                let av = _mm512_set1_ps(*a.get_unchecked(ki));
473                let base = ki * n + j0;
474                acc0 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base)), acc0);
475                acc1 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 16)), acc1);
476                acc2 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 32)), acc2);
477                acc3 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 48)), acc3);
478                acc4 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 64)), acc4);
479                acc5 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 80)), acc5);
480                acc6 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 96)), acc6);
481                acc7 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b.get_unchecked(base + 112)), acc7);
482                ki += 1;
483            }
484
485            // Store accumulators
486            let cp = c.as_mut_ptr().add(j0);
487            _mm512_storeu_ps(cp, acc0);
488            _mm512_storeu_ps(cp.add(16), acc1);
489            _mm512_storeu_ps(cp.add(32), acc2);
490            _mm512_storeu_ps(cp.add(48), acc3);
491            _mm512_storeu_ps(cp.add(64), acc4);
492            _mm512_storeu_ps(cp.add(80), acc5);
493            _mm512_storeu_ps(cp.add(96), acc6);
494            _mm512_storeu_ps(cp.add(112), acc7);
495        }
496
497        // Remainder N (< 128 elements) — process with AVX-512 individual zmm loads
498        // and scalar fallback for < 16 elements. No allocation needed.
499        if nt_end < n {
500            let rem = n - nt_end;
501            let rem16 = rem / 16 * 16;
502
503            // Process 16-wide chunks
504            for j0 in (0..rem16).step_by(16) {
505                let j = nt_end + j0;
506                let mut acc = _mm512_setzero_ps();
507                for ki in 0..k {
508                    let av = _mm512_set1_ps(*a.get_unchecked(ki));
509                    let bv = _mm512_loadu_ps(b.get_unchecked(ki * n + j));
510                    acc = _mm512_fmadd_ps(av, bv, acc);
511                }
512                _mm512_storeu_ps(c.as_mut_ptr().add(j), acc);
513            }
514
515            // Scalar remainder (< 16 elements)
516            for j in (nt_end + rem16)..n {
517                let mut sum = 0.0f32;
518                for ki in 0..k {
519                    sum += *a.get_unchecked(ki) * *b.get_unchecked(ki * n + j);
520                }
521                *c.get_unchecked_mut(j) = sum;
522            }
523        }
524    }
525}
526
527/// Scalar fallback GEMV for non-x86 or non-AVX2 platforms
528pub fn gemv_scalar(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
529    // 4-way K-unrolled axpy (auto-vectorizable)
530    let k4 = k / 4 * 4;
531    for ki in (0..k4).step_by(4) {
532        let a0 = a[ki];
533        let a1 = a[ki + 1];
534        let a2 = a[ki + 2];
535        let a3 = a[ki + 3];
536        let b0 = ki * n;
537        let b1 = b0 + n;
538        let b2 = b1 + n;
539        let b3 = b2 + n;
540        for j in 0..n {
541            c[j] += a0 * b[b0 + j] + a1 * b[b1 + j] + a2 * b[b2 + j] + a3 * b[b3 + j];
542        }
543    }
544
545    // Remainder K
546    for ki in k4..k {
547        let a_k = a[ki];
548        let b_start = ki * n;
549        for j in 0..n {
550            c[j] += a_k * b[b_start + j];
551        }
552    }
553}
554
555/// Dispatch GEMV to best available implementation
556pub fn gemv(k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
557    contract_pre_gemv!(a, b);
558    #[cfg(target_arch = "x86_64")]
559    {
560        // NEGATIVE RESULT (2026-04-05): AVX-512 GEMV is slower than AVX2 at ALL sizes.
561        // GEMV is bandwidth-bound (not compute-bound like GEMM). Zen 4 reduces clock
562        // ~10-15% during AVX-512 ops, and the wider SIMD can't compensate since the
563        // bottleneck is DRAM bandwidth.
564        // Measured: 128×512 AVX2=74.7 vs 512=61.6, 4096×4096 AVX2=16.3 vs 512=10.2.
565        // AVX-512 GEMV disabled. AVX2 GEMV remains the optimal path.
566        // gemv_tiled_avx512() retained for future Intel SPR (no AVX-512 throttle).
567        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
568            // SAFETY: AVX2+FMA verified by feature detection above.
569            // Slice bounds are checked by the caller (matmul_vector_matrix).
570            unsafe {
571                if n > GEMV_TILE_THRESHOLD {
572                    gemv_tiled_avx2(k, n, a, b, c);
573                } else {
574                    gemv_avx2(k, n, a, b, c);
575                }
576            }
577            return;
578        }
579    }
580    gemv_scalar(k, n, a, b, c);
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_gemv_basic() {
589        // 1×3 @ 3×4 → 1×4
590        let a = [1.0, 2.0, 3.0];
591        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
592        let mut c = [0.0f32; 4];
593
594        gemv(3, 4, &a, &b, &mut c);
595
596        // c[j] = 1*B[0,j] + 2*B[1,j] + 3*B[2,j]
597        assert!((c[0] - 38.0).abs() < 1e-5);
598        assert!((c[1] - 44.0).abs() < 1e-5);
599        assert!((c[2] - 50.0).abs() < 1e-5);
600        assert!((c[3] - 56.0).abs() < 1e-5);
601    }
602
603    #[test]
604    fn test_gemv_identity_row_select() {
605        // e_1 @ B should give B[1,:]
606        let a = [0.0, 1.0, 0.0];
607        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
608        let mut c = [0.0f32; 3];
609
610        gemv(3, 3, &a, &b, &mut c);
611
612        assert!((c[0] - 4.0).abs() < 1e-5);
613        assert!((c[1] - 5.0).abs() < 1e-5);
614        assert!((c[2] - 6.0).abs() < 1e-5);
615    }
616
617    #[test]
618    fn test_gemv_large_n() {
619        // K=2, N=17 (tests AVX2 8-element chunks + scalar remainder)
620        let k = 2;
621        let n = 17;
622        let a = [1.0f32, 2.0];
623        let b: Vec<f32> = (0..k * n).map(|i| i as f32).collect();
624        let mut c = vec![0.0f32; n];
625
626        gemv(k, n, &a, &b, &mut c);
627
628        // Verify against scalar reference
629        for j in 0..n {
630            let expected = a[0] * b[j] + a[1] * b[n + j];
631            assert!((c[j] - expected).abs() < 1e-4, "c[{j}] = {} expected {expected}", c[j]);
632        }
633    }
634
635    #[test]
636    fn test_gemv_zeros() {
637        let a = [0.0f32; 4];
638        let b = vec![1.0f32; 4 * 8];
639        let mut c = vec![0.0f32; 8];
640
641        gemv(4, 8, &a, &b, &mut c);
642
643        for j in 0..8 {
644            assert!((c[j]).abs() < 1e-10);
645        }
646    }
647
648    /// Test tiled path: N > GEMV_TILE_THRESHOLD triggers tiled kernel
649    #[test]
650    fn test_gemv_tiled_large_n() {
651        let k = 64;
652        let n = 8192; // > 4096 → tiled path
653
654        let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
655        let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
656        let mut c_tiled = vec![0.0f32; n];
657        let mut c_scalar = vec![0.0f32; n];
658
659        gemv(k, n, &a, &b, &mut c_tiled);
660        gemv_scalar(k, n, &a, &b, &mut c_scalar);
661
662        for j in 0..n {
663            let diff = (c_tiled[j] - c_scalar[j]).abs();
664            assert!(diff < 1e-2, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
665        }
666    }
667
668    /// Test tiled path with LLM-size dimensions
669    #[test]
670    fn test_gemv_tiled_llm_size() {
671        let k = 256; // reduced from 4096 for test speed
672        let n = 11008;
673
674        let a: Vec<f32> = (0..k).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
675        let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
676        let mut c_tiled = vec![0.0f32; n];
677        let mut c_scalar = vec![0.0f32; n];
678
679        gemv(k, n, &a, &b, &mut c_tiled);
680        gemv_scalar(k, n, &a, &b, &mut c_scalar);
681
682        for j in 0..n {
683            let diff = (c_tiled[j] - c_scalar[j]).abs();
684            assert!(diff < 1e-1, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
685        }
686    }
687
688    /// Test tiled path with N not a multiple of 64 (exercises remainder)
689    #[test]
690    fn test_gemv_tiled_remainder() {
691        let k = 32;
692        let n = 5000; // > 4096, not multiple of 64 → remainder = 5000 - 4992 = 8
693
694        let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
695        let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
696        let mut c_tiled = vec![0.0f32; n];
697        let mut c_scalar = vec![0.0f32; n];
698
699        gemv(k, n, &a, &b, &mut c_tiled);
700        gemv_scalar(k, n, &a, &b, &mut c_scalar);
701
702        for j in 0..n {
703            let diff = (c_tiled[j] - c_scalar[j]).abs();
704            assert!(diff < 1e-2, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
705        }
706    }
707
708    /// Test tiled path with non-multiple-of-4 K (exercises K remainder)
709    #[test]
710    fn test_gemv_tiled_k_remainder() {
711        let k = 67; // not multiple of 4
712        let n = 8192;
713
714        let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
715        let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
716        let mut c_tiled = vec![0.0f32; n];
717        let mut c_scalar = vec![0.0f32; n];
718
719        gemv(k, n, &a, &b, &mut c_tiled);
720        gemv_scalar(k, n, &a, &b, &mut c_scalar);
721
722        for j in 0..n {
723            let diff = (c_tiled[j] - c_scalar[j]).abs();
724            assert!(diff < 1e-2, "j={j}: tiled={} scalar={} diff={diff}", c_tiled[j], c_scalar[j]);
725        }
726    }
727
728    /// FALSIFY-AVX512-GEMV-001: AVX-512 GEMV matches scalar for attention-size dims.
729    /// k=128 (head_dim), n=512 (seq_len) — exercises AVX-512 path (n >= 128).
730    #[test]
731    fn test_gemv_avx512_attention_size() {
732        let k = 128;
733        let n = 512;
734
735        let a: Vec<f32> = (0..k).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
736        let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
737        let mut c_gemv = vec![0.0f32; n];
738        let mut c_scalar = vec![0.0f32; n];
739
740        gemv(k, n, &a, &b, &mut c_gemv);
741        gemv_scalar(k, n, &a, &b, &mut c_scalar);
742
743        let max_diff =
744            c_gemv.iter().zip(c_scalar.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
745        assert!(max_diff < 1e-2, "FALSIFY-AVX512-GEMV-001: max diff {max_diff}");
746    }
747
748    /// FALSIFY-AVX512-GEMV-002: AVX-512 GEMV with N not a multiple of 128.
749    /// Exercises the 16-wide remainder and scalar tail paths.
750    #[test]
751    fn test_gemv_avx512_remainder() {
752        let k = 128;
753        let n = 300; // 300 = 256 (2 tiles) + 32 (2 zmm remainder) + 12 (scalar)
754
755        let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
756        let b: Vec<f32> = (0..k * n).map(|i| ((i * 13 + 7) % 1000) as f32 / 1000.0 - 0.5).collect();
757        let mut c_gemv = vec![0.0f32; n];
758        let mut c_scalar = vec![0.0f32; n];
759
760        gemv(k, n, &a, &b, &mut c_gemv);
761        gemv_scalar(k, n, &a, &b, &mut c_scalar);
762
763        let max_diff =
764            c_gemv.iter().zip(c_scalar.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
765        assert!(max_diff < 1e-2, "FALSIFY-AVX512-GEMV-002: max diff {max_diff}");
766    }
767}