difflib_fast/simjoin.rs
1//! `simjoin` — exact all-pairs **weighted-cosine similarity join**: given a corpus of sparse
2//! non-negative vectors, find every pair `(i, j)` whose cosine similarity is `≥ t`.
3//!
4//! This is the `AllPairs` / `L2AP` family (Bayardo et al. WWW'07; Anastasiu & Karypis ICDE'14): an
5//! inverted index with **prefix filtering** so that the vast majority of vector pairs are never
6//! compared. It is the principled, exact replacement for shingle-candidate + verify-all
7//! near-duplicate detection (e.g. Type-3 code clones = functions × IDF-weighted lines).
8//!
9//! ## Correctness gate
10//! [`cosine_join`] (the indexed algorithm) must return the **exact same** pair set, with
11//! bit-identical similarities, as [`cosine_join_bruteforce`] (the naive `O(n²)` oracle). Both score
12//! a pair with the same [`cos_full`] sorted-merge dot, so the values match to the bit and a pair is
13//! never dropped or gained at the threshold. This equality is asserted on fuzzed corpora — the same
14//! "two implementations, one answer" discipline the RO path uses.
15//!
16//! ## Method (this reference)
17//! Vectors are L2-normalised (so `cos = dot`) and their dimensions relabelled to a global rank by
18//! **increasing max weight** (common low-weight dims first, rare high-weight dims last — so only the
19//! rare tail is indexed, keeping postings short). For a probe vector we look up candidates through
20//! the index, then verify with the full dot. When *indexing* a vector we skip the leading prefix
21//! whose max possible contribution to any dot stays `< t`:
22//! `Σ_{k<b} w_k · maxw[dim_k] < t` ⇒ a pair matching only in that prefix can't reach `t`
23//! (weights ≥ 0), so it is guaranteed to also share an indexed dim. The skipped prefix holds the few
24//! common dims (huge postings); only the rarer tail is indexed — that is the whole speed-up. The
25//! tighter Cauchy–Schwarz / L2-norm bounds (true L2AP) and accumulation-time pruning are the next
26//! optimisation layer on top of this exact base.
27
28// Index/rank ids are bounded by the corpus dimension count (≤ u32 by construction); ranks are
29// assigned densely from a sorted dim list. The `as` casts below are intentional and in-range.
30// Dense numerical code: `i`/`d`/`w`/`y`/`a`/`s`/`t` mirror the cosine/prefix-bound formulas and read
31// clearer than verbose names here — allow the single-char-names pedantic lint module-wide.
32#![allow(
33 clippy::cast_possible_truncation,
34 clippy::cast_precision_loss,
35 clippy::cast_sign_loss,
36 clippy::many_single_char_names
37)]
38
39use std::cmp::Ordering;
40use std::collections::{HashMap, HashSet};
41
42use rayon::prelude::*;
43
44use crate::Concurrency;
45
46/// A corpus of L2-normalised sparse non-negative vectors in CSR form, with dimensions relabelled to
47/// a global rank (decreasing max weight). Built once; joinable at any threshold.
48pub struct Corpus {
49 n: usize,
50 ndims: usize,
51 /// `n + 1` row offsets into `dims`/`wts`.
52 indptr: Vec<usize>,
53 /// Relabelled (rank) dimension ids, ascending within each row.
54 dims: Vec<u32>,
55 /// L2-normalised weights, aligned with `dims`.
56 wts: Vec<f64>,
57 /// Per relabelled dim: the max weight it takes across the corpus (the prefix-filter bound).
58 maxw: Vec<f64>,
59}
60
61impl Corpus {
62 /// Number of vectors.
63 #[must_use]
64 pub fn len(&self) -> usize {
65 self.n
66 }
67
68 /// True if the corpus has no vectors.
69 #[must_use]
70 pub fn is_empty(&self) -> bool {
71 self.n == 0
72 }
73
74 /// `(dims, weights)` of vector `i` — dims ascending by global rank, weights L2-normalised.
75 #[must_use]
76 fn row(&self, i: usize) -> (&[u32], &[f64]) {
77 let (s, e) = (self.indptr[i], self.indptr[i + 1]);
78 (&self.dims[s..e], &self.wts[s..e])
79 }
80
81 /// CSR view for GPU offload: `(indptr, dims, wts_f32)`, with `indptr` cast to `u32` and the
82 /// L2-normalised weights cast to `f32` (Apple GPUs have no `f64`). For the
83 /// [`crate::simjoin_gpu`] throughput experiment only — the f32 cast means GPU dots are *not*
84 /// bit-identical to the CPU `f64` path.
85 #[cfg(all(target_os = "macos", feature = "gpu"))]
86 #[must_use]
87 pub fn csr_f32(&self) -> (Vec<u32>, Vec<u32>, Vec<f32>) {
88 let indptr = self.indptr.iter().map(|&x| x as u32).collect();
89 let wts = self.wts.iter().map(|&w| w as f32).collect();
90 (indptr, self.dims.clone(), wts)
91 }
92
93 /// Build a corpus from **token documents** — each document a list of string tokens — as TF-IDF
94 /// sparse vectors: dim = a distinct token, weight = `(token count in doc) × ln(n / df_token)`
95 /// (`df` = number of documents containing the token). This is the principled input for a Type-3
96 /// code-clone join (documents = functions, tokens = canonicalised lines) and the shape most
97 /// callers actually have. A token appearing in every document gets `idf = 0` (contributes
98 /// nothing), as expected.
99 #[must_use]
100 pub fn from_token_docs<S: AsRef<str>>(docs: &[Vec<S>]) -> Corpus {
101 let n = docs.len();
102 let mut dim: HashMap<&str, u32> = HashMap::new();
103 let mut df: Vec<u32> = Vec::new();
104 // Assign a dim id to each distinct token and count document frequency (once per doc/token).
105 let mut doc_ids: Vec<Vec<u32>> = Vec::with_capacity(n);
106 for doc in docs {
107 let mut ids = Vec::with_capacity(doc.len());
108 let mut seen: HashSet<u32> = HashSet::new();
109 for tok in doc {
110 let id = *dim.entry(tok.as_ref()).or_insert_with(|| {
111 let i = df.len() as u32;
112 df.push(0);
113 i
114 });
115 ids.push(id);
116 if seen.insert(id) {
117 df[id as usize] += 1;
118 }
119 }
120 doc_ids.push(ids);
121 }
122 let idf: Vec<f64> = df.iter().map(|&d| (n as f64 / f64::from(d.max(1))).ln()).collect();
123 // Emit (dim, idf) once per token occurrence; `from_rows` sums duplicates → tf·idf per dim.
124 let rows: Vec<Vec<(u32, f64)>> = doc_ids
125 .iter()
126 .map(|ids| ids.iter().map(|&id| (id, idf[id as usize])).collect())
127 .collect();
128 Corpus::from_rows(&rows)
129 }
130
131 /// Build a corpus from raw `(dim, weight)` rows. Duplicate dims within a row are summed; each row
132 /// is L2-normalised; dims are relabelled to a global rank by decreasing max weight. Weights are
133 /// expected non-negative (the prefix-filter bound requires it — IDF weights satisfy this).
134 #[must_use]
135 pub fn from_rows(rows: &[Vec<(u32, f64)>]) -> Corpus {
136 let n = rows.len();
137 // 1. Merge duplicate dims + L2-normalise each row (kept as (orig_dim, weight)).
138 let normed: Vec<Vec<(u32, f64)>> = rows
139 .iter()
140 .map(|r| {
141 let mut m: HashMap<u32, f64> = HashMap::new();
142 for &(d, w) in r {
143 *m.entry(d).or_insert(0.0) += w;
144 }
145 let norm = m.values().map(|w| w * w).sum::<f64>().sqrt();
146 if norm > 0.0 {
147 m.into_iter().map(|(d, w)| (d, w / norm)).collect()
148 } else {
149 Vec::new()
150 }
151 })
152 .collect();
153 // 2. Max normalised weight per original dim.
154 let mut maxw_orig: HashMap<u32, f64> = HashMap::new();
155 for v in &normed {
156 for &(d, w) in v {
157 let e = maxw_orig.entry(d).or_insert(0.0);
158 if w > *e {
159 *e = w;
160 }
161 }
162 }
163 // 3. Rank dims by (max weight ASC, dim asc) → dense global order. Ascending so the common,
164 // low-weight dims land at the FRONT: they fill the un-indexed prefix (their tiny
165 // `w·maxw` keeps the cumulative bound under `t` for many of them), and only the rare,
166 // high-weight tail is indexed — short postings. (Reversing this indexes the common dims
167 // and their huge postings, which is correct but orders of magnitude slower.)
168 let mut dims_sorted: Vec<(u32, f64)> = maxw_orig.into_iter().collect();
169 dims_sorted.sort_by(|a, b| {
170 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0))
171 });
172 let rank: HashMap<u32, u32> =
173 dims_sorted.iter().enumerate().map(|(i, &(d, _))| (d, i as u32)).collect();
174 let ndims = dims_sorted.len();
175 let maxw: Vec<f64> = dims_sorted.iter().map(|&(_, w)| w).collect();
176 // 4. CSR with relabelled dims, ascending within each row.
177 let mut indptr = Vec::with_capacity(n + 1);
178 indptr.push(0);
179 let mut dims = Vec::new();
180 let mut wts = Vec::new();
181 for v in &normed {
182 let mut rv: Vec<(u32, f64)> = v.iter().map(|&(d, w)| (rank[&d], w)).collect();
183 rv.sort_unstable_by_key(|&(d, _)| d);
184 for (d, w) in rv {
185 dims.push(d);
186 wts.push(w);
187 }
188 indptr.push(dims.len());
189 }
190 Corpus { n, ndims, indptr, dims, wts, maxw }
191 }
192}
193
194/// Sorted-merge dot product of two rows (dims ascending). For L2-normalised non-negative vectors
195/// this is exactly their cosine similarity. The single scoring routine shared by the indexed join
196/// and the brute-force oracle, so both agree to the bit.
197#[must_use]
198#[cfg_attr(feature = "profiling", inline(never))]
199fn cos_full((da, wa): (&[u32], &[f64]), (db, wb): (&[u32], &[f64])) -> f64 {
200 // Rows are equal-length dim/weight slices (built by `Corpus::from_rows`).
201 debug_assert_eq!(da.len(), wa.len());
202 debug_assert_eq!(db.len(), wb.len());
203 let (la, lb) = (da.len(), db.len());
204 let (mut i, mut j) = (0usize, 0usize);
205 let mut s = 0.0f64;
206 // Branchless sorted-merge: the 3-way `cmp` branches mispredict on random dim order, and the
207 // weight loads carry bounds checks the optimiser can't elide. Here we always load both weights
208 // (`i<la=wa.len()`, `j<lb=wb.len()`) and mask the product by dim-equality — `s += 0.0` for
209 // unequal dims adds the *same* terms in the *same* increasing-dim order, so the result is
210 // bit-identical to the branchy merge while shedding every data-dependent branch.
211 while i < la && j < lb {
212 // SAFETY: `i < la == wa.len()` and `j < lb == wb.len()`.
213 let (ai, bj) = unsafe { (*da.get_unchecked(i), *db.get_unchecked(j)) };
214 let (wai, wbj) = unsafe { (*wa.get_unchecked(i), *wb.get_unchecked(j)) };
215 let eq = f64::from(u32::from(ai == bj));
216 s += eq * wai * wbj;
217 i += usize::from(ai <= bj);
218 j += usize::from(ai >= bj);
219 }
220 s
221}
222
223/// Per-vector prune data, read in the hot verify loop when a vector appears as a candidate. Packed
224/// into one array (not two parallel `Vec`s) so a candidate's `pnorm` and `split` come from a single
225/// scattered cache line instead of two — verify is memory-latency-bound on these random accesses.
226#[derive(Clone, Copy)]
227struct Bound {
228 /// ‖un-indexed prefix of y‖₂ — Cauchy–Schwarz cap on the dot mass the accumulator misses (the
229 /// prefix dims of `y` were never indexed, so never accumulated).
230 pnorm: f64,
231 /// Rank of `y`'s first *indexed* dim (`u32::MAX` if `y` indexed nothing). Every prefix dim of
232 /// `y` has rank `<` this — lets the probe restrict its norm to that rank range.
233 split: u32,
234}
235
236/// Per-vector data cached as each vector is indexed; read by the prune bound when the vector later
237/// appears as a candidate. (One `Cached` for the whole join.)
238struct Cached {
239 bound: Vec<Bound>,
240}
241
242/// Reused scratch buffers (allocated once for the whole join, not per probe).
243struct Scratch {
244 /// `acc[y]` = partial dot of the probe with `y` over shared *indexed* dims; `-1.0` = untouched
245 /// sentinel (a real partial dot of non-negative weights is always `≥ 0`).
246 acc: Vec<f64>,
247 /// Candidate ids the current probe touched (the keys to reset in `acc`).
248 touched: Vec<u32>,
249 /// Probe prefix L2 norms for this probe: `xpn[k] = ‖wi[..k]‖₂`, length `nnz+1`.
250 xpn: Vec<f64>,
251}
252
253/// Exact all-pairs cosine join via inverted index + **L2AP** prefix filtering and accumulation-time
254/// pruning. Returns `(j, i, cos)` with `j < i` for every pair with `cos ≥ t`. Bit-identical pair set
255/// to [`cosine_join_bruteforce`].
256///
257/// For each probe we accumulate a partial dot over shared *indexed* dims ([`accumulate`]), then for
258/// each touched candidate compute a Cauchy–Schwarz upper bound on the true cosine and skip the exact
259/// [`cos_full`] when it cannot reach `t` ([`verify_pruned`]). The bound is a filter only — survivors
260/// are scored exactly, so the output is byte-for-byte the brute-force result. On skewed data the
261/// bound prunes the ~99.9 % of candidates that collide on a single rare dim, so `cos_full` (the
262/// former 90 % hotspot) runs only on genuine near-matches.
263///
264/// The full inverted index is built once (postings ascending by id), then every vector is probed in
265/// **parallel**: probe `i` walks each posting only while `y < i` (postings are id-sorted), so it sees
266/// exactly the earlier vectors — each pair `(j, i)` with `j < i` is found once, from the larger id.
267/// This is the same candidate set the sequential index-as-you-go build produces, so the result is
268/// unchanged; the returned `Vec` is in arbitrary order (sort if a canonical order is needed).
269#[must_use]
270pub fn cosine_join(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
271 let n = c.n;
272 // Postings carry the indexed weight `(y, w_y[d])` so the scan can accumulate a partial dot.
273 let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
274 let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
275 for i in 0..n {
276 let (di, wi) = c.row(i);
277 index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
278 }
279 // Probe vectors in parallel; each worker keeps one reusable `Scratch` (an `n`-wide accumulator).
280 // `with_min_len` batches many probes per rayon task so the per-probe work (tiny) isn't dwarfed by
281 // task-splitting / scheduling overhead (`swtch_pri` in the profile).
282 (0..n)
283 .into_par_iter()
284 .with_min_len(256)
285 .map_init(
286 || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
287 |scratch, i| {
288 let (di, wi) = c.row(i);
289 accumulate(&index, (di, wi), i as u32, scratch);
290 let mut out = Vec::new();
291 verify_pruned(c, i, t, scratch, &cached, &mut out);
292 out
293 },
294 )
295 .flatten()
296 .collect()
297}
298
299/// Run the cosine join under a chosen [`Concurrency`] backend. Returns `(j, i, cos)` pairs with
300/// `j < i` and `cos ≥ t`, scores as `f64` (the `Gpu` mode's f32 cosines are widened losslessly).
301///
302/// - [`Concurrency::Cpu`] — [`cosine_join`]: exact `f64`, all-CPU, every platform.
303/// - [`Concurrency::GpuPlusCpu`] — exact `f64` hybrid: CPU generates survivor pairs, the GPU f32
304/// cosine *filters* the clear rejects, the CPU recomputes the exact `f64` score on what passes.
305/// **Byte-identical to `Cpu`**; both engines fully used. ~1.7–2× on bandwidth-bound real data.
306/// - [`Concurrency::Gpu`] — GPU-dominant `f32`: CPU generates survivor pairs, the GPU scores them and
307/// the result is emitted directly (no f64 re-verify). Fastest (~2×); differs from the exact answer
308/// only on pairs whose true cosine is within ~`1e-6` of `t` (measured: ≤1 pair in millions).
309///
310/// When the `gpu` feature is off, the target isn't macOS, or no Metal device can be acquired, the GPU
311/// modes transparently fall back to [`cosine_join`] (same as `Rationer`). This convenience entry
312/// **compiles + uploads the GPU corpus on every call** — fine for a one-shot join, but for repeated
313/// joins on one corpus build a [`CosineJoiner`] once and call [`CosineJoiner::join`], which holds the
314/// device + kernel + uploaded CSR across calls (and avoids the driver instability of compiling a
315/// Metal library hundreds of times in a tight loop).
316#[must_use]
317pub fn cosine_join_with(c: &Corpus, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
318 #[cfg(all(feature = "gpu", target_os = "macos"))]
319 {
320 if matches!(mode, Concurrency::Cpu) {
321 return cosine_join(c, t);
322 }
323 let (indptr, dims, wts) = c.csr_f32();
324 let Some(gpu) = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts) else {
325 return cosine_join(c, t); // no Metal device → CPU fallback
326 };
327 match mode {
328 Concurrency::GpuPlusCpu => cosine_join_gpu(c, t, &gpu),
329 // `Gpu`: emit the GPU f32 cosines directly (widened to f64), no exact re-verify.
330 Concurrency::Gpu => cosine_join_gpu_f32(c, t, &gpu)
331 .into_iter()
332 .map(|(a, b, s)| (a, b, f64::from(s)))
333 .collect(),
334 Concurrency::Cpu => unreachable!("handled above"),
335 }
336 }
337 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
338 {
339 let _ = mode; // GPU modes degrade to the CPU join when the feature is off / not macOS.
340 cosine_join(c, t)
341 }
342}
343
344/// A reusable cosine-join handle that owns the corpus and — under `feature = "gpu"` on macOS — the
345/// Metal device, compiled `batch_cosine` kernel, and the corpus CSR uploaded to unified memory, all
346/// acquired **once** at construction. Repeated [`join`](CosineJoiner::join)s at different thresholds
347/// then skip the per-call kernel compile + CSR upload that [`cosine_join_with`] pays (only the
348/// `t`-specific inverted index is rebuilt each call, on the CPU). Always constructible; degrades to
349/// the pure-CPU join when the `gpu` feature is off or no Metal device is available — mirroring
350/// `Rationer`. This is the right entry point for sweeping thresholds or joining repeatedly.
351pub struct CosineJoiner {
352 corpus: Corpus,
353 /// Owned GPU resources (device + kernel + CSR in UMA); `None` when no Metal device was acquired.
354 #[cfg(all(feature = "gpu", target_os = "macos"))]
355 gpu: Option<crate::simjoin_gpu::BatchCosineGpu>,
356}
357
358impl CosineJoiner {
359 /// Build a joiner over `corpus`, acquiring the GPU device + uploading the corpus CSR once if the
360 /// `gpu` feature is on and a Metal device is present.
361 #[must_use]
362 pub fn new(corpus: Corpus) -> Self {
363 #[cfg(all(feature = "gpu", target_os = "macos"))]
364 {
365 let (indptr, dims, wts) = corpus.csr_f32();
366 let gpu = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts);
367 Self { corpus, gpu }
368 }
369 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
370 {
371 Self { corpus }
372 }
373 }
374
375 /// The owned corpus (e.g. for `len()` or to run other queries against it).
376 #[must_use]
377 pub fn corpus(&self) -> &Corpus {
378 &self.corpus
379 }
380
381 /// Whether a Metal GPU backend was acquired. Always `false` without `feature = "gpu"` on macOS;
382 /// when `false`, every [`join`](CosineJoiner::join) runs on the CPU regardless of `mode`.
383 #[must_use]
384 pub fn has_gpu(&self) -> bool {
385 #[cfg(all(feature = "gpu", target_os = "macos"))]
386 {
387 self.gpu.is_some()
388 }
389 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
390 {
391 false
392 }
393 }
394
395 /// Run the join at threshold `t` under `mode`, reusing the handle's GPU resources. Returns the
396 /// same results as [`cosine_join_with`] (Cpu/GpuPlusCpu exact, Gpu f32→f64); falls back to the
397 /// CPU join when the GPU is unavailable.
398 #[must_use]
399 pub fn join(&self, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
400 #[cfg(all(feature = "gpu", target_os = "macos"))]
401 {
402 match (mode, self.gpu.as_ref()) {
403 (Concurrency::GpuPlusCpu, Some(g)) => cosine_join_gpu(&self.corpus, t, g),
404 (Concurrency::Gpu, Some(g)) => cosine_join_gpu_f32(&self.corpus, t, g)
405 .into_iter()
406 .map(|(a, b, s)| (a, b, f64::from(s)))
407 .collect(),
408 _ => cosine_join(&self.corpus, t), // Cpu mode, or no Metal device
409 }
410 }
411 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
412 {
413 let _ = mode;
414 cosine_join(&self.corpus, t)
415 }
416 }
417}
418
419/// FP slack for the prune bound: the Cauchy–Schwarz upper bound holds in exact arithmetic, but the
420/// accumulated dot and the `sqrt` norms each carry rounding error. We only *skip* `cos_full` when
421/// the bound is below `t` by more than this slack, so a true pair (exact cosine `≥ t`) is never
422/// pruned. Not skipping is always correctness-safe (just a wasted `cos_full`), so the slack trades a
423/// negligible number of extra verifies for safety and never changes the emitted pair set.
424/// `1e-9 ≫` the `~1e-15` accumulated error over ~15 terms.
425const PRUNE_SLACK: f64 = 1e-9;
426
427/// Probe row length at/above which the cheap monotone pre-bound in [`verify_pruned`] is worth it.
428/// The pre-bound skips the per-candidate `partition_point` for candidates that can't reach `t`, but
429/// `partition_point` over a *short* probe row is already cheap, so on sparse corpora (mean nnz ≈ 11,
430/// e.g. `PyPI` type3) the pre-bound's own `fmul`/`fadd`/compare is pure overhead that rarely prunes
431/// (match-dense ⇒ high survivor rate). It only pays on **dense** rows (e.g. find-dup-defs
432/// patternology, mean nnz ≈ 61) where `partition_point` is costly and the survivor rate is tiny.
433/// Gating on `di.len()` (a per-probe, loop-invariant test) keeps the sparse regime regression-free.
434const PREBOUND_MIN_DIMS: usize = 24;
435
436/// Phase 1 — accumulate: for each indexed dim of the probe, add `w_probe·w_y` into `acc[y]` for
437/// every **earlier** `y` (`y < cutoff`, the probe's own id) indexing that dim. Postings are id-sorted,
438/// so we `break` at the first `y ≥ cutoff`. Leaves `acc` `-1.0` everywhere except the touched ids
439/// (listed in `touched`, reset in [`verify_pruned`]). One scattered `acc[]` FMA per posting.
440#[cfg_attr(feature = "profiling", inline(never))]
441fn accumulate(index: &[Vec<(u32, f64)>], (di, wi): (&[u32], &[f64]), cutoff: u32, s: &mut Scratch) {
442 // Split the borrows so `acc` and `touched` are independent: otherwise the optimiser, seeing
443 // `s.touched.push(...)` go through the same `&mut Scratch`, can't prove the push leaves
444 // `s.acc`'s base pointer untouched and reloads `acc.ptr` from the struct on EVERY posting entry
445 // (one extra load across the ~5M-entry accumulation on dense corpora). With `acc` a separate
446 // `&mut [f64]` local the base stays in a register; a `touched` realloc can't alias it.
447 let Scratch { acc, touched, .. } = s;
448 let acc = acc.as_mut_slice();
449 touched.clear();
450 // At most one distinct candidate per acc slot, so `n == acc.len()` slots is enough headroom to
451 // append every first-touch without bounds checks or a realloc inside the loop.
452 touched.reserve(acc.len());
453 let tptr = touched.as_mut_ptr();
454 let mut tlen = 0usize;
455 for (&d, &w) in di.iter().zip(wi) {
456 for &(y, wy) in &index[d as usize] {
457 if y >= cutoff {
458 break;
459 }
460 let yu = y as usize;
461 // SAFETY: `y` is a vector id pushed by `index_suffix`, so `yu < n == acc.len()`.
462 let a = unsafe { acc.get_unchecked_mut(yu) };
463 // Branchless first-touch: the `*a < 0.0` test fed a data-dependent branch that mispredicts
464 // (first-touch vs repeat interleave unpredictably across the probe's dims). Instead select
465 // the base (`0.0` on first touch, current partial dot otherwise) with a conditional move,
466 // and append `y` to `touched` by an UNCONDITIONAL store that's only *committed* when the
467 // length is bumped — `tlen += first`. Same accumulator values and same touched order as
468 // the branchy form, so the result is bit-identical; just no mispredicting branch.
469 let first = *a < 0.0;
470 let base = if first { 0.0 } else { *a };
471 *a = base + w * wy;
472 // SAFETY: `tlen ≤ distinct candidates so far < acc.len() ≤ touched.capacity()`.
473 unsafe { *tptr.add(tlen) = y };
474 tlen += usize::from(first);
475 }
476 }
477 // SAFETY: `tlen` first-touch stores were written into the reserved region, in order.
478 unsafe { touched.set_len(tlen) };
479}
480
481/// Phase 2 — prune + verify. For each touched candidate `y`, reset its accumulator and test the
482/// **L2AP `l2` bound**: the dot mass missing from `acc[y]` (the dims in `prefix(y)`) is at most
483/// `‖x_{rank<split[y]}‖ · ‖prefix(y)‖` by Cauchy–Schwarz, where `x_{rank<split[y]}` is the probe
484/// restricted to the rank range `prefix(y)` lives in. Since the probe's mass sits in its rare
485/// (high-rank) dims, that restricted norm is tiny — a far tighter cap than the whole-probe `‖x‖=1`.
486/// Only if `acc[y] + that bound ≥ t` (minus FP slack) do we score exactly with [`cos_full`]. Filter
487/// only — the emitted value is the exact dot, so the pair set is bit-identical to brute force.
488#[cfg_attr(feature = "profiling", inline(never))]
489fn verify_pruned(
490 c: &Corpus,
491 i: usize,
492 t: f64,
493 s: &mut Scratch,
494 cached: &Cached,
495 out: &mut Vec<(usize, usize, f64)>,
496) {
497 let (di, wi) = c.row(i);
498 // Probe prefix L2 norms: xpn[k] = ‖wi[..k]‖₂ (di ascending by rank, so xpn[k] = norm over the
499 // probe's k lowest-rank dims). One sqrt per probe dim, reused across all its candidates.
500 s.xpn.clear();
501 s.xpn.push(0.0);
502 let mut sq = 0.0f64;
503 for &w in wi {
504 sq += w * w;
505 s.xpn.push(sq.sqrt());
506 }
507 // `‖probe‖` = the full prefix norm. Since `xpn` is monotonic, `xpn[kstar] ≤ xnorm` for every
508 // candidate, so `a + xnorm·pnorm ≥ a + xpn[kstar]·pnorm` (the exact bound). A candidate failing
509 // the cheap `a + xnorm·pnorm < need` test therefore also fails the exact bound — prune it without
510 // the per-candidate `partition_point`. The survivor set is bit-identical; only the binary search
511 // is skipped for the ~99% of touched candidates that can't reach `t` on dense corpora.
512 let xnorm = sq.sqrt();
513 let prebound = di.len() >= PREBOUND_MIN_DIMS;
514 let need = t - PRUNE_SLACK;
515 let Scratch { acc, touched, xpn } = s;
516 // (Software-prefetching the candidate row a few ahead was tried + reverted: no measurable change
517 // — with all cores gathering at once the join is memory-*bandwidth*-bound, not per-access
518 // latency-bound, so prefetch can't add throughput.)
519 for &y in touched.iter() {
520 let yu = y as usize;
521 // SAFETY: `yu < n` (same provenance as in `accumulate`).
522 let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
523 // SAFETY: `yu < n`. One scattered load fetches both prune fields.
524 let bd = unsafe { *cached.bound.get_unchecked(yu) };
525 // Cheap monotone pre-bound (dense rows only) — skip the binary search when even `xnorm`
526 // can't clear `need`. `prebound` is loop-invariant, so sparse rows pay nothing.
527 if prebound && a + xnorm * bd.pnorm < need {
528 continue;
529 }
530 // Number of probe dims with rank < split[y] → index into xpn (di sorted ascending).
531 let kstar = di.partition_point(|&d| d < bd.split);
532 // SAFETY: kstar ≤ di.len() == wi.len() = xpn.len()-1.
533 // (An added maxweight cap `min(…, Σ wx·maxw)` was tried + reverted: never tighter than the
534 // L2 cap on either synthetic or real data — it sums over all probe-prefix dims, not just the
535 // shared ones — so it pruned nothing and cost ~16% on the real PyPI corpus.)
536 let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
537 if bound >= need {
538 let cos = cos_full((di, wi), c.row(yu));
539 if cos >= t {
540 out.push((yu, i, cos));
541 }
542 }
543 }
544}
545
546/// Phase 3 — index this vector's suffix: skip the leading prefix whose max possible contribution to
547/// any dot stays `< t` (`Σ w_k·maxw[dim_k] < t`), index only the rarer tail (short postings = the
548/// whole speed-up), and cache `pnorm[i] = ‖prefix‖₂` and `split[i]` = first indexed rank for the
549/// [`verify_pruned`] bound.
550#[cfg_attr(feature = "profiling", inline(never))]
551fn index_suffix(
552 c: &Corpus,
553 i: usize,
554 (di, wi): (&[u32], &[f64]),
555 t: f64,
556 index: &mut [Vec<(u32, f64)>],
557 cached: &mut Cached,
558) {
559 // Largest safe prefix under the maxweight bound: `Σ_{k<b} w_k·maxw[dim_k] < t` ⇒ the prefix
560 // can't contribute `t` to any dot (weights ≥ 0). (A norm-based extension `‖x_{<b}‖ < t` was
561 // tried and reverted: it gave zero candidate reduction in the realistic `t<1` regime — the
562 // maxweight bound already indexes less — and it is FP-fragile at `t=1` where `‖x‖` rounds below
563 // `1` and indexes nothing, dropping exact-duplicate pairs.)
564 let mut rs = 0.0f64;
565 let mut b = 0usize;
566 for k in 0..di.len() {
567 let bound = wi[k] * c.maxw[di[k] as usize];
568 if rs + bound >= t {
569 break;
570 }
571 rs += bound;
572 b = k + 1;
573 }
574 let mut p = 0.0f64;
575 for &w in &wi[..b] {
576 p += w * w;
577 }
578 cached.bound[i] = Bound {
579 pnorm: p.sqrt(),
580 split: if b < di.len() { di[b] } else { u32::MAX },
581 };
582 for k in b..di.len() {
583 index[di[k] as usize].push((i as u32, wi[k]));
584 }
585}
586
587/// FP margin for the GPU f32 cosine *filter* in [`cosine_join_gpu`]: a survivor is dropped only when
588/// its GPU f32 cosine is below `t` by more than this. The f32 dot's error is `~1e-6` relative, so a
589/// `1e-4` absolute margin never drops a true pair; the CPU then recomputes the exact `f64` score on
590/// whatever passes, so the emitted pair set + scores stay bit-identical to [`cosine_join`].
591#[cfg(all(target_os = "macos", feature = "gpu"))]
592const GPU_FILTER_MARGIN: f64 = 1e-4;
593
594/// Like the verify half of [`verify_pruned`], but instead of scoring, **collects** each surviving
595/// `(candidate, probe)` pair (`candidate < probe`) for batch scoring elsewhere. The bound here MUST
596/// stay identical to `verify_pruned`'s so the survivor set matches exactly.
597#[cfg(all(target_os = "macos", feature = "gpu"))]
598fn collect_survivors(c: &Corpus, i: usize, t: f64, s: &mut Scratch, cached: &Cached, out: &mut Vec<(u32, u32)>) {
599 let (di, wi) = c.row(i);
600 s.xpn.clear();
601 s.xpn.push(0.0);
602 let mut sq = 0.0f64;
603 for &w in wi {
604 sq += w * w;
605 s.xpn.push(sq.sqrt());
606 }
607 let xnorm = sq.sqrt();
608 let prebound = di.len() >= PREBOUND_MIN_DIMS;
609 let need = t - PRUNE_SLACK;
610 let Scratch { acc, touched, xpn } = s;
611 for &y in touched.iter() {
612 let yu = y as usize;
613 // SAFETY: `yu < n` (same provenance as in `accumulate`).
614 let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
615 let bd = unsafe { *cached.bound.get_unchecked(yu) };
616 // Cheap monotone pre-bound (see `verify_pruned`) — same survivor set, skips the binary search.
617 if prebound && a + xnorm * bd.pnorm < need {
618 continue;
619 }
620 let kstar = di.partition_point(|&d| d < bd.split);
621 let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
622 if bound >= need {
623 out.push((y, i as u32)); // (candidate, probe), candidate < probe
624 }
625 }
626}
627
628/// **CPU+GPU hybrid join** (feature `gpu`, macOS). Returns the **exact same** pair set + scores as
629/// [`cosine_join`] (and thus the brute-force oracle) — only faster when the verify is bandwidth-bound.
630///
631/// Pipeline: CPU builds the index and, in parallel, accumulates + bounds every probe to a list of
632/// surviving `(candidate, probe)` pairs (no `cos_full`). The GPU then computes an f32 cosine for the
633/// whole batch (its memory-level parallelism clears the random-gather dots ~3× faster than the CPU),
634/// and the CPU recomputes the exact `f64` `cos_full` **only** on the pairs whose GPU score clears
635/// `t − margin` — typically a few percent of survivors. Because the GPU is a *conservative filter*
636/// (margin ≫ f32 error, so no true pair is ever dropped) and every emitted score is the exact CPU
637/// `f64` value, the output is byte-for-byte identical to [`cosine_join`].
638#[cfg(all(target_os = "macos", feature = "gpu"))]
639#[must_use]
640pub fn cosine_join_gpu(
641 c: &Corpus,
642 t: f64,
643 gpu: &crate::simjoin_gpu::BatchCosineGpu,
644) -> Vec<(usize, usize, f64)> {
645 let (pa, pb) = survivor_pairs(c, t);
646 if pa.is_empty() {
647 return Vec::new();
648 }
649 // GPU phase: f32 cosine over the whole survivor batch (conservative filter).
650 let gcos = gpu.cosine_batch(&pa, &pb);
651 let need = t - GPU_FILTER_MARGIN;
652 // CPU phase: exact f64 re-verify only on pairs the GPU filter passes.
653 (0..pa.len())
654 .into_par_iter()
655 .with_min_len(1024)
656 .filter_map(|k| {
657 if f64::from(gcos[k]) < need {
658 return None;
659 }
660 let (a, b) = (pa[k] as usize, pb[k] as usize);
661 let cos = cos_full(c.row(a), c.row(b));
662 (cos >= t).then_some((a, b, cos))
663 })
664 .collect()
665}
666
667/// **Pure-f32** CPU+GPU join (feature `gpu`, macOS): same survivor generation as [`cosine_join_gpu`]
668/// but emits the GPU's **f32** cosine directly, with **no exact f64 re-verify**. Trades byte-parity
669/// for speed (no re-verify, and `cos_full` never runs on the GPU survivors). The result differs from
670/// [`cosine_join`] only on pairs whose true cosine lies within ~`1e-6` (f32 rounding) of `t` — for a
671/// similarity join with an arbitrary threshold that is immaterial. Use when an ε-exact answer is
672/// acceptable; use [`cosine_join_gpu`] when bit-exactness is required.
673#[cfg(all(target_os = "macos", feature = "gpu"))]
674#[must_use]
675pub fn cosine_join_gpu_f32(
676 c: &Corpus,
677 t: f64,
678 gpu: &crate::simjoin_gpu::BatchCosineGpu,
679) -> Vec<(usize, usize, f32)> {
680 let (pa, pb) = survivor_pairs(c, t);
681 if pa.is_empty() {
682 return Vec::new();
683 }
684 let gcos = gpu.cosine_batch(&pa, &pb);
685 let tf = t as f32;
686 (0..pa.len())
687 .into_par_iter()
688 .with_min_len(1024)
689 .filter_map(|k| (gcos[k] >= tf).then_some((pa[k] as usize, pb[k] as usize, gcos[k])))
690 .collect()
691}
692
693/// CPU half shared by the GPU joins: build the index, then accumulate + bound every probe in
694/// parallel to the list of surviving `(candidate, probe)` pairs (candidate `<` probe), split into
695/// two `u32` arrays ready for [`crate::simjoin_gpu::BatchCosineGpu::cosine_batch`].
696#[cfg(all(target_os = "macos", feature = "gpu"))]
697fn survivor_pairs(c: &Corpus, t: f64) -> (Vec<u32>, Vec<u32>) {
698 let n = c.n;
699 let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
700 let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
701 for i in 0..n {
702 let (di, wi) = c.row(i);
703 index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
704 }
705 let pairs: Vec<(u32, u32)> = (0..n)
706 .into_par_iter()
707 .with_min_len(256)
708 .map_init(
709 || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
710 |scratch, i| {
711 let (di, wi) = c.row(i);
712 accumulate(&index, (di, wi), i as u32, scratch);
713 let mut out = Vec::new();
714 collect_survivors(c, i, t, scratch, &cached, &mut out);
715 out
716 },
717 )
718 .flatten()
719 .collect();
720 pairs.into_iter().unzip()
721}
722
723/// Diagnostic (feature `profiling`, off the hot path): counts that quantify the prune. Returns
724/// `(candidates, survivors, pairs)` — candidates touched by the accumulator, survivors that pass the
725/// Cauchy–Schwarz bound (i.e. the `cos_full` calls actually made), and real pairs. `survivors /
726/// candidates` is the prune pass-rate (lower = better); `survivors` is the verify volume we pay for.
727#[cfg(feature = "profiling")]
728#[must_use]
729pub fn cosine_join_counts(c: &Corpus, t: f64) -> (u64, u64, u64) {
730 let n = c.n;
731 let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
732 let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
733 for i in 0..n {
734 let (di, wi) = c.row(i);
735 index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
736 }
737 let mut s = Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() };
738 let (mut ncand, mut survivors, mut pairs) = (0u64, 0u64, 0u64);
739 let need = t - PRUNE_SLACK;
740 for i in 0..n {
741 let (di, wi) = c.row(i);
742 accumulate(&index, (di, wi), i as u32, &mut s);
743 ncand += s.touched.len() as u64;
744 s.xpn.clear();
745 s.xpn.push(0.0);
746 let mut sq = 0.0f64;
747 for &w in wi {
748 sq += w * w;
749 s.xpn.push(sq.sqrt());
750 }
751 let xnorm = sq.sqrt();
752 let prebound = di.len() >= PREBOUND_MIN_DIMS;
753 for &y in &s.touched {
754 let yu = y as usize;
755 let a = std::mem::replace(&mut s.acc[yu], -1.0);
756 let bd = cached.bound[yu];
757 if prebound && a + xnorm * bd.pnorm < need {
758 continue;
759 }
760 let kstar = di.partition_point(|&d| d < bd.split);
761 if a + s.xpn[kstar] * bd.pnorm >= need {
762 survivors += 1;
763 if cos_full((di, wi), c.row(yu)) >= t {
764 pairs += 1;
765 }
766 }
767 }
768 }
769 (ncand, survivors, pairs)
770}
771
772/// Naive `O(n²)` oracle: score every pair with [`cos_full`], keep `cos ≥ t`. The correctness
773/// reference [`cosine_join`] is validated against.
774#[must_use]
775pub fn cosine_join_bruteforce(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
776 let mut out: Vec<(usize, usize, f64)> = Vec::new();
777 for i in 0..c.n {
778 for j in 0..i {
779 let s = cos_full(c.row(i), c.row(j));
780 if s >= t {
781 out.push((j, i, s));
782 }
783 }
784 }
785 out
786}
787
788#[cfg(test)]
789mod tests {
790 use super::{cosine_join, cosine_join_bruteforce, Corpus};
791
792 fn xorshift(seed: u64) -> impl FnMut() -> u64 {
793 let mut s = seed;
794 move || {
795 s ^= s << 13;
796 s ^= s >> 7;
797 s ^= s << 17;
798 s
799 }
800 }
801
802 fn sort_pairs(mut v: Vec<(usize, usize, f64)>) -> Vec<(usize, usize, u64)> {
803 v.sort_by_key(|a| (a.0, a.1));
804 v.into_iter().map(|(a, b, s)| (a, b, s.to_bits())).collect()
805 }
806
807 #[test]
808 fn indexed_join_matches_bruteforce() {
809 let mut next = xorshift(0x9e37_79b9_7f4a_7c15);
810 for _ in 0..400 {
811 let n = (next() % 40 + 2) as usize;
812 let dim_space = next() % 15 + 1;
813 let rows: Vec<Vec<(u32, f64)>> = (0..n)
814 .map(|_| {
815 let nnz = (next() % 8) as usize;
816 (0..nnz)
817 .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
818 .collect()
819 })
820 .collect();
821 let c = Corpus::from_rows(&rows);
822 for &t in &[0.1_f64, 0.25, 0.5, 0.75, 0.9, 1.0] {
823 let got = sort_pairs(cosine_join(&c, t));
824 let want = sort_pairs(cosine_join_bruteforce(&c, t));
825 assert_eq!(got, want, "n={n} t={t}");
826 }
827 }
828 }
829
830 /// The CPU+GPU hybrid [`super::cosine_join_gpu`] must return bit-identical results to the pure-CPU
831 /// [`cosine_join`] on fuzzed corpora — the GPU is only a conservative filter; every emitted score
832 /// is the exact CPU `f64` value. Skips when no Metal device is present.
833 #[cfg(all(feature = "gpu", target_os = "macos"))]
834 #[test]
835 fn gpu_hybrid_matches_cpu() {
836 use super::CosineJoiner;
837 use crate::Concurrency;
838 let mut next = xorshift(0x1357_9bdf_0246_8ace);
839 for _ in 0..60 {
840 let n = (next() % 60 + 4) as usize;
841 let dim_space = next() % 20 + 2;
842 let rows: Vec<Vec<(u32, f64)>> = (0..n)
843 .map(|_| {
844 let nnz = (next() % 10) as usize;
845 (0..nnz)
846 .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
847 .collect()
848 })
849 .collect();
850 // One reusable handle per corpus — `join` is called repeatedly across thresholds, which
851 // also exercises that the handle reuses its GPU resources (no per-call library compile).
852 let joiner = CosineJoiner::new(Corpus::from_rows(&rows));
853 if !joiner.has_gpu() {
854 eprintln!("no Metal device — skipping gpu_hybrid_matches_cpu");
855 return;
856 }
857 for &t in &[0.1_f64, 0.3, 0.5, 0.7, 0.9, 1.0] {
858 let want = sort_pairs(cosine_join(joiner.corpus(), t));
859 // Exact GPU+CPU hybrid is byte-identical to the plain join; `Cpu` mode too.
860 assert_eq!(
861 sort_pairs(joiner.join(t, Concurrency::GpuPlusCpu)),
862 want,
863 "GpuPlusCpu n={n} t={t}"
864 );
865 assert_eq!(sort_pairs(joiner.join(t, Concurrency::Cpu)), want, "Cpu n={n} t={t}");
866 }
867 }
868 }
869}