Skip to main content

graph_core/
util.rs

1//! Shared SIMD utility helpers for the graph kernels.
2//!
3//! This module was physically moved here from `mnemosyne_rs::util` in
4//! Phase 223 so both the native PyO3 crate and the forthcoming
5//! `graph_wasm` sub-crate can share the dot-product primitives
6//! without re-rolling their own SIMD. The native crate re-exports
7//! these symbols through a thin bridge module at `mnemosyne_rs::util`,
8//! so the existing `crate::util::dot_simd` / `dot_and_self_dot` call
9//! sites keep working unchanged.
10//!
11//! ## Why a hand-rolled AVX2+FMA path?
12//!
13//! Phase 221 shipped `wide::f32x8` here. The Phase 222 design-gate
14//! investigation showed that at ≥4096 floats the `wide`-backed dot
15//! product was up to 6× slower than NumPy's OpenBLAS `sdot`. The
16//! root cause is the rustc default target baseline, which on
17//! `x86_64-unknown-linux-gnu` is just `sse2`. That forces LLVM to
18//! lower `wide::f32x8` (256-bit) into *pairs* of 128-bit SSE2
19//! instructions with no FMA — about half the throughput of native
20//! AVX2+FMA. See `benchmark_results/cosine_simd_investigation.md`
21//! for the disassembly and measurements.
22//!
23//! The fix is a runtime-dispatched hot path:
24//!
25//! * On x86_64 at call time we check `is_x86_feature_detected!` for
26//!   `avx2` + `fma`. If both are present we jump into an
27//!   `#[target_feature(enable = "avx2,fma")]` function that emits
28//!   256-bit `vmovups ymm` / `vfmadd231ps ymm` and gets us to BLAS
29//!   throughput for the per-pair dot product.
30//! * On every other configuration (non-AVX2 x86_64, aarch64, any
31//!   other target including `wasm32`) we fall back to the portable
32//!   `wide::f32x8` path. `wide` emits NEON on aarch64 and SIMD128 on
33//!   wasm32 so those targets are unaffected.
34//!
35//! The dispatch adds one cached atomic load per call (the
36//! `is_x86_feature_detected!` macro memoises its result). For
37//! 4096-float dot products that overhead is ~2 ns out of ~3 µs —
38//! well under 0.1%.
39//!
40//! We also expose a **fused** `dot_and_self_dot` that computes both
41//! the cross-dot product (with a query) and the self-dot product
42//! (for the row's L2 norm) in a single pass. The per-query-batch
43//! kernels are memory-bound at ≥10k candidates × 4096 dims (160 MB
44//! is far larger than any single-core L3), so cutting candidate-row
45//! traffic in half by computing both accumulators during the same
46//! row sweep is the second big win on top of AVX2+FMA.
47//!
48//! No new crates, no baseline bumps, no C deps — the whole fix is
49//! ~100 lines of `unsafe` behind a CPU-feature gate.
50
51/// L2 norm of an `f32` slice. Uses the SIMD dot product internally so
52/// we reuse a single tight inner loop for both `dot` and `norm`.
53#[inline]
54pub fn l2_norm(v: &[f32]) -> f32 {
55    dot_simd(v, v).sqrt()
56}
57
58/// SIMD-accelerated dot product.
59///
60/// At runtime dispatches to the AVX2+FMA fast path on capable x86_64
61/// CPUs (practically every deployment target since ~2014) and falls
62/// back to the portable `wide::f32x8` implementation otherwise.
63///
64/// The portable path uses four independent accumulators so modern
65/// CPUs with multiple multiply/FMA ports can retire one pair of
66/// multiplies per port per cycle; a single accumulator would
67/// serialise every add on the dependency chain.
68#[inline]
69pub fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
70    debug_assert_eq!(a.len(), b.len());
71
72    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
73    {
74        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
75            // SAFETY: the `is_x86_feature_detected!` checks above
76            // guarantee AVX2 + FMA are available on this CPU, which
77            // are the only instruction set extensions used inside
78            // `dot_avx2_fma`.
79            return unsafe { dot_avx2_fma(a, b) };
80        }
81    }
82
83    dot_portable(a, b)
84}
85
86/// Fused `(dot(a,b), dot(b,b))` — computes both the cross-dot product
87/// and the self-dot product of `b` in a single pass through `b`.
88///
89/// This is the per-row inner loop for cosine query-batch kernels.
90/// The plain two-pass implementation reads every candidate row twice
91/// (once for `l2_norm`, once for `dot(q, row)`) which on
92/// memory-bound workloads (≥ L3 footprint) effectively doubles the
93/// required DRAM bandwidth. Fusing lets each row stay hot in L1
94/// throughout both accumulators so we halve candidate-matrix
95/// bandwidth at the cost of ~1.5× the arithmetic on the row —
96/// which is free on AVX2+FMA CPUs (we have the FMA ports spare).
97///
98/// Return tuple is `(dot(a, b), dot(b, b))`.
99#[inline]
100pub fn dot_and_self_dot(a: &[f32], b: &[f32]) -> (f32, f32) {
101    debug_assert_eq!(a.len(), b.len());
102
103    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
104    {
105        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
106            // SAFETY: `is_x86_feature_detected!` above guarantees AVX2
107            // and FMA are available.
108            return unsafe { dot_and_self_dot_avx2_fma(a, b) };
109        }
110    }
111
112    // Portable fallback: run the two-pass dot twice. Slightly
113    // wasteful on non-x86 but only hit when AVX2+FMA isn't available.
114    (dot_portable(a, b), dot_portable(b, b))
115}
116
117/// Portable `wide::f32x8` dot product. Kept as the fallback path for
118/// non-AVX2 x86 targets and for non-x86 targets (e.g. aarch64, where
119/// `wide` emits NEON already, and wasm32 where it emits SIMD128).
120///
121/// Not re-exported at the crate root — callers should go through
122/// [`dot_simd`] which selects the best implementation for the
123/// running CPU.
124#[inline]
125pub(crate) fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
126    let n = a.len();
127
128    let mut acc0 = wide::f32x8::ZERO;
129    let mut acc1 = wide::f32x8::ZERO;
130    let mut acc2 = wide::f32x8::ZERO;
131    let mut acc3 = wide::f32x8::ZERO;
132
133    // Unroll 4× f32x8 = 32 floats per iteration.
134    let chunks_32 = n / 32;
135    let mut i = 0usize;
136    for _ in 0..chunks_32 {
137        let va0 = wide::f32x8::from(&a[i..i + 8]);
138        let vb0 = wide::f32x8::from(&b[i..i + 8]);
139        acc0 += va0 * vb0;
140
141        let va1 = wide::f32x8::from(&a[i + 8..i + 16]);
142        let vb1 = wide::f32x8::from(&b[i + 8..i + 16]);
143        acc1 += va1 * vb1;
144
145        let va2 = wide::f32x8::from(&a[i + 16..i + 24]);
146        let vb2 = wide::f32x8::from(&b[i + 16..i + 24]);
147        acc2 += va2 * vb2;
148
149        let va3 = wide::f32x8::from(&a[i + 24..i + 32]);
150        let vb3 = wide::f32x8::from(&b[i + 24..i + 32]);
151        acc3 += va3 * vb3;
152
153        i += 32;
154    }
155
156    // Remaining 8-aligned chunks (at most three of them).
157    while i + 8 <= n {
158        let va = wide::f32x8::from(&a[i..i + 8]);
159        let vb = wide::f32x8::from(&b[i..i + 8]);
160        acc0 += va * vb;
161        i += 8;
162    }
163
164    // Reduce the four lanes to a single scalar, then tail.
165    let mut total: f32 = (acc0 + acc1 + acc2 + acc3).reduce_add();
166    while i < n {
167        total += a[i] * b[i];
168        i += 1;
169    }
170    total
171}
172
173/// AVX2 + FMA dot product. 256-bit `ymm` registers, four independent
174/// FMA accumulators, 32-float-per-iteration unroll.
175///
176/// # Safety
177///
178/// The caller must ensure the current CPU supports both AVX2 and FMA
179/// (x86_64 feature `avx2` and `fma`). `dot_simd` is the only public
180/// caller and checks this via `is_x86_feature_detected!` before
181/// dispatching here.
182#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
183#[target_feature(enable = "avx2,fma")]
184#[inline]
185unsafe fn dot_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
186    #[cfg(target_arch = "x86")]
187    use std::arch::x86::*;
188    #[cfg(target_arch = "x86_64")]
189    use std::arch::x86_64::*;
190
191    let n = a.len();
192    let pa = a.as_ptr();
193    let pb = b.as_ptr();
194
195    // Four independent accumulators. Zen 3+, Skylake+ have two FMA
196    // ports so four accumulators comfortably saturate the pipeline
197    // while breaking the per-add dependency chain.
198    let mut acc0 = _mm256_setzero_ps();
199    let mut acc1 = _mm256_setzero_ps();
200    let mut acc2 = _mm256_setzero_ps();
201    let mut acc3 = _mm256_setzero_ps();
202
203    let mut i = 0usize;
204    // 32-float unroll. Pointer arithmetic is safe — bounded by i+32<=n.
205    while i + 32 <= n {
206        let va0 = _mm256_loadu_ps(pa.add(i));
207        let vb0 = _mm256_loadu_ps(pb.add(i));
208        acc0 = _mm256_fmadd_ps(va0, vb0, acc0);
209
210        let va1 = _mm256_loadu_ps(pa.add(i + 8));
211        let vb1 = _mm256_loadu_ps(pb.add(i + 8));
212        acc1 = _mm256_fmadd_ps(va1, vb1, acc1);
213
214        let va2 = _mm256_loadu_ps(pa.add(i + 16));
215        let vb2 = _mm256_loadu_ps(pb.add(i + 16));
216        acc2 = _mm256_fmadd_ps(va2, vb2, acc2);
217
218        let va3 = _mm256_loadu_ps(pa.add(i + 24));
219        let vb3 = _mm256_loadu_ps(pb.add(i + 24));
220        acc3 = _mm256_fmadd_ps(va3, vb3, acc3);
221
222        i += 32;
223    }
224
225    // Remaining 8-aligned chunks (at most three of them).
226    while i + 8 <= n {
227        let va = _mm256_loadu_ps(pa.add(i));
228        let vb = _mm256_loadu_ps(pb.add(i));
229        acc0 = _mm256_fmadd_ps(va, vb, acc0);
230        i += 8;
231    }
232
233    // Horizontal reduction: sum four ymm accumulators, then ymm → xmm,
234    // then xmm → scalar.
235    let s01 = _mm256_add_ps(acc0, acc1);
236    let s23 = _mm256_add_ps(acc2, acc3);
237    let s = _mm256_add_ps(s01, s23);
238
239    // Split the ymm into two xmm halves and sum them.
240    let hi = _mm256_extractf128_ps(s, 1);
241    let lo = _mm256_castps256_ps128(s);
242    let sum128 = _mm_add_ps(hi, lo);
243
244    // Sum the four lanes of the 128-bit vector.
245    let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3]
246    let sums = _mm_add_ps(sum128, shuf);
247    let shuf = _mm_movehl_ps(shuf, sums);
248    let sums = _mm_add_ss(sums, shuf);
249    let mut total: f32 = _mm_cvtss_f32(sums);
250
251    // Scalar tail (< 8 elements).
252    while i < n {
253        total += *pa.add(i) * *pb.add(i);
254        i += 1;
255    }
256    total
257}
258
259/// AVX2+FMA fused dot + self-dot. Two independent FMA accumulator
260/// chains, one for `sum(a*b)` and one for `sum(b*b)`, driven by the
261/// same `b` load so the row stays in L1 across both computations.
262///
263/// # Safety
264///
265/// Caller must guarantee AVX2 and FMA are available. `dot_and_self_dot`
266/// checks this before dispatching here.
267#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
268#[target_feature(enable = "avx2,fma")]
269#[inline]
270unsafe fn dot_and_self_dot_avx2_fma(a: &[f32], b: &[f32]) -> (f32, f32) {
271    #[cfg(target_arch = "x86")]
272    use std::arch::x86::*;
273    #[cfg(target_arch = "x86_64")]
274    use std::arch::x86_64::*;
275
276    let n = a.len();
277    let pa = a.as_ptr();
278    let pb = b.as_ptr();
279
280    // Two accumulator pairs (dot, self_dot). Two each breaks up the
281    // dependency chain on modern CPUs (two FMA ports per core); four
282    // would overpressurise the register file given we have eight live
283    // ymm state values on a 16-ymm architecture.
284    let mut dot0 = _mm256_setzero_ps();
285    let mut dot1 = _mm256_setzero_ps();
286    let mut sdot0 = _mm256_setzero_ps();
287    let mut sdot1 = _mm256_setzero_ps();
288
289    let mut i = 0usize;
290    while i + 16 <= n {
291        let va0 = _mm256_loadu_ps(pa.add(i));
292        let vb0 = _mm256_loadu_ps(pb.add(i));
293        dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
294        sdot0 = _mm256_fmadd_ps(vb0, vb0, sdot0);
295
296        let va1 = _mm256_loadu_ps(pa.add(i + 8));
297        let vb1 = _mm256_loadu_ps(pb.add(i + 8));
298        dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
299        sdot1 = _mm256_fmadd_ps(vb1, vb1, sdot1);
300        i += 16;
301    }
302    while i + 8 <= n {
303        let va = _mm256_loadu_ps(pa.add(i));
304        let vb = _mm256_loadu_ps(pb.add(i));
305        dot0 = _mm256_fmadd_ps(va, vb, dot0);
306        sdot0 = _mm256_fmadd_ps(vb, vb, sdot0);
307        i += 8;
308    }
309
310    // Horizontal reduction helper: sum the 8 lanes of a ymm.
311    #[inline(always)]
312    unsafe fn hsum(v: __m256) -> f32 {
313        let hi = _mm256_extractf128_ps(v, 1);
314        let lo = _mm256_castps256_ps128(v);
315        let s = _mm_add_ps(hi, lo);
316        let shuf = _mm_movehdup_ps(s);
317        let sums = _mm_add_ps(s, shuf);
318        let shuf = _mm_movehl_ps(shuf, sums);
319        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
320    }
321
322    let dot_sum = _mm256_add_ps(dot0, dot1);
323    let sdot_sum = _mm256_add_ps(sdot0, sdot1);
324    let mut dot: f32 = hsum(dot_sum);
325    let mut sdot: f32 = hsum(sdot_sum);
326
327    // Scalar tail (< 8 elements).
328    while i < n {
329        let ax = *pa.add(i);
330        let bx = *pb.add(i);
331        dot += ax * bx;
332        sdot += bx * bx;
333        i += 1;
334    }
335    (dot, sdot)
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    /// Compare the dispatched `dot_simd` against a naive scalar
343    /// reference across a range of lengths (including exact 32-float
344    /// multiples and all possible tail sizes 0..31). Any AVX2/FMA
345    /// codegen bug would surface here.
346    #[test]
347    fn dot_simd_matches_scalar() {
348        fn scalar(a: &[f32], b: &[f32]) -> f32 {
349            a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
350        }
351
352        // Use a deterministic pseudo-random fill so tail handling gets
353        // non-trivial values (otherwise fmadd-vs-mul-add rounding is
354        // invisible).
355        for &n in &[0usize, 1, 7, 8, 9, 15, 16, 31, 32, 33, 63, 64, 127, 128, 4096, 4097] {
356            let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.123).sin()).collect();
357            let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.456).cos()).collect();
358            let got = dot_simd(&a, &b);
359            let want = scalar(&a, &b);
360            // Generous epsilon: f32 FMA vs separate mul+add accumulates
361            // different rounding. 1e-3 is plenty given |values|≈1.
362            assert!(
363                (got - want).abs() < 1e-3_f32.max(want.abs() * 1e-4),
364                "n={n} got={got} want={want}"
365            );
366        }
367    }
368
369    /// Same test specifically for the portable fallback path — in
370    /// case the dispatch picks AVX2 but the fallback is still used
371    /// (or will be, on non-AVX2 CPUs).
372    #[test]
373    fn dot_portable_matches_scalar() {
374        fn scalar(a: &[f32], b: &[f32]) -> f32 {
375            a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
376        }
377        for &n in &[0usize, 8, 32, 4096] {
378            let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
379            let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.2).collect();
380            let got = dot_portable(&a, &b);
381            let want = scalar(&a, &b);
382            let tol = 1e-3_f32.max(want.abs() * 1e-4);
383            assert!((got - want).abs() < tol, "n={n} got={got} want={want}");
384        }
385    }
386
387    /// Fused dot + self-dot should match unfused.
388    #[test]
389    fn dot_and_self_dot_matches_separate_calls() {
390        for &n in &[0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 1024, 4096, 4097] {
391            let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.17).sin()).collect();
392            let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.31).cos()).collect();
393            let (dot_fused, sdot_fused) = dot_and_self_dot(&a, &b);
394            let dot_ref = dot_simd(&a, &b);
395            let sdot_ref = dot_simd(&b, &b);
396            let tol_dot = 1e-3_f32.max(dot_ref.abs() * 1e-4);
397            let tol_sdot = 1e-3_f32.max(sdot_ref.abs() * 1e-4);
398            assert!(
399                (dot_fused - dot_ref).abs() < tol_dot,
400                "dot mismatch n={n} fused={dot_fused} ref={dot_ref}"
401            );
402            assert!(
403                (sdot_fused - sdot_ref).abs() < tol_sdot,
404                "sdot mismatch n={n} fused={sdot_fused} ref={sdot_ref}"
405            );
406        }
407    }
408}