Skip to main content

difflib_fast/
simjoin.rs

1//! `simjoin` — exact all-pairs **weighted-cosine similarity join**: given a corpus of sparse
2//! non-negative vectors, find every pair `(i, j)` whose cosine similarity is `≥ t`.
3//!
4//! This is the `AllPairs` / `L2AP` family (Bayardo et al. WWW'07; Anastasiu & Karypis ICDE'14): an
5//! inverted index with **prefix filtering** so that the vast majority of vector pairs are never
6//! compared. It is the principled, exact replacement for shingle-candidate + verify-all
7//! near-duplicate detection (e.g. Type-3 code clones = functions × IDF-weighted lines).
8//!
9//! ## Correctness gate
10//! [`cosine_join`] (the indexed algorithm) must return the **exact same** pair set, with
11//! bit-identical similarities, as [`cosine_join_bruteforce`] (the naive `O(n²)` oracle). Both score
12//! a pair with the same [`cos_full`] sorted-merge dot, so the values match to the bit and a pair is
13//! never dropped or gained at the threshold. This equality is asserted on fuzzed corpora — the same
14//! "two implementations, one answer" discipline the RO path uses.
15//!
16//! ## Method (this reference)
17//! Vectors are L2-normalised (so `cos = dot`) and their dimensions relabelled to a global rank by
18//! **increasing max weight** (common low-weight dims first, rare high-weight dims last — so only the
19//! rare tail is indexed, keeping postings short). For a probe vector we look up candidates through
20//! the index, then verify with the full dot. When *indexing* a vector we skip the leading prefix
21//! whose max possible contribution to any dot stays `< t`:
22//! `Σ_{k<b} w_k · maxw[dim_k] < t` ⇒ a pair matching only in that prefix can't reach `t`
23//! (weights ≥ 0), so it is guaranteed to also share an indexed dim. The skipped prefix holds the few
24//! common dims (huge postings); only the rarer tail is indexed — that is the whole speed-up. The
25//! tighter Cauchy–Schwarz / L2-norm bounds (true L2AP) and accumulation-time pruning are the next
26//! optimisation layer on top of this exact base.
27
28// Index/rank ids are bounded by the corpus dimension count (≤ u32 by construction); ranks are
29// assigned densely from a sorted dim list. The `as` casts below are intentional and in-range.
30// Dense numerical code: `i`/`d`/`w`/`y`/`a`/`s`/`t` mirror the cosine/prefix-bound formulas and read
31// clearer than verbose names here — allow the single-char-names pedantic lint module-wide.
32#![allow(
33    clippy::cast_possible_truncation,
34    clippy::cast_precision_loss,
35    clippy::cast_sign_loss,
36    clippy::many_single_char_names
37)]
38
39use std::cmp::Ordering;
40use std::collections::{HashMap, HashSet};
41
42use rayon::prelude::*;
43
44use crate::Concurrency;
45
46/// A corpus of L2-normalised sparse non-negative vectors in CSR form, with dimensions relabelled to
47/// a global rank (decreasing max weight). Built once; joinable at any threshold.
48pub struct Corpus {
49    n: usize,
50    ndims: usize,
51    /// `n + 1` row offsets into `dims`/`wts`.
52    indptr: Vec<usize>,
53    /// Relabelled (rank) dimension ids, ascending within each row.
54    dims: Vec<u32>,
55    /// L2-normalised weights, aligned with `dims`.
56    wts: Vec<f64>,
57    /// Per relabelled dim: the max weight it takes across the corpus (the prefix-filter bound).
58    maxw: Vec<f64>,
59}
60
61impl Corpus {
62    /// Number of vectors.
63    #[must_use]
64    pub fn len(&self) -> usize {
65        self.n
66    }
67
68    /// True if the corpus has no vectors.
69    #[must_use]
70    pub fn is_empty(&self) -> bool {
71        self.n == 0
72    }
73
74    /// `(dims, weights)` of vector `i` — dims ascending by global rank, weights L2-normalised.
75    #[must_use]
76    fn row(&self, i: usize) -> (&[u32], &[f64]) {
77        let (s, e) = (self.indptr[i], self.indptr[i + 1]);
78        (&self.dims[s..e], &self.wts[s..e])
79    }
80
81    /// CSR view for GPU offload: `(indptr, dims, wts_f32)`, with `indptr` cast to `u32` and the
82    /// L2-normalised weights cast to `f32` (Apple GPUs have no `f64`). For the
83    /// [`crate::simjoin_gpu`] throughput experiment only — the f32 cast means GPU dots are *not*
84    /// bit-identical to the CPU `f64` path.
85    #[cfg(all(target_os = "macos", feature = "gpu"))]
86    #[must_use]
87    pub fn csr_f32(&self) -> (Vec<u32>, Vec<u32>, Vec<f32>) {
88        let indptr = self.indptr.iter().map(|&x| x as u32).collect();
89        let wts = self.wts.iter().map(|&w| w as f32).collect();
90        (indptr, self.dims.clone(), wts)
91    }
92
93    /// Build a corpus from **token documents** — each document a list of string tokens — as TF-IDF
94    /// sparse vectors: dim = a distinct token, weight = `(token count in doc) × ln(n / df_token)`
95    /// (`df` = number of documents containing the token). This is the principled input for a Type-3
96    /// code-clone join (documents = functions, tokens = canonicalised lines) and the shape most
97    /// callers actually have. A token appearing in every document gets `idf = 0` (contributes
98    /// nothing), as expected.
99    #[must_use]
100    pub fn from_token_docs<S: AsRef<str>>(docs: &[Vec<S>]) -> Corpus {
101        let n = docs.len();
102        let mut dim: HashMap<&str, u32> = HashMap::new();
103        let mut df: Vec<u32> = Vec::new();
104        // Assign a dim id to each distinct token and count document frequency (once per doc/token).
105        let mut doc_ids: Vec<Vec<u32>> = Vec::with_capacity(n);
106        for doc in docs {
107            let mut ids = Vec::with_capacity(doc.len());
108            let mut seen: HashSet<u32> = HashSet::new();
109            for tok in doc {
110                let id = *dim.entry(tok.as_ref()).or_insert_with(|| {
111                    let i = df.len() as u32;
112                    df.push(0);
113                    i
114                });
115                ids.push(id);
116                if seen.insert(id) {
117                    df[id as usize] += 1;
118                }
119            }
120            doc_ids.push(ids);
121        }
122        let idf: Vec<f64> = df.iter().map(|&d| (n as f64 / f64::from(d.max(1))).ln()).collect();
123        // Emit (dim, idf) once per token occurrence; `from_rows` sums duplicates → tf·idf per dim.
124        let rows: Vec<Vec<(u32, f64)>> = doc_ids
125            .iter()
126            .map(|ids| ids.iter().map(|&id| (id, idf[id as usize])).collect())
127            .collect();
128        Corpus::from_rows(&rows)
129    }
130
131    /// Build a corpus from raw `(dim, weight)` rows. Duplicate dims within a row are summed; each row
132    /// is L2-normalised; dims are relabelled to a global rank by decreasing max weight. Weights are
133    /// expected non-negative (the prefix-filter bound requires it — IDF weights satisfy this).
134    #[must_use]
135    pub fn from_rows(rows: &[Vec<(u32, f64)>]) -> Corpus {
136        let n = rows.len();
137        // 1. Merge duplicate dims + L2-normalise each row (kept as (orig_dim, weight)).
138        let normed: Vec<Vec<(u32, f64)>> = rows
139            .iter()
140            .map(|r| {
141                let mut m: HashMap<u32, f64> = HashMap::new();
142                for &(d, w) in r {
143                    *m.entry(d).or_insert(0.0) += w;
144                }
145                let norm = m.values().map(|w| w * w).sum::<f64>().sqrt();
146                if norm > 0.0 {
147                    m.into_iter().map(|(d, w)| (d, w / norm)).collect()
148                } else {
149                    Vec::new()
150                }
151            })
152            .collect();
153        // 2. Max normalised weight per original dim.
154        let mut maxw_orig: HashMap<u32, f64> = HashMap::new();
155        for v in &normed {
156            for &(d, w) in v {
157                let e = maxw_orig.entry(d).or_insert(0.0);
158                if w > *e {
159                    *e = w;
160                }
161            }
162        }
163        // 3. Rank dims by (max weight ASC, dim asc) → dense global order. Ascending so the common,
164        //    low-weight dims land at the FRONT: they fill the un-indexed prefix (their tiny
165        //    `w·maxw` keeps the cumulative bound under `t` for many of them), and only the rare,
166        //    high-weight tail is indexed — short postings. (Reversing this indexes the common dims
167        //    and their huge postings, which is correct but orders of magnitude slower.)
168        let mut dims_sorted: Vec<(u32, f64)> = maxw_orig.into_iter().collect();
169        dims_sorted.sort_by(|a, b| {
170            a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0))
171        });
172        let rank: HashMap<u32, u32> =
173            dims_sorted.iter().enumerate().map(|(i, &(d, _))| (d, i as u32)).collect();
174        let ndims = dims_sorted.len();
175        let maxw: Vec<f64> = dims_sorted.iter().map(|&(_, w)| w).collect();
176        // 4. CSR with relabelled dims, ascending within each row.
177        let mut indptr = Vec::with_capacity(n + 1);
178        indptr.push(0);
179        let mut dims = Vec::new();
180        let mut wts = Vec::new();
181        for v in &normed {
182            let mut rv: Vec<(u32, f64)> = v.iter().map(|&(d, w)| (rank[&d], w)).collect();
183            rv.sort_unstable_by_key(|&(d, _)| d);
184            for (d, w) in rv {
185                dims.push(d);
186                wts.push(w);
187            }
188            indptr.push(dims.len());
189        }
190        Corpus { n, ndims, indptr, dims, wts, maxw }
191    }
192}
193
194/// Sorted-merge dot product of two rows (dims ascending). For L2-normalised non-negative vectors
195/// this is exactly their cosine similarity. The single scoring routine shared by the indexed join
196/// and the brute-force oracle, so both agree to the bit.
197#[must_use]
198#[cfg_attr(feature = "profiling", inline(never))]
199fn cos_full((da, wa): (&[u32], &[f64]), (db, wb): (&[u32], &[f64])) -> f64 {
200    // Rows are equal-length dim/weight slices (built by `Corpus::from_rows`).
201    debug_assert_eq!(da.len(), wa.len());
202    debug_assert_eq!(db.len(), wb.len());
203    let (la, lb) = (da.len(), db.len());
204    let (mut i, mut j) = (0usize, 0usize);
205    let mut s = 0.0f64;
206    // Branchless sorted-merge: the 3-way `cmp` branches mispredict on random dim order, and the
207    // weight loads carry bounds checks the optimiser can't elide. Here we always load both weights
208    // (`i<la=wa.len()`, `j<lb=wb.len()`) and mask the product by dim-equality — `s += 0.0` for
209    // unequal dims adds the *same* terms in the *same* increasing-dim order, so the result is
210    // bit-identical to the branchy merge while shedding every data-dependent branch.
211    while i < la && j < lb {
212        // SAFETY: `i < la == wa.len()` and `j < lb == wb.len()`.
213        let (ai, bj) = unsafe { (*da.get_unchecked(i), *db.get_unchecked(j)) };
214        let (wai, wbj) = unsafe { (*wa.get_unchecked(i), *wb.get_unchecked(j)) };
215        let eq = f64::from(u32::from(ai == bj));
216        s += eq * wai * wbj;
217        i += usize::from(ai <= bj);
218        j += usize::from(ai >= bj);
219    }
220    s
221}
222
223/// Per-vector prune data, read in the hot verify loop when a vector appears as a candidate. Packed
224/// into one array (not two parallel `Vec`s) so a candidate's `pnorm` and `split` come from a single
225/// scattered cache line instead of two — verify is memory-latency-bound on these random accesses.
226#[derive(Clone, Copy)]
227struct Bound {
228    /// ‖un-indexed prefix of y‖₂ — Cauchy–Schwarz cap on the dot mass the accumulator misses (the
229    /// prefix dims of `y` were never indexed, so never accumulated).
230    pnorm: f64,
231    /// Rank of `y`'s first *indexed* dim (`u32::MAX` if `y` indexed nothing). Every prefix dim of
232    /// `y` has rank `<` this — lets the probe restrict its norm to that rank range.
233    split: u32,
234}
235
236/// Per-vector data cached as each vector is indexed; read by the prune bound when the vector later
237/// appears as a candidate. (One `Cached` for the whole join.)
238struct Cached {
239    bound: Vec<Bound>,
240}
241
242/// Reused scratch buffers (allocated once for the whole join, not per probe).
243struct Scratch {
244    /// `acc[y]` = partial dot of the probe with `y` over shared *indexed* dims; `-1.0` = untouched
245    /// sentinel (a real partial dot of non-negative weights is always `≥ 0`).
246    acc: Vec<f64>,
247    /// Candidate ids the current probe touched (the keys to reset in `acc`).
248    touched: Vec<u32>,
249    /// Probe prefix L2 norms for this probe: `xpn[k] = ‖wi[..k]‖₂`, length `nnz+1`.
250    xpn: Vec<f64>,
251}
252
253/// Exact all-pairs cosine join via inverted index + **L2AP** prefix filtering and accumulation-time
254/// pruning. Returns `(j, i, cos)` with `j < i` for every pair with `cos ≥ t`. Bit-identical pair set
255/// to [`cosine_join_bruteforce`].
256///
257/// For each probe we accumulate a partial dot over shared *indexed* dims ([`accumulate`]), then for
258/// each touched candidate compute a Cauchy–Schwarz upper bound on the true cosine and skip the exact
259/// [`cos_full`] when it cannot reach `t` ([`verify_pruned`]). The bound is a filter only — survivors
260/// are scored exactly, so the output is byte-for-byte the brute-force result. On skewed data the
261/// bound prunes the ~99.9 % of candidates that collide on a single rare dim, so `cos_full` (the
262/// former 90 % hotspot) runs only on genuine near-matches.
263///
264/// The full inverted index is built once (postings ascending by id), then every vector is probed in
265/// **parallel**: probe `i` walks each posting only while `y < i` (postings are id-sorted), so it sees
266/// exactly the earlier vectors — each pair `(j, i)` with `j < i` is found once, from the larger id.
267/// This is the same candidate set the sequential index-as-you-go build produces, so the result is
268/// unchanged; the returned `Vec` is in arbitrary order (sort if a canonical order is needed).
269#[must_use]
270pub fn cosine_join(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
271    let n = c.n;
272    // Postings carry the indexed weight `(y, w_y[d])` so the scan can accumulate a partial dot.
273    let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
274    let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
275    for i in 0..n {
276        let (di, wi) = c.row(i);
277        index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
278    }
279    // Probe vectors in parallel; each worker keeps one reusable `Scratch` (an `n`-wide accumulator).
280    // `with_min_len` batches many probes per rayon task so the per-probe work (tiny) isn't dwarfed by
281    // task-splitting / scheduling overhead (`swtch_pri` in the profile).
282    (0..n)
283        .into_par_iter()
284        .with_min_len(256)
285        .map_init(
286            || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
287            |scratch, i| {
288                let (di, wi) = c.row(i);
289                accumulate(&index, (di, wi), i as u32, scratch);
290                let mut out = Vec::new();
291                verify_pruned(c, i, t, scratch, &cached, &mut out);
292                out
293            },
294        )
295        .flatten()
296        .collect()
297}
298
299/// Run the cosine join under a chosen [`Concurrency`] backend. Returns `(j, i, cos)` pairs with
300/// `j < i` and `cos ≥ t`, scores as `f64` (the `Gpu` mode's f32 cosines are widened losslessly).
301///
302/// - [`Concurrency::Cpu`] — [`cosine_join`]: exact `f64`, all-CPU, every platform.
303/// - [`Concurrency::GpuPlusCpu`] — exact `f64` hybrid: CPU generates survivor pairs, the GPU f32
304///   cosine *filters* the clear rejects, the CPU recomputes the exact `f64` score on what passes.
305///   **Byte-identical to `Cpu`**; both engines fully used. ~1.7–2× on bandwidth-bound real data.
306/// - [`Concurrency::Gpu`] — GPU-dominant `f32`: CPU generates survivor pairs, the GPU scores them and
307///   the result is emitted directly (no f64 re-verify). Fastest (~2×); differs from the exact answer
308///   only on pairs whose true cosine is within ~`1e-6` of `t` (measured: ≤1 pair in millions).
309///
310/// When the `gpu` feature is off, the target isn't macOS, or no Metal device can be acquired, the GPU
311/// modes transparently fall back to [`cosine_join`] (same as `Rationer`). This convenience entry
312/// **compiles + uploads the GPU corpus on every call** — fine for a one-shot join, but for repeated
313/// joins on one corpus build a [`CosineJoiner`] once and call [`CosineJoiner::join`], which holds the
314/// device + kernel + uploaded CSR across calls (and avoids the driver instability of compiling a
315/// Metal library hundreds of times in a tight loop).
316#[must_use]
317pub fn cosine_join_with(c: &Corpus, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
318    #[cfg(all(feature = "gpu", target_os = "macos"))]
319    {
320        if matches!(mode, Concurrency::Cpu) {
321            return cosine_join(c, t);
322        }
323        let (indptr, dims, wts) = c.csr_f32();
324        let Some(gpu) = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts) else {
325            return cosine_join(c, t); // no Metal device → CPU fallback
326        };
327        match mode {
328            Concurrency::GpuPlusCpu => cosine_join_gpu(c, t, &gpu),
329            // `Gpu`: emit the GPU f32 cosines directly (widened to f64), no exact re-verify.
330            Concurrency::Gpu => cosine_join_gpu_f32(c, t, &gpu)
331                .into_iter()
332                .map(|(a, b, s)| (a, b, f64::from(s)))
333                .collect(),
334            Concurrency::Cpu => unreachable!("handled above"),
335        }
336    }
337    #[cfg(not(all(feature = "gpu", target_os = "macos")))]
338    {
339        let _ = mode; // GPU modes degrade to the CPU join when the feature is off / not macOS.
340        cosine_join(c, t)
341    }
342}
343
344/// A reusable cosine-join handle that owns the corpus and — under `feature = "gpu"` on macOS — the
345/// Metal device, compiled `batch_cosine` kernel, and the corpus CSR uploaded to unified memory, all
346/// acquired **once** at construction. Repeated [`join`](CosineJoiner::join)s at different thresholds
347/// then skip the per-call kernel compile + CSR upload that [`cosine_join_with`] pays (only the
348/// `t`-specific inverted index is rebuilt each call, on the CPU). Always constructible; degrades to
349/// the pure-CPU join when the `gpu` feature is off or no Metal device is available — mirroring
350/// `Rationer`. This is the right entry point for sweeping thresholds or joining repeatedly.
351pub struct CosineJoiner {
352    corpus: Corpus,
353    /// Owned GPU resources (device + kernel + CSR in UMA); `None` when no Metal device was acquired.
354    #[cfg(all(feature = "gpu", target_os = "macos"))]
355    gpu: Option<crate::simjoin_gpu::BatchCosineGpu>,
356}
357
358impl CosineJoiner {
359    /// Build a joiner over `corpus`, acquiring the GPU device + uploading the corpus CSR once if the
360    /// `gpu` feature is on and a Metal device is present.
361    #[must_use]
362    pub fn new(corpus: Corpus) -> Self {
363        #[cfg(all(feature = "gpu", target_os = "macos"))]
364        {
365            let (indptr, dims, wts) = corpus.csr_f32();
366            let gpu = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts);
367            Self { corpus, gpu }
368        }
369        #[cfg(not(all(feature = "gpu", target_os = "macos")))]
370        {
371            Self { corpus }
372        }
373    }
374
375    /// The owned corpus (e.g. for `len()` or to run other queries against it).
376    #[must_use]
377    pub fn corpus(&self) -> &Corpus {
378        &self.corpus
379    }
380
381    /// Whether a Metal GPU backend was acquired. Always `false` without `feature = "gpu"` on macOS;
382    /// when `false`, every [`join`](CosineJoiner::join) runs on the CPU regardless of `mode`.
383    #[must_use]
384    pub fn has_gpu(&self) -> bool {
385        #[cfg(all(feature = "gpu", target_os = "macos"))]
386        {
387            self.gpu.is_some()
388        }
389        #[cfg(not(all(feature = "gpu", target_os = "macos")))]
390        {
391            false
392        }
393    }
394
395    /// Run the join at threshold `t` under `mode`, reusing the handle's GPU resources. Returns the
396    /// same results as [`cosine_join_with`] (Cpu/GpuPlusCpu exact, Gpu f32→f64); falls back to the
397    /// CPU join when the GPU is unavailable.
398    #[must_use]
399    pub fn join(&self, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
400        #[cfg(all(feature = "gpu", target_os = "macos"))]
401        {
402            match (mode, self.gpu.as_ref()) {
403                (Concurrency::GpuPlusCpu, Some(g)) => cosine_join_gpu(&self.corpus, t, g),
404                (Concurrency::Gpu, Some(g)) => cosine_join_gpu_f32(&self.corpus, t, g)
405                    .into_iter()
406                    .map(|(a, b, s)| (a, b, f64::from(s)))
407                    .collect(),
408                _ => cosine_join(&self.corpus, t), // Cpu mode, or no Metal device
409            }
410        }
411        #[cfg(not(all(feature = "gpu", target_os = "macos")))]
412        {
413            let _ = mode;
414            cosine_join(&self.corpus, t)
415        }
416    }
417}
418
419/// FP slack for the prune bound: the Cauchy–Schwarz upper bound holds in exact arithmetic, but the
420/// accumulated dot and the `sqrt` norms each carry rounding error. We only *skip* `cos_full` when
421/// the bound is below `t` by more than this slack, so a true pair (exact cosine `≥ t`) is never
422/// pruned. Not skipping is always correctness-safe (just a wasted `cos_full`), so the slack trades a
423/// negligible number of extra verifies for safety and never changes the emitted pair set.
424/// `1e-9 ≫` the `~1e-15` accumulated error over ~15 terms.
425const PRUNE_SLACK: f64 = 1e-9;
426
427/// Phase 1 — accumulate: for each indexed dim of the probe, add `w_probe·w_y` into `acc[y]` for
428/// every **earlier** `y` (`y < cutoff`, the probe's own id) indexing that dim. Postings are id-sorted,
429/// so we `break` at the first `y ≥ cutoff`. Leaves `acc` `-1.0` everywhere except the touched ids
430/// (listed in `touched`, reset in [`verify_pruned`]). One scattered `acc[]` FMA per posting.
431#[cfg_attr(feature = "profiling", inline(never))]
432fn accumulate(index: &[Vec<(u32, f64)>], (di, wi): (&[u32], &[f64]), cutoff: u32, s: &mut Scratch) {
433    s.touched.clear();
434    for (&d, &w) in di.iter().zip(wi) {
435        for &(y, wy) in &index[d as usize] {
436            if y >= cutoff {
437                break;
438            }
439            let yu = y as usize;
440            // SAFETY: `y` is a vector id pushed by `index_suffix`, so `yu < n == acc.len()`.
441            let a = unsafe { s.acc.get_unchecked_mut(yu) };
442            if *a < 0.0 {
443                *a = 0.0;
444                s.touched.push(y);
445            }
446            *a += w * wy;
447        }
448    }
449}
450
451/// Phase 2 — prune + verify. For each touched candidate `y`, reset its accumulator and test the
452/// **L2AP `l2` bound**: the dot mass missing from `acc[y]` (the dims in `prefix(y)`) is at most
453/// `‖x_{rank<split[y]}‖ · ‖prefix(y)‖` by Cauchy–Schwarz, where `x_{rank<split[y]}` is the probe
454/// restricted to the rank range `prefix(y)` lives in. Since the probe's mass sits in its rare
455/// (high-rank) dims, that restricted norm is tiny — a far tighter cap than the whole-probe `‖x‖=1`.
456/// Only if `acc[y] + that bound ≥ t` (minus FP slack) do we score exactly with [`cos_full`]. Filter
457/// only — the emitted value is the exact dot, so the pair set is bit-identical to brute force.
458#[cfg_attr(feature = "profiling", inline(never))]
459fn verify_pruned(
460    c: &Corpus,
461    i: usize,
462    t: f64,
463    s: &mut Scratch,
464    cached: &Cached,
465    out: &mut Vec<(usize, usize, f64)>,
466) {
467    let (di, wi) = c.row(i);
468    // Probe prefix L2 norms: xpn[k] = ‖wi[..k]‖₂ (di ascending by rank, so xpn[k] = norm over the
469    // probe's k lowest-rank dims). One sqrt per probe dim, reused across all its candidates.
470    s.xpn.clear();
471    s.xpn.push(0.0);
472    let mut sq = 0.0f64;
473    for &w in wi {
474        sq += w * w;
475        s.xpn.push(sq.sqrt());
476    }
477    let need = t - PRUNE_SLACK;
478    let Scratch { acc, touched, xpn } = s;
479    // (Software-prefetching the candidate row a few ahead was tried + reverted: no measurable change
480    // — with all cores gathering at once the join is memory-*bandwidth*-bound, not per-access
481    // latency-bound, so prefetch can't add throughput.)
482    for &y in touched.iter() {
483        let yu = y as usize;
484        // SAFETY: `yu < n` (same provenance as in `accumulate`).
485        let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
486        // SAFETY: `yu < n`. One scattered load fetches both prune fields.
487        let bd = unsafe { *cached.bound.get_unchecked(yu) };
488        // Number of probe dims with rank < split[y] → index into xpn (di sorted ascending).
489        let kstar = di.partition_point(|&d| d < bd.split);
490        // SAFETY: kstar ≤ di.len() == wi.len() = xpn.len()-1.
491        // (An added maxweight cap `min(…, Σ wx·maxw)` was tried + reverted: never tighter than the
492        // L2 cap on either synthetic or real data — it sums over all probe-prefix dims, not just the
493        // shared ones — so it pruned nothing and cost ~16% on the real PyPI corpus.)
494        let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
495        if bound >= need {
496            let cos = cos_full((di, wi), c.row(yu));
497            if cos >= t {
498                out.push((yu, i, cos));
499            }
500        }
501    }
502}
503
504/// Phase 3 — index this vector's suffix: skip the leading prefix whose max possible contribution to
505/// any dot stays `< t` (`Σ w_k·maxw[dim_k] < t`), index only the rarer tail (short postings = the
506/// whole speed-up), and cache `pnorm[i] = ‖prefix‖₂` and `split[i]` = first indexed rank for the
507/// [`verify_pruned`] bound.
508#[cfg_attr(feature = "profiling", inline(never))]
509fn index_suffix(
510    c: &Corpus,
511    i: usize,
512    (di, wi): (&[u32], &[f64]),
513    t: f64,
514    index: &mut [Vec<(u32, f64)>],
515    cached: &mut Cached,
516) {
517    // Largest safe prefix under the maxweight bound: `Σ_{k<b} w_k·maxw[dim_k] < t` ⇒ the prefix
518    // can't contribute `t` to any dot (weights ≥ 0). (A norm-based extension `‖x_{<b}‖ < t` was
519    // tried and reverted: it gave zero candidate reduction in the realistic `t<1` regime — the
520    // maxweight bound already indexes less — and it is FP-fragile at `t=1` where `‖x‖` rounds below
521    // `1` and indexes nothing, dropping exact-duplicate pairs.)
522    let mut rs = 0.0f64;
523    let mut b = 0usize;
524    for k in 0..di.len() {
525        let bound = wi[k] * c.maxw[di[k] as usize];
526        if rs + bound >= t {
527            break;
528        }
529        rs += bound;
530        b = k + 1;
531    }
532    let mut p = 0.0f64;
533    for &w in &wi[..b] {
534        p += w * w;
535    }
536    cached.bound[i] = Bound {
537        pnorm: p.sqrt(),
538        split: if b < di.len() { di[b] } else { u32::MAX },
539    };
540    for k in b..di.len() {
541        index[di[k] as usize].push((i as u32, wi[k]));
542    }
543}
544
545/// FP margin for the GPU f32 cosine *filter* in [`cosine_join_gpu`]: a survivor is dropped only when
546/// its GPU f32 cosine is below `t` by more than this. The f32 dot's error is `~1e-6` relative, so a
547/// `1e-4` absolute margin never drops a true pair; the CPU then recomputes the exact `f64` score on
548/// whatever passes, so the emitted pair set + scores stay bit-identical to [`cosine_join`].
549#[cfg(all(target_os = "macos", feature = "gpu"))]
550const GPU_FILTER_MARGIN: f64 = 1e-4;
551
552/// Like the verify half of [`verify_pruned`], but instead of scoring, **collects** each surviving
553/// `(candidate, probe)` pair (`candidate < probe`) for batch scoring elsewhere. The bound here MUST
554/// stay identical to `verify_pruned`'s so the survivor set matches exactly.
555#[cfg(all(target_os = "macos", feature = "gpu"))]
556fn collect_survivors(c: &Corpus, i: usize, t: f64, s: &mut Scratch, cached: &Cached, out: &mut Vec<(u32, u32)>) {
557    let (di, wi) = c.row(i);
558    s.xpn.clear();
559    s.xpn.push(0.0);
560    let mut sq = 0.0f64;
561    for &w in wi {
562        sq += w * w;
563        s.xpn.push(sq.sqrt());
564    }
565    let need = t - PRUNE_SLACK;
566    let Scratch { acc, touched, xpn } = s;
567    for &y in touched.iter() {
568        let yu = y as usize;
569        // SAFETY: `yu < n` (same provenance as in `accumulate`).
570        let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
571        let bd = unsafe { *cached.bound.get_unchecked(yu) };
572        let kstar = di.partition_point(|&d| d < bd.split);
573        let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
574        if bound >= need {
575            out.push((y, i as u32)); // (candidate, probe), candidate < probe
576        }
577    }
578}
579
580/// **CPU+GPU hybrid join** (feature `gpu`, macOS). Returns the **exact same** pair set + scores as
581/// [`cosine_join`] (and thus the brute-force oracle) — only faster when the verify is bandwidth-bound.
582///
583/// Pipeline: CPU builds the index and, in parallel, accumulates + bounds every probe to a list of
584/// surviving `(candidate, probe)` pairs (no `cos_full`). The GPU then computes an f32 cosine for the
585/// whole batch (its memory-level parallelism clears the random-gather dots ~3× faster than the CPU),
586/// and the CPU recomputes the exact `f64` `cos_full` **only** on the pairs whose GPU score clears
587/// `t − margin` — typically a few percent of survivors. Because the GPU is a *conservative filter*
588/// (margin ≫ f32 error, so no true pair is ever dropped) and every emitted score is the exact CPU
589/// `f64` value, the output is byte-for-byte identical to [`cosine_join`].
590#[cfg(all(target_os = "macos", feature = "gpu"))]
591#[must_use]
592pub fn cosine_join_gpu(
593    c: &Corpus,
594    t: f64,
595    gpu: &crate::simjoin_gpu::BatchCosineGpu,
596) -> Vec<(usize, usize, f64)> {
597    let (pa, pb) = survivor_pairs(c, t);
598    if pa.is_empty() {
599        return Vec::new();
600    }
601    // GPU phase: f32 cosine over the whole survivor batch (conservative filter).
602    let gcos = gpu.cosine_batch(&pa, &pb);
603    let need = t - GPU_FILTER_MARGIN;
604    // CPU phase: exact f64 re-verify only on pairs the GPU filter passes.
605    (0..pa.len())
606        .into_par_iter()
607        .with_min_len(1024)
608        .filter_map(|k| {
609            if f64::from(gcos[k]) < need {
610                return None;
611            }
612            let (a, b) = (pa[k] as usize, pb[k] as usize);
613            let cos = cos_full(c.row(a), c.row(b));
614            (cos >= t).then_some((a, b, cos))
615        })
616        .collect()
617}
618
619/// **Pure-f32** CPU+GPU join (feature `gpu`, macOS): same survivor generation as [`cosine_join_gpu`]
620/// but emits the GPU's **f32** cosine directly, with **no exact f64 re-verify**. Trades byte-parity
621/// for speed (no re-verify, and `cos_full` never runs on the GPU survivors). The result differs from
622/// [`cosine_join`] only on pairs whose true cosine lies within ~`1e-6` (f32 rounding) of `t` — for a
623/// similarity join with an arbitrary threshold that is immaterial. Use when an ε-exact answer is
624/// acceptable; use [`cosine_join_gpu`] when bit-exactness is required.
625#[cfg(all(target_os = "macos", feature = "gpu"))]
626#[must_use]
627pub fn cosine_join_gpu_f32(
628    c: &Corpus,
629    t: f64,
630    gpu: &crate::simjoin_gpu::BatchCosineGpu,
631) -> Vec<(usize, usize, f32)> {
632    let (pa, pb) = survivor_pairs(c, t);
633    if pa.is_empty() {
634        return Vec::new();
635    }
636    let gcos = gpu.cosine_batch(&pa, &pb);
637    let tf = t as f32;
638    (0..pa.len())
639        .into_par_iter()
640        .with_min_len(1024)
641        .filter_map(|k| (gcos[k] >= tf).then_some((pa[k] as usize, pb[k] as usize, gcos[k])))
642        .collect()
643}
644
645/// CPU half shared by the GPU joins: build the index, then accumulate + bound every probe in
646/// parallel to the list of surviving `(candidate, probe)` pairs (candidate `<` probe), split into
647/// two `u32` arrays ready for [`crate::simjoin_gpu::BatchCosineGpu::cosine_batch`].
648#[cfg(all(target_os = "macos", feature = "gpu"))]
649fn survivor_pairs(c: &Corpus, t: f64) -> (Vec<u32>, Vec<u32>) {
650    let n = c.n;
651    let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
652    let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
653    for i in 0..n {
654        let (di, wi) = c.row(i);
655        index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
656    }
657    let pairs: Vec<(u32, u32)> = (0..n)
658        .into_par_iter()
659        .with_min_len(256)
660        .map_init(
661            || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
662            |scratch, i| {
663                let (di, wi) = c.row(i);
664                accumulate(&index, (di, wi), i as u32, scratch);
665                let mut out = Vec::new();
666                collect_survivors(c, i, t, scratch, &cached, &mut out);
667                out
668            },
669        )
670        .flatten()
671        .collect();
672    pairs.into_iter().unzip()
673}
674
675/// Diagnostic (feature `profiling`, off the hot path): counts that quantify the prune. Returns
676/// `(candidates, survivors, pairs)` — candidates touched by the accumulator, survivors that pass the
677/// Cauchy–Schwarz bound (i.e. the `cos_full` calls actually made), and real pairs. `survivors /
678/// candidates` is the prune pass-rate (lower = better); `survivors` is the verify volume we pay for.
679#[cfg(feature = "profiling")]
680#[must_use]
681pub fn cosine_join_counts(c: &Corpus, t: f64) -> (u64, u64, u64) {
682    let n = c.n;
683    let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
684    let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
685    for i in 0..n {
686        let (di, wi) = c.row(i);
687        index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
688    }
689    let mut s = Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() };
690    let (mut ncand, mut survivors, mut pairs) = (0u64, 0u64, 0u64);
691    let need = t - PRUNE_SLACK;
692    for i in 0..n {
693        let (di, wi) = c.row(i);
694        accumulate(&index, (di, wi), i as u32, &mut s);
695        ncand += s.touched.len() as u64;
696        s.xpn.clear();
697        s.xpn.push(0.0);
698        let mut sq = 0.0f64;
699        for &w in wi {
700            sq += w * w;
701            s.xpn.push(sq.sqrt());
702        }
703        for &y in &s.touched {
704            let yu = y as usize;
705            let a = std::mem::replace(&mut s.acc[yu], -1.0);
706            let bd = cached.bound[yu];
707            let kstar = di.partition_point(|&d| d < bd.split);
708            if a + s.xpn[kstar] * bd.pnorm >= need {
709                survivors += 1;
710                if cos_full((di, wi), c.row(yu)) >= t {
711                    pairs += 1;
712                }
713            }
714        }
715    }
716    (ncand, survivors, pairs)
717}
718
719/// Naive `O(n²)` oracle: score every pair with [`cos_full`], keep `cos ≥ t`. The correctness
720/// reference [`cosine_join`] is validated against.
721#[must_use]
722pub fn cosine_join_bruteforce(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
723    let mut out: Vec<(usize, usize, f64)> = Vec::new();
724    for i in 0..c.n {
725        for j in 0..i {
726            let s = cos_full(c.row(i), c.row(j));
727            if s >= t {
728                out.push((j, i, s));
729            }
730        }
731    }
732    out
733}
734
735#[cfg(test)]
736mod tests {
737    use super::{cosine_join, cosine_join_bruteforce, Corpus};
738
739    fn xorshift(seed: u64) -> impl FnMut() -> u64 {
740        let mut s = seed;
741        move || {
742            s ^= s << 13;
743            s ^= s >> 7;
744            s ^= s << 17;
745            s
746        }
747    }
748
749    fn sort_pairs(mut v: Vec<(usize, usize, f64)>) -> Vec<(usize, usize, u64)> {
750        v.sort_by_key(|a| (a.0, a.1));
751        v.into_iter().map(|(a, b, s)| (a, b, s.to_bits())).collect()
752    }
753
754    #[test]
755    fn indexed_join_matches_bruteforce() {
756        let mut next = xorshift(0x9e37_79b9_7f4a_7c15);
757        for _ in 0..400 {
758            let n = (next() % 40 + 2) as usize;
759            let dim_space = next() % 15 + 1;
760            let rows: Vec<Vec<(u32, f64)>> = (0..n)
761                .map(|_| {
762                    let nnz = (next() % 8) as usize;
763                    (0..nnz)
764                        .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
765                        .collect()
766                })
767                .collect();
768            let c = Corpus::from_rows(&rows);
769            for &t in &[0.1_f64, 0.25, 0.5, 0.75, 0.9, 1.0] {
770                let got = sort_pairs(cosine_join(&c, t));
771                let want = sort_pairs(cosine_join_bruteforce(&c, t));
772                assert_eq!(got, want, "n={n} t={t}");
773            }
774        }
775    }
776
777    /// The CPU+GPU hybrid [`super::cosine_join_gpu`] must return bit-identical results to the pure-CPU
778    /// [`cosine_join`] on fuzzed corpora — the GPU is only a conservative filter; every emitted score
779    /// is the exact CPU `f64` value. Skips when no Metal device is present.
780    #[cfg(all(feature = "gpu", target_os = "macos"))]
781    #[test]
782    fn gpu_hybrid_matches_cpu() {
783        use super::CosineJoiner;
784        use crate::Concurrency;
785        let mut next = xorshift(0x1357_9bdf_0246_8ace);
786        for _ in 0..60 {
787            let n = (next() % 60 + 4) as usize;
788            let dim_space = next() % 20 + 2;
789            let rows: Vec<Vec<(u32, f64)>> = (0..n)
790                .map(|_| {
791                    let nnz = (next() % 10) as usize;
792                    (0..nnz)
793                        .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
794                        .collect()
795                })
796                .collect();
797            // One reusable handle per corpus — `join` is called repeatedly across thresholds, which
798            // also exercises that the handle reuses its GPU resources (no per-call library compile).
799            let joiner = CosineJoiner::new(Corpus::from_rows(&rows));
800            if !joiner.has_gpu() {
801                eprintln!("no Metal device — skipping gpu_hybrid_matches_cpu");
802                return;
803            }
804            for &t in &[0.1_f64, 0.3, 0.5, 0.7, 0.9, 1.0] {
805                let want = sort_pairs(cosine_join(joiner.corpus(), t));
806                // Exact GPU+CPU hybrid is byte-identical to the plain join; `Cpu` mode too.
807                assert_eq!(
808                    sort_pairs(joiner.join(t, Concurrency::GpuPlusCpu)),
809                    want,
810                    "GpuPlusCpu n={n} t={t}"
811                );
812                assert_eq!(sort_pairs(joiner.join(t, Concurrency::Cpu)), want, "Cpu n={n} t={t}");
813            }
814        }
815    }
816}