Skip to main content

difflib_fast/
simjoin.rs

1//! `simjoin` — exact all-pairs **weighted-cosine similarity join**: given a corpus of sparse
2//! non-negative vectors, find every pair `(i, j)` whose cosine similarity is `≥ t`.
3//!
4//! This is the `AllPairs` / `L2AP` family (Bayardo et al. WWW'07; Anastasiu & Karypis ICDE'14): an
5//! inverted index with **prefix filtering** so that the vast majority of vector pairs are never
6//! compared. It is the principled, exact replacement for shingle-candidate + verify-all
7//! near-duplicate detection (e.g. Type-3 code clones = functions × IDF-weighted lines).
8//!
9//! ## Correctness gate
10//! [`cosine_join`] (the indexed algorithm) must return the **exact same** pair set, with
11//! bit-identical similarities, as [`cosine_join_bruteforce`] (the naive `O(n²)` oracle). Both score
12//! a pair with the same [`cos_full`] sorted-merge dot, so the values match to the bit and a pair is
13//! never dropped or gained at the threshold. This equality is asserted on fuzzed corpora — the same
14//! "two implementations, one answer" discipline the RO path uses.
15//!
16//! ## Method (this reference)
17//! Vectors are L2-normalised (so `cos = dot`) and their dimensions relabelled to a global rank by
18//! **increasing max weight** (common low-weight dims first, rare high-weight dims last — so only the
19//! rare tail is indexed, keeping postings short). For a probe vector we look up candidates through
20//! the index, then verify with the full dot. When *indexing* a vector we skip the leading prefix
21//! whose max possible contribution to any dot stays `< t`:
22//! `Σ_{k<b} w_k · maxw[dim_k] < t` ⇒ a pair matching only in that prefix can't reach `t`
23//! (weights ≥ 0), so it is guaranteed to also share an indexed dim. The skipped prefix holds the few
24//! common dims (huge postings); only the rarer tail is indexed — that is the whole speed-up. The
25//! tighter Cauchy–Schwarz / L2-norm bounds (true L2AP) and accumulation-time pruning are the next
26//! optimisation layer on top of this exact base.
27
28// Index/rank ids are bounded by the corpus dimension count (≤ u32 by construction); ranks are
29// assigned densely from a sorted dim list. The `as` casts below are intentional and in-range.
30// Dense numerical code: `i`/`d`/`w`/`y`/`a`/`s`/`t` mirror the cosine/prefix-bound formulas and read
31// clearer than verbose names here — allow the single-char-names pedantic lint module-wide.
32#![allow(
33    clippy::cast_possible_truncation,
34    clippy::cast_precision_loss,
35    clippy::cast_sign_loss,
36    clippy::many_single_char_names
37)]
38
39use std::cmp::Ordering;
40use std::collections::{HashMap, HashSet};
41
42use rayon::prelude::*;
43
44use crate::Concurrency;
45
46/// A corpus of L2-normalised sparse non-negative vectors in CSR form, with dimensions relabelled to
47/// a global rank (decreasing max weight). Built once; joinable at any threshold.
48pub struct Corpus {
49    n: usize,
50    ndims: usize,
51    /// `n + 1` row offsets into `dims`/`wts`.
52    indptr: Vec<usize>,
53    /// Relabelled (rank) dimension ids, ascending within each row.
54    dims: Vec<u32>,
55    /// L2-normalised weights, aligned with `dims`.
56    wts: Vec<f64>,
57    /// Per relabelled dim: the max weight it takes across the corpus (the prefix-filter bound).
58    maxw: Vec<f64>,
59}
60
61impl Corpus {
62    /// Number of vectors.
63    #[must_use]
64    pub fn len(&self) -> usize {
65        self.n
66    }
67
68    /// True if the corpus has no vectors.
69    #[must_use]
70    pub fn is_empty(&self) -> bool {
71        self.n == 0
72    }
73
74    /// `(dims, weights)` of vector `i` — dims ascending by global rank, weights L2-normalised.
75    #[must_use]
76    fn row(&self, i: usize) -> (&[u32], &[f64]) {
77        let (s, e) = (self.indptr[i], self.indptr[i + 1]);
78        (&self.dims[s..e], &self.wts[s..e])
79    }
80
81    /// CSR view for GPU offload: `(indptr, dims, wts_f32)`, with `indptr` cast to `u32` and the
82    /// L2-normalised weights cast to `f32` (Apple GPUs have no `f64`). For the
83    /// [`crate::simjoin_gpu`] throughput experiment only — the f32 cast means GPU dots are *not*
84    /// bit-identical to the CPU `f64` path.
85    #[cfg(all(target_os = "macos", feature = "gpu"))]
86    #[must_use]
87    pub fn csr_f32(&self) -> (Vec<u32>, Vec<u32>, Vec<f32>) {
88        let indptr = self.indptr.iter().map(|&x| x as u32).collect();
89        let wts = self.wts.iter().map(|&w| w as f32).collect();
90        (indptr, self.dims.clone(), wts)
91    }
92
93    /// Build a corpus from **token documents** — each document a list of string tokens — as TF-IDF
94    /// sparse vectors: dim = a distinct token, weight = `(token count in doc) × ln(n / df_token)`
95    /// (`df` = number of documents containing the token). This is the principled input for a Type-3
96    /// code-clone join (documents = functions, tokens = canonicalised lines) and the shape most
97    /// callers actually have. A token appearing in every document gets `idf = 0` (contributes
98    /// nothing), as expected.
99    #[must_use]
100    pub fn from_token_docs<S: AsRef<str>>(docs: &[Vec<S>]) -> Corpus {
101        let n = docs.len();
102        let mut dim: HashMap<&str, u32> = HashMap::new();
103        let mut df: Vec<u32> = Vec::new();
104        // Assign a dim id to each distinct token and count document frequency (once per doc/token).
105        let mut doc_ids: Vec<Vec<u32>> = Vec::with_capacity(n);
106        for doc in docs {
107            let mut ids = Vec::with_capacity(doc.len());
108            let mut seen: HashSet<u32> = HashSet::new();
109            for tok in doc {
110                let id = *dim.entry(tok.as_ref()).or_insert_with(|| {
111                    let i = df.len() as u32;
112                    df.push(0);
113                    i
114                });
115                ids.push(id);
116                if seen.insert(id) {
117                    df[id as usize] += 1;
118                }
119            }
120            doc_ids.push(ids);
121        }
122        let idf: Vec<f64> = df.iter().map(|&d| (n as f64 / f64::from(d.max(1))).ln()).collect();
123        // Emit (dim, idf) once per token occurrence; `from_rows` sums duplicates → tf·idf per dim.
124        let rows: Vec<Vec<(u32, f64)>> = doc_ids
125            .iter()
126            .map(|ids| ids.iter().map(|&id| (id, idf[id as usize])).collect())
127            .collect();
128        Corpus::from_rows(&rows)
129    }
130
131    /// Build a corpus from raw `(dim, weight)` rows. Duplicate dims within a row are summed; each row
132    /// is L2-normalised; dims are relabelled to a global rank by decreasing max weight. Weights are
133    /// expected non-negative (the prefix-filter bound requires it — IDF weights satisfy this).
134    #[must_use]
135    pub fn from_rows(rows: &[Vec<(u32, f64)>]) -> Corpus {
136        let n = rows.len();
137        // 1. Merge duplicate dims + L2-normalise each row (kept as (orig_dim, weight)).
138        let normed: Vec<Vec<(u32, f64)>> = rows
139            .iter()
140            .map(|r| {
141                let mut m: HashMap<u32, f64> = HashMap::new();
142                for &(d, w) in r {
143                    *m.entry(d).or_insert(0.0) += w;
144                }
145                let norm = m.values().map(|w| w * w).sum::<f64>().sqrt();
146                if norm > 0.0 {
147                    m.into_iter().map(|(d, w)| (d, w / norm)).collect()
148                } else {
149                    Vec::new()
150                }
151            })
152            .collect();
153        // 2. Max normalised weight per original dim.
154        let mut maxw_orig: HashMap<u32, f64> = HashMap::new();
155        for v in &normed {
156            for &(d, w) in v {
157                let e = maxw_orig.entry(d).or_insert(0.0);
158                if w > *e {
159                    *e = w;
160                }
161            }
162        }
163        // 3. Rank dims by (max weight ASC, dim asc) → dense global order. Ascending so the common,
164        //    low-weight dims land at the FRONT: they fill the un-indexed prefix (their tiny
165        //    `w·maxw` keeps the cumulative bound under `t` for many of them), and only the rare,
166        //    high-weight tail is indexed — short postings. (Reversing this indexes the common dims
167        //    and their huge postings, which is correct but orders of magnitude slower.)
168        let mut dims_sorted: Vec<(u32, f64)> = maxw_orig.into_iter().collect();
169        dims_sorted.sort_by(|a, b| {
170            a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0))
171        });
172        let rank: HashMap<u32, u32> =
173            dims_sorted.iter().enumerate().map(|(i, &(d, _))| (d, i as u32)).collect();
174        let ndims = dims_sorted.len();
175        let maxw: Vec<f64> = dims_sorted.iter().map(|&(_, w)| w).collect();
176        // 4. CSR with relabelled dims, ascending within each row.
177        let mut indptr = Vec::with_capacity(n + 1);
178        indptr.push(0);
179        let mut dims = Vec::new();
180        let mut wts = Vec::new();
181        for v in &normed {
182            let mut rv: Vec<(u32, f64)> = v.iter().map(|&(d, w)| (rank[&d], w)).collect();
183            rv.sort_unstable_by_key(|&(d, _)| d);
184            for (d, w) in rv {
185                dims.push(d);
186                wts.push(w);
187            }
188            indptr.push(dims.len());
189        }
190        Corpus { n, ndims, indptr, dims, wts, maxw }
191    }
192}
193
194/// Sorted-merge dot product of two rows (dims ascending). For L2-normalised non-negative vectors
195/// this is exactly their cosine similarity. The single scoring routine shared by the indexed join
196/// and the brute-force oracle, so both agree to the bit.
197#[must_use]
198#[cfg_attr(feature = "profiling", inline(never))]
199fn cos_full((da, wa): (&[u32], &[f64]), (db, wb): (&[u32], &[f64])) -> f64 {
200    // Rows are equal-length dim/weight slices (built by `Corpus::from_rows`).
201    debug_assert_eq!(da.len(), wa.len());
202    debug_assert_eq!(db.len(), wb.len());
203    let (la, lb) = (da.len(), db.len());
204    let (mut i, mut j) = (0usize, 0usize);
205    let mut s = 0.0f64;
206    // Branchless sorted-merge: the 3-way `cmp` branches mispredict on random dim order, and the
207    // weight loads carry bounds checks the optimiser can't elide. Here we always load both weights
208    // (`i<la=wa.len()`, `j<lb=wb.len()`) and mask the product by dim-equality — `s += 0.0` for
209    // unequal dims adds the *same* terms in the *same* increasing-dim order, so the result is
210    // bit-identical to the branchy merge while shedding every data-dependent branch.
211    while i < la && j < lb {
212        // SAFETY: `i < la == wa.len()` and `j < lb == wb.len()`.
213        let (ai, bj) = unsafe { (*da.get_unchecked(i), *db.get_unchecked(j)) };
214        let (wai, wbj) = unsafe { (*wa.get_unchecked(i), *wb.get_unchecked(j)) };
215        let eq = f64::from(u32::from(ai == bj));
216        s += eq * wai * wbj;
217        i += usize::from(ai <= bj);
218        j += usize::from(ai >= bj);
219    }
220    s
221}
222
223/// Per-vector prune data, read in the hot verify loop when a vector appears as a candidate. Packed
224/// into one array (not two parallel `Vec`s) so a candidate's `pnorm` and `split` come from a single
225/// scattered cache line instead of two — verify is memory-latency-bound on these random accesses.
226#[derive(Clone, Copy)]
227struct Bound {
228    /// ‖un-indexed prefix of y‖₂ — Cauchy–Schwarz cap on the dot mass the accumulator misses (the
229    /// prefix dims of `y` were never indexed, so never accumulated).
230    pnorm: f64,
231    /// Rank of `y`'s first *indexed* dim (`u32::MAX` if `y` indexed nothing). Every prefix dim of
232    /// `y` has rank `<` this — lets the probe restrict its norm to that rank range.
233    split: u32,
234}
235
236/// Per-vector data cached as each vector is indexed; read by the prune bound when the vector later
237/// appears as a candidate. (One `Cached` for the whole join.)
238struct Cached {
239    bound: Vec<Bound>,
240}
241
242/// Reused scratch buffers (allocated once for the whole join, not per probe).
243struct Scratch {
244    /// `acc[y]` = partial dot of the probe with `y` over shared *indexed* dims; `-1.0` = untouched
245    /// sentinel (a real partial dot of non-negative weights is always `≥ 0`).
246    acc: Vec<f64>,
247    /// Candidate ids the current probe touched (the keys to reset in `acc`).
248    touched: Vec<u32>,
249    /// Probe prefix L2 norms for this probe: `xpn[k] = ‖wi[..k]‖₂`, length `nnz+1`.
250    xpn: Vec<f64>,
251}
252
253/// Exact all-pairs cosine join via inverted index + **L2AP** prefix filtering and accumulation-time
254/// pruning. Returns `(j, i, cos)` with `j < i` for every pair with `cos ≥ t`. Bit-identical pair set
255/// to [`cosine_join_bruteforce`].
256///
257/// For each probe we accumulate a partial dot over shared *indexed* dims ([`accumulate`]), then for
258/// each touched candidate compute a Cauchy–Schwarz upper bound on the true cosine and skip the exact
259/// [`cos_full`] when it cannot reach `t` ([`verify_pruned`]). The bound is a filter only — survivors
260/// are scored exactly, so the output is byte-for-byte the brute-force result. On skewed data the
261/// bound prunes the ~99.9 % of candidates that collide on a single rare dim, so `cos_full` (the
262/// former 90 % hotspot) runs only on genuine near-matches.
263///
264/// The full inverted index is built once (postings ascending by id), then every vector is probed in
265/// **parallel**: probe `i` walks each posting only while `y < i` (postings are id-sorted), so it sees
266/// exactly the earlier vectors — each pair `(j, i)` with `j < i` is found once, from the larger id.
267/// This is the same candidate set the sequential index-as-you-go build produces, so the result is
268/// unchanged; the returned `Vec` is in arbitrary order (sort if a canonical order is needed).
269#[must_use]
270pub fn cosine_join(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
271    let n = c.n;
272    // Postings carry the indexed weight `(y, w_y[d])` so the scan can accumulate a partial dot.
273    let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
274    let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
275    for i in 0..n {
276        let (di, wi) = c.row(i);
277        index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
278    }
279    // Probe vectors in parallel; each worker keeps one reusable `Scratch` (an `n`-wide accumulator).
280    // `with_min_len` batches many probes per rayon task so the per-probe work (tiny) isn't dwarfed by
281    // task-splitting / scheduling overhead (`swtch_pri` in the profile).
282    (0..n)
283        .into_par_iter()
284        .with_min_len(256)
285        .map_init(
286            || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
287            |scratch, i| {
288                let (di, wi) = c.row(i);
289                accumulate(&index, (di, wi), i as u32, scratch);
290                let mut out = Vec::new();
291                verify_pruned(c, i, t, scratch, &cached, &mut out);
292                out
293            },
294        )
295        .flatten()
296        .collect()
297}
298
299/// Run the cosine join under a chosen [`Concurrency`] backend. Returns `(j, i, cos)` pairs with
300/// `j < i` and `cos ≥ t`, scores as `f64` (the `Gpu` mode's f32 cosines are widened losslessly).
301///
302/// - [`Concurrency::Cpu`] — [`cosine_join`]: exact `f64`, all-CPU, every platform.
303/// - [`Concurrency::GpuPlusCpu`] — exact `f64` hybrid: CPU generates survivor pairs, the GPU f32
304///   cosine *filters* the clear rejects, the CPU recomputes the exact `f64` score on what passes.
305///   **Byte-identical to `Cpu`**; both engines fully used. ~1.7–2× on bandwidth-bound real data.
306/// - [`Concurrency::Gpu`] — GPU-dominant `f32`: CPU generates survivor pairs, the GPU scores them and
307///   the result is emitted directly (no f64 re-verify). Fastest (~2×); differs from the exact answer
308///   only on pairs whose true cosine is within ~`1e-6` of `t` (measured: ≤1 pair in millions).
309///
310/// When the `gpu` feature is off, the target isn't macOS, or no Metal device can be acquired, the GPU
311/// modes transparently fall back to [`cosine_join`] (same as `Rationer`). This convenience entry
312/// **compiles + uploads the GPU corpus on every call** — fine for a one-shot join, but for repeated
313/// joins on one corpus build a [`CosineJoiner`] once and call [`CosineJoiner::join`], which holds the
314/// device + kernel + uploaded CSR across calls (and avoids the driver instability of compiling a
315/// Metal library hundreds of times in a tight loop).
316#[must_use]
317pub fn cosine_join_with(c: &Corpus, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
318    #[cfg(all(feature = "gpu", target_os = "macos"))]
319    {
320        if matches!(mode, Concurrency::Cpu) {
321            return cosine_join(c, t);
322        }
323        let (indptr, dims, wts) = c.csr_f32();
324        let Some(gpu) = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts) else {
325            return cosine_join(c, t); // no Metal device → CPU fallback
326        };
327        match mode {
328            Concurrency::GpuPlusCpu => cosine_join_gpu(c, t, &gpu),
329            // `Gpu`: emit the GPU f32 cosines directly (widened to f64), no exact re-verify.
330            Concurrency::Gpu => cosine_join_gpu_f32(c, t, &gpu)
331                .into_iter()
332                .map(|(a, b, s)| (a, b, f64::from(s)))
333                .collect(),
334            Concurrency::Cpu => unreachable!("handled above"),
335        }
336    }
337    #[cfg(not(all(feature = "gpu", target_os = "macos")))]
338    {
339        let _ = mode; // GPU modes degrade to the CPU join when the feature is off / not macOS.
340        cosine_join(c, t)
341    }
342}
343
344/// A reusable cosine-join handle that owns the corpus and — under `feature = "gpu"` on macOS — the
345/// Metal device, compiled `batch_cosine` kernel, and the corpus CSR uploaded to unified memory, all
346/// acquired **once** at construction. Repeated [`join`](CosineJoiner::join)s at different thresholds
347/// then skip the per-call kernel compile + CSR upload that [`cosine_join_with`] pays (only the
348/// `t`-specific inverted index is rebuilt each call, on the CPU). Always constructible; degrades to
349/// the pure-CPU join when the `gpu` feature is off or no Metal device is available — mirroring
350/// `Rationer`. This is the right entry point for sweeping thresholds or joining repeatedly.
351pub struct CosineJoiner {
352    corpus: Corpus,
353    /// Owned GPU resources (device + kernel + CSR in UMA); `None` when no Metal device was acquired.
354    #[cfg(all(feature = "gpu", target_os = "macos"))]
355    gpu: Option<crate::simjoin_gpu::BatchCosineGpu>,
356}
357
358impl CosineJoiner {
359    /// Build a joiner over `corpus`, acquiring the GPU device + uploading the corpus CSR once if the
360    /// `gpu` feature is on and a Metal device is present.
361    #[must_use]
362    pub fn new(corpus: Corpus) -> Self {
363        #[cfg(all(feature = "gpu", target_os = "macos"))]
364        {
365            let (indptr, dims, wts) = corpus.csr_f32();
366            let gpu = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts);
367            Self { corpus, gpu }
368        }
369        #[cfg(not(all(feature = "gpu", target_os = "macos")))]
370        {
371            Self { corpus }
372        }
373    }
374
375    /// The owned corpus (e.g. for `len()` or to run other queries against it).
376    #[must_use]
377    pub fn corpus(&self) -> &Corpus {
378        &self.corpus
379    }
380
381    /// Whether a Metal GPU backend was acquired. Always `false` without `feature = "gpu"` on macOS;
382    /// when `false`, every [`join`](CosineJoiner::join) runs on the CPU regardless of `mode`.
383    #[must_use]
384    pub fn has_gpu(&self) -> bool {
385        #[cfg(all(feature = "gpu", target_os = "macos"))]
386        {
387            self.gpu.is_some()
388        }
389        #[cfg(not(all(feature = "gpu", target_os = "macos")))]
390        {
391            false
392        }
393    }
394
395    /// Run the join at threshold `t` under `mode`, reusing the handle's GPU resources. Returns the
396    /// same results as [`cosine_join_with`] (Cpu/GpuPlusCpu exact, Gpu f32→f64); falls back to the
397    /// CPU join when the GPU is unavailable.
398    #[must_use]
399    pub fn join(&self, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
400        #[cfg(all(feature = "gpu", target_os = "macos"))]
401        {
402            match (mode, self.gpu.as_ref()) {
403                (Concurrency::GpuPlusCpu, Some(g)) => cosine_join_gpu(&self.corpus, t, g),
404                (Concurrency::Gpu, Some(g)) => cosine_join_gpu_f32(&self.corpus, t, g)
405                    .into_iter()
406                    .map(|(a, b, s)| (a, b, f64::from(s)))
407                    .collect(),
408                _ => cosine_join(&self.corpus, t), // Cpu mode, or no Metal device
409            }
410        }
411        #[cfg(not(all(feature = "gpu", target_os = "macos")))]
412        {
413            let _ = mode;
414            cosine_join(&self.corpus, t)
415        }
416    }
417}
418
419/// FP slack for the prune bound: the Cauchy–Schwarz upper bound holds in exact arithmetic, but the
420/// accumulated dot and the `sqrt` norms each carry rounding error. We only *skip* `cos_full` when
421/// the bound is below `t` by more than this slack, so a true pair (exact cosine `≥ t`) is never
422/// pruned. Not skipping is always correctness-safe (just a wasted `cos_full`), so the slack trades a
423/// negligible number of extra verifies for safety and never changes the emitted pair set.
424/// `1e-9 ≫` the `~1e-15` accumulated error over ~15 terms.
425const PRUNE_SLACK: f64 = 1e-9;
426
427/// Probe row length at/above which the cheap monotone pre-bound in [`verify_pruned`] is worth it.
428/// The pre-bound skips the per-candidate `partition_point` for candidates that can't reach `t`, but
429/// `partition_point` over a *short* probe row is already cheap, so on sparse corpora (mean nnz ≈ 11,
430/// e.g. `PyPI` type3) the pre-bound's own `fmul`/`fadd`/compare is pure overhead that rarely prunes
431/// (match-dense ⇒ high survivor rate). It only pays on **dense** rows (e.g. find-dup-defs
432/// patternology, mean nnz ≈ 61) where `partition_point` is costly and the survivor rate is tiny.
433/// Gating on `di.len()` (a per-probe, loop-invariant test) keeps the sparse regime regression-free.
434const PREBOUND_MIN_DIMS: usize = 24;
435
436/// Phase 1 — accumulate: for each indexed dim of the probe, add `w_probe·w_y` into `acc[y]` for
437/// every **earlier** `y` (`y < cutoff`, the probe's own id) indexing that dim. Postings are id-sorted,
438/// so we `break` at the first `y ≥ cutoff`. Leaves `acc` `-1.0` everywhere except the touched ids
439/// (listed in `touched`, reset in [`verify_pruned`]). One scattered `acc[]` FMA per posting.
440#[cfg_attr(feature = "profiling", inline(never))]
441fn accumulate(index: &[Vec<(u32, f64)>], (di, wi): (&[u32], &[f64]), cutoff: u32, s: &mut Scratch) {
442    // Split the borrows so `acc` and `touched` are independent: otherwise the optimiser, seeing
443    // `s.touched.push(...)` go through the same `&mut Scratch`, can't prove the push leaves
444    // `s.acc`'s base pointer untouched and reloads `acc.ptr` from the struct on EVERY posting entry
445    // (one extra load across the ~5M-entry accumulation on dense corpora). With `acc` a separate
446    // `&mut [f64]` local the base stays in a register; a `touched` realloc can't alias it.
447    let Scratch { acc, touched, .. } = s;
448    let acc = acc.as_mut_slice();
449    touched.clear();
450    // At most one distinct candidate per acc slot, so `n == acc.len()` slots is enough headroom to
451    // append every first-touch without bounds checks or a realloc inside the loop.
452    touched.reserve(acc.len());
453    let tptr = touched.as_mut_ptr();
454    let mut tlen = 0usize;
455    for (&d, &w) in di.iter().zip(wi) {
456        for &(y, wy) in &index[d as usize] {
457            if y >= cutoff {
458                break;
459            }
460            let yu = y as usize;
461            // SAFETY: `y` is a vector id pushed by `index_suffix`, so `yu < n == acc.len()`.
462            let a = unsafe { acc.get_unchecked_mut(yu) };
463            // Branchless first-touch: the `*a < 0.0` test fed a data-dependent branch that mispredicts
464            // (first-touch vs repeat interleave unpredictably across the probe's dims). Instead select
465            // the base (`0.0` on first touch, current partial dot otherwise) with a conditional move,
466            // and append `y` to `touched` by an UNCONDITIONAL store that's only *committed* when the
467            // length is bumped — `tlen += first`. Same accumulator values and same touched order as
468            // the branchy form, so the result is bit-identical; just no mispredicting branch.
469            let first = *a < 0.0;
470            let base = if first { 0.0 } else { *a };
471            *a = base + w * wy;
472            // SAFETY: `tlen ≤ distinct candidates so far < acc.len() ≤ touched.capacity()`.
473            unsafe { *tptr.add(tlen) = y };
474            tlen += usize::from(first);
475        }
476    }
477    // SAFETY: `tlen` first-touch stores were written into the reserved region, in order.
478    unsafe { touched.set_len(tlen) };
479}
480
481/// Phase 2 — prune + verify. For each touched candidate `y`, reset its accumulator and test the
482/// **L2AP `l2` bound**: the dot mass missing from `acc[y]` (the dims in `prefix(y)`) is at most
483/// `‖x_{rank<split[y]}‖ · ‖prefix(y)‖` by Cauchy–Schwarz, where `x_{rank<split[y]}` is the probe
484/// restricted to the rank range `prefix(y)` lives in. Since the probe's mass sits in its rare
485/// (high-rank) dims, that restricted norm is tiny — a far tighter cap than the whole-probe `‖x‖=1`.
486/// Only if `acc[y] + that bound ≥ t` (minus FP slack) do we score exactly with [`cos_full`]. Filter
487/// only — the emitted value is the exact dot, so the pair set is bit-identical to brute force.
488#[cfg_attr(feature = "profiling", inline(never))]
489fn verify_pruned(
490    c: &Corpus,
491    i: usize,
492    t: f64,
493    s: &mut Scratch,
494    cached: &Cached,
495    out: &mut Vec<(usize, usize, f64)>,
496) {
497    let (di, wi) = c.row(i);
498    // Probe prefix L2 norms: xpn[k] = ‖wi[..k]‖₂ (di ascending by rank, so xpn[k] = norm over the
499    // probe's k lowest-rank dims). One sqrt per probe dim, reused across all its candidates.
500    s.xpn.clear();
501    s.xpn.push(0.0);
502    let mut sq = 0.0f64;
503    for &w in wi {
504        sq += w * w;
505        s.xpn.push(sq.sqrt());
506    }
507    // `‖probe‖` = the full prefix norm. Since `xpn` is monotonic, `xpn[kstar] ≤ xnorm` for every
508    // candidate, so `a + xnorm·pnorm ≥ a + xpn[kstar]·pnorm` (the exact bound). A candidate failing
509    // the cheap `a + xnorm·pnorm < need` test therefore also fails the exact bound — prune it without
510    // the per-candidate `partition_point`. The survivor set is bit-identical; only the binary search
511    // is skipped for the ~99% of touched candidates that can't reach `t` on dense corpora.
512    let xnorm = sq.sqrt();
513    let prebound = di.len() >= PREBOUND_MIN_DIMS;
514    let need = t - PRUNE_SLACK;
515    let Scratch { acc, touched, xpn } = s;
516    // (Software-prefetching the candidate row a few ahead was tried + reverted: no measurable change
517    // — with all cores gathering at once the join is memory-*bandwidth*-bound, not per-access
518    // latency-bound, so prefetch can't add throughput.)
519    for &y in touched.iter() {
520        let yu = y as usize;
521        // SAFETY: `yu < n` (same provenance as in `accumulate`).
522        let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
523        // SAFETY: `yu < n`. One scattered load fetches both prune fields.
524        let bd = unsafe { *cached.bound.get_unchecked(yu) };
525        // Cheap monotone pre-bound (dense rows only) — skip the binary search when even `xnorm`
526        // can't clear `need`. `prebound` is loop-invariant, so sparse rows pay nothing.
527        if prebound && a + xnorm * bd.pnorm < need {
528            continue;
529        }
530        // Number of probe dims with rank < split[y] → index into xpn (di sorted ascending).
531        let kstar = di.partition_point(|&d| d < bd.split);
532        // SAFETY: kstar ≤ di.len() == wi.len() = xpn.len()-1.
533        // (An added maxweight cap `min(…, Σ wx·maxw)` was tried + reverted: never tighter than the
534        // L2 cap on either synthetic or real data — it sums over all probe-prefix dims, not just the
535        // shared ones — so it pruned nothing and cost ~16% on the real PyPI corpus.)
536        let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
537        if bound >= need {
538            let cos = cos_full((di, wi), c.row(yu));
539            if cos >= t {
540                out.push((yu, i, cos));
541            }
542        }
543    }
544}
545
546/// Phase 3 — index this vector's suffix: skip the leading prefix whose max possible contribution to
547/// any dot stays `< t` (`Σ w_k·maxw[dim_k] < t`), index only the rarer tail (short postings = the
548/// whole speed-up), and cache `pnorm[i] = ‖prefix‖₂` and `split[i]` = first indexed rank for the
549/// [`verify_pruned`] bound.
550#[cfg_attr(feature = "profiling", inline(never))]
551fn index_suffix(
552    c: &Corpus,
553    i: usize,
554    (di, wi): (&[u32], &[f64]),
555    t: f64,
556    index: &mut [Vec<(u32, f64)>],
557    cached: &mut Cached,
558) {
559    // Largest safe prefix under the maxweight bound: `Σ_{k<b} w_k·maxw[dim_k] < t` ⇒ the prefix
560    // can't contribute `t` to any dot (weights ≥ 0). (A norm-based extension `‖x_{<b}‖ < t` was
561    // tried and reverted: it gave zero candidate reduction in the realistic `t<1` regime — the
562    // maxweight bound already indexes less — and it is FP-fragile at `t=1` where `‖x‖` rounds below
563    // `1` and indexes nothing, dropping exact-duplicate pairs.)
564    let mut rs = 0.0f64;
565    let mut b = 0usize;
566    for k in 0..di.len() {
567        let bound = wi[k] * c.maxw[di[k] as usize];
568        if rs + bound >= t {
569            break;
570        }
571        rs += bound;
572        b = k + 1;
573    }
574    let mut p = 0.0f64;
575    for &w in &wi[..b] {
576        p += w * w;
577    }
578    cached.bound[i] = Bound {
579        pnorm: p.sqrt(),
580        split: if b < di.len() { di[b] } else { u32::MAX },
581    };
582    for k in b..di.len() {
583        index[di[k] as usize].push((i as u32, wi[k]));
584    }
585}
586
587/// FP margin for the GPU f32 cosine *filter* in [`cosine_join_gpu`]: a survivor is dropped only when
588/// its GPU f32 cosine is below `t` by more than this. The f32 dot's error is `~1e-6` relative, so a
589/// `1e-4` absolute margin never drops a true pair; the CPU then recomputes the exact `f64` score on
590/// whatever passes, so the emitted pair set + scores stay bit-identical to [`cosine_join`].
591#[cfg(all(target_os = "macos", feature = "gpu"))]
592const GPU_FILTER_MARGIN: f64 = 1e-4;
593
594/// Like the verify half of [`verify_pruned`], but instead of scoring, **collects** each surviving
595/// `(candidate, probe)` pair (`candidate < probe`) for batch scoring elsewhere. The bound here MUST
596/// stay identical to `verify_pruned`'s so the survivor set matches exactly.
597#[cfg(all(target_os = "macos", feature = "gpu"))]
598fn collect_survivors(c: &Corpus, i: usize, t: f64, s: &mut Scratch, cached: &Cached, out: &mut Vec<(u32, u32)>) {
599    let (di, wi) = c.row(i);
600    s.xpn.clear();
601    s.xpn.push(0.0);
602    let mut sq = 0.0f64;
603    for &w in wi {
604        sq += w * w;
605        s.xpn.push(sq.sqrt());
606    }
607    let xnorm = sq.sqrt();
608    let prebound = di.len() >= PREBOUND_MIN_DIMS;
609    let need = t - PRUNE_SLACK;
610    let Scratch { acc, touched, xpn } = s;
611    for &y in touched.iter() {
612        let yu = y as usize;
613        // SAFETY: `yu < n` (same provenance as in `accumulate`).
614        let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
615        let bd = unsafe { *cached.bound.get_unchecked(yu) };
616        // Cheap monotone pre-bound (see `verify_pruned`) — same survivor set, skips the binary search.
617        if prebound && a + xnorm * bd.pnorm < need {
618            continue;
619        }
620        let kstar = di.partition_point(|&d| d < bd.split);
621        let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
622        if bound >= need {
623            out.push((y, i as u32)); // (candidate, probe), candidate < probe
624        }
625    }
626}
627
628/// **CPU+GPU hybrid join** (feature `gpu`, macOS). Returns the **exact same** pair set + scores as
629/// [`cosine_join`] (and thus the brute-force oracle) — only faster when the verify is bandwidth-bound.
630///
631/// Pipeline: CPU builds the index and, in parallel, accumulates + bounds every probe to a list of
632/// surviving `(candidate, probe)` pairs (no `cos_full`). The GPU then computes an f32 cosine for the
633/// whole batch (its memory-level parallelism clears the random-gather dots ~3× faster than the CPU),
634/// and the CPU recomputes the exact `f64` `cos_full` **only** on the pairs whose GPU score clears
635/// `t − margin` — typically a few percent of survivors. Because the GPU is a *conservative filter*
636/// (margin ≫ f32 error, so no true pair is ever dropped) and every emitted score is the exact CPU
637/// `f64` value, the output is byte-for-byte identical to [`cosine_join`].
638#[cfg(all(target_os = "macos", feature = "gpu"))]
639#[must_use]
640pub fn cosine_join_gpu(
641    c: &Corpus,
642    t: f64,
643    gpu: &crate::simjoin_gpu::BatchCosineGpu,
644) -> Vec<(usize, usize, f64)> {
645    let (pa, pb) = survivor_pairs(c, t);
646    if pa.is_empty() {
647        return Vec::new();
648    }
649    // GPU phase: f32 cosine over the whole survivor batch (conservative filter).
650    let gcos = gpu.cosine_batch(&pa, &pb);
651    let need = t - GPU_FILTER_MARGIN;
652    // CPU phase: exact f64 re-verify only on pairs the GPU filter passes.
653    (0..pa.len())
654        .into_par_iter()
655        .with_min_len(1024)
656        .filter_map(|k| {
657            if f64::from(gcos[k]) < need {
658                return None;
659            }
660            let (a, b) = (pa[k] as usize, pb[k] as usize);
661            let cos = cos_full(c.row(a), c.row(b));
662            (cos >= t).then_some((a, b, cos))
663        })
664        .collect()
665}
666
667/// **Pure-f32** CPU+GPU join (feature `gpu`, macOS): same survivor generation as [`cosine_join_gpu`]
668/// but emits the GPU's **f32** cosine directly, with **no exact f64 re-verify**. Trades byte-parity
669/// for speed (no re-verify, and `cos_full` never runs on the GPU survivors). The result differs from
670/// [`cosine_join`] only on pairs whose true cosine lies within ~`1e-6` (f32 rounding) of `t` — for a
671/// similarity join with an arbitrary threshold that is immaterial. Use when an ε-exact answer is
672/// acceptable; use [`cosine_join_gpu`] when bit-exactness is required.
673#[cfg(all(target_os = "macos", feature = "gpu"))]
674#[must_use]
675pub fn cosine_join_gpu_f32(
676    c: &Corpus,
677    t: f64,
678    gpu: &crate::simjoin_gpu::BatchCosineGpu,
679) -> Vec<(usize, usize, f32)> {
680    let (pa, pb) = survivor_pairs(c, t);
681    if pa.is_empty() {
682        return Vec::new();
683    }
684    let gcos = gpu.cosine_batch(&pa, &pb);
685    let tf = t as f32;
686    (0..pa.len())
687        .into_par_iter()
688        .with_min_len(1024)
689        .filter_map(|k| (gcos[k] >= tf).then_some((pa[k] as usize, pb[k] as usize, gcos[k])))
690        .collect()
691}
692
693/// CPU half shared by the GPU joins: build the index, then accumulate + bound every probe in
694/// parallel to the list of surviving `(candidate, probe)` pairs (candidate `<` probe), split into
695/// two `u32` arrays ready for [`crate::simjoin_gpu::BatchCosineGpu::cosine_batch`].
696#[cfg(all(target_os = "macos", feature = "gpu"))]
697fn survivor_pairs(c: &Corpus, t: f64) -> (Vec<u32>, Vec<u32>) {
698    let n = c.n;
699    let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
700    let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
701    for i in 0..n {
702        let (di, wi) = c.row(i);
703        index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
704    }
705    let pairs: Vec<(u32, u32)> = (0..n)
706        .into_par_iter()
707        .with_min_len(256)
708        .map_init(
709            || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
710            |scratch, i| {
711                let (di, wi) = c.row(i);
712                accumulate(&index, (di, wi), i as u32, scratch);
713                let mut out = Vec::new();
714                collect_survivors(c, i, t, scratch, &cached, &mut out);
715                out
716            },
717        )
718        .flatten()
719        .collect();
720    pairs.into_iter().unzip()
721}
722
723/// Diagnostic (feature `profiling`, off the hot path): counts that quantify the prune. Returns
724/// `(candidates, survivors, pairs)` — candidates touched by the accumulator, survivors that pass the
725/// Cauchy–Schwarz bound (i.e. the `cos_full` calls actually made), and real pairs. `survivors /
726/// candidates` is the prune pass-rate (lower = better); `survivors` is the verify volume we pay for.
727#[cfg(feature = "profiling")]
728#[must_use]
729pub fn cosine_join_counts(c: &Corpus, t: f64) -> (u64, u64, u64) {
730    let n = c.n;
731    let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
732    let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
733    for i in 0..n {
734        let (di, wi) = c.row(i);
735        index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
736    }
737    let mut s = Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() };
738    let (mut ncand, mut survivors, mut pairs) = (0u64, 0u64, 0u64);
739    let need = t - PRUNE_SLACK;
740    for i in 0..n {
741        let (di, wi) = c.row(i);
742        accumulate(&index, (di, wi), i as u32, &mut s);
743        ncand += s.touched.len() as u64;
744        s.xpn.clear();
745        s.xpn.push(0.0);
746        let mut sq = 0.0f64;
747        for &w in wi {
748            sq += w * w;
749            s.xpn.push(sq.sqrt());
750        }
751        let xnorm = sq.sqrt();
752        let prebound = di.len() >= PREBOUND_MIN_DIMS;
753        for &y in &s.touched {
754            let yu = y as usize;
755            let a = std::mem::replace(&mut s.acc[yu], -1.0);
756            let bd = cached.bound[yu];
757            if prebound && a + xnorm * bd.pnorm < need {
758                continue;
759            }
760            let kstar = di.partition_point(|&d| d < bd.split);
761            if a + s.xpn[kstar] * bd.pnorm >= need {
762                survivors += 1;
763                if cos_full((di, wi), c.row(yu)) >= t {
764                    pairs += 1;
765                }
766            }
767        }
768    }
769    (ncand, survivors, pairs)
770}
771
772/// Naive `O(n²)` oracle: score every pair with [`cos_full`], keep `cos ≥ t`. The correctness
773/// reference [`cosine_join`] is validated against.
774#[must_use]
775pub fn cosine_join_bruteforce(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
776    let mut out: Vec<(usize, usize, f64)> = Vec::new();
777    for i in 0..c.n {
778        for j in 0..i {
779            let s = cos_full(c.row(i), c.row(j));
780            if s >= t {
781                out.push((j, i, s));
782            }
783        }
784    }
785    out
786}
787
788#[cfg(test)]
789mod tests {
790    use super::{cosine_join, cosine_join_bruteforce, Corpus};
791
792    fn xorshift(seed: u64) -> impl FnMut() -> u64 {
793        let mut s = seed;
794        move || {
795            s ^= s << 13;
796            s ^= s >> 7;
797            s ^= s << 17;
798            s
799        }
800    }
801
802    fn sort_pairs(mut v: Vec<(usize, usize, f64)>) -> Vec<(usize, usize, u64)> {
803        v.sort_by_key(|a| (a.0, a.1));
804        v.into_iter().map(|(a, b, s)| (a, b, s.to_bits())).collect()
805    }
806
807    #[test]
808    fn indexed_join_matches_bruteforce() {
809        let mut next = xorshift(0x9e37_79b9_7f4a_7c15);
810        for _ in 0..400 {
811            let n = (next() % 40 + 2) as usize;
812            let dim_space = next() % 15 + 1;
813            let rows: Vec<Vec<(u32, f64)>> = (0..n)
814                .map(|_| {
815                    let nnz = (next() % 8) as usize;
816                    (0..nnz)
817                        .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
818                        .collect()
819                })
820                .collect();
821            let c = Corpus::from_rows(&rows);
822            for &t in &[0.1_f64, 0.25, 0.5, 0.75, 0.9, 1.0] {
823                let got = sort_pairs(cosine_join(&c, t));
824                let want = sort_pairs(cosine_join_bruteforce(&c, t));
825                assert_eq!(got, want, "n={n} t={t}");
826            }
827        }
828    }
829
830    /// The CPU+GPU hybrid [`super::cosine_join_gpu`] must return bit-identical results to the pure-CPU
831    /// [`cosine_join`] on fuzzed corpora — the GPU is only a conservative filter; every emitted score
832    /// is the exact CPU `f64` value. Skips when no Metal device is present.
833    #[cfg(all(feature = "gpu", target_os = "macos"))]
834    #[test]
835    fn gpu_hybrid_matches_cpu() {
836        use super::CosineJoiner;
837        use crate::Concurrency;
838        let mut next = xorshift(0x1357_9bdf_0246_8ace);
839        for _ in 0..60 {
840            let n = (next() % 60 + 4) as usize;
841            let dim_space = next() % 20 + 2;
842            let rows: Vec<Vec<(u32, f64)>> = (0..n)
843                .map(|_| {
844                    let nnz = (next() % 10) as usize;
845                    (0..nnz)
846                        .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
847                        .collect()
848                })
849                .collect();
850            // One reusable handle per corpus — `join` is called repeatedly across thresholds, which
851            // also exercises that the handle reuses its GPU resources (no per-call library compile).
852            let joiner = CosineJoiner::new(Corpus::from_rows(&rows));
853            if !joiner.has_gpu() {
854                eprintln!("no Metal device — skipping gpu_hybrid_matches_cpu");
855                return;
856            }
857            for &t in &[0.1_f64, 0.3, 0.5, 0.7, 0.9, 1.0] {
858                let want = sort_pairs(cosine_join(joiner.corpus(), t));
859                // Exact GPU+CPU hybrid is byte-identical to the plain join; `Cpu` mode too.
860                assert_eq!(
861                    sort_pairs(joiner.join(t, Concurrency::GpuPlusCpu)),
862                    want,
863                    "GpuPlusCpu n={n} t={t}"
864                );
865                assert_eq!(sort_pairs(joiner.join(t, Concurrency::Cpu)), want, "Cpu n={n} t={t}");
866            }
867        }
868    }
869}