graph_core/util.rs
1//! Shared SIMD utility helpers for the graph kernels.
2//!
3//! This module was physically moved here from `mnemosyne_rs::util` in
4//! Phase 223 so both the native PyO3 crate and the forthcoming
5//! `graph_wasm` sub-crate can share the dot-product primitives
6//! without re-rolling their own SIMD. The native crate re-exports
7//! these symbols through a thin bridge module at `mnemosyne_rs::util`,
8//! so the existing `crate::util::dot_simd` / `dot_and_self_dot` call
9//! sites keep working unchanged.
10//!
11//! ## Why a hand-rolled AVX2+FMA path?
12//!
13//! Phase 221 shipped `wide::f32x8` here. The Phase 222 design-gate
14//! investigation showed that at ≥4096 floats the `wide`-backed dot
15//! product was up to 6× slower than NumPy's OpenBLAS `sdot`. The
16//! root cause is the rustc default target baseline, which on
17//! `x86_64-unknown-linux-gnu` is just `sse2`. That forces LLVM to
18//! lower `wide::f32x8` (256-bit) into *pairs* of 128-bit SSE2
19//! instructions with no FMA — about half the throughput of native
20//! AVX2+FMA. See `benchmark_results/cosine_simd_investigation.md`
21//! for the disassembly and measurements.
22//!
23//! The fix is a runtime-dispatched hot path:
24//!
25//! * On x86_64 at call time we check `is_x86_feature_detected!` for
26//! `avx2` + `fma`. If both are present we jump into an
27//! `#[target_feature(enable = "avx2,fma")]` function that emits
28//! 256-bit `vmovups ymm` / `vfmadd231ps ymm` and gets us to BLAS
29//! throughput for the per-pair dot product.
30//! * On every other configuration (non-AVX2 x86_64, aarch64, any
31//! other target including `wasm32`) we fall back to the portable
32//! `wide::f32x8` path. `wide` emits NEON on aarch64 and SIMD128 on
33//! wasm32 so those targets are unaffected.
34//!
35//! The dispatch adds one cached atomic load per call (the
36//! `is_x86_feature_detected!` macro memoises its result). For
37//! 4096-float dot products that overhead is ~2 ns out of ~3 µs —
38//! well under 0.1%.
39//!
40//! We also expose a **fused** `dot_and_self_dot` that computes both
41//! the cross-dot product (with a query) and the self-dot product
42//! (for the row's L2 norm) in a single pass. The per-query-batch
43//! kernels are memory-bound at ≥10k candidates × 4096 dims (160 MB
44//! is far larger than any single-core L3), so cutting candidate-row
45//! traffic in half by computing both accumulators during the same
46//! row sweep is the second big win on top of AVX2+FMA.
47//!
48//! No new crates, no baseline bumps, no C deps — the whole fix is
49//! ~100 lines of `unsafe` behind a CPU-feature gate.
50
51/// L2 norm of an `f32` slice. Uses the SIMD dot product internally so
52/// we reuse a single tight inner loop for both `dot` and `norm`.
53#[inline]
54pub fn l2_norm(v: &[f32]) -> f32 {
55 dot_simd(v, v).sqrt()
56}
57
58/// SIMD-accelerated dot product.
59///
60/// At runtime dispatches to the AVX2+FMA fast path on capable x86_64
61/// CPUs (practically every deployment target since ~2014) and falls
62/// back to the portable `wide::f32x8` implementation otherwise.
63///
64/// The portable path uses four independent accumulators so modern
65/// CPUs with multiple multiply/FMA ports can retire one pair of
66/// multiplies per port per cycle; a single accumulator would
67/// serialise every add on the dependency chain.
68#[inline]
69pub fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
70 debug_assert_eq!(a.len(), b.len());
71
72 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
73 {
74 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
75 // SAFETY: the `is_x86_feature_detected!` checks above
76 // guarantee AVX2 + FMA are available on this CPU, which
77 // are the only instruction set extensions used inside
78 // `dot_avx2_fma`.
79 return unsafe { dot_avx2_fma(a, b) };
80 }
81 }
82
83 dot_portable(a, b)
84}
85
86/// Fused `(dot(a,b), dot(b,b))` — computes both the cross-dot product
87/// and the self-dot product of `b` in a single pass through `b`.
88///
89/// This is the per-row inner loop for cosine query-batch kernels.
90/// The plain two-pass implementation reads every candidate row twice
91/// (once for `l2_norm`, once for `dot(q, row)`) which on
92/// memory-bound workloads (≥ L3 footprint) effectively doubles the
93/// required DRAM bandwidth. Fusing lets each row stay hot in L1
94/// throughout both accumulators so we halve candidate-matrix
95/// bandwidth at the cost of ~1.5× the arithmetic on the row —
96/// which is free on AVX2+FMA CPUs (we have the FMA ports spare).
97///
98/// Return tuple is `(dot(a, b), dot(b, b))`.
99#[inline]
100pub fn dot_and_self_dot(a: &[f32], b: &[f32]) -> (f32, f32) {
101 debug_assert_eq!(a.len(), b.len());
102
103 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
104 {
105 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
106 // SAFETY: `is_x86_feature_detected!` above guarantees AVX2
107 // and FMA are available.
108 return unsafe { dot_and_self_dot_avx2_fma(a, b) };
109 }
110 }
111
112 // Portable fallback: run the two-pass dot twice. Slightly
113 // wasteful on non-x86 but only hit when AVX2+FMA isn't available.
114 (dot_portable(a, b), dot_portable(b, b))
115}
116
117/// Portable `wide::f32x8` dot product. Kept as the fallback path for
118/// non-AVX2 x86 targets and for non-x86 targets (e.g. aarch64, where
119/// `wide` emits NEON already, and wasm32 where it emits SIMD128).
120///
121/// Not re-exported at the crate root — callers should go through
122/// [`dot_simd`] which selects the best implementation for the
123/// running CPU.
124#[inline]
125pub(crate) fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
126 let n = a.len();
127
128 let mut acc0 = wide::f32x8::ZERO;
129 let mut acc1 = wide::f32x8::ZERO;
130 let mut acc2 = wide::f32x8::ZERO;
131 let mut acc3 = wide::f32x8::ZERO;
132
133 // Unroll 4× f32x8 = 32 floats per iteration.
134 let chunks_32 = n / 32;
135 let mut i = 0usize;
136 for _ in 0..chunks_32 {
137 let va0 = wide::f32x8::from(&a[i..i + 8]);
138 let vb0 = wide::f32x8::from(&b[i..i + 8]);
139 acc0 += va0 * vb0;
140
141 let va1 = wide::f32x8::from(&a[i + 8..i + 16]);
142 let vb1 = wide::f32x8::from(&b[i + 8..i + 16]);
143 acc1 += va1 * vb1;
144
145 let va2 = wide::f32x8::from(&a[i + 16..i + 24]);
146 let vb2 = wide::f32x8::from(&b[i + 16..i + 24]);
147 acc2 += va2 * vb2;
148
149 let va3 = wide::f32x8::from(&a[i + 24..i + 32]);
150 let vb3 = wide::f32x8::from(&b[i + 24..i + 32]);
151 acc3 += va3 * vb3;
152
153 i += 32;
154 }
155
156 // Remaining 8-aligned chunks (at most three of them).
157 while i + 8 <= n {
158 let va = wide::f32x8::from(&a[i..i + 8]);
159 let vb = wide::f32x8::from(&b[i..i + 8]);
160 acc0 += va * vb;
161 i += 8;
162 }
163
164 // Reduce the four lanes to a single scalar, then tail.
165 let mut total: f32 = (acc0 + acc1 + acc2 + acc3).reduce_add();
166 while i < n {
167 total += a[i] * b[i];
168 i += 1;
169 }
170 total
171}
172
173/// AVX2 + FMA dot product. 256-bit `ymm` registers, four independent
174/// FMA accumulators, 32-float-per-iteration unroll.
175///
176/// # Safety
177///
178/// The caller must ensure the current CPU supports both AVX2 and FMA
179/// (x86_64 feature `avx2` and `fma`). `dot_simd` is the only public
180/// caller and checks this via `is_x86_feature_detected!` before
181/// dispatching here.
182#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
183#[target_feature(enable = "avx2,fma")]
184#[inline]
185unsafe fn dot_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
186 #[cfg(target_arch = "x86")]
187 use std::arch::x86::*;
188 #[cfg(target_arch = "x86_64")]
189 use std::arch::x86_64::*;
190
191 let n = a.len();
192 let pa = a.as_ptr();
193 let pb = b.as_ptr();
194
195 // Four independent accumulators. Zen 3+, Skylake+ have two FMA
196 // ports so four accumulators comfortably saturate the pipeline
197 // while breaking the per-add dependency chain.
198 let mut acc0 = _mm256_setzero_ps();
199 let mut acc1 = _mm256_setzero_ps();
200 let mut acc2 = _mm256_setzero_ps();
201 let mut acc3 = _mm256_setzero_ps();
202
203 let mut i = 0usize;
204 // 32-float unroll. Pointer arithmetic is safe — bounded by i+32<=n.
205 while i + 32 <= n {
206 let va0 = _mm256_loadu_ps(pa.add(i));
207 let vb0 = _mm256_loadu_ps(pb.add(i));
208 acc0 = _mm256_fmadd_ps(va0, vb0, acc0);
209
210 let va1 = _mm256_loadu_ps(pa.add(i + 8));
211 let vb1 = _mm256_loadu_ps(pb.add(i + 8));
212 acc1 = _mm256_fmadd_ps(va1, vb1, acc1);
213
214 let va2 = _mm256_loadu_ps(pa.add(i + 16));
215 let vb2 = _mm256_loadu_ps(pb.add(i + 16));
216 acc2 = _mm256_fmadd_ps(va2, vb2, acc2);
217
218 let va3 = _mm256_loadu_ps(pa.add(i + 24));
219 let vb3 = _mm256_loadu_ps(pb.add(i + 24));
220 acc3 = _mm256_fmadd_ps(va3, vb3, acc3);
221
222 i += 32;
223 }
224
225 // Remaining 8-aligned chunks (at most three of them).
226 while i + 8 <= n {
227 let va = _mm256_loadu_ps(pa.add(i));
228 let vb = _mm256_loadu_ps(pb.add(i));
229 acc0 = _mm256_fmadd_ps(va, vb, acc0);
230 i += 8;
231 }
232
233 // Horizontal reduction: sum four ymm accumulators, then ymm → xmm,
234 // then xmm → scalar.
235 let s01 = _mm256_add_ps(acc0, acc1);
236 let s23 = _mm256_add_ps(acc2, acc3);
237 let s = _mm256_add_ps(s01, s23);
238
239 // Split the ymm into two xmm halves and sum them.
240 let hi = _mm256_extractf128_ps(s, 1);
241 let lo = _mm256_castps256_ps128(s);
242 let sum128 = _mm_add_ps(hi, lo);
243
244 // Sum the four lanes of the 128-bit vector.
245 let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3]
246 let sums = _mm_add_ps(sum128, shuf);
247 let shuf = _mm_movehl_ps(shuf, sums);
248 let sums = _mm_add_ss(sums, shuf);
249 let mut total: f32 = _mm_cvtss_f32(sums);
250
251 // Scalar tail (< 8 elements).
252 while i < n {
253 total += *pa.add(i) * *pb.add(i);
254 i += 1;
255 }
256 total
257}
258
259/// AVX2+FMA fused dot + self-dot. Two independent FMA accumulator
260/// chains, one for `sum(a*b)` and one for `sum(b*b)`, driven by the
261/// same `b` load so the row stays in L1 across both computations.
262///
263/// # Safety
264///
265/// Caller must guarantee AVX2 and FMA are available. `dot_and_self_dot`
266/// checks this before dispatching here.
267#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
268#[target_feature(enable = "avx2,fma")]
269#[inline]
270unsafe fn dot_and_self_dot_avx2_fma(a: &[f32], b: &[f32]) -> (f32, f32) {
271 #[cfg(target_arch = "x86")]
272 use std::arch::x86::*;
273 #[cfg(target_arch = "x86_64")]
274 use std::arch::x86_64::*;
275
276 let n = a.len();
277 let pa = a.as_ptr();
278 let pb = b.as_ptr();
279
280 // Two accumulator pairs (dot, self_dot). Two each breaks up the
281 // dependency chain on modern CPUs (two FMA ports per core); four
282 // would overpressurise the register file given we have eight live
283 // ymm state values on a 16-ymm architecture.
284 let mut dot0 = _mm256_setzero_ps();
285 let mut dot1 = _mm256_setzero_ps();
286 let mut sdot0 = _mm256_setzero_ps();
287 let mut sdot1 = _mm256_setzero_ps();
288
289 let mut i = 0usize;
290 while i + 16 <= n {
291 let va0 = _mm256_loadu_ps(pa.add(i));
292 let vb0 = _mm256_loadu_ps(pb.add(i));
293 dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
294 sdot0 = _mm256_fmadd_ps(vb0, vb0, sdot0);
295
296 let va1 = _mm256_loadu_ps(pa.add(i + 8));
297 let vb1 = _mm256_loadu_ps(pb.add(i + 8));
298 dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
299 sdot1 = _mm256_fmadd_ps(vb1, vb1, sdot1);
300 i += 16;
301 }
302 while i + 8 <= n {
303 let va = _mm256_loadu_ps(pa.add(i));
304 let vb = _mm256_loadu_ps(pb.add(i));
305 dot0 = _mm256_fmadd_ps(va, vb, dot0);
306 sdot0 = _mm256_fmadd_ps(vb, vb, sdot0);
307 i += 8;
308 }
309
310 // Horizontal reduction helper: sum the 8 lanes of a ymm.
311 #[inline(always)]
312 unsafe fn hsum(v: __m256) -> f32 {
313 let hi = _mm256_extractf128_ps(v, 1);
314 let lo = _mm256_castps256_ps128(v);
315 let s = _mm_add_ps(hi, lo);
316 let shuf = _mm_movehdup_ps(s);
317 let sums = _mm_add_ps(s, shuf);
318 let shuf = _mm_movehl_ps(shuf, sums);
319 _mm_cvtss_f32(_mm_add_ss(sums, shuf))
320 }
321
322 let dot_sum = _mm256_add_ps(dot0, dot1);
323 let sdot_sum = _mm256_add_ps(sdot0, sdot1);
324 let mut dot: f32 = hsum(dot_sum);
325 let mut sdot: f32 = hsum(sdot_sum);
326
327 // Scalar tail (< 8 elements).
328 while i < n {
329 let ax = *pa.add(i);
330 let bx = *pb.add(i);
331 dot += ax * bx;
332 sdot += bx * bx;
333 i += 1;
334 }
335 (dot, sdot)
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 /// Compare the dispatched `dot_simd` against a naive scalar
343 /// reference across a range of lengths (including exact 32-float
344 /// multiples and all possible tail sizes 0..31). Any AVX2/FMA
345 /// codegen bug would surface here.
346 #[test]
347 fn dot_simd_matches_scalar() {
348 fn scalar(a: &[f32], b: &[f32]) -> f32 {
349 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
350 }
351
352 // Use a deterministic pseudo-random fill so tail handling gets
353 // non-trivial values (otherwise fmadd-vs-mul-add rounding is
354 // invisible).
355 for &n in &[0usize, 1, 7, 8, 9, 15, 16, 31, 32, 33, 63, 64, 127, 128, 4096, 4097] {
356 let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.123).sin()).collect();
357 let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.456).cos()).collect();
358 let got = dot_simd(&a, &b);
359 let want = scalar(&a, &b);
360 // Generous epsilon: f32 FMA vs separate mul+add accumulates
361 // different rounding. 1e-3 is plenty given |values|≈1.
362 assert!(
363 (got - want).abs() < 1e-3_f32.max(want.abs() * 1e-4),
364 "n={n} got={got} want={want}"
365 );
366 }
367 }
368
369 /// Same test specifically for the portable fallback path — in
370 /// case the dispatch picks AVX2 but the fallback is still used
371 /// (or will be, on non-AVX2 CPUs).
372 #[test]
373 fn dot_portable_matches_scalar() {
374 fn scalar(a: &[f32], b: &[f32]) -> f32 {
375 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
376 }
377 for &n in &[0usize, 8, 32, 4096] {
378 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
379 let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.2).collect();
380 let got = dot_portable(&a, &b);
381 let want = scalar(&a, &b);
382 let tol = 1e-3_f32.max(want.abs() * 1e-4);
383 assert!((got - want).abs() < tol, "n={n} got={got} want={want}");
384 }
385 }
386
387 /// Fused dot + self-dot should match unfused.
388 #[test]
389 fn dot_and_self_dot_matches_separate_calls() {
390 for &n in &[0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 1024, 4096, 4097] {
391 let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.17).sin()).collect();
392 let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.31).cos()).collect();
393 let (dot_fused, sdot_fused) = dot_and_self_dot(&a, &b);
394 let dot_ref = dot_simd(&a, &b);
395 let sdot_ref = dot_simd(&b, &b);
396 let tol_dot = 1e-3_f32.max(dot_ref.abs() * 1e-4);
397 let tol_sdot = 1e-3_f32.max(sdot_ref.abs() * 1e-4);
398 assert!(
399 (dot_fused - dot_ref).abs() < tol_dot,
400 "dot mismatch n={n} fused={dot_fused} ref={dot_ref}"
401 );
402 assert!(
403 (sdot_fused - sdot_ref).abs() < tol_sdot,
404 "sdot mismatch n={n} fused={sdot_fused} ref={sdot_ref}"
405 );
406 }
407 }
408}