gam_sae/candidate_index.rs
1//! Sublinear candidate-atom index for active-set proposal (#985 part 1).
2//!
3//! A frontier SAE dictionary holds `K ≈ 10^4–10^5` atoms. The per-row *local*
4//! block — the small linear/Newton system over the atoms that are actually
5//! active in a row — is cheap, because the active set collapses it to a handful
6//! of atoms. The expensive step is *proposing* that active set: a naive scan
7//! scores every one of the `K` atom frames against every row, which is `O(K)`
8//! per row and dominates the whole solve once `K` is large.
9//!
10//! This module builds a **sublinear** candidate index over per-atom *sketches*
11//! of each atom's decoder column-space (its Grassmann frame `U_k`). Given a row
12//! residual direction it returns the top candidate atom ids likely to be
13//! active, touching only `O(log K)`-ish buckets instead of all `K` atoms.
14//!
15//! ## Layering against Track 1
16//!
17//! Track 1 owns the *real* atom frames `U_k` and has not landed yet, so this
18//! module is written against a [`AtomFrameSketch`] trait. Any frame source —
19//! the eventual Grassmann frames, or the decoder column blocks `B_k` already
20//! present on [`crate::manifold::SaeManifoldAtom`] — can implement
21//! it. A concrete, dependency-free default
22//! ([`RandomProjectionFrameSketch`]) is provided: a seeded random-projection /
23//! random-hyperplane signature of the atom's orthonormalized column span. The
24//! index ([`SaeCandidateIndex`]) is a deterministic multi-table
25//! random-hyperplane LSH over those sketches.
26//!
27//! ## Recall contract
28//!
29//! Sublinear proposal is only safe if it *almost never* drops a truly-active
30//! atom. [`SaeCandidateIndex::recall_report`] takes a set of planted
31//! truly-active atoms per row, runs the proposal at a stated candidate budget,
32//! and records the rate at which planted atoms appear in the proposed set —
33//! **logging every miss** rather than silently truncating. The returned
34//! [`RecallReport`] carries `recall@budget` and the full miss list so a caller
35//! can widen the budget or fall back to a dense scan for the affected rows.
36//!
37//! Determinism: every random choice is seeded by an explicit index seed; no
38//! clock, no global RNG.
39
40use ndarray::{Array1, Array2, ArrayView1};
41use rand::SeedableRng;
42use rand::rngs::StdRng;
43use std::collections::{HashMap, HashSet};
44
45/// Salt mixed into the per-table hyperplane seed so the index tables and the
46/// default sketch never share a random stream even when handed the same base
47/// seed.
48const INDEX_HYPERPLANE_SALT: u64 = 0x9E37_79B9_7F4A_7C15;
49
50/// Salt for the default random-projection sketch's projection matrix.
51const SKETCH_PROJECTION_SALT: u64 = 0xC2B2_AE3D_27D4_EB4F;
52
53/// Numerical floor below which a direction / column is treated as zero.
54const DIRECTION_NORM_FLOOR: f64 = 1e-12;
55
56/// Lower bound of the auto-derived per-row candidate budget `C` (#985). Below
57/// this the proposal set is too small for the solver's accepted active set to
58/// have headroom over the planted/active atom count.
59pub const CANDIDATE_BUDGET_MIN: usize = 32;
60
61/// Upper bound of the auto-derived per-row candidate budget `C` (#985). The
62/// per-row local block stays a small dense solve no matter how large the
63/// dictionary grows; beyond this the proposal step stops being the bottleneck
64/// reduction it exists to be.
65pub const CANDIDATE_BUDGET_MAX: usize = 128;
66
67/// Auto-derive the per-row candidate budget `C` from the dictionary size `K`
68/// (#985): `C = 8·⌈log₂ K⌉`, clamped to
69/// [[`CANDIDATE_BUDGET_MIN`], [`CANDIDATE_BUDGET_MAX`]]. Logarithmic growth
70/// keeps the per-row local block effectively constant-size while giving larger
71/// dictionaries a little more recall headroom; the clamp realizes the issue's
72/// `C ≈ 32–128` band. Magic-by-default: derived from `K` alone, no flag.
73///
74/// Concretely: `K = 64 → 48`, `K = 1024 → 80`, `K = 10⁵ → 128`.
75pub fn auto_candidate_budget(num_atoms: usize) -> usize {
76 let log2 = if num_atoms <= 1 {
77 1
78 } else {
79 (usize::BITS - (num_atoms - 1).leading_zeros()) as usize
80 };
81 (8 * log2).clamp(CANDIDATE_BUDGET_MIN, CANDIDATE_BUDGET_MAX)
82}
83
84// ---------------------------------------------------------------------------
85// Sketch interface
86// ---------------------------------------------------------------------------
87
88/// A low-dimensional sketch of one atom's decoder column-space (its Grassmann
89/// frame `U_k`).
90///
91/// The index never needs the full frame: it only needs (a) the sketch
92/// dimension, shared by every atom in a dictionary, and (b), for any query
93/// direction in output space, the atom's *sketch coordinates* of that direction
94/// — i.e. the projection of the direction onto the atom's column-space,
95/// expressed in the sketch's coordinates. A frame `U_k` (orthonormal columns
96/// spanning the decoder range) yields these as `sketch = R · (U_kᵀ d)` for a
97/// shared random projection `R`; a raw decoder block `B_k` yields them by first
98/// orthonormalizing its columns. Both are valid implementors.
99pub trait AtomFrameSketch {
100 /// Dimension of the sketch vectors this implementor produces. Must be the
101 /// same positive value for every atom in one dictionary so the index can
102 /// build a single hyperplane bank.
103 fn sketch_dim(&self) -> usize;
104
105 /// Dimension of the ambient output space the query directions live in.
106 fn output_dim(&self) -> usize;
107
108 /// Number of atoms this source can sketch.
109 fn num_atoms(&self) -> usize;
110
111 /// Sketch of atom `atom_id`'s *frame itself* (a representative point of the
112 /// atom's column-space on the sphere of sketch space), used to place the
113 /// atom into the LSH tables at build time. Returns a vector of length
114 /// [`AtomFrameSketch::sketch_dim`].
115 fn atom_sketch(&self, atom_id: usize) -> Array1<f64>;
116
117 /// Sketch of a query *direction* `d` (length [`AtomFrameSketch::output_dim`])
118 /// as seen through atom `atom_id`'s frame: the direction's component inside
119 /// the atom's column-space, mapped into sketch coordinates. Used at query
120 /// time to score how strongly a row residual aligns with the atom.
121 fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64>;
122
123 /// Alignment score in `[0, 1]`: the fraction of the query direction's energy
124 /// that lies inside atom `atom_id`'s column-space. `1.0` means the direction
125 /// lies fully in the atom's range, `0.0` means it is orthogonal. Used to
126 /// rank the (small) candidate set the index returns.
127 fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64;
128
129 /// Sketch-space **probe** for a raw query direction (length
130 /// [`AtomFrameSketch::sketch_dim`]), comparable to the
131 /// [`AtomFrameSketch::atom_sketch`] representatives the LSH tables were
132 /// built from (#994).
133 ///
134 /// Implementors must return the exact cosine-LSH probe for their sketching
135 /// policy. For the shared-projection sketch this is `normalize(R · d)`,
136 /// `O(p · s)` per query, touching no atom.
137 fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64>;
138}
139
140// ---------------------------------------------------------------------------
141// Default concrete sketch: seeded random projection of the column span
142// ---------------------------------------------------------------------------
143
144/// A concrete [`AtomFrameSketch`] built from raw decoder column blocks `B_k`.
145///
146/// For each atom it orthonormalizes the decoder columns (modified Gram–Schmidt)
147/// to obtain a frame `U_k` with orthonormal columns spanning the decoder range,
148/// then sketches via a single shared seeded Gaussian random projection
149/// `R ∈ ℝ^{s×p}` applied to the in-range component of a direction:
150///
151/// * `atom_sketch(k) = normalize( R · u_k0 )`, the sketch of the atom's first
152/// (dominant) frame column — a stable representative point used to bucket the
153/// atom.
154/// * `project_direction(k, d) = R · (U_k U_kᵀ d)`, the sketch of the part of `d`
155/// that lies in the atom's range.
156/// * `alignment(k, d) = ‖U_kᵀ d‖ / ‖d‖`, the exact in-range energy fraction.
157///
158/// The shared `R` is a Johnson–Lindenstrauss style random projection, so sketch
159/// inner products approximately preserve angles between in-range directions —
160/// exactly what the LSH index needs. Everything is seeded; the same atoms +
161/// seed always produce the same sketches.
162pub struct RandomProjectionFrameSketch {
163 /// Orthonormal frame `U_k` per atom, shape `(p, r_k)` with `r_k` ≤ columns.
164 frames: Vec<Array2<f64>>,
165 /// Shared random projection `R`, shape `(sketch_dim, p)`.
166 projection: Array2<f64>,
167 /// Ambient output dimension `p`.
168 output_dim: usize,
169 /// Sketch dimension `s`.
170 sketch_dim: usize,
171}
172
173impl RandomProjectionFrameSketch {
174 /// Build the sketch from decoder column blocks.
175 ///
176 /// `decoder_blocks[k]` is `B_k` with shape `(p, m_k)`: `p` rows in output
177 /// space, `m_k` decoder columns for atom `k`. (`SaeManifoldAtom` stores the
178 /// transpose `(m_k, p)`; orient it `p`-rows before passing in.) All blocks
179 /// must share the same `p`. `sketch_dim` is the target sketch length `s`;
180 /// `seed` makes the projection deterministic.
181 pub fn from_decoder_blocks(
182 decoder_blocks: &[Array2<f64>],
183 sketch_dim: usize,
184 seed: u64,
185 ) -> Result<Self, String> {
186 if decoder_blocks.is_empty() {
187 return Err("RandomProjectionFrameSketch: need at least one decoder block".into());
188 }
189 if sketch_dim == 0 {
190 return Err("RandomProjectionFrameSketch: sketch_dim must be positive".into());
191 }
192 let output_dim = decoder_blocks[0].nrows();
193 if output_dim == 0 {
194 return Err("RandomProjectionFrameSketch: output dimension must be positive".into());
195 }
196 for (k, block) in decoder_blocks.iter().enumerate() {
197 if block.nrows() != output_dim {
198 return Err(format!(
199 "RandomProjectionFrameSketch: atom {k} has {} output rows, expected {output_dim}",
200 block.nrows()
201 ));
202 }
203 }
204
205 let frames: Vec<Array2<f64>> = decoder_blocks.iter().map(orthonormal_frame).collect();
206
207 let projection = gaussian_projection(sketch_dim, output_dim, seed ^ SKETCH_PROJECTION_SALT);
208
209 Ok(Self {
210 frames,
211 projection,
212 output_dim,
213 sketch_dim,
214 })
215 }
216
217 /// In-range component `U_k U_kᵀ d` of a direction (length `output_dim`).
218 fn in_range_component(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
219 let frame = &self.frames[atom_id];
220 // coords = U_kᵀ d (length r_k)
221 let mut comp = Array1::<f64>::zeros(self.output_dim);
222 for col in 0..frame.ncols() {
223 let u = frame.column(col);
224 let coord: f64 = u.iter().zip(direction.iter()).map(|(&a, &b)| a * b).sum();
225 for (c, &uval) in comp.iter_mut().zip(u.iter()) {
226 *c += coord * uval;
227 }
228 }
229 comp
230 }
231}
232
233impl AtomFrameSketch for RandomProjectionFrameSketch {
234 fn sketch_dim(&self) -> usize {
235 self.sketch_dim
236 }
237
238 fn output_dim(&self) -> usize {
239 self.output_dim
240 }
241
242 fn num_atoms(&self) -> usize {
243 self.frames.len()
244 }
245
246 fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
247 let frame = &self.frames[atom_id];
248 // Sketch the dominant (first) frame column as the atom's representative.
249 // If the frame is empty (rank-0 atom), fall back to a deterministic
250 // nonzero point so the atom is still bucketed somewhere.
251 if frame.ncols() == 0 {
252 let mut s = self.projection.column(0).to_owned();
253 normalize_in_place(&mut s);
254 return s;
255 }
256 let u0 = frame.column(0);
257 let mut s = mat_vec(&self.projection, u0);
258 normalize_in_place(&mut s);
259 s
260 }
261
262 fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
263 let comp = self.in_range_component(atom_id, direction);
264 mat_vec(&self.projection, comp.view())
265 }
266
267 /// Exact `O(p·s)` probe (#994): every atom shares the one projection `R`,
268 /// and the table representatives are `normalize(R · u_k0)`, so the correct
269 /// cosine-LSH probe for a direction is simply `normalize(R · d)` — no atom
270 /// is touched, and no masked-average approximation is involved.
271 fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
272 let mut s = mat_vec(&self.projection, direction);
273 normalize_in_place(&mut s);
274 s
275 }
276
277 fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
278 let dnorm = vec_norm(direction);
279 if dnorm < DIRECTION_NORM_FLOOR {
280 return 0.0;
281 }
282 let comp = self.in_range_component(atom_id, direction);
283 (vec_norm(comp.view()) / dnorm).clamp(0.0, 1.0)
284 }
285}
286
287// ---------------------------------------------------------------------------
288// Sublinear index: multi-table random-hyperplane LSH over sketches
289// ---------------------------------------------------------------------------
290
291/// A deterministic, sublinear candidate index over atom-frame sketches.
292///
293/// The structure is a **random-hyperplane LSH** with `num_tables` independent
294/// tables, each defined by `bits_per_table` seeded random hyperplanes in sketch
295/// space. An atom's sketch is reduced to a `bits_per_table`-bit sign signature
296/// per table (the sign of its dot with each hyperplane), and the atom id is
297/// stored in the bucket keyed by that signature. At query time the query
298/// direction is sketched *through each atom's frame*; we instead hash the *query
299/// sketch* per table and gather the union of atoms in the matching (and, to
300/// improve recall, the Hamming-1 neighbouring) buckets. Because each table
301/// touches only the atoms colliding in one bucket, total work is sublinear in
302/// `K` for well-spread sketches.
303///
304/// The gathered candidates are then ranked by exact
305/// [`AtomFrameSketch::alignment`] and the top `candidate_budget` are returned.
306/// All hyperplanes are seeded; building twice with the same seed yields byte-
307/// identical tables.
308pub struct SaeCandidateIndex {
309 /// Hyperplane banks, one per table: each `(bits_per_table, sketch_dim)`.
310 hyperplanes: Vec<Array2<f64>>,
311 /// Buckets per table: signature -> atom ids.
312 tables: Vec<HashMap<u64, Vec<usize>>>,
313 /// Sketch dimension shared by every atom.
314 sketch_dim: usize,
315 /// Number of atoms indexed.
316 num_atoms: usize,
317}
318
319/// Tuning for [`SaeCandidateIndex::build`]. All fields are explicit so the index
320/// never reads global state; no CLI flags.
321#[derive(Clone, Copy, Debug)]
322pub struct IndexConfig {
323 /// Number of independent LSH tables. More tables → higher recall, more work.
324 pub num_tables: usize,
325 /// Random hyperplanes per table (signature bit-width). More bits → finer
326 /// buckets (fewer collisions, lower recall per table).
327 pub bits_per_table: usize,
328 /// Whether to also probe Hamming-distance-1 neighbouring buckets per table
329 /// (multi-probe LSH). Cheap and a large recall win; kept on by default.
330 pub multiprobe: bool,
331 /// Master seed for all hyperplane banks.
332 pub seed: u64,
333}
334
335impl IndexConfig {
336 /// A default configuration sized for a sketch of dimension `sketch_dim` and
337 /// roughly `num_atoms` atoms. Chooses `bits_per_table ≈ log2(num_atoms)`
338 /// (capped by the sketch dimension) so the expected bucket occupancy is a
339 /// small constant, and a handful of tables for recall — both grow only
340 /// logarithmically in `num_atoms`, keeping queries sublinear.
341 pub fn auto(sketch_dim: usize, num_atoms: usize, seed: u64) -> Self {
342 let log2 = |n: usize| -> usize {
343 if n <= 1 {
344 1
345 } else {
346 (usize::BITS - (n - 1).leading_zeros()) as usize
347 }
348 };
349 // Cap at 63: sign_signature packs bits into a u64, so bits_per_table must be ≤ 63.
350 let bits = log2(num_atoms.max(2)).clamp(1, sketch_dim.max(1).min(63));
351 // Aim for ~constant per-bucket occupancy; a few tables recover recall
352 // lost to any single table's quantization.
353 let num_tables = log2(num_atoms.max(2)).clamp(4, 16);
354 Self {
355 num_tables,
356 bits_per_table: bits,
357 multiprobe: true,
358 seed,
359 }
360 }
361}
362
363impl SaeCandidateIndex {
364 /// Build the index over every atom of `sketch`.
365 pub fn build<S: AtomFrameSketch>(sketch: &S, config: IndexConfig) -> Result<Self, String> {
366 let sketch_dim = sketch.sketch_dim();
367 if sketch_dim == 0 {
368 return Err("SaeCandidateIndex: sketch_dim must be positive".into());
369 }
370 if config.num_tables == 0 || config.bits_per_table == 0 {
371 return Err("SaeCandidateIndex: num_tables and bits_per_table must be positive".into());
372 }
373 // sign_signature packs bits into a u64 with `1u64 << r` for r in 0..bits_per_table.
374 // Shifting by 64+ is a panic in debug and undefined behaviour in release; cap at 63.
375 if config.bits_per_table > 63 {
376 return Err(format!(
377 "SaeCandidateIndex: bits_per_table {} exceeds 63 (u64 signature limit)",
378 config.bits_per_table
379 ));
380 }
381 let num_atoms = sketch.num_atoms();
382
383 // One seeded hyperplane bank per table; seed is mixed per-table so the
384 // tables are independent yet fully reproducible.
385 let hyperplanes: Vec<Array2<f64>> = (0..config.num_tables)
386 .map(|t| {
387 let table_seed = mix_seed(config.seed ^ INDEX_HYPERPLANE_SALT, t as u64);
388 gaussian_projection(config.bits_per_table, sketch_dim, table_seed)
389 })
390 .collect();
391
392 let mut tables: Vec<HashMap<u64, Vec<usize>>> =
393 (0..config.num_tables).map(|_| HashMap::new()).collect();
394
395 for atom_id in 0..num_atoms {
396 let s = sketch.atom_sketch(atom_id);
397 if s.len() != sketch_dim {
398 return Err(format!(
399 "SaeCandidateIndex: atom {atom_id} sketch length {} != sketch_dim {sketch_dim}",
400 s.len()
401 ));
402 }
403 for (table, bank) in tables.iter_mut().zip(hyperplanes.iter()) {
404 let sig = sign_signature(bank, s.view());
405 table.entry(sig).or_default().push(atom_id);
406 }
407 }
408
409 Ok(Self {
410 hyperplanes,
411 tables,
412 sketch_dim,
413 num_atoms,
414 })
415 }
416
417 /// Number of atoms in the index.
418 pub fn num_atoms(&self) -> usize {
419 self.num_atoms
420 }
421
422 /// Gather the raw candidate atom-id set for a query `direction`, *without*
423 /// ranking or budget truncation. This is the sublinear part: it sketches the
424 /// query once per table (using a frame-agnostic global query sketch — the
425 /// query direction projected by the index's own representative projection)
426 /// and unions the colliding buckets (plus Hamming-1 neighbours when
427 /// multi-probe is enabled).
428 ///
429 /// `query_sketch` is the sketch-space query vector (length `sketch_dim`),
430 /// produced by the caller from the row residual via the
431 /// [`AtomFrameSketch`]. We probe each table with this single vector.
432 pub fn gather_candidates(&self, query_sketch: ArrayView1<f64>, multiprobe: bool) -> Vec<usize> {
433 let mut seen: HashSet<usize> = HashSet::new();
434 for (table, bank) in self.tables.iter().zip(self.hyperplanes.iter()) {
435 let (sig, margins) = sign_signature_with_margins(bank, query_sketch);
436 if let Some(ids) = table.get(&sig) {
437 seen.extend(ids.iter().copied());
438 }
439 if multiprobe {
440 // Flip the lowest-margin bit (the one most likely to be on the
441 // wrong side of its hyperplane) to reach the nearest neighbour
442 // bucket — standard multi-probe LSH, biggest recall win.
443 let flip_bit = lowest_margin_bit(&margins);
444 let neighbour = sig ^ (1u64 << flip_bit);
445 if let Some(ids) = table.get(&neighbour) {
446 seen.extend(ids.iter().copied());
447 }
448 }
449 }
450 let mut out: Vec<usize> = seen.into_iter().collect();
451 out.sort_unstable();
452 out
453 }
454
455 /// Propose the top `candidate_budget` atoms for a row whose residual is
456 /// `direction` (length `sketch.output_dim()`), ranked by exact frame
457 /// alignment.
458 ///
459 /// Pipeline: probe with [`AtomFrameSketch::query_sketch`] (`O(p·s)` for
460 /// shared-projection sketches, #994 — no atom is touched before the
461 /// gather), gather the sublinear candidate union, score each by
462 /// [`AtomFrameSketch::alignment`], and keep the highest-scoring
463 /// `candidate_budget`.
464 ///
465 /// Returns `(proposed_ids, dropped_for_budget)` where the second element
466 /// lists every gathered candidate that was truncated by the budget (never
467 /// silently discarded).
468 pub fn propose<S: AtomFrameSketch>(
469 &self,
470 sketch: &S,
471 direction: ArrayView1<f64>,
472 candidate_budget: usize,
473 config_multiprobe: bool,
474 ) -> Proposal {
475 let query_sketch = sketch.query_sketch(direction);
476 let gathered = if query_sketch.len() == self.sketch_dim {
477 self.gather_candidates(query_sketch.view(), config_multiprobe)
478 } else {
479 // A probe of the wrong dimension cannot be hashed against the
480 // tables; gather nothing rather than hash garbage. The recall
481 // report will then attribute every planted atom to `NotGathered`,
482 // which is the loud, attributable failure mode.
483 Vec::new()
484 };
485
486 // Exact-score every gathered candidate by frame alignment.
487 let mut scored: Vec<(usize, f64)> = gathered
488 .iter()
489 .map(|&id| (id, sketch.alignment(id, direction)))
490 .collect();
491 // Descending by alignment; ties broken by id for determinism.
492 scored.sort_by(|a, b| {
493 b.1.partial_cmp(&a.1)
494 .unwrap_or(std::cmp::Ordering::Equal)
495 .then(a.0.cmp(&b.0))
496 });
497
498 let keep = candidate_budget.min(scored.len());
499 let proposed: Vec<usize> = scored[..keep].iter().map(|&(id, _)| id).collect();
500 let dropped_for_budget: Vec<usize> = scored[keep..].iter().map(|&(id, _)| id).collect();
501
502 Proposal {
503 proposed,
504 dropped_for_budget,
505 gathered_count: gathered.len(),
506 }
507 }
508
509 /// Recall contract. For a set of rows, each with planted truly-active atom
510 /// ids and a residual direction, run [`SaeCandidateIndex::propose`] at the
511 /// given `candidate_budget` and record what fraction of planted atoms
512 /// appear in the proposed set. Every miss is logged — no silent truncation.
513 ///
514 /// `rows` is `(direction, planted_active_ids)` per row.
515 pub fn recall_report<S: AtomFrameSketch>(
516 &self,
517 sketch: &S,
518 rows: &[(Array1<f64>, Vec<usize>)],
519 candidate_budget: usize,
520 multiprobe: bool,
521 ) -> RecallReport {
522 let mut total_planted: usize = 0;
523 let mut total_recovered: usize = 0;
524 let mut misses: Vec<RecallMiss> = Vec::new();
525 let mut total_gathered: usize = 0;
526
527 for (row_idx, (direction, planted)) in rows.iter().enumerate() {
528 let proposal = self.propose(sketch, direction.view(), candidate_budget, multiprobe);
529 total_gathered += proposal.gathered_count;
530 let proposed_set: HashSet<usize> = proposal.proposed.iter().copied().collect();
531 // A candidate that was gathered but truncated by the budget counts
532 // as a miss *attributable to the budget*; one never gathered at all
533 // is a miss *attributable to the index*. We record both, flagged.
534 let dropped_set: HashSet<usize> = proposal.dropped_for_budget.iter().copied().collect();
535
536 for &atom in planted {
537 total_planted += 1;
538 if proposed_set.contains(&atom) {
539 total_recovered += 1;
540 } else {
541 let reason = if dropped_set.contains(&atom) {
542 MissReason::TruncatedByBudget
543 } else {
544 MissReason::NotGathered
545 };
546 misses.push(RecallMiss {
547 row: row_idx,
548 atom,
549 alignment: sketch.alignment(atom, direction.view()),
550 reason,
551 });
552 }
553 }
554 }
555
556 let recall = if total_planted == 0 {
557 1.0
558 } else {
559 total_recovered as f64 / total_planted as f64
560 };
561 let avg_gathered = if rows.is_empty() {
562 0.0
563 } else {
564 total_gathered as f64 / rows.len() as f64
565 };
566
567 RecallReport {
568 candidate_budget,
569 num_rows: rows.len(),
570 total_planted,
571 total_recovered,
572 recall,
573 avg_candidates_gathered: avg_gathered,
574 num_atoms: self.num_atoms,
575 misses,
576 }
577 }
578}
579
580/// One row's proposal: the budgeted candidate set plus what the budget dropped.
581#[derive(Clone, Debug)]
582pub struct Proposal {
583 /// The top `candidate_budget` atom ids by frame alignment.
584 pub proposed: Vec<usize>,
585 /// Gathered candidates truncated by the budget — logged, never silent.
586 pub dropped_for_budget: Vec<usize>,
587 /// How many candidates the sublinear gather returned before budgeting.
588 pub gathered_count: usize,
589}
590
591/// Why a planted atom failed to appear in a row's proposed candidate set.
592#[derive(Clone, Copy, Debug, PartialEq, Eq)]
593pub enum MissReason {
594 /// The index never gathered this atom into the candidate union (an LSH
595 /// recall miss — widen tables / probes).
596 NotGathered,
597 /// The atom *was* gathered but the budget truncated it (widen the budget).
598 TruncatedByBudget,
599}
600
601/// One recorded recall miss.
602#[derive(Clone, Copy, Debug)]
603pub struct RecallMiss {
604 /// Row index in the report's input.
605 pub row: usize,
606 /// The planted atom id that was missed.
607 pub atom: usize,
608 /// The atom's exact frame alignment with the row direction (diagnostic).
609 pub alignment: f64,
610 /// Whether the miss was an index miss or a budget truncation.
611 pub reason: MissReason,
612}
613
614/// Result of [`SaeCandidateIndex::recall_report`].
615#[derive(Clone, Debug)]
616pub struct RecallReport {
617 /// Candidate budget the recall was measured at.
618 pub candidate_budget: usize,
619 /// Number of rows evaluated.
620 pub num_rows: usize,
621 /// Total planted truly-active atoms across all rows.
622 pub total_planted: usize,
623 /// How many of them appeared in the proposed sets.
624 pub total_recovered: usize,
625 /// `recall@candidate_budget` = recovered / planted (1.0 if nothing planted).
626 pub recall: f64,
627 /// Mean number of candidates the sublinear gather returned per row — the
628 /// sublinearity witness; compare against `num_atoms`.
629 pub avg_candidates_gathered: f64,
630 /// Total atoms in the index (for the sublinearity ratio).
631 pub num_atoms: usize,
632 /// Every miss, with its row, atom, alignment, and reason. No silent drops.
633 pub misses: Vec<RecallMiss>,
634}
635
636impl RecallReport {
637 /// Convenience: ratio of mean gathered candidates to dictionary size. A
638 /// value far below `1.0` is the evidence that proposal touched a sublinear
639 /// slice of the dictionary.
640 pub fn sublinearity_ratio(&self) -> f64 {
641 if self.num_atoms == 0 {
642 0.0
643 } else {
644 self.avg_candidates_gathered / self.num_atoms as f64
645 }
646 }
647}
648
649// ---------------------------------------------------------------------------
650// Helpers (deterministic, dependency-light)
651// ---------------------------------------------------------------------------
652
653/// Mix a base seed with an index into a well-spread `u64` (SplitMix64 finalizer
654/// on the sum). Deterministic, no clock.
655#[inline]
656fn mix_seed(base: u64, idx: u64) -> u64 {
657 // Finalize `base + idx·G` with the canonical SplitMix64 step. The stateful
658 // form adds G internally, so pre-subtract one G to land on the same input
659 // and keep the output bit-identical to the previous inlined finalizer.
660 let mut state = base
661 .wrapping_add(idx.wrapping_mul(0x9E37_79B9_7F4A_7C15))
662 .wrapping_sub(0x9E37_79B9_7F4A_7C15);
663 gam_linalg::utils::splitmix64(&mut state)
664}
665
666/// A seeded Gaussian random matrix of shape `(rows, cols)` (rows of hyperplanes
667/// / projection rows). Uses Box–Muller off a seeded `StdRng`.
668fn gaussian_projection(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
669 use rand::RngExt as _;
670 let mut rng = StdRng::seed_from_u64(seed);
671 let mut m = Array2::<f64>::zeros((rows, cols));
672 for r in 0..rows {
673 for c in 0..cols {
674 let u1 = rng.random::<f64>().max(1e-16);
675 let u2 = rng.random::<f64>();
676 m[(r, c)] = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
677 }
678 }
679 m
680}
681
682/// Modified Gram–Schmidt orthonormalization of a decoder block's columns.
683/// Input `block` is `(p, m)`; output `U` is `(p, r)` with orthonormal columns
684/// spanning `range(block)`, `r ≤ m` (rank-deficient columns are dropped).
685fn orthonormal_frame(block: &Array2<f64>) -> Array2<f64> {
686 let p = block.nrows();
687 let m = block.ncols();
688 let mut cols: Vec<Array1<f64>> = Vec::with_capacity(m);
689 for j in 0..m {
690 let mut v = block.column(j).to_owned();
691 for q in &cols {
692 let proj: f64 = q.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
693 for (vi, &qi) in v.iter_mut().zip(q.iter()) {
694 *vi -= proj * qi;
695 }
696 }
697 let nrm = vec_norm(v.view());
698 if nrm > DIRECTION_NORM_FLOOR {
699 for vi in v.iter_mut() {
700 *vi /= nrm;
701 }
702 cols.push(v);
703 }
704 }
705 let r = cols.len();
706 let mut u = Array2::<f64>::zeros((p, r));
707 for (j, col) in cols.into_iter().enumerate() {
708 u.column_mut(j).assign(&col);
709 }
710 u
711}
712
713/// `M · v` for `M` shape `(rows, cols)`, `v` length `cols`.
714fn mat_vec(m: &Array2<f64>, v: ArrayView1<f64>) -> Array1<f64> {
715 let mut out = Array1::<f64>::zeros(m.nrows());
716 for r in 0..m.nrows() {
717 let row = m.row(r);
718 out[r] = row.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
719 }
720 out
721}
722
723#[inline]
724fn vec_norm(v: ArrayView1<f64>) -> f64 {
725 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
726}
727
728#[inline]
729fn normalize_in_place(v: &mut Array1<f64>) {
730 let n = vec_norm(v.view());
731 if n > DIRECTION_NORM_FLOOR {
732 for x in v.iter_mut() {
733 *x /= n;
734 }
735 }
736}
737
738/// Pack the sign bits of `bank · s` into a `u64` signature. `bank` is
739/// `(bits, sketch_dim)`; `bits ≤ 64` (enforced by config-derived bit widths).
740fn sign_signature(bank: &Array2<f64>, s: ArrayView1<f64>) -> u64 {
741 let mut sig = 0u64;
742 for r in 0..bank.nrows() {
743 let row = bank.row(r);
744 let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
745 if dot >= 0.0 {
746 sig |= 1u64 << r;
747 }
748 }
749 sig
750}
751
752/// Signature plus per-bit signed margins (the dot products), used by multi-probe
753/// to find the least-confident bit to flip.
754fn sign_signature_with_margins(bank: &Array2<f64>, s: ArrayView1<f64>) -> (u64, Vec<f64>) {
755 let mut sig = 0u64;
756 let mut margins = Vec::with_capacity(bank.nrows());
757 for r in 0..bank.nrows() {
758 let row = bank.row(r);
759 let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
760 if dot >= 0.0 {
761 sig |= 1u64 << r;
762 }
763 margins.push(dot);
764 }
765 (sig, margins)
766}
767
768/// Index of the bit whose hyperplane the query sits closest to (smallest `|dot|`)
769/// — the most likely to have landed in the wrong bucket.
770fn lowest_margin_bit(margins: &[f64]) -> usize {
771 let mut best = 0usize;
772 let mut best_abs = f64::INFINITY;
773 for (i, &m) in margins.iter().enumerate() {
774 let a = m.abs();
775 if a < best_abs {
776 best_abs = a;
777 best = i;
778 }
779 }
780 best
781}
782
783// ---------------------------------------------------------------------------
784// Tests
785// ---------------------------------------------------------------------------
786
787#[cfg(test)]
788mod tests {
789 use super::*;
790 use rand::RngExt as _;
791 use rand::rngs::StdRng;
792
793 /// Draw a unit vector in `p` dims from a seeded RNG.
794 fn unit_vec(rng: &mut StdRng, p: usize) -> Array1<f64> {
795 let mut v = Array1::<f64>::zeros(p);
796 for x in v.iter_mut() {
797 let u1 = rng.random::<f64>().max(1e-16);
798 let u2 = rng.random::<f64>();
799 *x = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
800 }
801 let n = vec_norm(v.view());
802 if n > DIRECTION_NORM_FLOOR {
803 for x in v.iter_mut() {
804 *x /= n;
805 }
806 }
807 v
808 }
809
810 /// Build a synthetic dictionary of `k` rank-1 atoms: atom `i`'s decoder
811 /// block is the outer-friendly single column `c_i` (a random unit direction
812 /// in output space). Returns the blocks and the list of column directions so
813 /// the planted-atom test can construct directions that lie in chosen atoms.
814 fn synthetic_dictionary(k: usize, p: usize, seed: u64) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
815 let mut rng = StdRng::seed_from_u64(seed);
816 let mut blocks = Vec::with_capacity(k);
817 let mut dirs = Vec::with_capacity(k);
818 for _ in 0..k {
819 let c = unit_vec(&mut rng, p);
820 let mut block = Array2::<f64>::zeros((p, 1));
821 block.column_mut(0).assign(&c);
822 blocks.push(block);
823 dirs.push(c);
824 }
825 (blocks, dirs)
826 }
827
828 #[test]
829 fn frame_alignment_is_exact_for_in_range_direction() {
830 let (blocks, dirs) = synthetic_dictionary(8, 16, 11);
831 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 7).unwrap();
832 // A direction equal to atom 3's column lies fully in its range.
833 let d = &dirs[3];
834 let a = sketch.alignment(3, d.view());
835 assert!(a > 0.999, "in-range alignment should be ~1, got {a}");
836 // An orthogonal-ish direction (atom 5's column is generically nearly
837 // orthogonal to atom 3) aligns weakly with atom 3.
838 let a_off = sketch.alignment(3, dirs[5].view());
839 assert!(
840 a_off < a,
841 "off-atom alignment {a_off} should be below in-range {a}"
842 );
843 }
844
845 /// Regression for the routing-confidence gate (#1026): a low-alignment LSH
846 /// route must never be silently certified. `certified_encode_with_index`
847 /// (and the amortized twin) flag a routed row UNCERTIFIED whenever the
848 /// best-aligned proposed atom's frame alignment is below
849 /// `encode::CANDIDATE_ROUTING_MIN_ALIGNMENT`. The gate's decision input is
850 /// exactly `sketch.alignment(best_atom, target)`; pin it here — with exact,
851 /// deterministic linear algebra rather than LSH gather luck — so a future
852 /// change to the frame-alignment formula cannot silently shift the
853 /// threshold's meaning out from under the gate.
854 ///
855 /// Two atoms whose decoder frames span ORTHOGONAL subspaces of a 6-dim
856 /// ambient (atom 0: `span(e0,e1)`, atom 1: `span(e2,e3)`; dims `e4,e5`
857 /// covered by neither). A direction wholly inside atom 1's subspace has
858 /// alignment exactly 1 with atom 1 (in-frame, ABOVE the gate) and exactly 0
859 /// with atom 0 (off-frame, BELOW the gate) — the mis-route the gate exists
860 /// to flag.
861 #[test]
862 fn routing_confidence_gate_input_separates_off_frame_from_in_frame() {
863 use crate::encode::CANDIDATE_ROUTING_MIN_ALIGNMENT as GATE;
864
865 let p = 6usize;
866 let mut block_a = Array2::<f64>::zeros((p, 2));
867 block_a[[0, 0]] = 1.0; // e0
868 block_a[[1, 1]] = 1.0; // e1
869 let mut block_b = Array2::<f64>::zeros((p, 2));
870 block_b[[2, 0]] = 1.0; // e2
871 block_b[[3, 1]] = 1.0; // e3
872 let sketch =
873 RandomProjectionFrameSketch::from_decoder_blocks(&[block_a, block_b], 16, 4242)
874 .unwrap();
875
876 // A unit direction wholly inside atom 1's (e2,e3) subspace.
877 let mut in_frame_b = Array1::<f64>::zeros(p);
878 in_frame_b[2] = 0.6;
879 in_frame_b[3] = 0.8; // unit norm (0.6² + 0.8² = 1)
880 let a_right = sketch.alignment(1, in_frame_b.view());
881 let a_wrong = sketch.alignment(0, in_frame_b.view());
882
883 assert!(
884 a_right > 0.999,
885 "an in-frame direction must align ~1 with its own atom; got {a_right}"
886 );
887 assert!(
888 a_wrong < 1e-9,
889 "an orthogonal-subspace direction must align ~0 with the wrong atom; got {a_wrong}"
890 );
891 // The exact predicate the encode gate evaluates per row: a mis-routed
892 // (orthogonal) atom falls BELOW the gate → flagged for the exact
893 // fallback; the correctly-routed atom sits AT/ABOVE the gate → trusted.
894 assert!(
895 a_wrong < GATE,
896 "a mis-routed (orthogonal) atom must fall below the routing gate {GATE}; got {a_wrong}"
897 );
898 assert!(
899 a_right >= GATE,
900 "the correctly-routed atom must sit at/above the routing gate {GATE}; got {a_right}"
901 );
902
903 // A direction in the UNCOVERED (e4,e5) subspace aligns ~0 with BOTH
904 // atoms, so whichever atom the LSH surfaces, the gate fires: no atom can
905 // certify this route. This is the worst-case the gate exists to catch.
906 let mut uncovered = Array1::<f64>::zeros(p);
907 uncovered[4] = 1.0;
908 for atom in 0..2 {
909 let a = sketch.alignment(atom, uncovered.view());
910 assert!(
911 a < GATE,
912 "an uncovered-subspace direction must fall below the gate for atom {atom}; got {a}"
913 );
914 }
915 }
916
917 #[test]
918 fn build_is_deterministic_for_a_fixed_seed() {
919 let (blocks, _) = synthetic_dictionary(64, 24, 99);
920 let s1 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
921 let s2 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
922 // Same seed → identical representative sketches.
923 for i in 0..blocks.len() {
924 let a = s1.atom_sketch(i);
925 let b = s2.atom_sketch(i);
926 let diff = vec_norm((&a - &b).view());
927 assert!(
928 diff < 1e-12,
929 "atom {i} sketch differs across builds: {diff:e}"
930 );
931 }
932 let cfg = IndexConfig::auto(16, blocks.len(), 5);
933 let idx1 = SaeCandidateIndex::build(&s1, cfg).unwrap();
934 let idx2 = SaeCandidateIndex::build(&s2, cfg).unwrap();
935 // Identical hyperplane banks and bucket contents.
936 for t in 0..idx1.tables.len() {
937 assert_eq!(idx1.tables[t].len(), idx2.tables[t].len());
938 }
939 }
940
941 #[test]
942 fn planted_atoms_are_recalled_above_floor_at_sublinear_budget() {
943 // A frontier-ish dictionary: many atoms, modest output dim.
944 let k = 2000usize;
945 let p = 48usize;
946 let (blocks, dirs) = synthetic_dictionary(k, p, 2026);
947 let sketch_dim = 24usize;
948 let sketch =
949 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 4242).unwrap();
950 let cfg = IndexConfig::auto(sketch_dim, k, 4242);
951 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
952
953 // Plant: each row's residual is dominated by one chosen atom's column
954 // (plus a little cross-talk from a second). The planted-active set is
955 // that dominant atom. We build many such rows deterministically.
956 let mut rng = StdRng::seed_from_u64(31337);
957 let n_rows = 200usize;
958 let mut rows: Vec<(Array1<f64>, Vec<usize>)> = Vec::with_capacity(n_rows);
959 for _ in 0..n_rows {
960 let primary = rng.random_range(0..k);
961 let secondary = rng.random_range(0..k);
962 // direction = 1.0 * c_primary + 0.15 * c_secondary
963 let mut d = dirs[primary].clone();
964 for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
965 *di += 0.15 * si;
966 }
967 let n = vec_norm(d.view());
968 for di in d.iter_mut() {
969 *di /= n;
970 }
971 rows.push((d, vec![primary]));
972 }
973
974 // Sublinear candidate budget: << K. We allow the gather to surface a
975 // handful, but the *budget* (the per-row local block size) stays small.
976 let candidate_budget = 32usize;
977 let report = index.recall_report(&sketch, &rows, candidate_budget, cfg.multiprobe);
978
979 // The gather must touch only a sublinear slice of the dictionary.
980 assert!(
981 report.sublinearity_ratio() < 0.5,
982 "gather was not sublinear: avg {} of {} atoms (ratio {:.3})",
983 report.avg_candidates_gathered,
984 report.num_atoms,
985 report.sublinearity_ratio()
986 );
987
988 // Recall floor: the LSH index must recover the planted dominant atom for
989 // the large majority of rows at this sublinear budget. Misses are
990 // logged, never silently dropped.
991 let floor = 0.80;
992 assert!(
993 report.recall >= floor,
994 "recall {:.3} below floor {floor}; {} misses logged (first few: {:?})",
995 report.recall,
996 report.misses.len(),
997 report
998 .misses
999 .iter()
1000 .take(5)
1001 .map(|m| (m.row, m.atom, m.reason, m.alignment))
1002 .collect::<Vec<_>>()
1003 );
1004
1005 // Every miss is accounted for with a reason — the no-silent-truncation
1006 // contract.
1007 let recovered = report.total_recovered;
1008 assert_eq!(
1009 report.total_planted - recovered,
1010 report.misses.len(),
1011 "miss list must account for every unrecovered planted atom"
1012 );
1013 }
1014
1015 #[test]
1016 fn auto_candidate_budget_tracks_the_issue_band() {
1017 assert_eq!(auto_candidate_budget(2), CANDIDATE_BUDGET_MIN);
1018 assert_eq!(auto_candidate_budget(64), 48);
1019 assert_eq!(auto_candidate_budget(1024), 80);
1020 assert_eq!(auto_candidate_budget(100_000), CANDIDATE_BUDGET_MAX);
1021 // Monotone non-decreasing in K and always inside the band.
1022 let mut prev = 0usize;
1023 for k in [2usize, 16, 64, 256, 1024, 4096, 65_536, 1_000_000] {
1024 let c = auto_candidate_budget(k);
1025 assert!(c >= prev, "budget must be monotone in K");
1026 assert!((CANDIDATE_BUDGET_MIN..=CANDIDATE_BUDGET_MAX).contains(&c));
1027 prev = c;
1028 }
1029 }
1030
1031 /// Build a planted row set for a dictionary: each row's residual direction
1032 /// is dominated by one chosen atom (plus cross-talk from a second), and
1033 /// the planted-active set is the dominant atom.
1034 fn planted_rows(
1035 dirs: &[Array1<f64>],
1036 n_rows: usize,
1037 seed: u64,
1038 ) -> Vec<(Array1<f64>, Vec<usize>)> {
1039 let k = dirs.len();
1040 let mut rng = StdRng::seed_from_u64(seed);
1041 let mut rows = Vec::with_capacity(n_rows);
1042 for _ in 0..n_rows {
1043 let primary = rng.random_range(0..k);
1044 let secondary = rng.random_range(0..k);
1045 let mut d = dirs[primary].clone();
1046 for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
1047 *di += 0.15 * si;
1048 }
1049 let n = vec_norm(d.view());
1050 for di in d.iter_mut() {
1051 *di /= n;
1052 }
1053 rows.push((d, vec![primary]));
1054 }
1055 rows
1056 }
1057
1058 #[test]
1059 fn k_ladder_recall_determinism_and_sublinearity() {
1060 // #985 part 2 (index tier): the K=2-era assumptions say nothing about
1061 // frontier K, so gate the proposal machinery on a planted ladder at
1062 // K = 64 and K = 1024 with the SAME battery per rung — recall above a
1063 // stated floor at the auto-derived budget, every miss accounted for,
1064 // and byte-identical proposals across two independent builds. The
1065 // gather must also become *relatively* sparser as K grows (the
1066 // sublinearity witness): what is allowed to touch half the dictionary
1067 // at K = 64 must not at K = 1024.
1068 let p = 48usize;
1069 let n_rows = 150usize;
1070 let mut ladder_ratios = Vec::new();
1071 for &k in &[64usize, 1024] {
1072 let (blocks, dirs) = synthetic_dictionary(k, p, 9000 + k as u64);
1073 let sketch_dim = 24usize;
1074 let sketch_seed = 71 + k as u64;
1075 let sketch =
1076 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
1077 .unwrap();
1078 let cfg = IndexConfig::auto(sketch_dim, k, sketch_seed);
1079 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1080
1081 let rows = planted_rows(&dirs, n_rows, 555 + k as u64);
1082 let budget = auto_candidate_budget(k);
1083 let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
1084
1085 // Recall floor at the auto-derived budget, with every miss carrying
1086 // a reason — the no-silent-truncation contract, per rung.
1087 let floor = 0.80;
1088 assert!(
1089 report.recall >= floor,
1090 "K={k}: recall {:.3} below floor {floor}; {} misses (first: {:?})",
1091 report.recall,
1092 report.misses.len(),
1093 report
1094 .misses
1095 .iter()
1096 .take(3)
1097 .map(|m| (m.row, m.atom, m.reason, m.alignment))
1098 .collect::<Vec<_>>()
1099 );
1100 assert_eq!(
1101 report.total_planted - report.total_recovered,
1102 report.misses.len(),
1103 "K={k}: miss list must account for every unrecovered planted atom"
1104 );
1105
1106 // Search determinism: an independent rebuild from the same inputs
1107 // proposes the identical candidate set for every probed row.
1108 let sketch2 =
1109 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
1110 .unwrap();
1111 let index2 = SaeCandidateIndex::build(&sketch2, cfg).unwrap();
1112 for (direction, _) in rows.iter().take(20) {
1113 let a = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
1114 let b = index2.propose(&sketch2, direction.view(), budget, cfg.multiprobe);
1115 assert_eq!(
1116 a.proposed, b.proposed,
1117 "K={k}: rebuild must propose identically"
1118 );
1119 }
1120
1121 // Proposal size is the budget, never the dictionary: the per-row
1122 // local block stays near the planted/active scale.
1123 for (direction, _) in rows.iter().take(20) {
1124 let prop = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
1125 assert!(prop.proposed.len() <= budget);
1126 }
1127
1128 ladder_ratios.push((k, report.sublinearity_ratio()));
1129 }
1130 // Relative sparsity must improve up the ladder: the gathered fraction
1131 // of the dictionary shrinks as K grows (sublinear gather), and at the
1132 // frontier-shaped rung it must be a small slice outright.
1133 let (_, ratio_small) = ladder_ratios[0];
1134 let (k_big, ratio_big) = ladder_ratios[1];
1135 assert!(
1136 ratio_big < ratio_small,
1137 "sublinearity must improve along the ladder: {ladder_ratios:?}"
1138 );
1139 assert!(
1140 ratio_big < 0.25,
1141 "K={k_big}: gather touched {:.1}% of the dictionary",
1142 ratio_big * 100.0
1143 );
1144 }
1145
1146 /// Counting wrapper: delegates everything, counts `project_direction`
1147 /// calls. The #994 acceptance gate: with the exact probe, building the
1148 /// query sketch touches NO atom, so a whole `propose` makes zero
1149 /// `project_direction` calls (scoring goes through `alignment`).
1150 struct CountingSketch<'a> {
1151 inner: &'a RandomProjectionFrameSketch,
1152 project_calls: std::cell::Cell<usize>,
1153 }
1154
1155 impl AtomFrameSketch for CountingSketch<'_> {
1156 fn sketch_dim(&self) -> usize {
1157 self.inner.sketch_dim()
1158 }
1159 fn output_dim(&self) -> usize {
1160 self.inner.output_dim()
1161 }
1162 fn num_atoms(&self) -> usize {
1163 self.inner.num_atoms()
1164 }
1165 fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
1166 self.inner.atom_sketch(atom_id)
1167 }
1168 fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
1169 self.project_calls.set(self.project_calls.get() + 1);
1170 self.inner.project_direction(atom_id, direction)
1171 }
1172 fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
1173 self.inner.alignment(atom_id, direction)
1174 }
1175 fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
1176 self.inner.query_sketch(direction)
1177 }
1178 }
1179
1180 #[test]
1181 fn query_probe_touches_no_atom_before_the_gather() {
1182 let k = 512usize;
1183 let p = 32usize;
1184 let (blocks, dirs) = synthetic_dictionary(k, p, 77);
1185 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 13).unwrap();
1186 let cfg = IndexConfig::auto(16, k, 13);
1187 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1188 let counting = CountingSketch {
1189 inner: &sketch,
1190 project_calls: std::cell::Cell::new(0),
1191 };
1192 drop(index.propose(&counting, dirs[5].view(), 32, cfg.multiprobe));
1193 assert_eq!(
1194 counting.project_calls.get(),
1195 0,
1196 "the exact query probe must be independent of K: no per-atom \
1197 projection before the gather (#994)"
1198 );
1199 }
1200
1201 /// Build a coherent-cluster dictionary: `n_clusters` random unit centers,
1202 /// each with `cluster_size` atoms drawn as small perturbations of the
1203 /// center (renormalized). Exactly the non-isotropic regime where the old
1204 /// masked-average probe degraded (#994).
1205 fn coherent_cluster_dictionary(
1206 n_clusters: usize,
1207 cluster_size: usize,
1208 p: usize,
1209 spread: f64,
1210 seed: u64,
1211 ) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
1212 let mut rng = StdRng::seed_from_u64(seed);
1213 let mut blocks = Vec::with_capacity(n_clusters * cluster_size);
1214 let mut dirs = Vec::with_capacity(n_clusters * cluster_size);
1215 for _ in 0..n_clusters {
1216 let center = unit_vec(&mut rng, p);
1217 for _ in 0..cluster_size {
1218 let noise = unit_vec(&mut rng, p);
1219 let mut c = center.clone();
1220 for (ci, &ni) in c.iter_mut().zip(noise.iter()) {
1221 *ci += spread * ni;
1222 }
1223 let n = vec_norm(c.view());
1224 for ci in c.iter_mut() {
1225 *ci /= n;
1226 }
1227 let mut block = Array2::<f64>::zeros((p, 1));
1228 block.column_mut(0).assign(&c);
1229 blocks.push(block);
1230 dirs.push(c);
1231 }
1232 }
1233 (blocks, dirs)
1234 }
1235
1236 #[test]
1237 fn coherent_clusters_are_recalled_with_the_exact_probe() {
1238 // 32 clusters × 32 near-parallel atoms = 1024 atoms. Rows are
1239 // dominated by one specific cluster member; the proposal must recover
1240 // that exact member (not merely its cluster) at the auto budget —
1241 // exact alignment scoring separates siblings once the probe lands the
1242 // gather in the right bucket neighborhood.
1243 let n_clusters = 32usize;
1244 let cluster_size = 32usize;
1245 let k = n_clusters * cluster_size;
1246 let p = 48usize;
1247 let (blocks, dirs) = coherent_cluster_dictionary(n_clusters, cluster_size, p, 0.25, 4242);
1248 let sketch_dim = 24usize;
1249 let sketch =
1250 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 99).unwrap();
1251 let cfg = IndexConfig::auto(sketch_dim, k, 99);
1252 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1253
1254 let rows = planted_rows(&dirs, 150, 31337);
1255 let budget = auto_candidate_budget(k);
1256 let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
1257 let floor = 0.80;
1258 assert!(
1259 report.recall >= floor,
1260 "coherent-cluster recall {:.3} below floor {floor}; {} misses (first: {:?})",
1261 report.recall,
1262 report.misses.len(),
1263 report
1264 .misses
1265 .iter()
1266 .take(3)
1267 .map(|m| (m.row, m.atom, m.reason, m.alignment))
1268 .collect::<Vec<_>>()
1269 );
1270 // Still a sublinear slice of the dictionary, clusters or not.
1271 assert!(
1272 report.sublinearity_ratio() < 0.5,
1273 "cluster gather touched {:.1}% of the dictionary",
1274 report.sublinearity_ratio() * 100.0
1275 );
1276 }
1277
1278 #[test]
1279 fn exact_probe_matches_shared_projection_of_the_direction() {
1280 // The override is literally normalize(R·d): verify against a manual
1281 // computation through the public surface (atom_sketch of a rank-1 atom
1282 // whose only column IS the direction gives normalize(R·d) too).
1283 let p = 16usize;
1284 let mut rng = StdRng::seed_from_u64(5);
1285 let d = unit_vec(&mut rng, p);
1286 let mut block = Array2::<f64>::zeros((p, 1));
1287 block.column_mut(0).assign(&d);
1288 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&[block], 8, 21).unwrap();
1289 let via_probe = sketch.query_sketch(d.view());
1290 let via_atom = sketch.atom_sketch(0);
1291 let diff = vec_norm((&via_probe - &via_atom).view());
1292 assert!(
1293 diff < 1e-10,
1294 "query_sketch(d) must equal the rank-1 atom representative of d: diff {diff:e}"
1295 );
1296 }
1297
1298 #[test]
1299 fn empty_planted_rows_report_perfect_recall() {
1300 let (blocks, dirs) = synthetic_dictionary(32, 16, 1);
1301 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 3).unwrap();
1302 let cfg = IndexConfig::auto(12, 32, 3);
1303 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1304 let rows = vec![(dirs[0].clone(), Vec::<usize>::new())];
1305 let report = index.recall_report(&sketch, &rows, 8, true);
1306 assert_eq!(report.recall, 1.0);
1307 assert!(report.misses.is_empty());
1308 }
1309}