mnemosyne-graph-core 0.1.0

Shared graph kernels for the Mnemosyne memory substrate: force-directed simulation, R-tree viewport index, and the SIMD primitives both the native PyO3 crate and the WASM sub-crate depend on.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
//! Shared SIMD utility helpers for the graph kernels.
//!
//! This module was physically moved here from `mnemosyne_rs::util` in
//! Phase 223 so both the native PyO3 crate and the forthcoming
//! `graph_wasm` sub-crate can share the dot-product primitives
//! without re-rolling their own SIMD. The native crate re-exports
//! these symbols through a thin bridge module at `mnemosyne_rs::util`,
//! so the existing `crate::util::dot_simd` / `dot_and_self_dot` call
//! sites keep working unchanged.
//!
//! ## Why a hand-rolled AVX2+FMA path?
//!
//! Phase 221 shipped `wide::f32x8` here. The Phase 222 design-gate
//! investigation showed that at ≥4096 floats the `wide`-backed dot
//! product was up to 6× slower than NumPy's OpenBLAS `sdot`. The
//! root cause is the rustc default target baseline, which on
//! `x86_64-unknown-linux-gnu` is just `sse2`. That forces LLVM to
//! lower `wide::f32x8` (256-bit) into *pairs* of 128-bit SSE2
//! instructions with no FMA — about half the throughput of native
//! AVX2+FMA. See `benchmark_results/cosine_simd_investigation.md`
//! for the disassembly and measurements.
//!
//! The fix is a runtime-dispatched hot path:
//!
//! * On x86_64 at call time we check `is_x86_feature_detected!` for
//!   `avx2` + `fma`. If both are present we jump into an
//!   `#[target_feature(enable = "avx2,fma")]` function that emits
//!   256-bit `vmovups ymm` / `vfmadd231ps ymm` and gets us to BLAS
//!   throughput for the per-pair dot product.
//! * On every other configuration (non-AVX2 x86_64, aarch64, any
//!   other target including `wasm32`) we fall back to the portable
//!   `wide::f32x8` path. `wide` emits NEON on aarch64 and SIMD128 on
//!   wasm32 so those targets are unaffected.
//!
//! The dispatch adds one cached atomic load per call (the
//! `is_x86_feature_detected!` macro memoises its result). For
//! 4096-float dot products that overhead is ~2 ns out of ~3 µs —
//! well under 0.1%.
//!
//! We also expose a **fused** `dot_and_self_dot` that computes both
//! the cross-dot product (with a query) and the self-dot product
//! (for the row's L2 norm) in a single pass. The per-query-batch
//! kernels are memory-bound at ≥10k candidates × 4096 dims (160 MB
//! is far larger than any single-core L3), so cutting candidate-row
//! traffic in half by computing both accumulators during the same
//! row sweep is the second big win on top of AVX2+FMA.
//!
//! No new crates, no baseline bumps, no C deps — the whole fix is
//! ~100 lines of `unsafe` behind a CPU-feature gate.

/// L2 norm of an `f32` slice. Uses the SIMD dot product internally so
/// we reuse a single tight inner loop for both `dot` and `norm`.
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
    dot_simd(v, v).sqrt()
}

/// SIMD-accelerated dot product.
///
/// At runtime dispatches to the AVX2+FMA fast path on capable x86_64
/// CPUs (practically every deployment target since ~2014) and falls
/// back to the portable `wide::f32x8` implementation otherwise.
///
/// The portable path uses four independent accumulators so modern
/// CPUs with multiple multiply/FMA ports can retire one pair of
/// multiplies per port per cycle; a single accumulator would
/// serialise every add on the dependency chain.
#[inline]
pub fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
    debug_assert_eq!(a.len(), b.len());

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            // SAFETY: the `is_x86_feature_detected!` checks above
            // guarantee AVX2 + FMA are available on this CPU, which
            // are the only instruction set extensions used inside
            // `dot_avx2_fma`.
            return unsafe { dot_avx2_fma(a, b) };
        }
    }

    dot_portable(a, b)
}

/// Fused `(dot(a,b), dot(b,b))` — computes both the cross-dot product
/// and the self-dot product of `b` in a single pass through `b`.
///
/// This is the per-row inner loop for cosine query-batch kernels.
/// The plain two-pass implementation reads every candidate row twice
/// (once for `l2_norm`, once for `dot(q, row)`) which on
/// memory-bound workloads (≥ L3 footprint) effectively doubles the
/// required DRAM bandwidth. Fusing lets each row stay hot in L1
/// throughout both accumulators so we halve candidate-matrix
/// bandwidth at the cost of ~1.5× the arithmetic on the row —
/// which is free on AVX2+FMA CPUs (we have the FMA ports spare).
///
/// Return tuple is `(dot(a, b), dot(b, b))`.
#[inline]
pub fn dot_and_self_dot(a: &[f32], b: &[f32]) -> (f32, f32) {
    debug_assert_eq!(a.len(), b.len());

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            // SAFETY: `is_x86_feature_detected!` above guarantees AVX2
            // and FMA are available.
            return unsafe { dot_and_self_dot_avx2_fma(a, b) };
        }
    }

    // Portable fallback: run the two-pass dot twice. Slightly
    // wasteful on non-x86 but only hit when AVX2+FMA isn't available.
    (dot_portable(a, b), dot_portable(b, b))
}

/// Portable `wide::f32x8` dot product. Kept as the fallback path for
/// non-AVX2 x86 targets and for non-x86 targets (e.g. aarch64, where
/// `wide` emits NEON already, and wasm32 where it emits SIMD128).
///
/// Not re-exported at the crate root — callers should go through
/// [`dot_simd`] which selects the best implementation for the
/// running CPU.
#[inline]
pub(crate) fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
    let n = a.len();

    let mut acc0 = wide::f32x8::ZERO;
    let mut acc1 = wide::f32x8::ZERO;
    let mut acc2 = wide::f32x8::ZERO;
    let mut acc3 = wide::f32x8::ZERO;

    // Unroll 4× f32x8 = 32 floats per iteration.
    let chunks_32 = n / 32;
    let mut i = 0usize;
    for _ in 0..chunks_32 {
        let va0 = wide::f32x8::from(&a[i..i + 8]);
        let vb0 = wide::f32x8::from(&b[i..i + 8]);
        acc0 += va0 * vb0;

        let va1 = wide::f32x8::from(&a[i + 8..i + 16]);
        let vb1 = wide::f32x8::from(&b[i + 8..i + 16]);
        acc1 += va1 * vb1;

        let va2 = wide::f32x8::from(&a[i + 16..i + 24]);
        let vb2 = wide::f32x8::from(&b[i + 16..i + 24]);
        acc2 += va2 * vb2;

        let va3 = wide::f32x8::from(&a[i + 24..i + 32]);
        let vb3 = wide::f32x8::from(&b[i + 24..i + 32]);
        acc3 += va3 * vb3;

        i += 32;
    }

    // Remaining 8-aligned chunks (at most three of them).
    while i + 8 <= n {
        let va = wide::f32x8::from(&a[i..i + 8]);
        let vb = wide::f32x8::from(&b[i..i + 8]);
        acc0 += va * vb;
        i += 8;
    }

    // Reduce the four lanes to a single scalar, then tail.
    let mut total: f32 = (acc0 + acc1 + acc2 + acc3).reduce_add();
    while i < n {
        total += a[i] * b[i];
        i += 1;
    }
    total
}

/// AVX2 + FMA dot product. 256-bit `ymm` registers, four independent
/// FMA accumulators, 32-float-per-iteration unroll.
///
/// # Safety
///
/// The caller must ensure the current CPU supports both AVX2 and FMA
/// (x86_64 feature `avx2` and `fma`). `dot_simd` is the only public
/// caller and checks this via `is_x86_feature_detected!` before
/// dispatching here.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn dot_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
    #[cfg(target_arch = "x86")]
    use std::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use std::arch::x86_64::*;

    let n = a.len();
    let pa = a.as_ptr();
    let pb = b.as_ptr();

    // Four independent accumulators. Zen 3+, Skylake+ have two FMA
    // ports so four accumulators comfortably saturate the pipeline
    // while breaking the per-add dependency chain.
    let mut acc0 = _mm256_setzero_ps();
    let mut acc1 = _mm256_setzero_ps();
    let mut acc2 = _mm256_setzero_ps();
    let mut acc3 = _mm256_setzero_ps();

    let mut i = 0usize;
    // 32-float unroll. Pointer arithmetic is safe — bounded by i+32<=n.
    while i + 32 <= n {
        let va0 = _mm256_loadu_ps(pa.add(i));
        let vb0 = _mm256_loadu_ps(pb.add(i));
        acc0 = _mm256_fmadd_ps(va0, vb0, acc0);

        let va1 = _mm256_loadu_ps(pa.add(i + 8));
        let vb1 = _mm256_loadu_ps(pb.add(i + 8));
        acc1 = _mm256_fmadd_ps(va1, vb1, acc1);

        let va2 = _mm256_loadu_ps(pa.add(i + 16));
        let vb2 = _mm256_loadu_ps(pb.add(i + 16));
        acc2 = _mm256_fmadd_ps(va2, vb2, acc2);

        let va3 = _mm256_loadu_ps(pa.add(i + 24));
        let vb3 = _mm256_loadu_ps(pb.add(i + 24));
        acc3 = _mm256_fmadd_ps(va3, vb3, acc3);

        i += 32;
    }

    // Remaining 8-aligned chunks (at most three of them).
    while i + 8 <= n {
        let va = _mm256_loadu_ps(pa.add(i));
        let vb = _mm256_loadu_ps(pb.add(i));
        acc0 = _mm256_fmadd_ps(va, vb, acc0);
        i += 8;
    }

    // Horizontal reduction: sum four ymm accumulators, then ymm → xmm,
    // then xmm → scalar.
    let s01 = _mm256_add_ps(acc0, acc1);
    let s23 = _mm256_add_ps(acc2, acc3);
    let s = _mm256_add_ps(s01, s23);

    // Split the ymm into two xmm halves and sum them.
    let hi = _mm256_extractf128_ps(s, 1);
    let lo = _mm256_castps256_ps128(s);
    let sum128 = _mm_add_ps(hi, lo);

    // Sum the four lanes of the 128-bit vector.
    let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3]
    let sums = _mm_add_ps(sum128, shuf);
    let shuf = _mm_movehl_ps(shuf, sums);
    let sums = _mm_add_ss(sums, shuf);
    let mut total: f32 = _mm_cvtss_f32(sums);

    // Scalar tail (< 8 elements).
    while i < n {
        total += *pa.add(i) * *pb.add(i);
        i += 1;
    }
    total
}

/// AVX2+FMA fused dot + self-dot. Two independent FMA accumulator
/// chains, one for `sum(a*b)` and one for `sum(b*b)`, driven by the
/// same `b` load so the row stays in L1 across both computations.
///
/// # Safety
///
/// Caller must guarantee AVX2 and FMA are available. `dot_and_self_dot`
/// checks this before dispatching here.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn dot_and_self_dot_avx2_fma(a: &[f32], b: &[f32]) -> (f32, f32) {
    #[cfg(target_arch = "x86")]
    use std::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use std::arch::x86_64::*;

    let n = a.len();
    let pa = a.as_ptr();
    let pb = b.as_ptr();

    // Two accumulator pairs (dot, self_dot). Two each breaks up the
    // dependency chain on modern CPUs (two FMA ports per core); four
    // would overpressurise the register file given we have eight live
    // ymm state values on a 16-ymm architecture.
    let mut dot0 = _mm256_setzero_ps();
    let mut dot1 = _mm256_setzero_ps();
    let mut sdot0 = _mm256_setzero_ps();
    let mut sdot1 = _mm256_setzero_ps();

    let mut i = 0usize;
    while i + 16 <= n {
        let va0 = _mm256_loadu_ps(pa.add(i));
        let vb0 = _mm256_loadu_ps(pb.add(i));
        dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
        sdot0 = _mm256_fmadd_ps(vb0, vb0, sdot0);

        let va1 = _mm256_loadu_ps(pa.add(i + 8));
        let vb1 = _mm256_loadu_ps(pb.add(i + 8));
        dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
        sdot1 = _mm256_fmadd_ps(vb1, vb1, sdot1);
        i += 16;
    }
    while i + 8 <= n {
        let va = _mm256_loadu_ps(pa.add(i));
        let vb = _mm256_loadu_ps(pb.add(i));
        dot0 = _mm256_fmadd_ps(va, vb, dot0);
        sdot0 = _mm256_fmadd_ps(vb, vb, sdot0);
        i += 8;
    }

    // Horizontal reduction helper: sum the 8 lanes of a ymm.
    #[inline(always)]
    unsafe fn hsum(v: __m256) -> f32 {
        let hi = _mm256_extractf128_ps(v, 1);
        let lo = _mm256_castps256_ps128(v);
        let s = _mm_add_ps(hi, lo);
        let shuf = _mm_movehdup_ps(s);
        let sums = _mm_add_ps(s, shuf);
        let shuf = _mm_movehl_ps(shuf, sums);
        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
    }

    let dot_sum = _mm256_add_ps(dot0, dot1);
    let sdot_sum = _mm256_add_ps(sdot0, sdot1);
    let mut dot: f32 = hsum(dot_sum);
    let mut sdot: f32 = hsum(sdot_sum);

    // Scalar tail (< 8 elements).
    while i < n {
        let ax = *pa.add(i);
        let bx = *pb.add(i);
        dot += ax * bx;
        sdot += bx * bx;
        i += 1;
    }
    (dot, sdot)
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Compare the dispatched `dot_simd` against a naive scalar
    /// reference across a range of lengths (including exact 32-float
    /// multiples and all possible tail sizes 0..31). Any AVX2/FMA
    /// codegen bug would surface here.
    #[test]
    fn dot_simd_matches_scalar() {
        fn scalar(a: &[f32], b: &[f32]) -> f32 {
            a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
        }

        // Use a deterministic pseudo-random fill so tail handling gets
        // non-trivial values (otherwise fmadd-vs-mul-add rounding is
        // invisible).
        for &n in &[0usize, 1, 7, 8, 9, 15, 16, 31, 32, 33, 63, 64, 127, 128, 4096, 4097] {
            let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.123).sin()).collect();
            let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.456).cos()).collect();
            let got = dot_simd(&a, &b);
            let want = scalar(&a, &b);
            // Generous epsilon: f32 FMA vs separate mul+add accumulates
            // different rounding. 1e-3 is plenty given |values|≈1.
            assert!(
                (got - want).abs() < 1e-3_f32.max(want.abs() * 1e-4),
                "n={n} got={got} want={want}"
            );
        }
    }

    /// Same test specifically for the portable fallback path — in
    /// case the dispatch picks AVX2 but the fallback is still used
    /// (or will be, on non-AVX2 CPUs).
    #[test]
    fn dot_portable_matches_scalar() {
        fn scalar(a: &[f32], b: &[f32]) -> f32 {
            a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
        }
        for &n in &[0usize, 8, 32, 4096] {
            let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
            let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.2).collect();
            let got = dot_portable(&a, &b);
            let want = scalar(&a, &b);
            let tol = 1e-3_f32.max(want.abs() * 1e-4);
            assert!((got - want).abs() < tol, "n={n} got={got} want={want}");
        }
    }

    /// Fused dot + self-dot should match unfused.
    #[test]
    fn dot_and_self_dot_matches_separate_calls() {
        for &n in &[0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 1024, 4096, 4097] {
            let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.17).sin()).collect();
            let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.31).cos()).collect();
            let (dot_fused, sdot_fused) = dot_and_self_dot(&a, &b);
            let dot_ref = dot_simd(&a, &b);
            let sdot_ref = dot_simd(&b, &b);
            let tol_dot = 1e-3_f32.max(dot_ref.abs() * 1e-4);
            let tol_sdot = 1e-3_f32.max(sdot_ref.abs() * 1e-4);
            assert!(
                (dot_fused - dot_ref).abs() < tol_dot,
                "dot mismatch n={n} fused={dot_fused} ref={dot_ref}"
            );
            assert!(
                (sdot_fused - sdot_ref).abs() < tol_sdot,
                "sdot mismatch n={n} fused={sdot_fused} ref={sdot_ref}"
            );
        }
    }
}