pub fn cosine_join(c: &Corpus, t: f64) -> Vec<(usize, usize, f64)>Expand description
Exact all-pairs cosine join via inverted index + L2AP prefix filtering and accumulation-time
pruning. Returns (j, i, cos) with j < i for every pair with cos ≥ t. Bit-identical pair set
to cosine_join_bruteforce.
For each probe we accumulate a partial dot over shared indexed dims ([accumulate]), then for
each touched candidate compute a Cauchy–Schwarz upper bound on the true cosine and skip the exact
[cos_full] when it cannot reach t ([verify_pruned]). The bound is a filter only — survivors
are scored exactly, so the output is byte-for-byte the brute-force result. On skewed data the
bound prunes the ~99.9 % of candidates that collide on a single rare dim, so cos_full (the
former 90 % hotspot) runs only on genuine near-matches.
The full inverted index is built once (postings ascending by id), then every vector is probed in
parallel: probe i walks each posting only while y < i (postings are id-sorted), so it sees
exactly the earlier vectors — each pair (j, i) with j < i is found once, from the larger id.
This is the same candidate set the sequential index-as-you-go build produces, so the result is
unchanged; the returned Vec is in arbitrary order (sort if a canonical order is needed).
Examples found in repository?
53fn main() {
54 let path: String = arg(1, "perf-local/pypi-type3.simjoin.bin".to_string());
55 let t: f64 = arg(2, 0.8);
56 let reps: usize = arg(3, 3);
57
58 let rows = load(&path);
59 let n = rows.len();
60 let nnz: usize = rows.iter().map(Vec::len).sum();
61
62 let b0 = Instant::now();
63 let corpus = Corpus::from_rows(&rows);
64 let build_ms = b0.elapsed().as_secs_f64() * 1000.0;
65
66 #[cfg(feature = "profiling")]
67 if std::env::var("STATS").is_ok() {
68 let (ncand, survivors, pairs) = cosine_join_counts(&corpus, t);
69 eprintln!(
70 "STATS n={n} t={t} | candidates={ncand} survivors(cos_full)={survivors} pairs={pairs} \
71 | prune_pass={:.4} survivor_precision={:.3}",
72 survivors as f64 / ncand.max(1) as f64,
73 pairs as f64 / survivors.max(1) as f64,
74 );
75 }
76
77 let mut ms: Vec<f64> = Vec::with_capacity(reps);
78 let mut npairs = 0usize;
79 for _ in 0..reps {
80 let t0 = Instant::now();
81 let pairs = cosine_join(&corpus, t);
82 ms.push(t0.elapsed().as_secs_f64() * 1000.0);
83 npairs = pairs.len();
84 std::hint::black_box(&pairs);
85 }
86 ms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
87 eprintln!(
88 "pypi-type3 n={n} nnz_total={nnz} mean_nnz={:.1} t={t} | build={build_ms:.0}ms | \
89 join: min={:.1}ms median={:.1}ms | pairs={npairs}",
90 nnz as f64 / n as f64,
91 ms[0],
92 ms[reps / 2],
93 );
94}More examples
74fn main() {
75 let n: usize = arg(1, 100_000);
76 let nnz: usize = arg(2, 14);
77 let ndims: usize = arg(3, 20_000);
78 let t: f64 = arg(4, 0.7);
79 let reps: usize = arg(5, 3);
80
81 let rows = gen(n, nnz, ndims, 0x1234_5678_9abc_def1);
82 let build0 = Instant::now();
83 let corpus = Corpus::from_rows(&rows);
84 let build_ms = build0.elapsed().as_secs_f64() * 1000.0;
85
86 // Strategy diagnostic (profiling builds only): posting touches / candidates / pairs. The
87 // candidates-per-pair ratio decides whether to prune harder or speed the dot up.
88 #[cfg(feature = "profiling")]
89 if std::env::var("STATS").is_ok() {
90 let (ncand, survivors, pairs) = cosine_join_counts(&corpus, t);
91 eprintln!(
92 "STATS n={n} t={t} | candidates={ncand} survivors(cos_full)={survivors} pairs={pairs} \
93 | prune_pass={:.4} cos_full_saved={:.4} survivor_precision={:.3}",
94 survivors as f64 / ncand.max(1) as f64,
95 1.0 - survivors as f64 / ncand.max(1) as f64,
96 pairs as f64 / survivors.max(1) as f64,
97 );
98 }
99
100 let mut ms: Vec<f64> = Vec::with_capacity(reps);
101 let mut npairs = 0usize;
102 for _ in 0..reps {
103 let t0 = Instant::now();
104 let pairs = cosine_join(&corpus, t);
105 ms.push(t0.elapsed().as_secs_f64() * 1000.0);
106 npairs = pairs.len();
107 std::hint::black_box(&pairs);
108 }
109 ms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
110 eprintln!(
111 "n={n} nnz={nnz} ndims={ndims} t={t} | build={build_ms:.0}ms | join: min={:.1}ms median={:.1}ms | pairs={npairs}",
112 ms[0],
113 ms[reps / 2],
114 );
115}