use ndarray::{Array1, Array2, ArrayView1};
use rand::SeedableRng;
use rand::rngs::StdRng;
use std::collections::{HashMap, HashSet};
const INDEX_HYPERPLANE_SALT: u64 = 0x9E37_79B9_7F4A_7C15;
const SKETCH_PROJECTION_SALT: u64 = 0xC2B2_AE3D_27D4_EB4F;
const DIRECTION_NORM_FLOOR: f64 = 1e-12;
pub const CANDIDATE_BUDGET_MIN: usize = 32;
pub const CANDIDATE_BUDGET_MAX: usize = 128;
pub fn auto_candidate_budget(num_atoms: usize) -> usize {
let log2 = if num_atoms <= 1 {
1
} else {
(usize::BITS - (num_atoms - 1).leading_zeros()) as usize
};
(8 * log2).clamp(CANDIDATE_BUDGET_MIN, CANDIDATE_BUDGET_MAX)
}
pub trait AtomFrameSketch {
fn sketch_dim(&self) -> usize;
fn output_dim(&self) -> usize;
fn num_atoms(&self) -> usize;
fn atom_sketch(&self, atom_id: usize) -> Array1<f64>;
fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64>;
fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64;
fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
let sketch_dim = self.sketch_dim();
let num = self.num_atoms();
if num == 0 {
return Array1::<f64>::zeros(sketch_dim);
}
let sample = ((num as f64).sqrt().ceil() as usize).clamp(1, num);
let stride = (num / sample).max(1);
let mut acc = Array1::<f64>::zeros(sketch_dim);
let mut count = 0usize;
let mut id = 0usize;
while id < num {
let q = self.project_direction(id, direction);
if q.len() == sketch_dim {
for (a, &v) in acc.iter_mut().zip(q.iter()) {
*a += v;
}
count += 1;
}
id += stride;
}
if count > 0 {
for a in acc.iter_mut() {
*a /= count as f64;
}
}
normalize_in_place(&mut acc);
acc
}
}
pub struct RandomProjectionFrameSketch {
frames: Vec<Array2<f64>>,
projection: Array2<f64>,
output_dim: usize,
sketch_dim: usize,
}
impl RandomProjectionFrameSketch {
pub fn from_decoder_blocks(
decoder_blocks: &[Array2<f64>],
sketch_dim: usize,
seed: u64,
) -> Result<Self, String> {
if decoder_blocks.is_empty() {
return Err("RandomProjectionFrameSketch: need at least one decoder block".into());
}
if sketch_dim == 0 {
return Err("RandomProjectionFrameSketch: sketch_dim must be positive".into());
}
let output_dim = decoder_blocks[0].nrows();
if output_dim == 0 {
return Err("RandomProjectionFrameSketch: output dimension must be positive".into());
}
for (k, block) in decoder_blocks.iter().enumerate() {
if block.nrows() != output_dim {
return Err(format!(
"RandomProjectionFrameSketch: atom {k} has {} output rows, expected {output_dim}",
block.nrows()
));
}
}
let frames: Vec<Array2<f64>> = decoder_blocks.iter().map(orthonormal_frame).collect();
let projection = gaussian_projection(sketch_dim, output_dim, seed ^ SKETCH_PROJECTION_SALT);
Ok(Self {
frames,
projection,
output_dim,
sketch_dim,
})
}
fn in_range_component(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
let frame = &self.frames[atom_id];
let mut comp = Array1::<f64>::zeros(self.output_dim);
for col in 0..frame.ncols() {
let u = frame.column(col);
let coord: f64 = u.iter().zip(direction.iter()).map(|(&a, &b)| a * b).sum();
for (c, &uval) in comp.iter_mut().zip(u.iter()) {
*c += coord * uval;
}
}
comp
}
}
impl AtomFrameSketch for RandomProjectionFrameSketch {
fn sketch_dim(&self) -> usize {
self.sketch_dim
}
fn output_dim(&self) -> usize {
self.output_dim
}
fn num_atoms(&self) -> usize {
self.frames.len()
}
fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
let frame = &self.frames[atom_id];
if frame.ncols() == 0 {
let mut s = self.projection.column(0).to_owned();
normalize_in_place(&mut s);
return s;
}
let u0 = frame.column(0);
let mut s = mat_vec(&self.projection, u0);
normalize_in_place(&mut s);
s
}
fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
let comp = self.in_range_component(atom_id, direction);
mat_vec(&self.projection, comp.view())
}
fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
let mut s = mat_vec(&self.projection, direction);
normalize_in_place(&mut s);
s
}
fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
let dnorm = vec_norm(direction);
if dnorm < DIRECTION_NORM_FLOOR {
return 0.0;
}
let comp = self.in_range_component(atom_id, direction);
(vec_norm(comp.view()) / dnorm).clamp(0.0, 1.0)
}
}
pub struct SaeCandidateIndex {
hyperplanes: Vec<Array2<f64>>,
tables: Vec<HashMap<u64, Vec<usize>>>,
sketch_dim: usize,
num_atoms: usize,
}
#[derive(Clone, Copy, Debug)]
pub struct IndexConfig {
pub num_tables: usize,
pub bits_per_table: usize,
pub multiprobe: bool,
pub seed: u64,
}
impl IndexConfig {
pub fn auto(sketch_dim: usize, num_atoms: usize, seed: u64) -> Self {
let log2 = |n: usize| -> usize {
if n <= 1 {
1
} else {
(usize::BITS - (n - 1).leading_zeros()) as usize
}
};
let bits = log2(num_atoms.max(2)).clamp(1, sketch_dim.max(1));
let num_tables = log2(num_atoms.max(2)).clamp(4, 16);
Self {
num_tables,
bits_per_table: bits,
multiprobe: true,
seed,
}
}
}
impl SaeCandidateIndex {
pub fn build<S: AtomFrameSketch>(sketch: &S, config: IndexConfig) -> Result<Self, String> {
let sketch_dim = sketch.sketch_dim();
if sketch_dim == 0 {
return Err("SaeCandidateIndex: sketch_dim must be positive".into());
}
if config.num_tables == 0 || config.bits_per_table == 0 {
return Err("SaeCandidateIndex: num_tables and bits_per_table must be positive".into());
}
let num_atoms = sketch.num_atoms();
let hyperplanes: Vec<Array2<f64>> = (0..config.num_tables)
.map(|t| {
let table_seed = mix_seed(config.seed ^ INDEX_HYPERPLANE_SALT, t as u64);
gaussian_projection(config.bits_per_table, sketch_dim, table_seed)
})
.collect();
let mut tables: Vec<HashMap<u64, Vec<usize>>> =
(0..config.num_tables).map(|_| HashMap::new()).collect();
for atom_id in 0..num_atoms {
let s = sketch.atom_sketch(atom_id);
if s.len() != sketch_dim {
return Err(format!(
"SaeCandidateIndex: atom {atom_id} sketch length {} != sketch_dim {sketch_dim}",
s.len()
));
}
for (table, bank) in tables.iter_mut().zip(hyperplanes.iter()) {
let sig = sign_signature(bank, s.view());
table.entry(sig).or_default().push(atom_id);
}
}
Ok(Self {
hyperplanes,
tables,
sketch_dim,
num_atoms,
})
}
pub fn num_atoms(&self) -> usize {
self.num_atoms
}
pub fn gather_candidates(&self, query_sketch: ArrayView1<f64>, multiprobe: bool) -> Vec<usize> {
let mut seen: HashSet<usize> = HashSet::new();
for (table, bank) in self.tables.iter().zip(self.hyperplanes.iter()) {
let (sig, margins) = sign_signature_with_margins(bank, query_sketch);
if let Some(ids) = table.get(&sig) {
seen.extend(ids.iter().copied());
}
if multiprobe {
let flip_bit = lowest_margin_bit(&margins);
let neighbour = sig ^ (1u64 << flip_bit);
if let Some(ids) = table.get(&neighbour) {
seen.extend(ids.iter().copied());
}
}
}
let mut out: Vec<usize> = seen.into_iter().collect();
out.sort_unstable();
out
}
pub fn propose<S: AtomFrameSketch>(
&self,
sketch: &S,
direction: ArrayView1<f64>,
candidate_budget: usize,
config_multiprobe: bool,
) -> Proposal {
let query_sketch = sketch.query_sketch(direction);
let gathered = if query_sketch.len() == self.sketch_dim {
self.gather_candidates(query_sketch.view(), config_multiprobe)
} else {
Vec::new()
};
let mut scored: Vec<(usize, f64)> = gathered
.iter()
.map(|&id| (id, sketch.alignment(id, direction)))
.collect();
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
let keep = candidate_budget.min(scored.len());
let proposed: Vec<usize> = scored[..keep].iter().map(|&(id, _)| id).collect();
let dropped_for_budget: Vec<usize> = scored[keep..].iter().map(|&(id, _)| id).collect();
Proposal {
proposed,
dropped_for_budget,
gathered_count: gathered.len(),
}
}
pub fn recall_report<S: AtomFrameSketch>(
&self,
sketch: &S,
rows: &[(Array1<f64>, Vec<usize>)],
candidate_budget: usize,
multiprobe: bool,
) -> RecallReport {
let mut total_planted: usize = 0;
let mut total_recovered: usize = 0;
let mut misses: Vec<RecallMiss> = Vec::new();
let mut total_gathered: usize = 0;
for (row_idx, (direction, planted)) in rows.iter().enumerate() {
let proposal = self.propose(sketch, direction.view(), candidate_budget, multiprobe);
total_gathered += proposal.gathered_count;
let proposed_set: HashSet<usize> = proposal.proposed.iter().copied().collect();
let dropped_set: HashSet<usize> = proposal.dropped_for_budget.iter().copied().collect();
for &atom in planted {
total_planted += 1;
if proposed_set.contains(&atom) {
total_recovered += 1;
} else {
let reason = if dropped_set.contains(&atom) {
MissReason::TruncatedByBudget
} else {
MissReason::NotGathered
};
misses.push(RecallMiss {
row: row_idx,
atom,
alignment: sketch.alignment(atom, direction.view()),
reason,
});
}
}
}
let recall = if total_planted == 0 {
1.0
} else {
total_recovered as f64 / total_planted as f64
};
let avg_gathered = if rows.is_empty() {
0.0
} else {
total_gathered as f64 / rows.len() as f64
};
RecallReport {
candidate_budget,
num_rows: rows.len(),
total_planted,
total_recovered,
recall,
avg_candidates_gathered: avg_gathered,
num_atoms: self.num_atoms,
misses,
}
}
}
#[derive(Clone, Debug)]
pub struct Proposal {
pub proposed: Vec<usize>,
pub dropped_for_budget: Vec<usize>,
pub gathered_count: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MissReason {
NotGathered,
TruncatedByBudget,
}
#[derive(Clone, Copy, Debug)]
pub struct RecallMiss {
pub row: usize,
pub atom: usize,
pub alignment: f64,
pub reason: MissReason,
}
#[derive(Clone, Debug)]
pub struct RecallReport {
pub candidate_budget: usize,
pub num_rows: usize,
pub total_planted: usize,
pub total_recovered: usize,
pub recall: f64,
pub avg_candidates_gathered: f64,
pub num_atoms: usize,
pub misses: Vec<RecallMiss>,
}
impl RecallReport {
pub fn sublinearity_ratio(&self) -> f64 {
if self.num_atoms == 0 {
0.0
} else {
self.avg_candidates_gathered / self.num_atoms as f64
}
}
}
#[inline]
fn mix_seed(base: u64, idx: u64) -> u64 {
let mut z = base.wrapping_add(idx.wrapping_mul(0x9E37_79B9_7F4A_7C15));
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn gaussian_projection(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
use rand::RngExt as _;
let mut rng = StdRng::seed_from_u64(seed);
let mut m = Array2::<f64>::zeros((rows, cols));
for r in 0..rows {
for c in 0..cols {
let u1 = rng.random::<f64>().max(1e-16);
let u2 = rng.random::<f64>();
m[(r, c)] = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
}
}
m
}
fn orthonormal_frame(block: &Array2<f64>) -> Array2<f64> {
let p = block.nrows();
let m = block.ncols();
let mut cols: Vec<Array1<f64>> = Vec::with_capacity(m);
for j in 0..m {
let mut v = block.column(j).to_owned();
for q in &cols {
let proj: f64 = q.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
for (vi, &qi) in v.iter_mut().zip(q.iter()) {
*vi -= proj * qi;
}
}
let nrm = vec_norm(v.view());
if nrm > DIRECTION_NORM_FLOOR {
for vi in v.iter_mut() {
*vi /= nrm;
}
cols.push(v);
}
}
let r = cols.len();
let mut u = Array2::<f64>::zeros((p, r));
for (j, col) in cols.into_iter().enumerate() {
u.column_mut(j).assign(&col);
}
u
}
fn mat_vec(m: &Array2<f64>, v: ArrayView1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(m.nrows());
for r in 0..m.nrows() {
let row = m.row(r);
out[r] = row.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
}
out
}
#[inline]
fn vec_norm(v: ArrayView1<f64>) -> f64 {
v.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
#[inline]
fn normalize_in_place(v: &mut Array1<f64>) {
let n = vec_norm(v.view());
if n > DIRECTION_NORM_FLOOR {
for x in v.iter_mut() {
*x /= n;
}
}
}
fn sign_signature(bank: &Array2<f64>, s: ArrayView1<f64>) -> u64 {
let mut sig = 0u64;
for r in 0..bank.nrows() {
let row = bank.row(r);
let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
if dot >= 0.0 {
sig |= 1u64 << r;
}
}
sig
}
fn sign_signature_with_margins(bank: &Array2<f64>, s: ArrayView1<f64>) -> (u64, Vec<f64>) {
let mut sig = 0u64;
let mut margins = Vec::with_capacity(bank.nrows());
for r in 0..bank.nrows() {
let row = bank.row(r);
let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
if dot >= 0.0 {
sig |= 1u64 << r;
}
margins.push(dot);
}
(sig, margins)
}
fn lowest_margin_bit(margins: &[f64]) -> usize {
let mut best = 0usize;
let mut best_abs = f64::INFINITY;
for (i, &m) in margins.iter().enumerate() {
let a = m.abs();
if a < best_abs {
best_abs = a;
best = i;
}
}
best
}
#[cfg(test)]
mod tests {
use super::*;
use rand::RngExt as _;
use rand::rngs::StdRng;
fn unit_vec(rng: &mut StdRng, p: usize) -> Array1<f64> {
let mut v = Array1::<f64>::zeros(p);
for x in v.iter_mut() {
let u1 = rng.random::<f64>().max(1e-16);
let u2 = rng.random::<f64>();
*x = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
}
let n = vec_norm(v.view());
if n > DIRECTION_NORM_FLOOR {
for x in v.iter_mut() {
*x /= n;
}
}
v
}
fn synthetic_dictionary(k: usize, p: usize, seed: u64) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
let mut rng = StdRng::seed_from_u64(seed);
let mut blocks = Vec::with_capacity(k);
let mut dirs = Vec::with_capacity(k);
for _ in 0..k {
let c = unit_vec(&mut rng, p);
let mut block = Array2::<f64>::zeros((p, 1));
block.column_mut(0).assign(&c);
blocks.push(block);
dirs.push(c);
}
(blocks, dirs)
}
#[test]
fn frame_alignment_is_exact_for_in_range_direction() {
let (blocks, dirs) = synthetic_dictionary(8, 16, 11);
let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 7).unwrap();
let d = &dirs[3];
let a = sketch.alignment(3, d.view());
assert!(a > 0.999, "in-range alignment should be ~1, got {a}");
let a_off = sketch.alignment(3, dirs[5].view());
assert!(
a_off < a,
"off-atom alignment {a_off} should be below in-range {a}"
);
}
#[test]
fn build_is_deterministic_for_a_fixed_seed() {
let (blocks, _) = synthetic_dictionary(64, 24, 99);
let s1 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
let s2 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
for i in 0..blocks.len() {
let a = s1.atom_sketch(i);
let b = s2.atom_sketch(i);
let diff = vec_norm((&a - &b).view());
assert!(
diff < 1e-12,
"atom {i} sketch differs across builds: {diff:e}"
);
}
let cfg = IndexConfig::auto(16, blocks.len(), 5);
let idx1 = SaeCandidateIndex::build(&s1, cfg).unwrap();
let idx2 = SaeCandidateIndex::build(&s2, cfg).unwrap();
for t in 0..idx1.tables.len() {
assert_eq!(idx1.tables[t].len(), idx2.tables[t].len());
}
}
#[test]
fn planted_atoms_are_recalled_above_floor_at_sublinear_budget() {
let k = 2000usize;
let p = 48usize;
let (blocks, dirs) = synthetic_dictionary(k, p, 2026);
let sketch_dim = 24usize;
let sketch =
RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 4242).unwrap();
let cfg = IndexConfig::auto(sketch_dim, k, 4242);
let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
let mut rng = StdRng::seed_from_u64(31337);
let n_rows = 200usize;
let mut rows: Vec<(Array1<f64>, Vec<usize>)> = Vec::with_capacity(n_rows);
for _ in 0..n_rows {
let primary = rng.random_range(0..k);
let secondary = rng.random_range(0..k);
let mut d = dirs[primary].clone();
for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
*di += 0.15 * si;
}
let n = vec_norm(d.view());
for di in d.iter_mut() {
*di /= n;
}
rows.push((d, vec![primary]));
}
let candidate_budget = 32usize;
let report = index.recall_report(&sketch, &rows, candidate_budget, cfg.multiprobe);
assert!(
report.sublinearity_ratio() < 0.5,
"gather was not sublinear: avg {} of {} atoms (ratio {:.3})",
report.avg_candidates_gathered,
report.num_atoms,
report.sublinearity_ratio()
);
let floor = 0.80;
assert!(
report.recall >= floor,
"recall {:.3} below floor {floor}; {} misses logged (first few: {:?})",
report.recall,
report.misses.len(),
report
.misses
.iter()
.take(5)
.map(|m| (m.row, m.atom, m.reason, m.alignment))
.collect::<Vec<_>>()
);
let recovered = report.total_recovered;
assert_eq!(
report.total_planted - recovered,
report.misses.len(),
"miss list must account for every unrecovered planted atom"
);
}
#[test]
fn auto_candidate_budget_tracks_the_issue_band() {
assert_eq!(auto_candidate_budget(2), CANDIDATE_BUDGET_MIN);
assert_eq!(auto_candidate_budget(64), 48);
assert_eq!(auto_candidate_budget(1024), 80);
assert_eq!(auto_candidate_budget(100_000), CANDIDATE_BUDGET_MAX);
let mut prev = 0usize;
for k in [2usize, 16, 64, 256, 1024, 4096, 65_536, 1_000_000] {
let c = auto_candidate_budget(k);
assert!(c >= prev, "budget must be monotone in K");
assert!((CANDIDATE_BUDGET_MIN..=CANDIDATE_BUDGET_MAX).contains(&c));
prev = c;
}
}
fn planted_rows(
dirs: &[Array1<f64>],
n_rows: usize,
seed: u64,
) -> Vec<(Array1<f64>, Vec<usize>)> {
let k = dirs.len();
let mut rng = StdRng::seed_from_u64(seed);
let mut rows = Vec::with_capacity(n_rows);
for _ in 0..n_rows {
let primary = rng.random_range(0..k);
let secondary = rng.random_range(0..k);
let mut d = dirs[primary].clone();
for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
*di += 0.15 * si;
}
let n = vec_norm(d.view());
for di in d.iter_mut() {
*di /= n;
}
rows.push((d, vec![primary]));
}
rows
}
#[test]
fn k_ladder_recall_determinism_and_sublinearity() {
let p = 48usize;
let n_rows = 150usize;
let mut ladder_ratios = Vec::new();
for &k in &[64usize, 1024] {
let (blocks, dirs) = synthetic_dictionary(k, p, 9000 + k as u64);
let sketch_dim = 24usize;
let sketch_seed = 71 + k as u64;
let sketch =
RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
.unwrap();
let cfg = IndexConfig::auto(sketch_dim, k, sketch_seed);
let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
let rows = planted_rows(&dirs, n_rows, 555 + k as u64);
let budget = auto_candidate_budget(k);
let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
let floor = 0.80;
assert!(
report.recall >= floor,
"K={k}: recall {:.3} below floor {floor}; {} misses (first: {:?})",
report.recall,
report.misses.len(),
report
.misses
.iter()
.take(3)
.map(|m| (m.row, m.atom, m.reason, m.alignment))
.collect::<Vec<_>>()
);
assert_eq!(
report.total_planted - report.total_recovered,
report.misses.len(),
"K={k}: miss list must account for every unrecovered planted atom"
);
let sketch2 =
RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
.unwrap();
let index2 = SaeCandidateIndex::build(&sketch2, cfg).unwrap();
for (direction, _) in rows.iter().take(20) {
let a = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
let b = index2.propose(&sketch2, direction.view(), budget, cfg.multiprobe);
assert_eq!(
a.proposed, b.proposed,
"K={k}: rebuild must propose identically"
);
}
for (direction, _) in rows.iter().take(20) {
let prop = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
assert!(prop.proposed.len() <= budget);
}
ladder_ratios.push((k, report.sublinearity_ratio()));
}
let (_, ratio_small) = ladder_ratios[0];
let (k_big, ratio_big) = ladder_ratios[1];
assert!(
ratio_big < ratio_small,
"sublinearity must improve along the ladder: {ladder_ratios:?}"
);
assert!(
ratio_big < 0.25,
"K={k_big}: gather touched {:.1}% of the dictionary",
ratio_big * 100.0
);
}
struct CountingSketch<'a> {
inner: &'a RandomProjectionFrameSketch,
project_calls: std::cell::Cell<usize>,
}
impl AtomFrameSketch for CountingSketch<'_> {
fn sketch_dim(&self) -> usize {
self.inner.sketch_dim()
}
fn output_dim(&self) -> usize {
self.inner.output_dim()
}
fn num_atoms(&self) -> usize {
self.inner.num_atoms()
}
fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
self.inner.atom_sketch(atom_id)
}
fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
self.project_calls.set(self.project_calls.get() + 1);
self.inner.project_direction(atom_id, direction)
}
fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
self.inner.alignment(atom_id, direction)
}
fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
self.inner.query_sketch(direction)
}
}
#[test]
fn query_probe_touches_no_atom_before_the_gather() {
let k = 512usize;
let p = 32usize;
let (blocks, dirs) = synthetic_dictionary(k, p, 77);
let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 13).unwrap();
let cfg = IndexConfig::auto(16, k, 13);
let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
let counting = CountingSketch {
inner: &sketch,
project_calls: std::cell::Cell::new(0),
};
let _ = index.propose(&counting, dirs[5].view(), 32, cfg.multiprobe);
assert_eq!(
counting.project_calls.get(),
0,
"the exact query probe must be independent of K: no per-atom \
projection before the gather (#994)"
);
}
fn coherent_cluster_dictionary(
n_clusters: usize,
cluster_size: usize,
p: usize,
spread: f64,
seed: u64,
) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
let mut rng = StdRng::seed_from_u64(seed);
let mut blocks = Vec::with_capacity(n_clusters * cluster_size);
let mut dirs = Vec::with_capacity(n_clusters * cluster_size);
for _ in 0..n_clusters {
let center = unit_vec(&mut rng, p);
for _ in 0..cluster_size {
let noise = unit_vec(&mut rng, p);
let mut c = center.clone();
for (ci, &ni) in c.iter_mut().zip(noise.iter()) {
*ci += spread * ni;
}
let n = vec_norm(c.view());
for ci in c.iter_mut() {
*ci /= n;
}
let mut block = Array2::<f64>::zeros((p, 1));
block.column_mut(0).assign(&c);
blocks.push(block);
dirs.push(c);
}
}
(blocks, dirs)
}
#[test]
fn coherent_clusters_are_recalled_with_the_exact_probe() {
let n_clusters = 32usize;
let cluster_size = 32usize;
let k = n_clusters * cluster_size;
let p = 48usize;
let (blocks, dirs) = coherent_cluster_dictionary(n_clusters, cluster_size, p, 0.25, 4242);
let sketch_dim = 24usize;
let sketch =
RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 99).unwrap();
let cfg = IndexConfig::auto(sketch_dim, k, 99);
let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
let rows = planted_rows(&dirs, 150, 31337);
let budget = auto_candidate_budget(k);
let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
let floor = 0.80;
assert!(
report.recall >= floor,
"coherent-cluster recall {:.3} below floor {floor}; {} misses (first: {:?})",
report.recall,
report.misses.len(),
report
.misses
.iter()
.take(3)
.map(|m| (m.row, m.atom, m.reason, m.alignment))
.collect::<Vec<_>>()
);
assert!(
report.sublinearity_ratio() < 0.5,
"cluster gather touched {:.1}% of the dictionary",
report.sublinearity_ratio() * 100.0
);
}
#[test]
fn exact_probe_matches_shared_projection_of_the_direction() {
let p = 16usize;
let mut rng = StdRng::seed_from_u64(5);
let d = unit_vec(&mut rng, p);
let mut block = Array2::<f64>::zeros((p, 1));
block.column_mut(0).assign(&d);
let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&[block], 8, 21).unwrap();
let via_probe = sketch.query_sketch(d.view());
let via_atom = sketch.atom_sketch(0);
let diff = vec_norm((&via_probe - &via_atom).view());
assert!(
diff < 1e-10,
"query_sketch(d) must equal the rank-1 atom representative of d: diff {diff:e}"
);
}
#[test]
fn empty_planted_rows_report_perfect_recall() {
let (blocks, dirs) = synthetic_dictionary(32, 16, 1);
let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 3).unwrap();
let cfg = IndexConfig::auto(12, 32, 3);
let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
let rows = vec![(dirs[0].clone(), Vec::<usize>::new())];
let report = index.recall_report(&sketch, &rows, 8, true);
assert_eq!(report.recall, 1.0);
assert!(report.misses.is_empty());
}
}