gam_sae/candidate_index.rs
1//! Sublinear candidate-atom index for active-set proposal (#985 part 1).
2//!
3//! A frontier SAE dictionary holds `K ≈ 10^4–10^5` atoms. The per-row *local*
4//! block — the small linear/Newton system over the atoms that are actually
5//! active in a row — is cheap, because the active set collapses it to a handful
6//! of atoms. The expensive step is *proposing* that active set: a naive scan
7//! scores every one of the `K` atom frames against every row, which is `O(K)`
8//! per row and dominates the whole solve once `K` is large.
9//!
10//! This module builds a **sublinear** candidate index over per-atom *sketches*
11//! of each atom's decoder column-space (its Grassmann frame `U_k`). Given a row
12//! residual direction it returns the top candidate atom ids likely to be
13//! active, touching only `O(log K)`-ish buckets instead of all `K` atoms.
14//!
15//! ## Layering against Track 1
16//!
17//! Track 1 owns the *real* atom frames `U_k` and has not landed yet, so this
18//! module is written against a [`AtomFrameSketch`] trait. Any frame source —
19//! the eventual Grassmann frames, or the decoder column blocks `B_k` already
20//! present on [`crate::manifold::SaeManifoldAtom`] — can implement
21//! it. A concrete, dependency-free default
22//! ([`RandomProjectionFrameSketch`]) is provided: a seeded random-projection /
23//! random-hyperplane signature of the atom's orthonormalized column span. The
24//! index ([`SaeCandidateIndex`]) is a deterministic multi-table
25//! random-hyperplane LSH over those sketches.
26//!
27//! ## Recall contract
28//!
29//! Sublinear proposal is only safe if it *almost never* drops a truly-active
30//! atom. [`SaeCandidateIndex::recall_report`] takes a set of planted
31//! truly-active atoms per row, runs the proposal at a stated candidate budget,
32//! and records the rate at which planted atoms appear in the proposed set —
33//! **logging every miss** rather than silently truncating. The returned
34//! [`RecallReport`] carries `recall@budget` and the full miss list so a caller
35//! can widen the budget or fall back to a dense scan for the affected rows.
36//!
37//! Determinism: every random choice is seeded by an explicit index seed; no
38//! clock, no global RNG.
39
40use ndarray::{Array1, Array2, ArrayView1};
41use rand::SeedableRng;
42use rand::rngs::StdRng;
43use std::collections::{HashMap, HashSet};
44
45/// Salt mixed into the per-table hyperplane seed so the index tables and the
46/// default sketch never share a random stream even when handed the same base
47/// seed.
48const INDEX_HYPERPLANE_SALT: u64 = 0x9E37_79B9_7F4A_7C15;
49
50/// Salt for the default random-projection sketch's projection matrix.
51const SKETCH_PROJECTION_SALT: u64 = 0xC2B2_AE3D_27D4_EB4F;
52
53/// Numerical floor below which a direction / column is treated as zero.
54const DIRECTION_NORM_FLOOR: f64 = 1e-12;
55
56/// Lower bound of the auto-derived per-row candidate budget `C` (#985). Below
57/// this the proposal set is too small for the solver's accepted active set to
58/// have headroom over the planted/active atom count.
59pub const CANDIDATE_BUDGET_MIN: usize = 32;
60
61/// Upper bound of the auto-derived per-row candidate budget `C` (#985). The
62/// per-row local block stays a small dense solve no matter how large the
63/// dictionary grows; beyond this the proposal step stops being the bottleneck
64/// reduction it exists to be.
65pub const CANDIDATE_BUDGET_MAX: usize = 128;
66
67/// Auto-derive the per-row candidate budget `C` from the dictionary size `K`
68/// (#985): `C = 8·⌈log₂ K⌉`, clamped to
69/// [[`CANDIDATE_BUDGET_MIN`], [`CANDIDATE_BUDGET_MAX`]]. Logarithmic growth
70/// keeps the per-row local block effectively constant-size while giving larger
71/// dictionaries a little more recall headroom; the clamp realizes the issue's
72/// `C ≈ 32–128` band. Magic-by-default: derived from `K` alone, no flag.
73///
74/// Concretely: `K = 64 → 48`, `K = 1024 → 80`, `K = 10⁵ → 128`.
75pub fn auto_candidate_budget(num_atoms: usize) -> usize {
76 let log2 = if num_atoms <= 1 {
77 1
78 } else {
79 (usize::BITS - (num_atoms - 1).leading_zeros()) as usize
80 };
81 (8 * log2).clamp(CANDIDATE_BUDGET_MIN, CANDIDATE_BUDGET_MAX)
82}
83
84// ---------------------------------------------------------------------------
85// Sketch interface
86// ---------------------------------------------------------------------------
87
88/// A low-dimensional sketch of one atom's decoder column-space (its Grassmann
89/// frame `U_k`).
90///
91/// The index never needs the full frame: it only needs (a) the sketch
92/// dimension, shared by every atom in a dictionary, and (b), for any query
93/// direction in output space, the atom's *sketch coordinates* of that direction
94/// — i.e. the projection of the direction onto the atom's column-space,
95/// expressed in the sketch's coordinates. A frame `U_k` (orthonormal columns
96/// spanning the decoder range) yields these as `sketch = R · (U_kᵀ d)` for a
97/// shared random projection `R`; a raw decoder block `B_k` yields them by first
98/// orthonormalizing its columns. Both are valid implementors.
99pub trait AtomFrameSketch {
100 /// Dimension of the sketch vectors this implementor produces. Must be the
101 /// same positive value for every atom in one dictionary so the index can
102 /// build a single hyperplane bank.
103 fn sketch_dim(&self) -> usize;
104
105 /// Dimension of the ambient output space the query directions live in.
106 fn output_dim(&self) -> usize;
107
108 /// Number of atoms this source can sketch.
109 fn num_atoms(&self) -> usize;
110
111 /// Sketch of atom `atom_id`'s *frame itself* (a representative point of the
112 /// atom's column-space on the sphere of sketch space), used to place the
113 /// atom into the LSH tables at build time. Returns a vector of length
114 /// [`AtomFrameSketch::sketch_dim`].
115 fn atom_sketch(&self, atom_id: usize) -> Array1<f64>;
116
117 /// Sketch of a query *direction* `d` (length [`AtomFrameSketch::output_dim`])
118 /// as seen through atom `atom_id`'s frame: the direction's component inside
119 /// the atom's column-space, mapped into sketch coordinates. Used at query
120 /// time to score how strongly a row residual aligns with the atom.
121 fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64>;
122
123 /// Alignment score in `[0, 1]`: the fraction of the query direction's energy
124 /// that lies inside atom `atom_id`'s column-space. `1.0` means the direction
125 /// lies fully in the atom's range, `0.0` means it is orthogonal. Used to
126 /// rank the (small) candidate set the index returns.
127 fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64;
128
129 /// Sketch-space **probe** for a raw query direction (length
130 /// [`AtomFrameSketch::sketch_dim`]), comparable to the
131 /// [`AtomFrameSketch::atom_sketch`] representatives the LSH tables were
132 /// built from (#994).
133 ///
134 /// Implementors must return the exact cosine-LSH probe for their sketching
135 /// policy. For the shared-projection sketch this is `normalize(R · d)`,
136 /// `O(p · s)` per query, touching no atom.
137 fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64>;
138}
139
140// ---------------------------------------------------------------------------
141// Default concrete sketch: seeded random projection of the column span
142// ---------------------------------------------------------------------------
143
144/// A concrete [`AtomFrameSketch`] built from raw decoder column blocks `B_k`.
145///
146/// For each atom it orthonormalizes the decoder columns (modified Gram–Schmidt)
147/// to obtain a frame `U_k` with orthonormal columns spanning the decoder range,
148/// then sketches via a single shared seeded Gaussian random projection
149/// `R ∈ ℝ^{s×p}` applied to the in-range component of a direction:
150///
151/// * `atom_sketch(k) = normalize( R · u_k0 )`, the sketch of the atom's first
152/// (dominant) frame column — a stable representative point used to bucket the
153/// atom.
154/// * `project_direction(k, d) = R · (U_k U_kᵀ d)`, the sketch of the part of `d`
155/// that lies in the atom's range.
156/// * `alignment(k, d) = ‖U_kᵀ d‖ / ‖d‖`, the exact in-range energy fraction.
157///
158/// The shared `R` is a Johnson–Lindenstrauss style random projection, so sketch
159/// inner products approximately preserve angles between in-range directions —
160/// exactly what the LSH index needs. Everything is seeded; the same atoms +
161/// seed always produce the same sketches.
162pub struct RandomProjectionFrameSketch {
163 /// Orthonormal frame `U_k` per atom, shape `(p, r_k)` with `r_k` ≤ columns.
164 frames: Vec<Array2<f64>>,
165 /// Shared random projection `R`, shape `(sketch_dim, p)`.
166 projection: Array2<f64>,
167 /// Ambient output dimension `p`.
168 output_dim: usize,
169 /// Sketch dimension `s`.
170 sketch_dim: usize,
171}
172
173impl RandomProjectionFrameSketch {
174 /// Build the sketch from decoder column blocks.
175 ///
176 /// `decoder_blocks[k]` is `B_k` with shape `(p, m_k)`: `p` rows in output
177 /// space, `m_k` decoder columns for atom `k`. (`SaeManifoldAtom` stores the
178 /// transpose `(m_k, p)`; orient it `p`-rows before passing in.) All blocks
179 /// must share the same `p`. `sketch_dim` is the target sketch length `s`;
180 /// `seed` makes the projection deterministic.
181 pub fn from_decoder_blocks(
182 decoder_blocks: &[Array2<f64>],
183 sketch_dim: usize,
184 seed: u64,
185 ) -> Result<Self, String> {
186 if decoder_blocks.is_empty() {
187 return Err("RandomProjectionFrameSketch: need at least one decoder block".into());
188 }
189 if sketch_dim == 0 {
190 return Err("RandomProjectionFrameSketch: sketch_dim must be positive".into());
191 }
192 let output_dim = decoder_blocks[0].nrows();
193 if output_dim == 0 {
194 return Err("RandomProjectionFrameSketch: output dimension must be positive".into());
195 }
196 for (k, block) in decoder_blocks.iter().enumerate() {
197 if block.nrows() != output_dim {
198 return Err(format!(
199 "RandomProjectionFrameSketch: atom {k} has {} output rows, expected {output_dim}",
200 block.nrows()
201 ));
202 }
203 }
204
205 let frames: Vec<Array2<f64>> = decoder_blocks.iter().map(orthonormal_frame).collect();
206
207 let projection = gaussian_projection(sketch_dim, output_dim, seed ^ SKETCH_PROJECTION_SALT);
208
209 Ok(Self {
210 frames,
211 projection,
212 output_dim,
213 sketch_dim,
214 })
215 }
216
217 /// In-range component `U_k U_kᵀ d` of a direction (length `output_dim`).
218 fn in_range_component(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
219 let frame = &self.frames[atom_id];
220 // coords = U_kᵀ d (length r_k)
221 let mut comp = Array1::<f64>::zeros(self.output_dim);
222 for col in 0..frame.ncols() {
223 let u = frame.column(col);
224 let coord: f64 = u.iter().zip(direction.iter()).map(|(&a, &b)| a * b).sum();
225 for (c, &uval) in comp.iter_mut().zip(u.iter()) {
226 *c += coord * uval;
227 }
228 }
229 comp
230 }
231}
232
233impl AtomFrameSketch for RandomProjectionFrameSketch {
234 fn sketch_dim(&self) -> usize {
235 self.sketch_dim
236 }
237
238 fn output_dim(&self) -> usize {
239 self.output_dim
240 }
241
242 fn num_atoms(&self) -> usize {
243 self.frames.len()
244 }
245
246 fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
247 let frame = &self.frames[atom_id];
248 // Sketch the dominant (first) frame column as the atom's representative.
249 // If the frame is empty (rank-0 atom), fall back to a deterministic
250 // nonzero point so the atom is still bucketed somewhere.
251 if frame.ncols() == 0 {
252 let mut s = self.projection.column(0).to_owned();
253 normalize_in_place(&mut s);
254 return s;
255 }
256 let u0 = frame.column(0);
257 let mut s = mat_vec(&self.projection, u0);
258 normalize_in_place(&mut s);
259 s
260 }
261
262 fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
263 let comp = self.in_range_component(atom_id, direction);
264 mat_vec(&self.projection, comp.view())
265 }
266
267 /// Exact `O(p·s)` probe (#994): every atom shares the one projection `R`,
268 /// and the table representatives are `normalize(R · u_k0)`, so the correct
269 /// cosine-LSH probe for a direction is simply `normalize(R · d)` — no atom
270 /// is touched, and no masked-average approximation is involved.
271 fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
272 let mut s = mat_vec(&self.projection, direction);
273 normalize_in_place(&mut s);
274 s
275 }
276
277 fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
278 let dnorm = vec_norm(direction);
279 if dnorm < DIRECTION_NORM_FLOOR {
280 return 0.0;
281 }
282 let comp = self.in_range_component(atom_id, direction);
283 (vec_norm(comp.view()) / dnorm).clamp(0.0, 1.0)
284 }
285}
286
287// ---------------------------------------------------------------------------
288// Sublinear index: multi-table random-hyperplane LSH over sketches
289// ---------------------------------------------------------------------------
290
291/// A deterministic, sublinear candidate index over atom-frame sketches.
292///
293/// The structure is a **random-hyperplane LSH** with `num_tables` independent
294/// tables, each defined by `bits_per_table` seeded random hyperplanes in sketch
295/// space. An atom's sketch is reduced to a `bits_per_table`-bit sign signature
296/// per table (the sign of its dot with each hyperplane), and the atom id is
297/// stored in the bucket keyed by that signature. At query time the query
298/// direction is sketched *through each atom's frame*; we instead hash the *query
299/// sketch* per table and gather the union of atoms in the matching (and, to
300/// improve recall, the Hamming-1 neighbouring) buckets. Because each table
301/// touches only the atoms colliding in one bucket, total work is sublinear in
302/// `K` for well-spread sketches.
303///
304/// The gathered candidates are then ranked by exact
305/// [`AtomFrameSketch::alignment`] and the top `candidate_budget` are returned.
306/// All hyperplanes are seeded; building twice with the same seed yields byte-
307/// identical tables.
308pub struct SaeCandidateIndex {
309 /// Hyperplane banks, one per table: each `(bits_per_table, sketch_dim)`.
310 hyperplanes: Vec<Array2<f64>>,
311 /// Buckets per table: signature -> atom ids.
312 tables: Vec<HashMap<u64, Vec<usize>>>,
313 /// Sketch dimension shared by every atom.
314 sketch_dim: usize,
315 /// Number of atoms indexed.
316 num_atoms: usize,
317}
318
319/// Tuning for [`SaeCandidateIndex::build`]. All fields are explicit so the index
320/// never reads global state; no CLI flags.
321#[derive(Clone, Copy, Debug)]
322pub struct IndexConfig {
323 /// Number of independent LSH tables. More tables → higher recall, more work.
324 pub num_tables: usize,
325 /// Random hyperplanes per table (signature bit-width). More bits → finer
326 /// buckets (fewer collisions, lower recall per table).
327 pub bits_per_table: usize,
328 /// Whether to also probe Hamming-distance-1 neighbouring buckets per table
329 /// (multi-probe LSH). Cheap and a large recall win; kept on by default.
330 pub multiprobe: bool,
331 /// Master seed for all hyperplane banks.
332 pub seed: u64,
333}
334
335impl IndexConfig {
336 /// A default configuration sized for a sketch of dimension `sketch_dim` and
337 /// roughly `num_atoms` atoms. Chooses `bits_per_table ≈ log2(num_atoms)`
338 /// (capped by the sketch dimension) so the expected bucket occupancy is a
339 /// small constant, and a handful of tables for recall — both grow only
340 /// logarithmically in `num_atoms`, keeping queries sublinear.
341 pub fn auto(sketch_dim: usize, num_atoms: usize, seed: u64) -> Self {
342 let log2 = |n: usize| -> usize {
343 if n <= 1 {
344 1
345 } else {
346 (usize::BITS - (n - 1).leading_zeros()) as usize
347 }
348 };
349 // Cap at 63: sign_signature packs bits into a u64, so bits_per_table must be ≤ 63.
350 let bits = log2(num_atoms.max(2)).clamp(1, sketch_dim.max(1).min(63));
351 // Aim for ~constant per-bucket occupancy; a few tables recover recall
352 // lost to any single table's quantization.
353 let num_tables = log2(num_atoms.max(2)).clamp(4, 16);
354 Self {
355 num_tables,
356 bits_per_table: bits,
357 multiprobe: true,
358 seed,
359 }
360 }
361}
362
363impl SaeCandidateIndex {
364 /// Build the index over every atom of `sketch`.
365 pub fn build<S: AtomFrameSketch>(sketch: &S, config: IndexConfig) -> Result<Self, String> {
366 let sketch_dim = sketch.sketch_dim();
367 if sketch_dim == 0 {
368 return Err("SaeCandidateIndex: sketch_dim must be positive".into());
369 }
370 if config.num_tables == 0 || config.bits_per_table == 0 {
371 return Err("SaeCandidateIndex: num_tables and bits_per_table must be positive".into());
372 }
373 // sign_signature packs bits into a u64 with `1u64 << r` for r in 0..bits_per_table.
374 // Shifting by 64+ is a panic in debug and undefined behaviour in release; cap at 63.
375 if config.bits_per_table > 63 {
376 return Err(format!(
377 "SaeCandidateIndex: bits_per_table {} exceeds 63 (u64 signature limit)",
378 config.bits_per_table
379 ));
380 }
381 let num_atoms = sketch.num_atoms();
382
383 // One seeded hyperplane bank per table; seed is mixed per-table so the
384 // tables are independent yet fully reproducible.
385 let hyperplanes: Vec<Array2<f64>> = (0..config.num_tables)
386 .map(|t| {
387 let table_seed = mix_seed(config.seed ^ INDEX_HYPERPLANE_SALT, t as u64);
388 gaussian_projection(config.bits_per_table, sketch_dim, table_seed)
389 })
390 .collect();
391
392 let mut tables: Vec<HashMap<u64, Vec<usize>>> =
393 (0..config.num_tables).map(|_| HashMap::new()).collect();
394
395 for atom_id in 0..num_atoms {
396 let s = sketch.atom_sketch(atom_id);
397 if s.len() != sketch_dim {
398 return Err(format!(
399 "SaeCandidateIndex: atom {atom_id} sketch length {} != sketch_dim {sketch_dim}",
400 s.len()
401 ));
402 }
403 for (table, bank) in tables.iter_mut().zip(hyperplanes.iter()) {
404 let sig = sign_signature(bank, s.view());
405 table.entry(sig).or_default().push(atom_id);
406 }
407 }
408
409 Ok(Self {
410 hyperplanes,
411 tables,
412 sketch_dim,
413 num_atoms,
414 })
415 }
416
417 /// Number of atoms in the index.
418 pub fn num_atoms(&self) -> usize {
419 self.num_atoms
420 }
421
422 /// Gather the raw candidate atom-id set for a query `direction`, *without*
423 /// ranking or budget truncation. This is the sublinear part: it sketches the
424 /// query once per table (using a frame-agnostic global query sketch — the
425 /// query direction projected by the index's own representative projection)
426 /// and unions the colliding buckets (plus Hamming-1 neighbours when
427 /// multi-probe is enabled).
428 ///
429 /// `query_sketch` is the sketch-space query vector (length `sketch_dim`),
430 /// produced by the caller from the row residual via the
431 /// [`AtomFrameSketch`]. We probe each table with this single vector.
432 pub fn gather_candidates(&self, query_sketch: ArrayView1<f64>, multiprobe: bool) -> Vec<usize> {
433 let mut seen: HashSet<usize> = HashSet::new();
434 for (table, bank) in self.tables.iter().zip(self.hyperplanes.iter()) {
435 let (sig, margins) = sign_signature_with_margins(bank, query_sketch);
436 if let Some(ids) = table.get(&sig) {
437 seen.extend(ids.iter().copied());
438 }
439 if multiprobe {
440 // Flip the lowest-margin bit (the one most likely to be on the
441 // wrong side of its hyperplane) to reach the nearest neighbour
442 // bucket — standard multi-probe LSH, biggest recall win.
443 let flip_bit = lowest_margin_bit(&margins);
444 let neighbour = sig ^ (1u64 << flip_bit);
445 if let Some(ids) = table.get(&neighbour) {
446 seen.extend(ids.iter().copied());
447 }
448 }
449 }
450 let mut out: Vec<usize> = seen.into_iter().collect();
451 out.sort_unstable();
452 out
453 }
454
455 /// Propose the top `candidate_budget` atoms for a row whose residual is
456 /// `direction` (length `sketch.output_dim()`), ranked by exact frame
457 /// alignment.
458 ///
459 /// Pipeline: probe with [`AtomFrameSketch::query_sketch`] (`O(p·s)` for
460 /// shared-projection sketches, #994 — no atom is touched before the
461 /// gather), gather the sublinear candidate union, score each by
462 /// [`AtomFrameSketch::alignment`], and keep the highest-scoring
463 /// `candidate_budget`.
464 ///
465 /// Returns `(proposed_ids, dropped_for_budget)` where the second element
466 /// lists every gathered candidate that was truncated by the budget (never
467 /// silently discarded).
468 pub fn propose<S: AtomFrameSketch>(
469 &self,
470 sketch: &S,
471 direction: ArrayView1<f64>,
472 candidate_budget: usize,
473 config_multiprobe: bool,
474 ) -> Proposal {
475 let query_sketch = sketch.query_sketch(direction);
476 let gathered = if query_sketch.len() == self.sketch_dim {
477 self.gather_candidates(query_sketch.view(), config_multiprobe)
478 } else {
479 // A probe of the wrong dimension cannot be hashed against the
480 // tables; gather nothing rather than hash garbage. The recall
481 // report will then attribute every planted atom to `NotGathered`,
482 // which is the loud, attributable failure mode.
483 Vec::new()
484 };
485
486 // Exact-score every gathered candidate by frame alignment.
487 let mut scored: Vec<(usize, f64)> = gathered
488 .iter()
489 .map(|&id| (id, sketch.alignment(id, direction)))
490 .collect();
491 // Descending by alignment; ties broken by id for determinism.
492 scored.sort_by(|a, b| {
493 b.1.partial_cmp(&a.1)
494 .unwrap_or(std::cmp::Ordering::Equal)
495 .then(a.0.cmp(&b.0))
496 });
497
498 let keep = candidate_budget.min(scored.len());
499 let proposed: Vec<usize> = scored[..keep].iter().map(|&(id, _)| id).collect();
500 let dropped_for_budget: Vec<usize> = scored[keep..].iter().map(|&(id, _)| id).collect();
501
502 Proposal {
503 proposed,
504 dropped_for_budget,
505 gathered_count: gathered.len(),
506 }
507 }
508
509 /// Recall contract. For a set of rows, each with planted truly-active atom
510 /// ids and a residual direction, run [`SaeCandidateIndex::propose`] at the
511 /// given `candidate_budget` and record what fraction of planted atoms
512 /// appear in the proposed set. Every miss is logged — no silent truncation.
513 ///
514 /// `rows` is `(direction, planted_active_ids)` per row.
515 pub fn recall_report<S: AtomFrameSketch>(
516 &self,
517 sketch: &S,
518 rows: &[(Array1<f64>, Vec<usize>)],
519 candidate_budget: usize,
520 multiprobe: bool,
521 ) -> RecallReport {
522 let mut total_planted: usize = 0;
523 let mut total_recovered: usize = 0;
524 let mut misses: Vec<RecallMiss> = Vec::new();
525 let mut total_gathered: usize = 0;
526
527 for (row_idx, (direction, planted)) in rows.iter().enumerate() {
528 let proposal = self.propose(sketch, direction.view(), candidate_budget, multiprobe);
529 total_gathered += proposal.gathered_count;
530 let proposed_set: HashSet<usize> = proposal.proposed.iter().copied().collect();
531 // A candidate that was gathered but truncated by the budget counts
532 // as a miss *attributable to the budget*; one never gathered at all
533 // is a miss *attributable to the index*. We record both, flagged.
534 let dropped_set: HashSet<usize> = proposal.dropped_for_budget.iter().copied().collect();
535
536 for &atom in planted {
537 total_planted += 1;
538 if proposed_set.contains(&atom) {
539 total_recovered += 1;
540 } else {
541 let reason = if dropped_set.contains(&atom) {
542 MissReason::TruncatedByBudget
543 } else {
544 MissReason::NotGathered
545 };
546 misses.push(RecallMiss {
547 row: row_idx,
548 atom,
549 alignment: sketch.alignment(atom, direction.view()),
550 reason,
551 });
552 }
553 }
554 }
555
556 let recall = if total_planted == 0 {
557 1.0
558 } else {
559 total_recovered as f64 / total_planted as f64
560 };
561 let avg_gathered = if rows.is_empty() {
562 0.0
563 } else {
564 total_gathered as f64 / rows.len() as f64
565 };
566
567 RecallReport {
568 candidate_budget,
569 num_rows: rows.len(),
570 total_planted,
571 total_recovered,
572 recall,
573 avg_candidates_gathered: avg_gathered,
574 num_atoms: self.num_atoms,
575 misses,
576 }
577 }
578
579 /// EXACT routing (#1777 / roadmap "real exact-routing guarantee"): return the
580 /// **global argmax** of the routing score over the WHOLE dictionary — the atom
581 /// whose frame best aligns with `direction` — with a guarantee that no
582 /// ungathered atom is silently better.
583 ///
584 /// The sublinear [`Self::propose`] gather is only a HEURISTIC: a gathered atom
585 /// at alignment `0.6` does not rule out an *ungathered* atom at `1.0`, because
586 /// the gather's alignment is a lower bound on the selected atom, never an upper
587 /// bound on the atoms it skipped. So `propose` alone can silently miss the true
588 /// best atom. This method closes that hole and is the path the encode router
589 /// uses.
590 ///
591 /// Correctness mechanism (sound, not heuristic):
592 /// * **LSH fast path with a TRUE upper bound.** The routing score is the frame
593 /// alignment `‖U_kᵀ d‖ / ‖d‖ ∈ [0, 1]`, so [`ROUTING_ALIGNMENT_UPPER_BOUND`]
594 /// (`1.0`) is a hard ceiling for *every* atom — gathered or not. If the best
595 /// gathered candidate already sits within [`ROUTING_CERT_EPS`] of that
596 /// ceiling, no ungathered atom can beat it: the gathered best is a certified
597 /// global score-maximizer and we return it WITHOUT a full scan
598 /// ([`ExactRoute::lsh_certified`] = `true`).
599 /// * **Exact fallback otherwise.** When the gathered best is not certified by
600 /// that bound (no tighter sound bound on ungathered atoms is available), run
601 /// the full [`brute_force_best_atom`] scan and return its argmax. This is the
602 /// ground truth — correctness over speed, exactly the roadmap contract.
603 ///
604 /// In both branches the returned atom has no atom of strictly greater routing
605 /// score anywhere in the dictionary (no silent miss). Returns `None` only for
606 /// an empty dictionary or an all-non-finite scan (a degenerate sketch).
607 pub fn route_exact<S: AtomFrameSketch>(
608 &self,
609 sketch: &S,
610 direction: ArrayView1<f64>,
611 candidate_budget: usize,
612 multiprobe: bool,
613 ) -> Option<ExactRoute> {
614 // Heuristic LSH gather first (sublinear) — the speed fast path.
615 let proposal = self.propose(sketch, direction, candidate_budget, multiprobe);
616 let lsh_best = proposal
617 .proposed
618 .first()
619 .copied()
620 .map(|id| (id, sketch.alignment(id, direction)));
621
622 if let Some((b, a_b)) = lsh_best {
623 if a_b.is_finite() && a_b >= ROUTING_ALIGNMENT_UPPER_BOUND - ROUTING_CERT_EPS {
624 // Universal-bound certificate: the routing score is capped at 1.0
625 // for EVERY atom, so a gathered atom already at the ceiling cannot
626 // be beaten by any ungathered one. Sound global optimality with no
627 // full scan.
628 return Some(ExactRoute {
629 atom: b,
630 alignment: a_b,
631 lsh_certified: true,
632 lsh_agreed: true,
633 did_full_scan: false,
634 });
635 }
636 }
637
638 // Not certified by the bound ⇒ the gather might have missed a better
639 // ungathered atom. The only sound recourse without a tighter per-atom upper
640 // bound is the exact full scan: it IS the global argmax.
641 let (atom, alignment) = brute_force_best_atom(sketch, direction)?;
642 let lsh_agreed = lsh_best.is_some_and(|(b, _)| b == atom);
643 Some(ExactRoute {
644 atom,
645 alignment,
646 lsh_certified: false,
647 lsh_agreed,
648 did_full_scan: true,
649 })
650 }
651}
652
653/// Hard upper bound on the routing score (frame alignment) of ANY atom: the
654/// alignment `‖U_kᵀ d‖ / ‖d‖` is the fraction of a direction's energy inside the
655/// atom's column-space, so it lies in `[0, 1]` for every atom, gathered or not.
656/// This is the *true* upper bound that makes [`SaeCandidateIndex::route_exact`]'s
657/// LSH fast path sound: a gathered atom at the ceiling cannot be beaten.
658pub const ROUTING_ALIGNMENT_UPPER_BOUND: f64 = 1.0;
659
660/// Tolerance for certifying the LSH fast path against [`ROUTING_ALIGNMENT_UPPER_BOUND`].
661/// A gathered best within this of the ceiling is treated as a certified global
662/// maximizer (floating-point slack on the `‖·‖`/`‖·‖` ratio).
663pub const ROUTING_CERT_EPS: f64 = 1e-12;
664
665/// Brute-force EXACT global argmax of the routing score (frame alignment) over the
666/// WHOLE dictionary: scan every atom, return `(atom_id, alignment)` of the highest
667/// scorer. Ties break to the LOWEST id (a strict `>` replacement keeps the first
668/// maximizer), matching [`SaeCandidateIndex::propose`]'s id-ascending tie-break so
669/// the two agree atom-for-atom. Non-finite alignments are skipped. Returns `None`
670/// for an empty dictionary (or one whose every atom scored non-finite).
671///
672/// This is `O(K)` per call and is the ground truth [`SaeCandidateIndex::route_exact`]
673/// falls back to whenever the LSH gather is not certified optimal.
674pub fn brute_force_best_atom<S: AtomFrameSketch>(
675 sketch: &S,
676 direction: ArrayView1<f64>,
677) -> Option<(usize, f64)> {
678 let mut best: Option<(usize, f64)> = None;
679 for id in 0..sketch.num_atoms() {
680 let a = sketch.alignment(id, direction);
681 if !a.is_finite() {
682 continue;
683 }
684 match best {
685 Some((_, ba)) if a <= ba => {}
686 _ => best = Some((id, a)),
687 }
688 }
689 best
690}
691
692/// Result of [`SaeCandidateIndex::route_exact`]: the certified-or-exact global
693/// argmax of the routing score for one row, plus how it was obtained.
694#[derive(Clone, Copy, Debug)]
695pub struct ExactRoute {
696 /// The chosen atom id — a GLOBAL routing-score argmax (no atom in the
697 /// dictionary has a strictly greater score). No silent miss.
698 pub atom: usize,
699 /// The chosen atom's exact frame alignment with the row direction.
700 pub alignment: f64,
701 /// `true` ⇒ the LSH fast path certified optimality via the universal upper
702 /// bound (gathered best at the `1.0` ceiling); no full scan was needed.
703 pub lsh_certified: bool,
704 /// Whether the LSH gather's best candidate equalled the returned argmax.
705 /// `true` whenever `lsh_certified`; a diagnostic of the gather's recall.
706 pub lsh_agreed: bool,
707 /// `true` ⇒ the exact `O(K)` fallback scan ran (the LSH bound did not certify).
708 pub did_full_scan: bool,
709}
710
711/// One row's proposal: the budgeted candidate set plus what the budget dropped.
712#[derive(Clone, Debug)]
713pub struct Proposal {
714 /// The top `candidate_budget` atom ids by frame alignment.
715 pub proposed: Vec<usize>,
716 /// Gathered candidates truncated by the budget — logged, never silent.
717 pub dropped_for_budget: Vec<usize>,
718 /// How many candidates the sublinear gather returned before budgeting.
719 pub gathered_count: usize,
720}
721
722/// Why a planted atom failed to appear in a row's proposed candidate set.
723#[derive(Clone, Copy, Debug, PartialEq, Eq)]
724pub enum MissReason {
725 /// The index never gathered this atom into the candidate union (an LSH
726 /// recall miss — widen tables / probes).
727 NotGathered,
728 /// The atom *was* gathered but the budget truncated it (widen the budget).
729 TruncatedByBudget,
730}
731
732/// One recorded recall miss.
733#[derive(Clone, Copy, Debug)]
734pub struct RecallMiss {
735 /// Row index in the report's input.
736 pub row: usize,
737 /// The planted atom id that was missed.
738 pub atom: usize,
739 /// The atom's exact frame alignment with the row direction (diagnostic).
740 pub alignment: f64,
741 /// Whether the miss was an index miss or a budget truncation.
742 pub reason: MissReason,
743}
744
745/// Result of [`SaeCandidateIndex::recall_report`].
746#[derive(Clone, Debug)]
747pub struct RecallReport {
748 /// Candidate budget the recall was measured at.
749 pub candidate_budget: usize,
750 /// Number of rows evaluated.
751 pub num_rows: usize,
752 /// Total planted truly-active atoms across all rows.
753 pub total_planted: usize,
754 /// How many of them appeared in the proposed sets.
755 pub total_recovered: usize,
756 /// `recall@candidate_budget` = recovered / planted (1.0 if nothing planted).
757 pub recall: f64,
758 /// Mean number of candidates the sublinear gather returned per row — the
759 /// sublinearity witness; compare against `num_atoms`.
760 pub avg_candidates_gathered: f64,
761 /// Total atoms in the index (for the sublinearity ratio).
762 pub num_atoms: usize,
763 /// Every miss, with its row, atom, alignment, and reason. No silent drops.
764 pub misses: Vec<RecallMiss>,
765}
766
767impl RecallReport {
768 /// Convenience: ratio of mean gathered candidates to dictionary size. A
769 /// value far below `1.0` is the evidence that proposal touched a sublinear
770 /// slice of the dictionary.
771 pub fn sublinearity_ratio(&self) -> f64 {
772 if self.num_atoms == 0 {
773 0.0
774 } else {
775 self.avg_candidates_gathered / self.num_atoms as f64
776 }
777 }
778}
779
780// ---------------------------------------------------------------------------
781// Helpers (deterministic, dependency-light)
782// ---------------------------------------------------------------------------
783
784/// Mix a base seed with an index into a well-spread `u64` (SplitMix64 finalizer
785/// on the sum). Deterministic, no clock.
786#[inline]
787fn mix_seed(base: u64, idx: u64) -> u64 {
788 // Finalize `base + idx·G` with the canonical SplitMix64 step. The stateful
789 // form adds G internally, so pre-subtract one G to land on the same input
790 // and keep the output bit-identical to the previous inlined finalizer.
791 let mut state = base
792 .wrapping_add(idx.wrapping_mul(0x9E37_79B9_7F4A_7C15))
793 .wrapping_sub(0x9E37_79B9_7F4A_7C15);
794 gam_linalg::utils::splitmix64(&mut state)
795}
796
797/// A seeded Gaussian random matrix of shape `(rows, cols)` (rows of hyperplanes
798/// / projection rows). Uses Box–Muller off a seeded `StdRng`.
799fn gaussian_projection(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
800 use rand::RngExt as _;
801 let mut rng = StdRng::seed_from_u64(seed);
802 let mut m = Array2::<f64>::zeros((rows, cols));
803 for r in 0..rows {
804 for c in 0..cols {
805 let u1 = rng.random::<f64>().max(1e-16);
806 let u2 = rng.random::<f64>();
807 m[(r, c)] = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
808 }
809 }
810 m
811}
812
813/// Modified Gram–Schmidt orthonormalization of a decoder block's columns.
814/// Input `block` is `(p, m)`; output `U` is `(p, r)` with orthonormal columns
815/// spanning `range(block)`, `r ≤ m` (rank-deficient columns are dropped).
816fn orthonormal_frame(block: &Array2<f64>) -> Array2<f64> {
817 let p = block.nrows();
818 let m = block.ncols();
819 let mut cols: Vec<Array1<f64>> = Vec::with_capacity(m);
820 for j in 0..m {
821 let mut v = block.column(j).to_owned();
822 for q in &cols {
823 let proj: f64 = q.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
824 for (vi, &qi) in v.iter_mut().zip(q.iter()) {
825 *vi -= proj * qi;
826 }
827 }
828 let nrm = vec_norm(v.view());
829 if nrm > DIRECTION_NORM_FLOOR {
830 for vi in v.iter_mut() {
831 *vi /= nrm;
832 }
833 cols.push(v);
834 }
835 }
836 let r = cols.len();
837 let mut u = Array2::<f64>::zeros((p, r));
838 for (j, col) in cols.into_iter().enumerate() {
839 u.column_mut(j).assign(&col);
840 }
841 u
842}
843
844/// `M · v` for `M` shape `(rows, cols)`, `v` length `cols`.
845fn mat_vec(m: &Array2<f64>, v: ArrayView1<f64>) -> Array1<f64> {
846 let mut out = Array1::<f64>::zeros(m.nrows());
847 for r in 0..m.nrows() {
848 let row = m.row(r);
849 out[r] = row.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
850 }
851 out
852}
853
854#[inline]
855fn vec_norm(v: ArrayView1<f64>) -> f64 {
856 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
857}
858
859#[inline]
860fn normalize_in_place(v: &mut Array1<f64>) {
861 let n = vec_norm(v.view());
862 if n > DIRECTION_NORM_FLOOR {
863 for x in v.iter_mut() {
864 *x /= n;
865 }
866 }
867}
868
869/// Pack the sign bits of `bank · s` into a `u64` signature. `bank` is
870/// `(bits, sketch_dim)`; `bits ≤ 64` (enforced by config-derived bit widths).
871fn sign_signature(bank: &Array2<f64>, s: ArrayView1<f64>) -> u64 {
872 let mut sig = 0u64;
873 for r in 0..bank.nrows() {
874 let row = bank.row(r);
875 let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
876 if dot >= 0.0 {
877 sig |= 1u64 << r;
878 }
879 }
880 sig
881}
882
883/// Signature plus per-bit signed margins (the dot products), used by multi-probe
884/// to find the least-confident bit to flip.
885fn sign_signature_with_margins(bank: &Array2<f64>, s: ArrayView1<f64>) -> (u64, Vec<f64>) {
886 let mut sig = 0u64;
887 let mut margins = Vec::with_capacity(bank.nrows());
888 for r in 0..bank.nrows() {
889 let row = bank.row(r);
890 let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
891 if dot >= 0.0 {
892 sig |= 1u64 << r;
893 }
894 margins.push(dot);
895 }
896 (sig, margins)
897}
898
899/// Index of the bit whose hyperplane the query sits closest to (smallest `|dot|`)
900/// — the most likely to have landed in the wrong bucket.
901fn lowest_margin_bit(margins: &[f64]) -> usize {
902 let mut best = 0usize;
903 let mut best_abs = f64::INFINITY;
904 for (i, &m) in margins.iter().enumerate() {
905 let a = m.abs();
906 if a < best_abs {
907 best_abs = a;
908 best = i;
909 }
910 }
911 best
912}
913
914// ---------------------------------------------------------------------------
915// Tests
916// ---------------------------------------------------------------------------
917
918#[cfg(test)]
919mod tests {
920 use super::*;
921 use rand::RngExt as _;
922 use rand::rngs::StdRng;
923
924 /// Draw a unit vector in `p` dims from a seeded RNG.
925 fn unit_vec(rng: &mut StdRng, p: usize) -> Array1<f64> {
926 let mut v = Array1::<f64>::zeros(p);
927 for x in v.iter_mut() {
928 let u1 = rng.random::<f64>().max(1e-16);
929 let u2 = rng.random::<f64>();
930 *x = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
931 }
932 let n = vec_norm(v.view());
933 if n > DIRECTION_NORM_FLOOR {
934 for x in v.iter_mut() {
935 *x /= n;
936 }
937 }
938 v
939 }
940
941 /// Build a synthetic dictionary of `k` rank-1 atoms: atom `i`'s decoder
942 /// block is the outer-friendly single column `c_i` (a random unit direction
943 /// in output space). Returns the blocks and the list of column directions so
944 /// the planted-atom test can construct directions that lie in chosen atoms.
945 fn synthetic_dictionary(k: usize, p: usize, seed: u64) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
946 let mut rng = StdRng::seed_from_u64(seed);
947 let mut blocks = Vec::with_capacity(k);
948 let mut dirs = Vec::with_capacity(k);
949 for _ in 0..k {
950 let c = unit_vec(&mut rng, p);
951 let mut block = Array2::<f64>::zeros((p, 1));
952 block.column_mut(0).assign(&c);
953 blocks.push(block);
954 dirs.push(c);
955 }
956 (blocks, dirs)
957 }
958
959 #[test]
960 fn frame_alignment_is_exact_for_in_range_direction() {
961 let (blocks, dirs) = synthetic_dictionary(8, 16, 11);
962 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 7).unwrap();
963 // A direction equal to atom 3's column lies fully in its range.
964 let d = &dirs[3];
965 let a = sketch.alignment(3, d.view());
966 assert!(a > 0.999, "in-range alignment should be ~1, got {a}");
967 // An orthogonal-ish direction (atom 5's column is generically nearly
968 // orthogonal to atom 3) aligns weakly with atom 3.
969 let a_off = sketch.alignment(3, dirs[5].view());
970 assert!(
971 a_off < a,
972 "off-atom alignment {a_off} should be below in-range {a}"
973 );
974 }
975
976 /// Regression for the routing-confidence gate (#1026): a low-alignment LSH
977 /// route must be flagged for the exact fallback, never trusted by the
978 /// heuristic gate. This gate is a confidence/quality proxy, NOT a
979 /// global-optimality certificate — being at/above the threshold means the
980 /// chosen atom is itself a reasonable fit, it does not prove no better
981 /// ungathered atom exists. `certified_encode_with_index` (and the amortized
982 /// twin) flag a routed row UNCERTIFIED whenever the best-aligned proposed
983 /// atom's frame alignment is below
984 /// `encode::CANDIDATE_ROUTING_MIN_ALIGNMENT`. The gate's decision input is
985 /// exactly `sketch.alignment(best_atom, target)`; pin it here — with exact,
986 /// deterministic linear algebra rather than LSH gather luck — so a future
987 /// change to the frame-alignment formula cannot silently shift the
988 /// threshold's meaning out from under the gate.
989 ///
990 /// Two atoms whose decoder frames span ORTHOGONAL subspaces of a 6-dim
991 /// ambient (atom 0: `span(e0,e1)`, atom 1: `span(e2,e3)`; dims `e4,e5`
992 /// covered by neither). A direction wholly inside atom 1's subspace has
993 /// alignment exactly 1 with atom 1 (in-frame, ABOVE the gate) and exactly 0
994 /// with atom 0 (off-frame, BELOW the gate) — the mis-route the gate exists
995 /// to flag.
996 #[test]
997 fn routing_confidence_gate_input_separates_off_frame_from_in_frame() {
998 use crate::encode::CANDIDATE_ROUTING_MIN_ALIGNMENT as GATE;
999
1000 let p = 6usize;
1001 let mut block_a = Array2::<f64>::zeros((p, 2));
1002 block_a[[0, 0]] = 1.0; // e0
1003 block_a[[1, 1]] = 1.0; // e1
1004 let mut block_b = Array2::<f64>::zeros((p, 2));
1005 block_b[[2, 0]] = 1.0; // e2
1006 block_b[[3, 1]] = 1.0; // e3
1007 let sketch =
1008 RandomProjectionFrameSketch::from_decoder_blocks(&[block_a, block_b], 16, 4242)
1009 .unwrap();
1010
1011 // A unit direction wholly inside atom 1's (e2,e3) subspace.
1012 let mut in_frame_b = Array1::<f64>::zeros(p);
1013 in_frame_b[2] = 0.6;
1014 in_frame_b[3] = 0.8; // unit norm (0.6² + 0.8² = 1)
1015 let a_right = sketch.alignment(1, in_frame_b.view());
1016 let a_wrong = sketch.alignment(0, in_frame_b.view());
1017
1018 assert!(
1019 a_right > 0.999,
1020 "an in-frame direction must align ~1 with its own atom; got {a_right}"
1021 );
1022 assert!(
1023 a_wrong < 1e-9,
1024 "an orthogonal-subspace direction must align ~0 with the wrong atom; got {a_wrong}"
1025 );
1026 // The exact predicate the encode gate evaluates per row: a mis-routed
1027 // (orthogonal) atom falls BELOW the gate → flagged for the exact
1028 // fallback; the correctly-routed atom sits AT/ABOVE the gate → trusted by
1029 // the heuristic gate (a confidence proxy, not a global-optimality
1030 // certificate).
1031 assert!(
1032 a_wrong < GATE,
1033 "a mis-routed (orthogonal) atom must fall below the routing gate {GATE}; got {a_wrong}"
1034 );
1035 assert!(
1036 a_right >= GATE,
1037 "the correctly-routed atom must sit at/above the routing gate {GATE}; got {a_right}"
1038 );
1039
1040 // A direction in the UNCOVERED (e4,e5) subspace aligns ~0 with BOTH
1041 // atoms, so whichever atom the LSH surfaces, the gate fires: no atom can
1042 // certify this route. This is the worst-case the gate exists to catch.
1043 let mut uncovered = Array1::<f64>::zeros(p);
1044 uncovered[4] = 1.0;
1045 for atom in 0..2 {
1046 let a = sketch.alignment(atom, uncovered.view());
1047 assert!(
1048 a < GATE,
1049 "an uncovered-subspace direction must fall below the gate for atom {atom}; got {a}"
1050 );
1051 }
1052 }
1053
1054 #[test]
1055 fn build_is_deterministic_for_a_fixed_seed() {
1056 let (blocks, _) = synthetic_dictionary(64, 24, 99);
1057 let s1 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
1058 let s2 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
1059 // Same seed → identical representative sketches.
1060 for i in 0..blocks.len() {
1061 let a = s1.atom_sketch(i);
1062 let b = s2.atom_sketch(i);
1063 let diff = vec_norm((&a - &b).view());
1064 assert!(
1065 diff < 1e-12,
1066 "atom {i} sketch differs across builds: {diff:e}"
1067 );
1068 }
1069 let cfg = IndexConfig::auto(16, blocks.len(), 5);
1070 let idx1 = SaeCandidateIndex::build(&s1, cfg).unwrap();
1071 let idx2 = SaeCandidateIndex::build(&s2, cfg).unwrap();
1072 // Identical hyperplane banks and bucket contents.
1073 for t in 0..idx1.tables.len() {
1074 assert_eq!(idx1.tables[t].len(), idx2.tables[t].len());
1075 }
1076 }
1077
1078 #[test]
1079 fn planted_atoms_are_recalled_above_floor_at_sublinear_budget() {
1080 // A frontier-ish dictionary: many atoms, modest output dim.
1081 let k = 2000usize;
1082 let p = 48usize;
1083 let (blocks, dirs) = synthetic_dictionary(k, p, 2026);
1084 let sketch_dim = 24usize;
1085 let sketch =
1086 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 4242).unwrap();
1087 let cfg = IndexConfig::auto(sketch_dim, k, 4242);
1088 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1089
1090 // Plant: each row's residual is dominated by one chosen atom's column
1091 // (plus a little cross-talk from a second). The planted-active set is
1092 // that dominant atom. We build many such rows deterministically.
1093 let mut rng = StdRng::seed_from_u64(31337);
1094 let n_rows = 200usize;
1095 let mut rows: Vec<(Array1<f64>, Vec<usize>)> = Vec::with_capacity(n_rows);
1096 for _ in 0..n_rows {
1097 let primary = rng.random_range(0..k);
1098 let secondary = rng.random_range(0..k);
1099 // direction = 1.0 * c_primary + 0.15 * c_secondary
1100 let mut d = dirs[primary].clone();
1101 for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
1102 *di += 0.15 * si;
1103 }
1104 let n = vec_norm(d.view());
1105 for di in d.iter_mut() {
1106 *di /= n;
1107 }
1108 rows.push((d, vec![primary]));
1109 }
1110
1111 // Sublinear candidate budget: << K. We allow the gather to surface a
1112 // handful, but the *budget* (the per-row local block size) stays small.
1113 let candidate_budget = 32usize;
1114 let report = index.recall_report(&sketch, &rows, candidate_budget, cfg.multiprobe);
1115
1116 // The gather must touch only a sublinear slice of the dictionary.
1117 assert!(
1118 report.sublinearity_ratio() < 0.5,
1119 "gather was not sublinear: avg {} of {} atoms (ratio {:.3})",
1120 report.avg_candidates_gathered,
1121 report.num_atoms,
1122 report.sublinearity_ratio()
1123 );
1124
1125 // Recall floor: the LSH index must recover the planted dominant atom for
1126 // the large majority of rows at this sublinear budget. Misses are
1127 // logged, never silently dropped.
1128 let floor = 0.80;
1129 assert!(
1130 report.recall >= floor,
1131 "recall {:.3} below floor {floor}; {} misses logged (first few: {:?})",
1132 report.recall,
1133 report.misses.len(),
1134 report
1135 .misses
1136 .iter()
1137 .take(5)
1138 .map(|m| (m.row, m.atom, m.reason, m.alignment))
1139 .collect::<Vec<_>>()
1140 );
1141
1142 // Every miss is accounted for with a reason — the no-silent-truncation
1143 // contract.
1144 let recovered = report.total_recovered;
1145 assert_eq!(
1146 report.total_planted - recovered,
1147 report.misses.len(),
1148 "miss list must account for every unrecovered planted atom"
1149 );
1150 }
1151
1152 #[test]
1153 fn auto_candidate_budget_tracks_the_issue_band() {
1154 assert_eq!(auto_candidate_budget(2), CANDIDATE_BUDGET_MIN);
1155 assert_eq!(auto_candidate_budget(64), 48);
1156 assert_eq!(auto_candidate_budget(1024), 80);
1157 assert_eq!(auto_candidate_budget(100_000), CANDIDATE_BUDGET_MAX);
1158 // Monotone non-decreasing in K and always inside the band.
1159 let mut prev = 0usize;
1160 for k in [2usize, 16, 64, 256, 1024, 4096, 65_536, 1_000_000] {
1161 let c = auto_candidate_budget(k);
1162 assert!(c >= prev, "budget must be monotone in K");
1163 assert!((CANDIDATE_BUDGET_MIN..=CANDIDATE_BUDGET_MAX).contains(&c));
1164 prev = c;
1165 }
1166 }
1167
1168 /// Build a planted row set for a dictionary: each row's residual direction
1169 /// is dominated by one chosen atom (plus cross-talk from a second), and
1170 /// the planted-active set is the dominant atom.
1171 fn planted_rows(
1172 dirs: &[Array1<f64>],
1173 n_rows: usize,
1174 seed: u64,
1175 ) -> Vec<(Array1<f64>, Vec<usize>)> {
1176 let k = dirs.len();
1177 let mut rng = StdRng::seed_from_u64(seed);
1178 let mut rows = Vec::with_capacity(n_rows);
1179 for _ in 0..n_rows {
1180 let primary = rng.random_range(0..k);
1181 let secondary = rng.random_range(0..k);
1182 let mut d = dirs[primary].clone();
1183 for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
1184 *di += 0.15 * si;
1185 }
1186 let n = vec_norm(d.view());
1187 for di in d.iter_mut() {
1188 *di /= n;
1189 }
1190 rows.push((d, vec![primary]));
1191 }
1192 rows
1193 }
1194
1195 #[test]
1196 fn k_ladder_recall_determinism_and_sublinearity() {
1197 // #985 part 2 (index tier): the K=2-era assumptions say nothing about
1198 // frontier K, so gate the proposal machinery on a planted ladder at
1199 // K = 64 and K = 1024 with the SAME battery per rung — recall above a
1200 // stated floor at the auto-derived budget, every miss accounted for,
1201 // and byte-identical proposals across two independent builds. The
1202 // gather must also become *relatively* sparser as K grows (the
1203 // sublinearity witness): what is allowed to touch half the dictionary
1204 // at K = 64 must not at K = 1024.
1205 let p = 48usize;
1206 let n_rows = 150usize;
1207 let mut ladder_ratios = Vec::new();
1208 for &k in &[64usize, 1024] {
1209 let (blocks, dirs) = synthetic_dictionary(k, p, 9000 + k as u64);
1210 let sketch_dim = 24usize;
1211 let sketch_seed = 71 + k as u64;
1212 let sketch =
1213 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
1214 .unwrap();
1215 let cfg = IndexConfig::auto(sketch_dim, k, sketch_seed);
1216 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1217
1218 let rows = planted_rows(&dirs, n_rows, 555 + k as u64);
1219 let budget = auto_candidate_budget(k);
1220 let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
1221
1222 // Recall floor at the auto-derived budget, with every miss carrying
1223 // a reason — the no-silent-truncation contract, per rung.
1224 let floor = 0.80;
1225 assert!(
1226 report.recall >= floor,
1227 "K={k}: recall {:.3} below floor {floor}; {} misses (first: {:?})",
1228 report.recall,
1229 report.misses.len(),
1230 report
1231 .misses
1232 .iter()
1233 .take(3)
1234 .map(|m| (m.row, m.atom, m.reason, m.alignment))
1235 .collect::<Vec<_>>()
1236 );
1237 assert_eq!(
1238 report.total_planted - report.total_recovered,
1239 report.misses.len(),
1240 "K={k}: miss list must account for every unrecovered planted atom"
1241 );
1242
1243 // Search determinism: an independent rebuild from the same inputs
1244 // proposes the identical candidate set for every probed row.
1245 let sketch2 =
1246 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
1247 .unwrap();
1248 let index2 = SaeCandidateIndex::build(&sketch2, cfg).unwrap();
1249 for (direction, _) in rows.iter().take(20) {
1250 let a = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
1251 let b = index2.propose(&sketch2, direction.view(), budget, cfg.multiprobe);
1252 assert_eq!(
1253 a.proposed, b.proposed,
1254 "K={k}: rebuild must propose identically"
1255 );
1256 }
1257
1258 // Proposal size is the budget, never the dictionary: the per-row
1259 // local block stays near the planted/active scale.
1260 for (direction, _) in rows.iter().take(20) {
1261 let prop = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
1262 assert!(prop.proposed.len() <= budget);
1263 }
1264
1265 ladder_ratios.push((k, report.sublinearity_ratio()));
1266 }
1267 // Relative sparsity must improve up the ladder: the gathered fraction
1268 // of the dictionary shrinks as K grows (sublinear gather), and at the
1269 // frontier-shaped rung it must be a small slice outright.
1270 let (_, ratio_small) = ladder_ratios[0];
1271 let (k_big, ratio_big) = ladder_ratios[1];
1272 assert!(
1273 ratio_big < ratio_small,
1274 "sublinearity must improve along the ladder: {ladder_ratios:?}"
1275 );
1276 assert!(
1277 ratio_big < 0.25,
1278 "K={k_big}: gather touched {:.1}% of the dictionary",
1279 ratio_big * 100.0
1280 );
1281 }
1282
1283 /// Counting wrapper: delegates everything, counts `project_direction`
1284 /// calls. The #994 acceptance gate: with the exact probe, building the
1285 /// query sketch touches NO atom, so a whole `propose` makes zero
1286 /// `project_direction` calls (scoring goes through `alignment`).
1287 struct CountingSketch<'a> {
1288 inner: &'a RandomProjectionFrameSketch,
1289 project_calls: std::cell::Cell<usize>,
1290 }
1291
1292 impl AtomFrameSketch for CountingSketch<'_> {
1293 fn sketch_dim(&self) -> usize {
1294 self.inner.sketch_dim()
1295 }
1296 fn output_dim(&self) -> usize {
1297 self.inner.output_dim()
1298 }
1299 fn num_atoms(&self) -> usize {
1300 self.inner.num_atoms()
1301 }
1302 fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
1303 self.inner.atom_sketch(atom_id)
1304 }
1305 fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
1306 self.project_calls.set(self.project_calls.get() + 1);
1307 self.inner.project_direction(atom_id, direction)
1308 }
1309 fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
1310 self.inner.alignment(atom_id, direction)
1311 }
1312 fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
1313 self.inner.query_sketch(direction)
1314 }
1315 }
1316
1317 #[test]
1318 fn query_probe_touches_no_atom_before_the_gather() {
1319 let k = 512usize;
1320 let p = 32usize;
1321 let (blocks, dirs) = synthetic_dictionary(k, p, 77);
1322 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 13).unwrap();
1323 let cfg = IndexConfig::auto(16, k, 13);
1324 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1325 let counting = CountingSketch {
1326 inner: &sketch,
1327 project_calls: std::cell::Cell::new(0),
1328 };
1329 drop(index.propose(&counting, dirs[5].view(), 32, cfg.multiprobe));
1330 assert_eq!(
1331 counting.project_calls.get(),
1332 0,
1333 "the exact query probe must be independent of K: no per-atom \
1334 projection before the gather (#994)"
1335 );
1336 }
1337
1338 /// Build a coherent-cluster dictionary: `n_clusters` random unit centers,
1339 /// each with `cluster_size` atoms drawn as small perturbations of the
1340 /// center (renormalized). Exactly the non-isotropic regime where the old
1341 /// masked-average probe degraded (#994).
1342 fn coherent_cluster_dictionary(
1343 n_clusters: usize,
1344 cluster_size: usize,
1345 p: usize,
1346 spread: f64,
1347 seed: u64,
1348 ) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
1349 let mut rng = StdRng::seed_from_u64(seed);
1350 let mut blocks = Vec::with_capacity(n_clusters * cluster_size);
1351 let mut dirs = Vec::with_capacity(n_clusters * cluster_size);
1352 for _ in 0..n_clusters {
1353 let center = unit_vec(&mut rng, p);
1354 for _ in 0..cluster_size {
1355 let noise = unit_vec(&mut rng, p);
1356 let mut c = center.clone();
1357 for (ci, &ni) in c.iter_mut().zip(noise.iter()) {
1358 *ci += spread * ni;
1359 }
1360 let n = vec_norm(c.view());
1361 for ci in c.iter_mut() {
1362 *ci /= n;
1363 }
1364 let mut block = Array2::<f64>::zeros((p, 1));
1365 block.column_mut(0).assign(&c);
1366 blocks.push(block);
1367 dirs.push(c);
1368 }
1369 }
1370 (blocks, dirs)
1371 }
1372
1373 #[test]
1374 fn coherent_clusters_are_recalled_with_the_exact_probe() {
1375 // 32 clusters × 32 near-parallel atoms = 1024 atoms. Rows are
1376 // dominated by one specific cluster member; the proposal must recover
1377 // that exact member (not merely its cluster) at the auto budget —
1378 // exact alignment scoring separates siblings once the probe lands the
1379 // gather in the right bucket neighborhood.
1380 let n_clusters = 32usize;
1381 let cluster_size = 32usize;
1382 let k = n_clusters * cluster_size;
1383 let p = 48usize;
1384 let (blocks, dirs) = coherent_cluster_dictionary(n_clusters, cluster_size, p, 0.25, 4242);
1385 let sketch_dim = 24usize;
1386 let sketch =
1387 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 99).unwrap();
1388 let cfg = IndexConfig::auto(sketch_dim, k, 99);
1389 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1390
1391 let rows = planted_rows(&dirs, 150, 31337);
1392 let budget = auto_candidate_budget(k);
1393 let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
1394 let floor = 0.80;
1395 assert!(
1396 report.recall >= floor,
1397 "coherent-cluster recall {:.3} below floor {floor}; {} misses (first: {:?})",
1398 report.recall,
1399 report.misses.len(),
1400 report
1401 .misses
1402 .iter()
1403 .take(3)
1404 .map(|m| (m.row, m.atom, m.reason, m.alignment))
1405 .collect::<Vec<_>>()
1406 );
1407 // Still a sublinear slice of the dictionary, clusters or not.
1408 assert!(
1409 report.sublinearity_ratio() < 0.5,
1410 "cluster gather touched {:.1}% of the dictionary",
1411 report.sublinearity_ratio() * 100.0
1412 );
1413 }
1414
1415 #[test]
1416 fn exact_probe_matches_shared_projection_of_the_direction() {
1417 // The override is literally normalize(R·d): verify against a manual
1418 // computation through the public surface (atom_sketch of a rank-1 atom
1419 // whose only column IS the direction gives normalize(R·d) too).
1420 let p = 16usize;
1421 let mut rng = StdRng::seed_from_u64(5);
1422 let d = unit_vec(&mut rng, p);
1423 let mut block = Array2::<f64>::zeros((p, 1));
1424 block.column_mut(0).assign(&d);
1425 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&[block], 8, 21).unwrap();
1426 let via_probe = sketch.query_sketch(d.view());
1427 let via_atom = sketch.atom_sketch(0);
1428 let diff = vec_norm((&via_probe - &via_atom).view());
1429 assert!(
1430 diff < 1e-10,
1431 "query_sketch(d) must equal the rank-1 atom representative of d: diff {diff:e}"
1432 );
1433 }
1434
1435 /// The exact-routing guarantee (#1777 / roadmap): for EVERY row,
1436 /// [`SaeCandidateIndex::route_exact`] selects the SAME atom as the brute-force
1437 /// full-scan global argmax of the routing score — no silent miss — even on the
1438 /// rows where the sublinear LSH gather alone picks a worse atom. This is the
1439 /// acceptance contract: production routing == brute-force argmax, by
1440 /// construction of the exact fallback.
1441 #[test]
1442 fn route_exact_matches_brute_force_argmax_with_no_silent_miss() {
1443 // ── Arm A: exact-fallback path. Random unit-direction queries against a
1444 // frontier-shaped dictionary. A random direction lies fully in no atom's
1445 // rank-1 range, so the alignment is < 1 everywhere: the universal-bound
1446 // fast path never fires and route_exact must run the exact scan. ──────────
1447 let k = 1500usize;
1448 let p = 32usize;
1449 let (blocks, dirs) = synthetic_dictionary(k, p, 2027);
1450 let sketch_dim = 24usize;
1451 let sketch =
1452 RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 909).unwrap();
1453 let cfg = IndexConfig::auto(sketch_dim, k, 909);
1454 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1455 let budget = auto_candidate_budget(k);
1456
1457 let mut rng = StdRng::seed_from_u64(8675309);
1458 let n_rows = 300usize;
1459 let mut lsh_only_misses = 0usize; // rows where LSH-alone picked a worse atom
1460 for _ in 0..n_rows {
1461 let d = unit_vec(&mut rng, p);
1462
1463 // Ground truth: brute-force global argmax.
1464 let (truth_atom, truth_align) = brute_force_best_atom(&sketch, d.view())
1465 .expect("non-empty dictionary has an argmax");
1466
1467 // No-silent-miss invariant: nothing in the dictionary beats the truth.
1468 for id in 0..k {
1469 let a = sketch.alignment(id, d.view());
1470 assert!(
1471 a <= truth_align + 1e-12,
1472 "brute force is not the argmax: atom {id} scores {a} > {truth_align}"
1473 );
1474 }
1475
1476 // Production exact router must equal the brute-force argmax, exactly.
1477 let route = index
1478 .route_exact(&sketch, d.view(), budget, cfg.multiprobe)
1479 .expect("route_exact returns an argmax for a non-empty dictionary");
1480 assert_eq!(
1481 route.atom, truth_atom,
1482 "route_exact must select the brute-force global argmax"
1483 );
1484 assert!(
1485 (route.alignment - truth_align).abs() < 1e-12,
1486 "route_exact alignment {} != brute force {truth_align}",
1487 route.alignment
1488 );
1489 assert!(
1490 route.did_full_scan && !route.lsh_certified,
1491 "a sub-ceiling row must take the exact-scan fallback, not the bound fast path"
1492 );
1493
1494 // What the OLD heuristic (LSH gather best) alone would have chosen.
1495 let proposal = index.propose(&sketch, d.view(), budget, cfg.multiprobe);
1496 if let Some(&lsh_best) = proposal.proposed.first() {
1497 if lsh_best != truth_atom {
1498 lsh_only_misses += 1;
1499 }
1500 } else {
1501 lsh_only_misses += 1;
1502 }
1503 }
1504 // The fallback is doing real work: the sublinear gather alone DOES silently
1505 // miss the global best on a meaningful fraction of rows, and route_exact
1506 // recovered every one of them (asserted above, per row).
1507 assert!(
1508 lsh_only_misses > 0,
1509 "test is vacuous: LSH-alone never missed, so the exact fallback was never exercised"
1510 );
1511
1512 // ── Arm B: certified fast path. A query equal to a unique atom's own
1513 // column aligns exactly 1.0 with it (the ceiling) and < 1 with all others,
1514 // so route_exact certifies optimality via the universal bound — no scan —
1515 // and still returns the brute-force argmax. ─────────────────────────────
1516 let j = 777usize;
1517 let dj = &dirs[j];
1518 let (truth_atom, _) = brute_force_best_atom(&sketch, dj.view()).unwrap();
1519 assert_eq!(truth_atom, j, "the in-frame column is its own unique argmax");
1520 let route = index
1521 .route_exact(&sketch, dj.view(), budget, cfg.multiprobe)
1522 .unwrap();
1523 assert_eq!(route.atom, j, "fast path must still return the global argmax");
1524 assert!(
1525 route.lsh_certified && !route.did_full_scan,
1526 "a ceiling-alignment row must be certified by the universal bound, no scan"
1527 );
1528 assert!(
1529 route.alignment >= ROUTING_ALIGNMENT_UPPER_BOUND - ROUTING_CERT_EPS,
1530 "certified alignment must sit at the universal ceiling; got {}",
1531 route.alignment
1532 );
1533 }
1534
1535 #[test]
1536 fn empty_planted_rows_report_perfect_recall() {
1537 let (blocks, dirs) = synthetic_dictionary(32, 16, 1);
1538 let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 3).unwrap();
1539 let cfg = IndexConfig::auto(12, 32, 3);
1540 let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1541 let rows = vec![(dirs[0].clone(), Vec::<usize>::new())];
1542 let report = index.recall_report(&sketch, &rows, 8, true);
1543 assert_eq!(report.recall, 1.0);
1544 assert!(report.misses.is_empty());
1545 }
1546}