difflib_fast/simjoin.rs
1//! `simjoin` — exact all-pairs **weighted-cosine similarity join**: given a corpus of sparse
2//! non-negative vectors, find every pair `(i, j)` whose cosine similarity is `≥ t`.
3//!
4//! This is the `AllPairs` / `L2AP` family (Bayardo et al. WWW'07; Anastasiu & Karypis ICDE'14): an
5//! inverted index with **prefix filtering** so that the vast majority of vector pairs are never
6//! compared. It is the principled, exact replacement for shingle-candidate + verify-all
7//! near-duplicate detection (e.g. Type-3 code clones = functions × IDF-weighted lines).
8//!
9//! ## Correctness gate
10//! [`cosine_join`] (the indexed algorithm) must return the **exact same** pair set, with
11//! bit-identical similarities, as [`cosine_join_bruteforce`] (the naive `O(n²)` oracle). Both score
12//! a pair with the same [`cos_full`] sorted-merge dot, so the values match to the bit and a pair is
13//! never dropped or gained at the threshold. This equality is asserted on fuzzed corpora — the same
14//! "two implementations, one answer" discipline the RO path uses.
15//!
16//! ## Method (this reference)
17//! Vectors are L2-normalised (so `cos = dot`) and their dimensions relabelled to a global rank by
18//! **increasing max weight** (common low-weight dims first, rare high-weight dims last — so only the
19//! rare tail is indexed, keeping postings short). For a probe vector we look up candidates through
20//! the index, then verify with the full dot. When *indexing* a vector we skip the leading prefix
21//! whose max possible contribution to any dot stays `< t`:
22//! `Σ_{k<b} w_k · maxw[dim_k] < t` ⇒ a pair matching only in that prefix can't reach `t`
23//! (weights ≥ 0), so it is guaranteed to also share an indexed dim. The skipped prefix holds the few
24//! common dims (huge postings); only the rarer tail is indexed — that is the whole speed-up. The
25//! tighter Cauchy–Schwarz / L2-norm bounds (true L2AP) and accumulation-time pruning are the next
26//! optimisation layer on top of this exact base.
27
28// Index/rank ids are bounded by the corpus dimension count (≤ u32 by construction); ranks are
29// assigned densely from a sorted dim list. The `as` casts below are intentional and in-range.
30// Dense numerical code: `i`/`d`/`w`/`y`/`a`/`s`/`t` mirror the cosine/prefix-bound formulas and read
31// clearer than verbose names here — allow the single-char-names pedantic lint module-wide.
32#![allow(
33 clippy::cast_possible_truncation,
34 clippy::cast_precision_loss,
35 clippy::cast_sign_loss,
36 clippy::many_single_char_names
37)]
38
39use std::cmp::Ordering;
40use std::collections::{HashMap, HashSet};
41
42use rayon::prelude::*;
43
44use crate::Concurrency;
45
46/// A corpus of L2-normalised sparse non-negative vectors in CSR form, with dimensions relabelled to
47/// a global rank (decreasing max weight). Built once; joinable at any threshold.
48pub struct Corpus {
49 n: usize,
50 ndims: usize,
51 /// `n + 1` row offsets into `dims`/`wts`.
52 indptr: Vec<usize>,
53 /// Relabelled (rank) dimension ids, ascending within each row.
54 dims: Vec<u32>,
55 /// L2-normalised weights, aligned with `dims`.
56 wts: Vec<f64>,
57 /// Per relabelled dim: the max weight it takes across the corpus (the prefix-filter bound).
58 maxw: Vec<f64>,
59}
60
61impl Corpus {
62 /// Number of vectors.
63 #[must_use]
64 pub fn len(&self) -> usize {
65 self.n
66 }
67
68 /// True if the corpus has no vectors.
69 #[must_use]
70 pub fn is_empty(&self) -> bool {
71 self.n == 0
72 }
73
74 /// `(dims, weights)` of vector `i` — dims ascending by global rank, weights L2-normalised.
75 #[must_use]
76 fn row(&self, i: usize) -> (&[u32], &[f64]) {
77 let (s, e) = (self.indptr[i], self.indptr[i + 1]);
78 (&self.dims[s..e], &self.wts[s..e])
79 }
80
81 /// CSR view for GPU offload: `(indptr, dims, wts_f32)`, with `indptr` cast to `u32` and the
82 /// L2-normalised weights cast to `f32` (Apple GPUs have no `f64`). For the
83 /// [`crate::simjoin_gpu`] throughput experiment only — the f32 cast means GPU dots are *not*
84 /// bit-identical to the CPU `f64` path.
85 #[cfg(all(target_os = "macos", feature = "gpu"))]
86 #[must_use]
87 pub fn csr_f32(&self) -> (Vec<u32>, Vec<u32>, Vec<f32>) {
88 let indptr = self.indptr.iter().map(|&x| x as u32).collect();
89 let wts = self.wts.iter().map(|&w| w as f32).collect();
90 (indptr, self.dims.clone(), wts)
91 }
92
93 /// Build a corpus from **token documents** — each document a list of string tokens — as TF-IDF
94 /// sparse vectors: dim = a distinct token, weight = `(token count in doc) × ln(n / df_token)`
95 /// (`df` = number of documents containing the token). This is the principled input for a Type-3
96 /// code-clone join (documents = functions, tokens = canonicalised lines) and the shape most
97 /// callers actually have. A token appearing in every document gets `idf = 0` (contributes
98 /// nothing), as expected.
99 #[must_use]
100 pub fn from_token_docs<S: AsRef<str>>(docs: &[Vec<S>]) -> Corpus {
101 let n = docs.len();
102 let mut dim: HashMap<&str, u32> = HashMap::new();
103 let mut df: Vec<u32> = Vec::new();
104 // Assign a dim id to each distinct token and count document frequency (once per doc/token).
105 let mut doc_ids: Vec<Vec<u32>> = Vec::with_capacity(n);
106 for doc in docs {
107 let mut ids = Vec::with_capacity(doc.len());
108 let mut seen: HashSet<u32> = HashSet::new();
109 for tok in doc {
110 let id = *dim.entry(tok.as_ref()).or_insert_with(|| {
111 let i = df.len() as u32;
112 df.push(0);
113 i
114 });
115 ids.push(id);
116 if seen.insert(id) {
117 df[id as usize] += 1;
118 }
119 }
120 doc_ids.push(ids);
121 }
122 let idf: Vec<f64> = df.iter().map(|&d| (n as f64 / f64::from(d.max(1))).ln()).collect();
123 // Emit (dim, idf) once per token occurrence; `from_rows` sums duplicates → tf·idf per dim.
124 let rows: Vec<Vec<(u32, f64)>> = doc_ids
125 .iter()
126 .map(|ids| ids.iter().map(|&id| (id, idf[id as usize])).collect())
127 .collect();
128 Corpus::from_rows(&rows)
129 }
130
131 /// Build a corpus from raw `(dim, weight)` rows. Duplicate dims within a row are summed; each row
132 /// is L2-normalised; dims are relabelled to a global rank by decreasing max weight. Weights are
133 /// expected non-negative (the prefix-filter bound requires it — IDF weights satisfy this).
134 #[must_use]
135 pub fn from_rows(rows: &[Vec<(u32, f64)>]) -> Corpus {
136 let n = rows.len();
137 // 1. Merge duplicate dims + L2-normalise each row (kept as (orig_dim, weight)).
138 let normed: Vec<Vec<(u32, f64)>> = rows
139 .iter()
140 .map(|r| {
141 let mut m: HashMap<u32, f64> = HashMap::new();
142 for &(d, w) in r {
143 *m.entry(d).or_insert(0.0) += w;
144 }
145 let norm = m.values().map(|w| w * w).sum::<f64>().sqrt();
146 if norm > 0.0 {
147 m.into_iter().map(|(d, w)| (d, w / norm)).collect()
148 } else {
149 Vec::new()
150 }
151 })
152 .collect();
153 // 2. Max normalised weight per original dim.
154 let mut maxw_orig: HashMap<u32, f64> = HashMap::new();
155 for v in &normed {
156 for &(d, w) in v {
157 let e = maxw_orig.entry(d).or_insert(0.0);
158 if w > *e {
159 *e = w;
160 }
161 }
162 }
163 // 3. Rank dims by (max weight ASC, dim asc) → dense global order. Ascending so the common,
164 // low-weight dims land at the FRONT: they fill the un-indexed prefix (their tiny
165 // `w·maxw` keeps the cumulative bound under `t` for many of them), and only the rare,
166 // high-weight tail is indexed — short postings. (Reversing this indexes the common dims
167 // and their huge postings, which is correct but orders of magnitude slower.)
168 let mut dims_sorted: Vec<(u32, f64)> = maxw_orig.into_iter().collect();
169 dims_sorted.sort_by(|a, b| {
170 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0))
171 });
172 let rank: HashMap<u32, u32> =
173 dims_sorted.iter().enumerate().map(|(i, &(d, _))| (d, i as u32)).collect();
174 let ndims = dims_sorted.len();
175 let maxw: Vec<f64> = dims_sorted.iter().map(|&(_, w)| w).collect();
176 // 4. CSR with relabelled dims, ascending within each row.
177 let mut indptr = Vec::with_capacity(n + 1);
178 indptr.push(0);
179 let mut dims = Vec::new();
180 let mut wts = Vec::new();
181 for v in &normed {
182 let mut rv: Vec<(u32, f64)> = v.iter().map(|&(d, w)| (rank[&d], w)).collect();
183 rv.sort_unstable_by_key(|&(d, _)| d);
184 for (d, w) in rv {
185 dims.push(d);
186 wts.push(w);
187 }
188 indptr.push(dims.len());
189 }
190 Corpus { n, ndims, indptr, dims, wts, maxw }
191 }
192}
193
194/// Sorted-merge dot product of two rows (dims ascending). For L2-normalised non-negative vectors
195/// this is exactly their cosine similarity. The single scoring routine shared by the indexed join
196/// and the brute-force oracle, so both agree to the bit.
197#[must_use]
198#[cfg_attr(feature = "profiling", inline(never))]
199fn cos_full((da, wa): (&[u32], &[f64]), (db, wb): (&[u32], &[f64])) -> f64 {
200 // Rows are equal-length dim/weight slices (built by `Corpus::from_rows`).
201 debug_assert_eq!(da.len(), wa.len());
202 debug_assert_eq!(db.len(), wb.len());
203 let (la, lb) = (da.len(), db.len());
204 let (mut i, mut j) = (0usize, 0usize);
205 let mut s = 0.0f64;
206 // Branchless sorted-merge: the 3-way `cmp` branches mispredict on random dim order, and the
207 // weight loads carry bounds checks the optimiser can't elide. Here we always load both weights
208 // (`i<la=wa.len()`, `j<lb=wb.len()`) and mask the product by dim-equality — `s += 0.0` for
209 // unequal dims adds the *same* terms in the *same* increasing-dim order, so the result is
210 // bit-identical to the branchy merge while shedding every data-dependent branch.
211 while i < la && j < lb {
212 // SAFETY: `i < la == wa.len()` and `j < lb == wb.len()`.
213 let (ai, bj) = unsafe { (*da.get_unchecked(i), *db.get_unchecked(j)) };
214 let (wai, wbj) = unsafe { (*wa.get_unchecked(i), *wb.get_unchecked(j)) };
215 let eq = f64::from(u32::from(ai == bj));
216 s += eq * wai * wbj;
217 i += usize::from(ai <= bj);
218 j += usize::from(ai >= bj);
219 }
220 s
221}
222
223/// Per-vector prune data, read in the hot verify loop when a vector appears as a candidate. Packed
224/// into one array (not two parallel `Vec`s) so a candidate's `pnorm` and `split` come from a single
225/// scattered cache line instead of two — verify is memory-latency-bound on these random accesses.
226#[derive(Clone, Copy)]
227struct Bound {
228 /// ‖un-indexed prefix of y‖₂ — Cauchy–Schwarz cap on the dot mass the accumulator misses (the
229 /// prefix dims of `y` were never indexed, so never accumulated).
230 pnorm: f64,
231 /// Rank of `y`'s first *indexed* dim (`u32::MAX` if `y` indexed nothing). Every prefix dim of
232 /// `y` has rank `<` this — lets the probe restrict its norm to that rank range.
233 split: u32,
234}
235
236/// Per-vector data cached as each vector is indexed; read by the prune bound when the vector later
237/// appears as a candidate. (One `Cached` for the whole join.)
238struct Cached {
239 bound: Vec<Bound>,
240}
241
242/// Reused scratch buffers (allocated once for the whole join, not per probe).
243struct Scratch {
244 /// `acc[y]` = partial dot of the probe with `y` over shared *indexed* dims; `-1.0` = untouched
245 /// sentinel (a real partial dot of non-negative weights is always `≥ 0`).
246 acc: Vec<f64>,
247 /// Candidate ids the current probe touched (the keys to reset in `acc`).
248 touched: Vec<u32>,
249 /// Probe prefix L2 norms for this probe: `xpn[k] = ‖wi[..k]‖₂`, length `nnz+1`.
250 xpn: Vec<f64>,
251}
252
253/// Exact all-pairs cosine join via inverted index + **L2AP** prefix filtering and accumulation-time
254/// pruning. Returns `(j, i, cos)` with `j < i` for every pair with `cos ≥ t`. Bit-identical pair set
255/// to [`cosine_join_bruteforce`].
256///
257/// For each probe we accumulate a partial dot over shared *indexed* dims ([`accumulate`]), then for
258/// each touched candidate compute a Cauchy–Schwarz upper bound on the true cosine and skip the exact
259/// [`cos_full`] when it cannot reach `t` ([`verify_pruned`]). The bound is a filter only — survivors
260/// are scored exactly, so the output is byte-for-byte the brute-force result. On skewed data the
261/// bound prunes the ~99.9 % of candidates that collide on a single rare dim, so `cos_full` (the
262/// former 90 % hotspot) runs only on genuine near-matches.
263///
264/// The full inverted index is built once (postings ascending by id), then every vector is probed in
265/// **parallel**: probe `i` walks each posting only while `y < i` (postings are id-sorted), so it sees
266/// exactly the earlier vectors — each pair `(j, i)` with `j < i` is found once, from the larger id.
267/// This is the same candidate set the sequential index-as-you-go build produces, so the result is
268/// unchanged; the returned `Vec` is in arbitrary order (sort if a canonical order is needed).
269#[must_use]
270pub fn cosine_join(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
271 let n = c.n;
272 // Postings carry the indexed weight `(y, w_y[d])` so the scan can accumulate a partial dot.
273 let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
274 let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
275 for i in 0..n {
276 let (di, wi) = c.row(i);
277 index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
278 }
279 // Probe vectors in parallel; each worker keeps one reusable `Scratch` (an `n`-wide accumulator).
280 // `with_min_len` batches many probes per rayon task so the per-probe work (tiny) isn't dwarfed by
281 // task-splitting / scheduling overhead (`swtch_pri` in the profile).
282 (0..n)
283 .into_par_iter()
284 .with_min_len(256)
285 .map_init(
286 || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
287 |scratch, i| {
288 let (di, wi) = c.row(i);
289 accumulate(&index, (di, wi), i as u32, scratch);
290 let mut out = Vec::new();
291 verify_pruned(c, i, t, scratch, &cached, &mut out);
292 out
293 },
294 )
295 .flatten()
296 .collect()
297}
298
299/// Run the cosine join under a chosen [`Concurrency`] backend. Returns `(j, i, cos)` pairs with
300/// `j < i` and `cos ≥ t`, scores as `f64` (the `Gpu` mode's f32 cosines are widened losslessly).
301///
302/// - [`Concurrency::Cpu`] — [`cosine_join`]: exact `f64`, all-CPU, every platform.
303/// - [`Concurrency::GpuPlusCpu`] — exact `f64` hybrid: CPU generates survivor pairs, the GPU f32
304/// cosine *filters* the clear rejects, the CPU recomputes the exact `f64` score on what passes.
305/// **Byte-identical to `Cpu`**; both engines fully used. ~1.7–2× on bandwidth-bound real data.
306/// - [`Concurrency::Gpu`] — GPU-dominant `f32`: CPU generates survivor pairs, the GPU scores them and
307/// the result is emitted directly (no f64 re-verify). Fastest (~2×); differs from the exact answer
308/// only on pairs whose true cosine is within ~`1e-6` of `t` (measured: ≤1 pair in millions).
309///
310/// When the `gpu` feature is off, the target isn't macOS, or no Metal device can be acquired, the GPU
311/// modes transparently fall back to [`cosine_join`] (same as `Rationer`). This convenience entry
312/// **compiles + uploads the GPU corpus on every call** — fine for a one-shot join, but for repeated
313/// joins on one corpus build a [`CosineJoiner`] once and call [`CosineJoiner::join`], which holds the
314/// device + kernel + uploaded CSR across calls (and avoids the driver instability of compiling a
315/// Metal library hundreds of times in a tight loop).
316#[must_use]
317pub fn cosine_join_with(c: &Corpus, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
318 #[cfg(all(feature = "gpu", target_os = "macos"))]
319 {
320 if matches!(mode, Concurrency::Cpu) {
321 return cosine_join(c, t);
322 }
323 let (indptr, dims, wts) = c.csr_f32();
324 let Some(gpu) = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts) else {
325 return cosine_join(c, t); // no Metal device → CPU fallback
326 };
327 match mode {
328 Concurrency::GpuPlusCpu => cosine_join_gpu(c, t, &gpu),
329 // `Gpu`: emit the GPU f32 cosines directly (widened to f64), no exact re-verify.
330 Concurrency::Gpu => cosine_join_gpu_f32(c, t, &gpu)
331 .into_iter()
332 .map(|(a, b, s)| (a, b, f64::from(s)))
333 .collect(),
334 Concurrency::Cpu => unreachable!("handled above"),
335 }
336 }
337 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
338 {
339 let _ = mode; // GPU modes degrade to the CPU join when the feature is off / not macOS.
340 cosine_join(c, t)
341 }
342}
343
344/// A reusable cosine-join handle that owns the corpus and — under `feature = "gpu"` on macOS — the
345/// Metal device, compiled `batch_cosine` kernel, and the corpus CSR uploaded to unified memory, all
346/// acquired **once** at construction. Repeated [`join`](CosineJoiner::join)s at different thresholds
347/// then skip the per-call kernel compile + CSR upload that [`cosine_join_with`] pays (only the
348/// `t`-specific inverted index is rebuilt each call, on the CPU). Always constructible; degrades to
349/// the pure-CPU join when the `gpu` feature is off or no Metal device is available — mirroring
350/// `Rationer`. This is the right entry point for sweeping thresholds or joining repeatedly.
351pub struct CosineJoiner {
352 corpus: Corpus,
353 /// Owned GPU resources (device + kernel + CSR in UMA); `None` when no Metal device was acquired.
354 #[cfg(all(feature = "gpu", target_os = "macos"))]
355 gpu: Option<crate::simjoin_gpu::BatchCosineGpu>,
356}
357
358impl CosineJoiner {
359 /// Build a joiner over `corpus`, acquiring the GPU device + uploading the corpus CSR once if the
360 /// `gpu` feature is on and a Metal device is present.
361 #[must_use]
362 pub fn new(corpus: Corpus) -> Self {
363 #[cfg(all(feature = "gpu", target_os = "macos"))]
364 {
365 let (indptr, dims, wts) = corpus.csr_f32();
366 let gpu = crate::simjoin_gpu::BatchCosineGpu::new(&indptr, &dims, &wts);
367 Self { corpus, gpu }
368 }
369 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
370 {
371 Self { corpus }
372 }
373 }
374
375 /// The owned corpus (e.g. for `len()` or to run other queries against it).
376 #[must_use]
377 pub fn corpus(&self) -> &Corpus {
378 &self.corpus
379 }
380
381 /// Whether a Metal GPU backend was acquired. Always `false` without `feature = "gpu"` on macOS;
382 /// when `false`, every [`join`](CosineJoiner::join) runs on the CPU regardless of `mode`.
383 #[must_use]
384 pub fn has_gpu(&self) -> bool {
385 #[cfg(all(feature = "gpu", target_os = "macos"))]
386 {
387 self.gpu.is_some()
388 }
389 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
390 {
391 false
392 }
393 }
394
395 /// Run the join at threshold `t` under `mode`, reusing the handle's GPU resources. Returns the
396 /// same results as [`cosine_join_with`] (Cpu/GpuPlusCpu exact, Gpu f32→f64); falls back to the
397 /// CPU join when the GPU is unavailable.
398 #[must_use]
399 pub fn join(&self, t: f64, mode: Concurrency) -> Vec<(usize, usize, f64)> {
400 #[cfg(all(feature = "gpu", target_os = "macos"))]
401 {
402 match (mode, self.gpu.as_ref()) {
403 (Concurrency::GpuPlusCpu, Some(g)) => cosine_join_gpu(&self.corpus, t, g),
404 (Concurrency::Gpu, Some(g)) => cosine_join_gpu_f32(&self.corpus, t, g)
405 .into_iter()
406 .map(|(a, b, s)| (a, b, f64::from(s)))
407 .collect(),
408 _ => cosine_join(&self.corpus, t), // Cpu mode, or no Metal device
409 }
410 }
411 #[cfg(not(all(feature = "gpu", target_os = "macos")))]
412 {
413 let _ = mode;
414 cosine_join(&self.corpus, t)
415 }
416 }
417}
418
419/// FP slack for the prune bound: the Cauchy–Schwarz upper bound holds in exact arithmetic, but the
420/// accumulated dot and the `sqrt` norms each carry rounding error. We only *skip* `cos_full` when
421/// the bound is below `t` by more than this slack, so a true pair (exact cosine `≥ t`) is never
422/// pruned. Not skipping is always correctness-safe (just a wasted `cos_full`), so the slack trades a
423/// negligible number of extra verifies for safety and never changes the emitted pair set.
424/// `1e-9 ≫` the `~1e-15` accumulated error over ~15 terms.
425const PRUNE_SLACK: f64 = 1e-9;
426
427/// Phase 1 — accumulate: for each indexed dim of the probe, add `w_probe·w_y` into `acc[y]` for
428/// every **earlier** `y` (`y < cutoff`, the probe's own id) indexing that dim. Postings are id-sorted,
429/// so we `break` at the first `y ≥ cutoff`. Leaves `acc` `-1.0` everywhere except the touched ids
430/// (listed in `touched`, reset in [`verify_pruned`]). One scattered `acc[]` FMA per posting.
431#[cfg_attr(feature = "profiling", inline(never))]
432fn accumulate(index: &[Vec<(u32, f64)>], (di, wi): (&[u32], &[f64]), cutoff: u32, s: &mut Scratch) {
433 s.touched.clear();
434 for (&d, &w) in di.iter().zip(wi) {
435 for &(y, wy) in &index[d as usize] {
436 if y >= cutoff {
437 break;
438 }
439 let yu = y as usize;
440 // SAFETY: `y` is a vector id pushed by `index_suffix`, so `yu < n == acc.len()`.
441 let a = unsafe { s.acc.get_unchecked_mut(yu) };
442 if *a < 0.0 {
443 *a = 0.0;
444 s.touched.push(y);
445 }
446 *a += w * wy;
447 }
448 }
449}
450
451/// Phase 2 — prune + verify. For each touched candidate `y`, reset its accumulator and test the
452/// **L2AP `l2` bound**: the dot mass missing from `acc[y]` (the dims in `prefix(y)`) is at most
453/// `‖x_{rank<split[y]}‖ · ‖prefix(y)‖` by Cauchy–Schwarz, where `x_{rank<split[y]}` is the probe
454/// restricted to the rank range `prefix(y)` lives in. Since the probe's mass sits in its rare
455/// (high-rank) dims, that restricted norm is tiny — a far tighter cap than the whole-probe `‖x‖=1`.
456/// Only if `acc[y] + that bound ≥ t` (minus FP slack) do we score exactly with [`cos_full`]. Filter
457/// only — the emitted value is the exact dot, so the pair set is bit-identical to brute force.
458#[cfg_attr(feature = "profiling", inline(never))]
459fn verify_pruned(
460 c: &Corpus,
461 i: usize,
462 t: f64,
463 s: &mut Scratch,
464 cached: &Cached,
465 out: &mut Vec<(usize, usize, f64)>,
466) {
467 let (di, wi) = c.row(i);
468 // Probe prefix L2 norms: xpn[k] = ‖wi[..k]‖₂ (di ascending by rank, so xpn[k] = norm over the
469 // probe's k lowest-rank dims). One sqrt per probe dim, reused across all its candidates.
470 s.xpn.clear();
471 s.xpn.push(0.0);
472 let mut sq = 0.0f64;
473 for &w in wi {
474 sq += w * w;
475 s.xpn.push(sq.sqrt());
476 }
477 let need = t - PRUNE_SLACK;
478 let Scratch { acc, touched, xpn } = s;
479 // (Software-prefetching the candidate row a few ahead was tried + reverted: no measurable change
480 // — with all cores gathering at once the join is memory-*bandwidth*-bound, not per-access
481 // latency-bound, so prefetch can't add throughput.)
482 for &y in touched.iter() {
483 let yu = y as usize;
484 // SAFETY: `yu < n` (same provenance as in `accumulate`).
485 let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
486 // SAFETY: `yu < n`. One scattered load fetches both prune fields.
487 let bd = unsafe { *cached.bound.get_unchecked(yu) };
488 // Number of probe dims with rank < split[y] → index into xpn (di sorted ascending).
489 let kstar = di.partition_point(|&d| d < bd.split);
490 // SAFETY: kstar ≤ di.len() == wi.len() = xpn.len()-1.
491 // (An added maxweight cap `min(…, Σ wx·maxw)` was tried + reverted: never tighter than the
492 // L2 cap on either synthetic or real data — it sums over all probe-prefix dims, not just the
493 // shared ones — so it pruned nothing and cost ~16% on the real PyPI corpus.)
494 let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
495 if bound >= need {
496 let cos = cos_full((di, wi), c.row(yu));
497 if cos >= t {
498 out.push((yu, i, cos));
499 }
500 }
501 }
502}
503
504/// Phase 3 — index this vector's suffix: skip the leading prefix whose max possible contribution to
505/// any dot stays `< t` (`Σ w_k·maxw[dim_k] < t`), index only the rarer tail (short postings = the
506/// whole speed-up), and cache `pnorm[i] = ‖prefix‖₂` and `split[i]` = first indexed rank for the
507/// [`verify_pruned`] bound.
508#[cfg_attr(feature = "profiling", inline(never))]
509fn index_suffix(
510 c: &Corpus,
511 i: usize,
512 (di, wi): (&[u32], &[f64]),
513 t: f64,
514 index: &mut [Vec<(u32, f64)>],
515 cached: &mut Cached,
516) {
517 // Largest safe prefix under the maxweight bound: `Σ_{k<b} w_k·maxw[dim_k] < t` ⇒ the prefix
518 // can't contribute `t` to any dot (weights ≥ 0). (A norm-based extension `‖x_{<b}‖ < t` was
519 // tried and reverted: it gave zero candidate reduction in the realistic `t<1` regime — the
520 // maxweight bound already indexes less — and it is FP-fragile at `t=1` where `‖x‖` rounds below
521 // `1` and indexes nothing, dropping exact-duplicate pairs.)
522 let mut rs = 0.0f64;
523 let mut b = 0usize;
524 for k in 0..di.len() {
525 let bound = wi[k] * c.maxw[di[k] as usize];
526 if rs + bound >= t {
527 break;
528 }
529 rs += bound;
530 b = k + 1;
531 }
532 let mut p = 0.0f64;
533 for &w in &wi[..b] {
534 p += w * w;
535 }
536 cached.bound[i] = Bound {
537 pnorm: p.sqrt(),
538 split: if b < di.len() { di[b] } else { u32::MAX },
539 };
540 for k in b..di.len() {
541 index[di[k] as usize].push((i as u32, wi[k]));
542 }
543}
544
545/// FP margin for the GPU f32 cosine *filter* in [`cosine_join_gpu`]: a survivor is dropped only when
546/// its GPU f32 cosine is below `t` by more than this. The f32 dot's error is `~1e-6` relative, so a
547/// `1e-4` absolute margin never drops a true pair; the CPU then recomputes the exact `f64` score on
548/// whatever passes, so the emitted pair set + scores stay bit-identical to [`cosine_join`].
549#[cfg(all(target_os = "macos", feature = "gpu"))]
550const GPU_FILTER_MARGIN: f64 = 1e-4;
551
552/// Like the verify half of [`verify_pruned`], but instead of scoring, **collects** each surviving
553/// `(candidate, probe)` pair (`candidate < probe`) for batch scoring elsewhere. The bound here MUST
554/// stay identical to `verify_pruned`'s so the survivor set matches exactly.
555#[cfg(all(target_os = "macos", feature = "gpu"))]
556fn collect_survivors(c: &Corpus, i: usize, t: f64, s: &mut Scratch, cached: &Cached, out: &mut Vec<(u32, u32)>) {
557 let (di, wi) = c.row(i);
558 s.xpn.clear();
559 s.xpn.push(0.0);
560 let mut sq = 0.0f64;
561 for &w in wi {
562 sq += w * w;
563 s.xpn.push(sq.sqrt());
564 }
565 let need = t - PRUNE_SLACK;
566 let Scratch { acc, touched, xpn } = s;
567 for &y in touched.iter() {
568 let yu = y as usize;
569 // SAFETY: `yu < n` (same provenance as in `accumulate`).
570 let a = unsafe { std::mem::replace(acc.get_unchecked_mut(yu), -1.0) };
571 let bd = unsafe { *cached.bound.get_unchecked(yu) };
572 let kstar = di.partition_point(|&d| d < bd.split);
573 let bound = a + unsafe { xpn.get_unchecked(kstar) } * bd.pnorm;
574 if bound >= need {
575 out.push((y, i as u32)); // (candidate, probe), candidate < probe
576 }
577 }
578}
579
580/// **CPU+GPU hybrid join** (feature `gpu`, macOS). Returns the **exact same** pair set + scores as
581/// [`cosine_join`] (and thus the brute-force oracle) — only faster when the verify is bandwidth-bound.
582///
583/// Pipeline: CPU builds the index and, in parallel, accumulates + bounds every probe to a list of
584/// surviving `(candidate, probe)` pairs (no `cos_full`). The GPU then computes an f32 cosine for the
585/// whole batch (its memory-level parallelism clears the random-gather dots ~3× faster than the CPU),
586/// and the CPU recomputes the exact `f64` `cos_full` **only** on the pairs whose GPU score clears
587/// `t − margin` — typically a few percent of survivors. Because the GPU is a *conservative filter*
588/// (margin ≫ f32 error, so no true pair is ever dropped) and every emitted score is the exact CPU
589/// `f64` value, the output is byte-for-byte identical to [`cosine_join`].
590#[cfg(all(target_os = "macos", feature = "gpu"))]
591#[must_use]
592pub fn cosine_join_gpu(
593 c: &Corpus,
594 t: f64,
595 gpu: &crate::simjoin_gpu::BatchCosineGpu,
596) -> Vec<(usize, usize, f64)> {
597 let (pa, pb) = survivor_pairs(c, t);
598 if pa.is_empty() {
599 return Vec::new();
600 }
601 // GPU phase: f32 cosine over the whole survivor batch (conservative filter).
602 let gcos = gpu.cosine_batch(&pa, &pb);
603 let need = t - GPU_FILTER_MARGIN;
604 // CPU phase: exact f64 re-verify only on pairs the GPU filter passes.
605 (0..pa.len())
606 .into_par_iter()
607 .with_min_len(1024)
608 .filter_map(|k| {
609 if f64::from(gcos[k]) < need {
610 return None;
611 }
612 let (a, b) = (pa[k] as usize, pb[k] as usize);
613 let cos = cos_full(c.row(a), c.row(b));
614 (cos >= t).then_some((a, b, cos))
615 })
616 .collect()
617}
618
619/// **Pure-f32** CPU+GPU join (feature `gpu`, macOS): same survivor generation as [`cosine_join_gpu`]
620/// but emits the GPU's **f32** cosine directly, with **no exact f64 re-verify**. Trades byte-parity
621/// for speed (no re-verify, and `cos_full` never runs on the GPU survivors). The result differs from
622/// [`cosine_join`] only on pairs whose true cosine lies within ~`1e-6` (f32 rounding) of `t` — for a
623/// similarity join with an arbitrary threshold that is immaterial. Use when an ε-exact answer is
624/// acceptable; use [`cosine_join_gpu`] when bit-exactness is required.
625#[cfg(all(target_os = "macos", feature = "gpu"))]
626#[must_use]
627pub fn cosine_join_gpu_f32(
628 c: &Corpus,
629 t: f64,
630 gpu: &crate::simjoin_gpu::BatchCosineGpu,
631) -> Vec<(usize, usize, f32)> {
632 let (pa, pb) = survivor_pairs(c, t);
633 if pa.is_empty() {
634 return Vec::new();
635 }
636 let gcos = gpu.cosine_batch(&pa, &pb);
637 let tf = t as f32;
638 (0..pa.len())
639 .into_par_iter()
640 .with_min_len(1024)
641 .filter_map(|k| (gcos[k] >= tf).then_some((pa[k] as usize, pb[k] as usize, gcos[k])))
642 .collect()
643}
644
645/// CPU half shared by the GPU joins: build the index, then accumulate + bound every probe in
646/// parallel to the list of surviving `(candidate, probe)` pairs (candidate `<` probe), split into
647/// two `u32` arrays ready for [`crate::simjoin_gpu::BatchCosineGpu::cosine_batch`].
648#[cfg(all(target_os = "macos", feature = "gpu"))]
649fn survivor_pairs(c: &Corpus, t: f64) -> (Vec<u32>, Vec<u32>) {
650 let n = c.n;
651 let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
652 let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
653 for i in 0..n {
654 let (di, wi) = c.row(i);
655 index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
656 }
657 let pairs: Vec<(u32, u32)> = (0..n)
658 .into_par_iter()
659 .with_min_len(256)
660 .map_init(
661 || Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() },
662 |scratch, i| {
663 let (di, wi) = c.row(i);
664 accumulate(&index, (di, wi), i as u32, scratch);
665 let mut out = Vec::new();
666 collect_survivors(c, i, t, scratch, &cached, &mut out);
667 out
668 },
669 )
670 .flatten()
671 .collect();
672 pairs.into_iter().unzip()
673}
674
675/// Diagnostic (feature `profiling`, off the hot path): counts that quantify the prune. Returns
676/// `(candidates, survivors, pairs)` — candidates touched by the accumulator, survivors that pass the
677/// Cauchy–Schwarz bound (i.e. the `cos_full` calls actually made), and real pairs. `survivors /
678/// candidates` is the prune pass-rate (lower = better); `survivors` is the verify volume we pay for.
679#[cfg(feature = "profiling")]
680#[must_use]
681pub fn cosine_join_counts(c: &Corpus, t: f64) -> (u64, u64, u64) {
682 let n = c.n;
683 let mut index: Vec<Vec<(u32, f64)>> = vec![Vec::new(); c.ndims];
684 let mut cached = Cached { bound: vec![Bound { pnorm: 0.0, split: u32::MAX }; n] };
685 for i in 0..n {
686 let (di, wi) = c.row(i);
687 index_suffix(c, i, (di, wi), t, &mut index, &mut cached);
688 }
689 let mut s = Scratch { acc: vec![-1.0; n], touched: Vec::new(), xpn: Vec::new() };
690 let (mut ncand, mut survivors, mut pairs) = (0u64, 0u64, 0u64);
691 let need = t - PRUNE_SLACK;
692 for i in 0..n {
693 let (di, wi) = c.row(i);
694 accumulate(&index, (di, wi), i as u32, &mut s);
695 ncand += s.touched.len() as u64;
696 s.xpn.clear();
697 s.xpn.push(0.0);
698 let mut sq = 0.0f64;
699 for &w in wi {
700 sq += w * w;
701 s.xpn.push(sq.sqrt());
702 }
703 for &y in &s.touched {
704 let yu = y as usize;
705 let a = std::mem::replace(&mut s.acc[yu], -1.0);
706 let bd = cached.bound[yu];
707 let kstar = di.partition_point(|&d| d < bd.split);
708 if a + s.xpn[kstar] * bd.pnorm >= need {
709 survivors += 1;
710 if cos_full((di, wi), c.row(yu)) >= t {
711 pairs += 1;
712 }
713 }
714 }
715 }
716 (ncand, survivors, pairs)
717}
718
719/// Naive `O(n²)` oracle: score every pair with [`cos_full`], keep `cos ≥ t`. The correctness
720/// reference [`cosine_join`] is validated against.
721#[must_use]
722pub fn cosine_join_bruteforce(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)> {
723 let mut out: Vec<(usize, usize, f64)> = Vec::new();
724 for i in 0..c.n {
725 for j in 0..i {
726 let s = cos_full(c.row(i), c.row(j));
727 if s >= t {
728 out.push((j, i, s));
729 }
730 }
731 }
732 out
733}
734
735#[cfg(test)]
736mod tests {
737 use super::{cosine_join, cosine_join_bruteforce, Corpus};
738
739 fn xorshift(seed: u64) -> impl FnMut() -> u64 {
740 let mut s = seed;
741 move || {
742 s ^= s << 13;
743 s ^= s >> 7;
744 s ^= s << 17;
745 s
746 }
747 }
748
749 fn sort_pairs(mut v: Vec<(usize, usize, f64)>) -> Vec<(usize, usize, u64)> {
750 v.sort_by_key(|a| (a.0, a.1));
751 v.into_iter().map(|(a, b, s)| (a, b, s.to_bits())).collect()
752 }
753
754 #[test]
755 fn indexed_join_matches_bruteforce() {
756 let mut next = xorshift(0x9e37_79b9_7f4a_7c15);
757 for _ in 0..400 {
758 let n = (next() % 40 + 2) as usize;
759 let dim_space = next() % 15 + 1;
760 let rows: Vec<Vec<(u32, f64)>> = (0..n)
761 .map(|_| {
762 let nnz = (next() % 8) as usize;
763 (0..nnz)
764 .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
765 .collect()
766 })
767 .collect();
768 let c = Corpus::from_rows(&rows);
769 for &t in &[0.1_f64, 0.25, 0.5, 0.75, 0.9, 1.0] {
770 let got = sort_pairs(cosine_join(&c, t));
771 let want = sort_pairs(cosine_join_bruteforce(&c, t));
772 assert_eq!(got, want, "n={n} t={t}");
773 }
774 }
775 }
776
777 /// The CPU+GPU hybrid [`super::cosine_join_gpu`] must return bit-identical results to the pure-CPU
778 /// [`cosine_join`] on fuzzed corpora — the GPU is only a conservative filter; every emitted score
779 /// is the exact CPU `f64` value. Skips when no Metal device is present.
780 #[cfg(all(feature = "gpu", target_os = "macos"))]
781 #[test]
782 fn gpu_hybrid_matches_cpu() {
783 use super::CosineJoiner;
784 use crate::Concurrency;
785 let mut next = xorshift(0x1357_9bdf_0246_8ace);
786 for _ in 0..60 {
787 let n = (next() % 60 + 4) as usize;
788 let dim_space = next() % 20 + 2;
789 let rows: Vec<Vec<(u32, f64)>> = (0..n)
790 .map(|_| {
791 let nnz = (next() % 10) as usize;
792 (0..nnz)
793 .map(|_| ((next() % dim_space) as u32, (next() % 10 + 1) as f64))
794 .collect()
795 })
796 .collect();
797 // One reusable handle per corpus — `join` is called repeatedly across thresholds, which
798 // also exercises that the handle reuses its GPU resources (no per-call library compile).
799 let joiner = CosineJoiner::new(Corpus::from_rows(&rows));
800 if !joiner.has_gpu() {
801 eprintln!("no Metal device — skipping gpu_hybrid_matches_cpu");
802 return;
803 }
804 for &t in &[0.1_f64, 0.3, 0.5, 0.7, 0.9, 1.0] {
805 let want = sort_pairs(cosine_join(joiner.corpus(), t));
806 // Exact GPU+CPU hybrid is byte-identical to the plain join; `Cpu` mode too.
807 assert_eq!(
808 sort_pairs(joiner.join(t, Concurrency::GpuPlusCpu)),
809 want,
810 "GpuPlusCpu n={n} t={t}"
811 );
812 assert_eq!(sort_pairs(joiner.join(t, Concurrency::Cpu)), want, "Cpu n={n} t={t}");
813 }
814 }
815 }
816}