Skip to main content

gam_sae/
candidate_index.rs

1//! Sublinear candidate-atom index for active-set proposal (#985 part 1).
2//!
3//! A frontier SAE dictionary holds `K ≈ 10^4–10^5` atoms. The per-row *local*
4//! block — the small linear/Newton system over the atoms that are actually
5//! active in a row — is cheap, because the active set collapses it to a handful
6//! of atoms. The expensive step is *proposing* that active set: a naive scan
7//! scores every one of the `K` atom frames against every row, which is `O(K)`
8//! per row and dominates the whole solve once `K` is large.
9//!
10//! This module builds a **sublinear** candidate index over per-atom *sketches*
11//! of each atom's decoder column-space (its Grassmann frame `U_k`). Given a row
12//! residual direction it returns the top candidate atom ids likely to be
13//! active, touching only `O(log K)`-ish buckets instead of all `K` atoms.
14//!
15//! ## Layering against Track 1
16//!
17//! Track 1 owns the *real* atom frames `U_k` and has not landed yet, so this
18//! module is written against a [`AtomFrameSketch`] trait. Any frame source —
19//! the eventual Grassmann frames, or the decoder column blocks `B_k` already
20//! present on [`crate::manifold::SaeManifoldAtom`] — can implement
21//! it. A concrete, dependency-free default
22//! ([`RandomProjectionFrameSketch`]) is provided: a seeded random-projection /
23//! random-hyperplane signature of the atom's orthonormalized column span. The
24//! index ([`SaeCandidateIndex`]) is a deterministic multi-table
25//! random-hyperplane LSH over those sketches.
26//!
27//! ## Recall contract
28//!
29//! Sublinear proposal is only safe if it *almost never* drops a truly-active
30//! atom. [`SaeCandidateIndex::recall_report`] takes a set of planted
31//! truly-active atoms per row, runs the proposal at a stated candidate budget,
32//! and records the rate at which planted atoms appear in the proposed set —
33//! **logging every miss** rather than silently truncating. The returned
34//! [`RecallReport`] carries `recall@budget` and the full miss list so a caller
35//! can widen the budget or fall back to a dense scan for the affected rows.
36//!
37//! Determinism: every random choice is seeded by an explicit index seed; no
38//! clock, no global RNG.
39
40use ndarray::{Array1, Array2, ArrayView1};
41use rand::SeedableRng;
42use rand::rngs::StdRng;
43use std::collections::{HashMap, HashSet};
44
45/// Salt mixed into the per-table hyperplane seed so the index tables and the
46/// default sketch never share a random stream even when handed the same base
47/// seed.
48const INDEX_HYPERPLANE_SALT: u64 = 0x9E37_79B9_7F4A_7C15;
49
50/// Salt for the default random-projection sketch's projection matrix.
51const SKETCH_PROJECTION_SALT: u64 = 0xC2B2_AE3D_27D4_EB4F;
52
53/// Numerical floor below which a direction / column is treated as zero.
54const DIRECTION_NORM_FLOOR: f64 = 1e-12;
55
56/// Lower bound of the auto-derived per-row candidate budget `C` (#985). Below
57/// this the proposal set is too small for the solver's accepted active set to
58/// have headroom over the planted/active atom count.
59pub const CANDIDATE_BUDGET_MIN: usize = 32;
60
61/// Upper bound of the auto-derived per-row candidate budget `C` (#985). The
62/// per-row local block stays a small dense solve no matter how large the
63/// dictionary grows; beyond this the proposal step stops being the bottleneck
64/// reduction it exists to be.
65pub const CANDIDATE_BUDGET_MAX: usize = 128;
66
67/// Auto-derive the per-row candidate budget `C` from the dictionary size `K`
68/// (#985): `C = 8·⌈log₂ K⌉`, clamped to
69/// [[`CANDIDATE_BUDGET_MIN`], [`CANDIDATE_BUDGET_MAX`]]. Logarithmic growth
70/// keeps the per-row local block effectively constant-size while giving larger
71/// dictionaries a little more recall headroom; the clamp realizes the issue's
72/// `C ≈ 32–128` band. Magic-by-default: derived from `K` alone, no flag.
73///
74/// Concretely: `K = 64 → 48`, `K = 1024 → 80`, `K = 10⁵ → 128`.
75pub fn auto_candidate_budget(num_atoms: usize) -> usize {
76    let log2 = if num_atoms <= 1 {
77        1
78    } else {
79        (usize::BITS - (num_atoms - 1).leading_zeros()) as usize
80    };
81    (8 * log2).clamp(CANDIDATE_BUDGET_MIN, CANDIDATE_BUDGET_MAX)
82}
83
84// ---------------------------------------------------------------------------
85// Sketch interface
86// ---------------------------------------------------------------------------
87
88/// A low-dimensional sketch of one atom's decoder column-space (its Grassmann
89/// frame `U_k`).
90///
91/// The index never needs the full frame: it only needs (a) the sketch
92/// dimension, shared by every atom in a dictionary, and (b), for any query
93/// direction in output space, the atom's *sketch coordinates* of that direction
94/// — i.e. the projection of the direction onto the atom's column-space,
95/// expressed in the sketch's coordinates. A frame `U_k` (orthonormal columns
96/// spanning the decoder range) yields these as `sketch = R · (U_kᵀ d)` for a
97/// shared random projection `R`; a raw decoder block `B_k` yields them by first
98/// orthonormalizing its columns. Both are valid implementors.
99pub trait AtomFrameSketch {
100    /// Dimension of the sketch vectors this implementor produces. Must be the
101    /// same positive value for every atom in one dictionary so the index can
102    /// build a single hyperplane bank.
103    fn sketch_dim(&self) -> usize;
104
105    /// Dimension of the ambient output space the query directions live in.
106    fn output_dim(&self) -> usize;
107
108    /// Number of atoms this source can sketch.
109    fn num_atoms(&self) -> usize;
110
111    /// Sketch of atom `atom_id`'s *frame itself* (a representative point of the
112    /// atom's column-space on the sphere of sketch space), used to place the
113    /// atom into the LSH tables at build time. Returns a vector of length
114    /// [`AtomFrameSketch::sketch_dim`].
115    fn atom_sketch(&self, atom_id: usize) -> Array1<f64>;
116
117    /// Sketch of a query *direction* `d` (length [`AtomFrameSketch::output_dim`])
118    /// as seen through atom `atom_id`'s frame: the direction's component inside
119    /// the atom's column-space, mapped into sketch coordinates. Used at query
120    /// time to score how strongly a row residual aligns with the atom.
121    fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64>;
122
123    /// Alignment score in `[0, 1]`: the fraction of the query direction's energy
124    /// that lies inside atom `atom_id`'s column-space. `1.0` means the direction
125    /// lies fully in the atom's range, `0.0` means it is orthogonal. Used to
126    /// rank the (small) candidate set the index returns.
127    fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64;
128
129    /// Sketch-space **probe** for a raw query direction (length
130    /// [`AtomFrameSketch::sketch_dim`]), comparable to the
131    /// [`AtomFrameSketch::atom_sketch`] representatives the LSH tables were
132    /// built from (#994).
133    ///
134    /// Implementors must return the exact cosine-LSH probe for their sketching
135    /// policy. For the shared-projection sketch this is `normalize(R · d)`,
136    /// `O(p · s)` per query, touching no atom.
137    fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64>;
138}
139
140// ---------------------------------------------------------------------------
141// Default concrete sketch: seeded random projection of the column span
142// ---------------------------------------------------------------------------
143
144/// A concrete [`AtomFrameSketch`] built from raw decoder column blocks `B_k`.
145///
146/// For each atom it orthonormalizes the decoder columns (modified Gram–Schmidt)
147/// to obtain a frame `U_k` with orthonormal columns spanning the decoder range,
148/// then sketches via a single shared seeded Gaussian random projection
149/// `R ∈ ℝ^{s×p}` applied to the in-range component of a direction:
150///
151/// * `atom_sketch(k)   = normalize( R · u_k0 )`, the sketch of the atom's first
152///   (dominant) frame column — a stable representative point used to bucket the
153///   atom.
154/// * `project_direction(k, d) = R · (U_k U_kᵀ d)`, the sketch of the part of `d`
155///   that lies in the atom's range.
156/// * `alignment(k, d) = ‖U_kᵀ d‖ / ‖d‖`, the exact in-range energy fraction.
157///
158/// The shared `R` is a Johnson–Lindenstrauss style random projection, so sketch
159/// inner products approximately preserve angles between in-range directions —
160/// exactly what the LSH index needs. Everything is seeded; the same atoms +
161/// seed always produce the same sketches.
162pub struct RandomProjectionFrameSketch {
163    /// Orthonormal frame `U_k` per atom, shape `(p, r_k)` with `r_k` ≤ columns.
164    frames: Vec<Array2<f64>>,
165    /// Shared random projection `R`, shape `(sketch_dim, p)`.
166    projection: Array2<f64>,
167    /// Ambient output dimension `p`.
168    output_dim: usize,
169    /// Sketch dimension `s`.
170    sketch_dim: usize,
171}
172
173impl RandomProjectionFrameSketch {
174    /// Build the sketch from decoder column blocks.
175    ///
176    /// `decoder_blocks[k]` is `B_k` with shape `(p, m_k)`: `p` rows in output
177    /// space, `m_k` decoder columns for atom `k`. (`SaeManifoldAtom` stores the
178    /// transpose `(m_k, p)`; orient it `p`-rows before passing in.) All blocks
179    /// must share the same `p`. `sketch_dim` is the target sketch length `s`;
180    /// `seed` makes the projection deterministic.
181    pub fn from_decoder_blocks(
182        decoder_blocks: &[Array2<f64>],
183        sketch_dim: usize,
184        seed: u64,
185    ) -> Result<Self, String> {
186        if decoder_blocks.is_empty() {
187            return Err("RandomProjectionFrameSketch: need at least one decoder block".into());
188        }
189        if sketch_dim == 0 {
190            return Err("RandomProjectionFrameSketch: sketch_dim must be positive".into());
191        }
192        let output_dim = decoder_blocks[0].nrows();
193        if output_dim == 0 {
194            return Err("RandomProjectionFrameSketch: output dimension must be positive".into());
195        }
196        for (k, block) in decoder_blocks.iter().enumerate() {
197            if block.nrows() != output_dim {
198                return Err(format!(
199                    "RandomProjectionFrameSketch: atom {k} has {} output rows, expected {output_dim}",
200                    block.nrows()
201                ));
202            }
203        }
204
205        let frames: Vec<Array2<f64>> = decoder_blocks.iter().map(orthonormal_frame).collect();
206
207        let projection = gaussian_projection(sketch_dim, output_dim, seed ^ SKETCH_PROJECTION_SALT);
208
209        Ok(Self {
210            frames,
211            projection,
212            output_dim,
213            sketch_dim,
214        })
215    }
216
217    /// In-range component `U_k U_kᵀ d` of a direction (length `output_dim`).
218    fn in_range_component(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
219        let frame = &self.frames[atom_id];
220        // coords = U_kᵀ d  (length r_k)
221        let mut comp = Array1::<f64>::zeros(self.output_dim);
222        for col in 0..frame.ncols() {
223            let u = frame.column(col);
224            let coord: f64 = u.iter().zip(direction.iter()).map(|(&a, &b)| a * b).sum();
225            for (c, &uval) in comp.iter_mut().zip(u.iter()) {
226                *c += coord * uval;
227            }
228        }
229        comp
230    }
231}
232
233impl AtomFrameSketch for RandomProjectionFrameSketch {
234    fn sketch_dim(&self) -> usize {
235        self.sketch_dim
236    }
237
238    fn output_dim(&self) -> usize {
239        self.output_dim
240    }
241
242    fn num_atoms(&self) -> usize {
243        self.frames.len()
244    }
245
246    fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
247        let frame = &self.frames[atom_id];
248        // Sketch the dominant (first) frame column as the atom's representative.
249        // If the frame is empty (rank-0 atom), fall back to a deterministic
250        // nonzero point so the atom is still bucketed somewhere.
251        if frame.ncols() == 0 {
252            let mut s = self.projection.column(0).to_owned();
253            normalize_in_place(&mut s);
254            return s;
255        }
256        let u0 = frame.column(0);
257        let mut s = mat_vec(&self.projection, u0);
258        normalize_in_place(&mut s);
259        s
260    }
261
262    fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
263        let comp = self.in_range_component(atom_id, direction);
264        mat_vec(&self.projection, comp.view())
265    }
266
267    /// Exact `O(p·s)` probe (#994): every atom shares the one projection `R`,
268    /// and the table representatives are `normalize(R · u_k0)`, so the correct
269    /// cosine-LSH probe for a direction is simply `normalize(R · d)` — no atom
270    /// is touched, and no masked-average approximation is involved.
271    fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
272        let mut s = mat_vec(&self.projection, direction);
273        normalize_in_place(&mut s);
274        s
275    }
276
277    fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
278        let dnorm = vec_norm(direction);
279        if dnorm < DIRECTION_NORM_FLOOR {
280            return 0.0;
281        }
282        let comp = self.in_range_component(atom_id, direction);
283        (vec_norm(comp.view()) / dnorm).clamp(0.0, 1.0)
284    }
285}
286
287// ---------------------------------------------------------------------------
288// Sublinear index: multi-table random-hyperplane LSH over sketches
289// ---------------------------------------------------------------------------
290
291/// A deterministic, sublinear candidate index over atom-frame sketches.
292///
293/// The structure is a **random-hyperplane LSH** with `num_tables` independent
294/// tables, each defined by `bits_per_table` seeded random hyperplanes in sketch
295/// space. An atom's sketch is reduced to a `bits_per_table`-bit sign signature
296/// per table (the sign of its dot with each hyperplane), and the atom id is
297/// stored in the bucket keyed by that signature. At query time the query
298/// direction is sketched *through each atom's frame*; we instead hash the *query
299/// sketch* per table and gather the union of atoms in the matching (and, to
300/// improve recall, the Hamming-1 neighbouring) buckets. Because each table
301/// touches only the atoms colliding in one bucket, total work is sublinear in
302/// `K` for well-spread sketches.
303///
304/// The gathered candidates are then ranked by exact
305/// [`AtomFrameSketch::alignment`] and the top `candidate_budget` are returned.
306/// All hyperplanes are seeded; building twice with the same seed yields byte-
307/// identical tables.
308pub struct SaeCandidateIndex {
309    /// Hyperplane banks, one per table: each `(bits_per_table, sketch_dim)`.
310    hyperplanes: Vec<Array2<f64>>,
311    /// Buckets per table: signature -> atom ids.
312    tables: Vec<HashMap<u64, Vec<usize>>>,
313    /// Sketch dimension shared by every atom.
314    sketch_dim: usize,
315    /// Number of atoms indexed.
316    num_atoms: usize,
317}
318
319/// Tuning for [`SaeCandidateIndex::build`]. All fields are explicit so the index
320/// never reads global state; no CLI flags.
321#[derive(Clone, Copy, Debug)]
322pub struct IndexConfig {
323    /// Number of independent LSH tables. More tables → higher recall, more work.
324    pub num_tables: usize,
325    /// Random hyperplanes per table (signature bit-width). More bits → finer
326    /// buckets (fewer collisions, lower recall per table).
327    pub bits_per_table: usize,
328    /// Whether to also probe Hamming-distance-1 neighbouring buckets per table
329    /// (multi-probe LSH). Cheap and a large recall win; kept on by default.
330    pub multiprobe: bool,
331    /// Master seed for all hyperplane banks.
332    pub seed: u64,
333}
334
335impl IndexConfig {
336    /// A default configuration sized for a sketch of dimension `sketch_dim` and
337    /// roughly `num_atoms` atoms. Chooses `bits_per_table ≈ log2(num_atoms)`
338    /// (capped by the sketch dimension) so the expected bucket occupancy is a
339    /// small constant, and a handful of tables for recall — both grow only
340    /// logarithmically in `num_atoms`, keeping queries sublinear.
341    pub fn auto(sketch_dim: usize, num_atoms: usize, seed: u64) -> Self {
342        let log2 = |n: usize| -> usize {
343            if n <= 1 {
344                1
345            } else {
346                (usize::BITS - (n - 1).leading_zeros()) as usize
347            }
348        };
349        // Cap at 63: sign_signature packs bits into a u64, so bits_per_table must be ≤ 63.
350        let bits = log2(num_atoms.max(2)).clamp(1, sketch_dim.max(1).min(63));
351        // Aim for ~constant per-bucket occupancy; a few tables recover recall
352        // lost to any single table's quantization.
353        let num_tables = log2(num_atoms.max(2)).clamp(4, 16);
354        Self {
355            num_tables,
356            bits_per_table: bits,
357            multiprobe: true,
358            seed,
359        }
360    }
361}
362
363impl SaeCandidateIndex {
364    /// Build the index over every atom of `sketch`.
365    pub fn build<S: AtomFrameSketch>(sketch: &S, config: IndexConfig) -> Result<Self, String> {
366        let sketch_dim = sketch.sketch_dim();
367        if sketch_dim == 0 {
368            return Err("SaeCandidateIndex: sketch_dim must be positive".into());
369        }
370        if config.num_tables == 0 || config.bits_per_table == 0 {
371            return Err("SaeCandidateIndex: num_tables and bits_per_table must be positive".into());
372        }
373        // sign_signature packs bits into a u64 with `1u64 << r` for r in 0..bits_per_table.
374        // Shifting by 64+ is a panic in debug and undefined behaviour in release; cap at 63.
375        if config.bits_per_table > 63 {
376            return Err(format!(
377                "SaeCandidateIndex: bits_per_table {} exceeds 63 (u64 signature limit)",
378                config.bits_per_table
379            ));
380        }
381        let num_atoms = sketch.num_atoms();
382
383        // One seeded hyperplane bank per table; seed is mixed per-table so the
384        // tables are independent yet fully reproducible.
385        let hyperplanes: Vec<Array2<f64>> = (0..config.num_tables)
386            .map(|t| {
387                let table_seed = mix_seed(config.seed ^ INDEX_HYPERPLANE_SALT, t as u64);
388                gaussian_projection(config.bits_per_table, sketch_dim, table_seed)
389            })
390            .collect();
391
392        let mut tables: Vec<HashMap<u64, Vec<usize>>> =
393            (0..config.num_tables).map(|_| HashMap::new()).collect();
394
395        for atom_id in 0..num_atoms {
396            let s = sketch.atom_sketch(atom_id);
397            if s.len() != sketch_dim {
398                return Err(format!(
399                    "SaeCandidateIndex: atom {atom_id} sketch length {} != sketch_dim {sketch_dim}",
400                    s.len()
401                ));
402            }
403            for (table, bank) in tables.iter_mut().zip(hyperplanes.iter()) {
404                let sig = sign_signature(bank, s.view());
405                table.entry(sig).or_default().push(atom_id);
406            }
407        }
408
409        Ok(Self {
410            hyperplanes,
411            tables,
412            sketch_dim,
413            num_atoms,
414        })
415    }
416
417    /// Number of atoms in the index.
418    pub fn num_atoms(&self) -> usize {
419        self.num_atoms
420    }
421
422    /// Gather the raw candidate atom-id set for a query `direction`, *without*
423    /// ranking or budget truncation. This is the sublinear part: it sketches the
424    /// query once per table (using a frame-agnostic global query sketch — the
425    /// query direction projected by the index's own representative projection)
426    /// and unions the colliding buckets (plus Hamming-1 neighbours when
427    /// multi-probe is enabled).
428    ///
429    /// `query_sketch` is the sketch-space query vector (length `sketch_dim`),
430    /// produced by the caller from the row residual via the
431    /// [`AtomFrameSketch`]. We probe each table with this single vector.
432    pub fn gather_candidates(&self, query_sketch: ArrayView1<f64>, multiprobe: bool) -> Vec<usize> {
433        let mut seen: HashSet<usize> = HashSet::new();
434        for (table, bank) in self.tables.iter().zip(self.hyperplanes.iter()) {
435            let (sig, margins) = sign_signature_with_margins(bank, query_sketch);
436            if let Some(ids) = table.get(&sig) {
437                seen.extend(ids.iter().copied());
438            }
439            if multiprobe {
440                // Flip the lowest-margin bit (the one most likely to be on the
441                // wrong side of its hyperplane) to reach the nearest neighbour
442                // bucket — standard multi-probe LSH, biggest recall win.
443                let flip_bit = lowest_margin_bit(&margins);
444                let neighbour = sig ^ (1u64 << flip_bit);
445                if let Some(ids) = table.get(&neighbour) {
446                    seen.extend(ids.iter().copied());
447                }
448            }
449        }
450        let mut out: Vec<usize> = seen.into_iter().collect();
451        out.sort_unstable();
452        out
453    }
454
455    /// Propose the top `candidate_budget` atoms for a row whose residual is
456    /// `direction` (length `sketch.output_dim()`), ranked by exact frame
457    /// alignment.
458    ///
459    /// Pipeline: probe with [`AtomFrameSketch::query_sketch`] (`O(p·s)` for
460    /// shared-projection sketches, #994 — no atom is touched before the
461    /// gather), gather the sublinear candidate union, score each by
462    /// [`AtomFrameSketch::alignment`], and keep the highest-scoring
463    /// `candidate_budget`.
464    ///
465    /// Returns `(proposed_ids, dropped_for_budget)` where the second element
466    /// lists every gathered candidate that was truncated by the budget (never
467    /// silently discarded).
468    pub fn propose<S: AtomFrameSketch>(
469        &self,
470        sketch: &S,
471        direction: ArrayView1<f64>,
472        candidate_budget: usize,
473        config_multiprobe: bool,
474    ) -> Proposal {
475        let query_sketch = sketch.query_sketch(direction);
476        let gathered = if query_sketch.len() == self.sketch_dim {
477            self.gather_candidates(query_sketch.view(), config_multiprobe)
478        } else {
479            // A probe of the wrong dimension cannot be hashed against the
480            // tables; gather nothing rather than hash garbage. The recall
481            // report will then attribute every planted atom to `NotGathered`,
482            // which is the loud, attributable failure mode.
483            Vec::new()
484        };
485
486        // Exact-score every gathered candidate by frame alignment.
487        let mut scored: Vec<(usize, f64)> = gathered
488            .iter()
489            .map(|&id| (id, sketch.alignment(id, direction)))
490            .collect();
491        // Descending by alignment; ties broken by id for determinism.
492        scored.sort_by(|a, b| {
493            b.1.partial_cmp(&a.1)
494                .unwrap_or(std::cmp::Ordering::Equal)
495                .then(a.0.cmp(&b.0))
496        });
497
498        let keep = candidate_budget.min(scored.len());
499        let proposed: Vec<usize> = scored[..keep].iter().map(|&(id, _)| id).collect();
500        let dropped_for_budget: Vec<usize> = scored[keep..].iter().map(|&(id, _)| id).collect();
501
502        Proposal {
503            proposed,
504            dropped_for_budget,
505            gathered_count: gathered.len(),
506        }
507    }
508
509    /// Recall contract. For a set of rows, each with planted truly-active atom
510    /// ids and a residual direction, run [`SaeCandidateIndex::propose`] at the
511    /// given `candidate_budget` and record what fraction of planted atoms
512    /// appear in the proposed set. Every miss is logged — no silent truncation.
513    ///
514    /// `rows` is `(direction, planted_active_ids)` per row.
515    pub fn recall_report<S: AtomFrameSketch>(
516        &self,
517        sketch: &S,
518        rows: &[(Array1<f64>, Vec<usize>)],
519        candidate_budget: usize,
520        multiprobe: bool,
521    ) -> RecallReport {
522        let mut total_planted: usize = 0;
523        let mut total_recovered: usize = 0;
524        let mut misses: Vec<RecallMiss> = Vec::new();
525        let mut total_gathered: usize = 0;
526
527        for (row_idx, (direction, planted)) in rows.iter().enumerate() {
528            let proposal = self.propose(sketch, direction.view(), candidate_budget, multiprobe);
529            total_gathered += proposal.gathered_count;
530            let proposed_set: HashSet<usize> = proposal.proposed.iter().copied().collect();
531            // A candidate that was gathered but truncated by the budget counts
532            // as a miss *attributable to the budget*; one never gathered at all
533            // is a miss *attributable to the index*. We record both, flagged.
534            let dropped_set: HashSet<usize> = proposal.dropped_for_budget.iter().copied().collect();
535
536            for &atom in planted {
537                total_planted += 1;
538                if proposed_set.contains(&atom) {
539                    total_recovered += 1;
540                } else {
541                    let reason = if dropped_set.contains(&atom) {
542                        MissReason::TruncatedByBudget
543                    } else {
544                        MissReason::NotGathered
545                    };
546                    misses.push(RecallMiss {
547                        row: row_idx,
548                        atom,
549                        alignment: sketch.alignment(atom, direction.view()),
550                        reason,
551                    });
552                }
553            }
554        }
555
556        let recall = if total_planted == 0 {
557            1.0
558        } else {
559            total_recovered as f64 / total_planted as f64
560        };
561        let avg_gathered = if rows.is_empty() {
562            0.0
563        } else {
564            total_gathered as f64 / rows.len() as f64
565        };
566
567        RecallReport {
568            candidate_budget,
569            num_rows: rows.len(),
570            total_planted,
571            total_recovered,
572            recall,
573            avg_candidates_gathered: avg_gathered,
574            num_atoms: self.num_atoms,
575            misses,
576        }
577    }
578}
579
580/// One row's proposal: the budgeted candidate set plus what the budget dropped.
581#[derive(Clone, Debug)]
582pub struct Proposal {
583    /// The top `candidate_budget` atom ids by frame alignment.
584    pub proposed: Vec<usize>,
585    /// Gathered candidates truncated by the budget — logged, never silent.
586    pub dropped_for_budget: Vec<usize>,
587    /// How many candidates the sublinear gather returned before budgeting.
588    pub gathered_count: usize,
589}
590
591/// Why a planted atom failed to appear in a row's proposed candidate set.
592#[derive(Clone, Copy, Debug, PartialEq, Eq)]
593pub enum MissReason {
594    /// The index never gathered this atom into the candidate union (an LSH
595    /// recall miss — widen tables / probes).
596    NotGathered,
597    /// The atom *was* gathered but the budget truncated it (widen the budget).
598    TruncatedByBudget,
599}
600
601/// One recorded recall miss.
602#[derive(Clone, Copy, Debug)]
603pub struct RecallMiss {
604    /// Row index in the report's input.
605    pub row: usize,
606    /// The planted atom id that was missed.
607    pub atom: usize,
608    /// The atom's exact frame alignment with the row direction (diagnostic).
609    pub alignment: f64,
610    /// Whether the miss was an index miss or a budget truncation.
611    pub reason: MissReason,
612}
613
614/// Result of [`SaeCandidateIndex::recall_report`].
615#[derive(Clone, Debug)]
616pub struct RecallReport {
617    /// Candidate budget the recall was measured at.
618    pub candidate_budget: usize,
619    /// Number of rows evaluated.
620    pub num_rows: usize,
621    /// Total planted truly-active atoms across all rows.
622    pub total_planted: usize,
623    /// How many of them appeared in the proposed sets.
624    pub total_recovered: usize,
625    /// `recall@candidate_budget` = recovered / planted (1.0 if nothing planted).
626    pub recall: f64,
627    /// Mean number of candidates the sublinear gather returned per row — the
628    /// sublinearity witness; compare against `num_atoms`.
629    pub avg_candidates_gathered: f64,
630    /// Total atoms in the index (for the sublinearity ratio).
631    pub num_atoms: usize,
632    /// Every miss, with its row, atom, alignment, and reason. No silent drops.
633    pub misses: Vec<RecallMiss>,
634}
635
636impl RecallReport {
637    /// Convenience: ratio of mean gathered candidates to dictionary size. A
638    /// value far below `1.0` is the evidence that proposal touched a sublinear
639    /// slice of the dictionary.
640    pub fn sublinearity_ratio(&self) -> f64 {
641        if self.num_atoms == 0 {
642            0.0
643        } else {
644            self.avg_candidates_gathered / self.num_atoms as f64
645        }
646    }
647}
648
649// ---------------------------------------------------------------------------
650// Helpers (deterministic, dependency-light)
651// ---------------------------------------------------------------------------
652
653/// Mix a base seed with an index into a well-spread `u64` (SplitMix64 finalizer
654/// on the sum). Deterministic, no clock.
655#[inline]
656fn mix_seed(base: u64, idx: u64) -> u64 {
657    // Finalize `base + idx·G` with the canonical SplitMix64 step. The stateful
658    // form adds G internally, so pre-subtract one G to land on the same input
659    // and keep the output bit-identical to the previous inlined finalizer.
660    let mut state = base
661        .wrapping_add(idx.wrapping_mul(0x9E37_79B9_7F4A_7C15))
662        .wrapping_sub(0x9E37_79B9_7F4A_7C15);
663    gam_linalg::utils::splitmix64(&mut state)
664}
665
666/// A seeded Gaussian random matrix of shape `(rows, cols)` (rows of hyperplanes
667/// / projection rows). Uses Box–Muller off a seeded `StdRng`.
668fn gaussian_projection(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
669    use rand::RngExt as _;
670    let mut rng = StdRng::seed_from_u64(seed);
671    let mut m = Array2::<f64>::zeros((rows, cols));
672    for r in 0..rows {
673        for c in 0..cols {
674            let u1 = rng.random::<f64>().max(1e-16);
675            let u2 = rng.random::<f64>();
676            m[(r, c)] = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
677        }
678    }
679    m
680}
681
682/// Modified Gram–Schmidt orthonormalization of a decoder block's columns.
683/// Input `block` is `(p, m)`; output `U` is `(p, r)` with orthonormal columns
684/// spanning `range(block)`, `r ≤ m` (rank-deficient columns are dropped).
685fn orthonormal_frame(block: &Array2<f64>) -> Array2<f64> {
686    let p = block.nrows();
687    let m = block.ncols();
688    let mut cols: Vec<Array1<f64>> = Vec::with_capacity(m);
689    for j in 0..m {
690        let mut v = block.column(j).to_owned();
691        for q in &cols {
692            let proj: f64 = q.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
693            for (vi, &qi) in v.iter_mut().zip(q.iter()) {
694                *vi -= proj * qi;
695            }
696        }
697        let nrm = vec_norm(v.view());
698        if nrm > DIRECTION_NORM_FLOOR {
699            for vi in v.iter_mut() {
700                *vi /= nrm;
701            }
702            cols.push(v);
703        }
704    }
705    let r = cols.len();
706    let mut u = Array2::<f64>::zeros((p, r));
707    for (j, col) in cols.into_iter().enumerate() {
708        u.column_mut(j).assign(&col);
709    }
710    u
711}
712
713/// `M · v` for `M` shape `(rows, cols)`, `v` length `cols`.
714fn mat_vec(m: &Array2<f64>, v: ArrayView1<f64>) -> Array1<f64> {
715    let mut out = Array1::<f64>::zeros(m.nrows());
716    for r in 0..m.nrows() {
717        let row = m.row(r);
718        out[r] = row.iter().zip(v.iter()).map(|(&a, &b)| a * b).sum();
719    }
720    out
721}
722
723#[inline]
724fn vec_norm(v: ArrayView1<f64>) -> f64 {
725    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
726}
727
728#[inline]
729fn normalize_in_place(v: &mut Array1<f64>) {
730    let n = vec_norm(v.view());
731    if n > DIRECTION_NORM_FLOOR {
732        for x in v.iter_mut() {
733            *x /= n;
734        }
735    }
736}
737
738/// Pack the sign bits of `bank · s` into a `u64` signature. `bank` is
739/// `(bits, sketch_dim)`; `bits ≤ 64` (enforced by config-derived bit widths).
740fn sign_signature(bank: &Array2<f64>, s: ArrayView1<f64>) -> u64 {
741    let mut sig = 0u64;
742    for r in 0..bank.nrows() {
743        let row = bank.row(r);
744        let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
745        if dot >= 0.0 {
746            sig |= 1u64 << r;
747        }
748    }
749    sig
750}
751
752/// Signature plus per-bit signed margins (the dot products), used by multi-probe
753/// to find the least-confident bit to flip.
754fn sign_signature_with_margins(bank: &Array2<f64>, s: ArrayView1<f64>) -> (u64, Vec<f64>) {
755    let mut sig = 0u64;
756    let mut margins = Vec::with_capacity(bank.nrows());
757    for r in 0..bank.nrows() {
758        let row = bank.row(r);
759        let dot: f64 = row.iter().zip(s.iter()).map(|(&a, &b)| a * b).sum();
760        if dot >= 0.0 {
761            sig |= 1u64 << r;
762        }
763        margins.push(dot);
764    }
765    (sig, margins)
766}
767
768/// Index of the bit whose hyperplane the query sits closest to (smallest `|dot|`)
769/// — the most likely to have landed in the wrong bucket.
770fn lowest_margin_bit(margins: &[f64]) -> usize {
771    let mut best = 0usize;
772    let mut best_abs = f64::INFINITY;
773    for (i, &m) in margins.iter().enumerate() {
774        let a = m.abs();
775        if a < best_abs {
776            best_abs = a;
777            best = i;
778        }
779    }
780    best
781}
782
783// ---------------------------------------------------------------------------
784// Tests
785// ---------------------------------------------------------------------------
786
787#[cfg(test)]
788mod tests {
789    use super::*;
790    use rand::RngExt as _;
791    use rand::rngs::StdRng;
792
793    /// Draw a unit vector in `p` dims from a seeded RNG.
794    fn unit_vec(rng: &mut StdRng, p: usize) -> Array1<f64> {
795        let mut v = Array1::<f64>::zeros(p);
796        for x in v.iter_mut() {
797            let u1 = rng.random::<f64>().max(1e-16);
798            let u2 = rng.random::<f64>();
799            *x = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
800        }
801        let n = vec_norm(v.view());
802        if n > DIRECTION_NORM_FLOOR {
803            for x in v.iter_mut() {
804                *x /= n;
805            }
806        }
807        v
808    }
809
810    /// Build a synthetic dictionary of `k` rank-1 atoms: atom `i`'s decoder
811    /// block is the outer-friendly single column `c_i` (a random unit direction
812    /// in output space). Returns the blocks and the list of column directions so
813    /// the planted-atom test can construct directions that lie in chosen atoms.
814    fn synthetic_dictionary(k: usize, p: usize, seed: u64) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
815        let mut rng = StdRng::seed_from_u64(seed);
816        let mut blocks = Vec::with_capacity(k);
817        let mut dirs = Vec::with_capacity(k);
818        for _ in 0..k {
819            let c = unit_vec(&mut rng, p);
820            let mut block = Array2::<f64>::zeros((p, 1));
821            block.column_mut(0).assign(&c);
822            blocks.push(block);
823            dirs.push(c);
824        }
825        (blocks, dirs)
826    }
827
828    #[test]
829    fn frame_alignment_is_exact_for_in_range_direction() {
830        let (blocks, dirs) = synthetic_dictionary(8, 16, 11);
831        let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 7).unwrap();
832        // A direction equal to atom 3's column lies fully in its range.
833        let d = &dirs[3];
834        let a = sketch.alignment(3, d.view());
835        assert!(a > 0.999, "in-range alignment should be ~1, got {a}");
836        // An orthogonal-ish direction (atom 5's column is generically nearly
837        // orthogonal to atom 3) aligns weakly with atom 3.
838        let a_off = sketch.alignment(3, dirs[5].view());
839        assert!(
840            a_off < a,
841            "off-atom alignment {a_off} should be below in-range {a}"
842        );
843    }
844
845    /// Regression for the routing-confidence gate (#1026): a low-alignment LSH
846    /// route must never be silently certified. `certified_encode_with_index`
847    /// (and the amortized twin) flag a routed row UNCERTIFIED whenever the
848    /// best-aligned proposed atom's frame alignment is below
849    /// `encode::CANDIDATE_ROUTING_MIN_ALIGNMENT`. The gate's decision input is
850    /// exactly `sketch.alignment(best_atom, target)`; pin it here — with exact,
851    /// deterministic linear algebra rather than LSH gather luck — so a future
852    /// change to the frame-alignment formula cannot silently shift the
853    /// threshold's meaning out from under the gate.
854    ///
855    /// Two atoms whose decoder frames span ORTHOGONAL subspaces of a 6-dim
856    /// ambient (atom 0: `span(e0,e1)`, atom 1: `span(e2,e3)`; dims `e4,e5`
857    /// covered by neither). A direction wholly inside atom 1's subspace has
858    /// alignment exactly 1 with atom 1 (in-frame, ABOVE the gate) and exactly 0
859    /// with atom 0 (off-frame, BELOW the gate) — the mis-route the gate exists
860    /// to flag.
861    #[test]
862    fn routing_confidence_gate_input_separates_off_frame_from_in_frame() {
863        use crate::encode::CANDIDATE_ROUTING_MIN_ALIGNMENT as GATE;
864
865        let p = 6usize;
866        let mut block_a = Array2::<f64>::zeros((p, 2));
867        block_a[[0, 0]] = 1.0; // e0
868        block_a[[1, 1]] = 1.0; // e1
869        let mut block_b = Array2::<f64>::zeros((p, 2));
870        block_b[[2, 0]] = 1.0; // e2
871        block_b[[3, 1]] = 1.0; // e3
872        let sketch =
873            RandomProjectionFrameSketch::from_decoder_blocks(&[block_a, block_b], 16, 4242)
874                .unwrap();
875
876        // A unit direction wholly inside atom 1's (e2,e3) subspace.
877        let mut in_frame_b = Array1::<f64>::zeros(p);
878        in_frame_b[2] = 0.6;
879        in_frame_b[3] = 0.8; // unit norm (0.6² + 0.8² = 1)
880        let a_right = sketch.alignment(1, in_frame_b.view());
881        let a_wrong = sketch.alignment(0, in_frame_b.view());
882
883        assert!(
884            a_right > 0.999,
885            "an in-frame direction must align ~1 with its own atom; got {a_right}"
886        );
887        assert!(
888            a_wrong < 1e-9,
889            "an orthogonal-subspace direction must align ~0 with the wrong atom; got {a_wrong}"
890        );
891        // The exact predicate the encode gate evaluates per row: a mis-routed
892        // (orthogonal) atom falls BELOW the gate → flagged for the exact
893        // fallback; the correctly-routed atom sits AT/ABOVE the gate → trusted.
894        assert!(
895            a_wrong < GATE,
896            "a mis-routed (orthogonal) atom must fall below the routing gate {GATE}; got {a_wrong}"
897        );
898        assert!(
899            a_right >= GATE,
900            "the correctly-routed atom must sit at/above the routing gate {GATE}; got {a_right}"
901        );
902
903        // A direction in the UNCOVERED (e4,e5) subspace aligns ~0 with BOTH
904        // atoms, so whichever atom the LSH surfaces, the gate fires: no atom can
905        // certify this route. This is the worst-case the gate exists to catch.
906        let mut uncovered = Array1::<f64>::zeros(p);
907        uncovered[4] = 1.0;
908        for atom in 0..2 {
909            let a = sketch.alignment(atom, uncovered.view());
910            assert!(
911                a < GATE,
912                "an uncovered-subspace direction must fall below the gate for atom {atom}; got {a}"
913            );
914        }
915    }
916
917    #[test]
918    fn build_is_deterministic_for_a_fixed_seed() {
919        let (blocks, _) = synthetic_dictionary(64, 24, 99);
920        let s1 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
921        let s2 = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 5).unwrap();
922        // Same seed → identical representative sketches.
923        for i in 0..blocks.len() {
924            let a = s1.atom_sketch(i);
925            let b = s2.atom_sketch(i);
926            let diff = vec_norm((&a - &b).view());
927            assert!(
928                diff < 1e-12,
929                "atom {i} sketch differs across builds: {diff:e}"
930            );
931        }
932        let cfg = IndexConfig::auto(16, blocks.len(), 5);
933        let idx1 = SaeCandidateIndex::build(&s1, cfg).unwrap();
934        let idx2 = SaeCandidateIndex::build(&s2, cfg).unwrap();
935        // Identical hyperplane banks and bucket contents.
936        for t in 0..idx1.tables.len() {
937            assert_eq!(idx1.tables[t].len(), idx2.tables[t].len());
938        }
939    }
940
941    #[test]
942    fn planted_atoms_are_recalled_above_floor_at_sublinear_budget() {
943        // A frontier-ish dictionary: many atoms, modest output dim.
944        let k = 2000usize;
945        let p = 48usize;
946        let (blocks, dirs) = synthetic_dictionary(k, p, 2026);
947        let sketch_dim = 24usize;
948        let sketch =
949            RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 4242).unwrap();
950        let cfg = IndexConfig::auto(sketch_dim, k, 4242);
951        let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
952
953        // Plant: each row's residual is dominated by one chosen atom's column
954        // (plus a little cross-talk from a second). The planted-active set is
955        // that dominant atom. We build many such rows deterministically.
956        let mut rng = StdRng::seed_from_u64(31337);
957        let n_rows = 200usize;
958        let mut rows: Vec<(Array1<f64>, Vec<usize>)> = Vec::with_capacity(n_rows);
959        for _ in 0..n_rows {
960            let primary = rng.random_range(0..k);
961            let secondary = rng.random_range(0..k);
962            // direction = 1.0 * c_primary + 0.15 * c_secondary
963            let mut d = dirs[primary].clone();
964            for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
965                *di += 0.15 * si;
966            }
967            let n = vec_norm(d.view());
968            for di in d.iter_mut() {
969                *di /= n;
970            }
971            rows.push((d, vec![primary]));
972        }
973
974        // Sublinear candidate budget: << K. We allow the gather to surface a
975        // handful, but the *budget* (the per-row local block size) stays small.
976        let candidate_budget = 32usize;
977        let report = index.recall_report(&sketch, &rows, candidate_budget, cfg.multiprobe);
978
979        // The gather must touch only a sublinear slice of the dictionary.
980        assert!(
981            report.sublinearity_ratio() < 0.5,
982            "gather was not sublinear: avg {} of {} atoms (ratio {:.3})",
983            report.avg_candidates_gathered,
984            report.num_atoms,
985            report.sublinearity_ratio()
986        );
987
988        // Recall floor: the LSH index must recover the planted dominant atom for
989        // the large majority of rows at this sublinear budget. Misses are
990        // logged, never silently dropped.
991        let floor = 0.80;
992        assert!(
993            report.recall >= floor,
994            "recall {:.3} below floor {floor}; {} misses logged (first few: {:?})",
995            report.recall,
996            report.misses.len(),
997            report
998                .misses
999                .iter()
1000                .take(5)
1001                .map(|m| (m.row, m.atom, m.reason, m.alignment))
1002                .collect::<Vec<_>>()
1003        );
1004
1005        // Every miss is accounted for with a reason — the no-silent-truncation
1006        // contract.
1007        let recovered = report.total_recovered;
1008        assert_eq!(
1009            report.total_planted - recovered,
1010            report.misses.len(),
1011            "miss list must account for every unrecovered planted atom"
1012        );
1013    }
1014
1015    #[test]
1016    fn auto_candidate_budget_tracks_the_issue_band() {
1017        assert_eq!(auto_candidate_budget(2), CANDIDATE_BUDGET_MIN);
1018        assert_eq!(auto_candidate_budget(64), 48);
1019        assert_eq!(auto_candidate_budget(1024), 80);
1020        assert_eq!(auto_candidate_budget(100_000), CANDIDATE_BUDGET_MAX);
1021        // Monotone non-decreasing in K and always inside the band.
1022        let mut prev = 0usize;
1023        for k in [2usize, 16, 64, 256, 1024, 4096, 65_536, 1_000_000] {
1024            let c = auto_candidate_budget(k);
1025            assert!(c >= prev, "budget must be monotone in K");
1026            assert!((CANDIDATE_BUDGET_MIN..=CANDIDATE_BUDGET_MAX).contains(&c));
1027            prev = c;
1028        }
1029    }
1030
1031    /// Build a planted row set for a dictionary: each row's residual direction
1032    /// is dominated by one chosen atom (plus cross-talk from a second), and
1033    /// the planted-active set is the dominant atom.
1034    fn planted_rows(
1035        dirs: &[Array1<f64>],
1036        n_rows: usize,
1037        seed: u64,
1038    ) -> Vec<(Array1<f64>, Vec<usize>)> {
1039        let k = dirs.len();
1040        let mut rng = StdRng::seed_from_u64(seed);
1041        let mut rows = Vec::with_capacity(n_rows);
1042        for _ in 0..n_rows {
1043            let primary = rng.random_range(0..k);
1044            let secondary = rng.random_range(0..k);
1045            let mut d = dirs[primary].clone();
1046            for (di, &si) in d.iter_mut().zip(dirs[secondary].iter()) {
1047                *di += 0.15 * si;
1048            }
1049            let n = vec_norm(d.view());
1050            for di in d.iter_mut() {
1051                *di /= n;
1052            }
1053            rows.push((d, vec![primary]));
1054        }
1055        rows
1056    }
1057
1058    #[test]
1059    fn k_ladder_recall_determinism_and_sublinearity() {
1060        // #985 part 2 (index tier): the K=2-era assumptions say nothing about
1061        // frontier K, so gate the proposal machinery on a planted ladder at
1062        // K = 64 and K = 1024 with the SAME battery per rung — recall above a
1063        // stated floor at the auto-derived budget, every miss accounted for,
1064        // and byte-identical proposals across two independent builds. The
1065        // gather must also become *relatively* sparser as K grows (the
1066        // sublinearity witness): what is allowed to touch half the dictionary
1067        // at K = 64 must not at K = 1024.
1068        let p = 48usize;
1069        let n_rows = 150usize;
1070        let mut ladder_ratios = Vec::new();
1071        for &k in &[64usize, 1024] {
1072            let (blocks, dirs) = synthetic_dictionary(k, p, 9000 + k as u64);
1073            let sketch_dim = 24usize;
1074            let sketch_seed = 71 + k as u64;
1075            let sketch =
1076                RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
1077                    .unwrap();
1078            let cfg = IndexConfig::auto(sketch_dim, k, sketch_seed);
1079            let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1080
1081            let rows = planted_rows(&dirs, n_rows, 555 + k as u64);
1082            let budget = auto_candidate_budget(k);
1083            let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
1084
1085            // Recall floor at the auto-derived budget, with every miss carrying
1086            // a reason — the no-silent-truncation contract, per rung.
1087            let floor = 0.80;
1088            assert!(
1089                report.recall >= floor,
1090                "K={k}: recall {:.3} below floor {floor}; {} misses (first: {:?})",
1091                report.recall,
1092                report.misses.len(),
1093                report
1094                    .misses
1095                    .iter()
1096                    .take(3)
1097                    .map(|m| (m.row, m.atom, m.reason, m.alignment))
1098                    .collect::<Vec<_>>()
1099            );
1100            assert_eq!(
1101                report.total_planted - report.total_recovered,
1102                report.misses.len(),
1103                "K={k}: miss list must account for every unrecovered planted atom"
1104            );
1105
1106            // Search determinism: an independent rebuild from the same inputs
1107            // proposes the identical candidate set for every probed row.
1108            let sketch2 =
1109                RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, sketch_seed)
1110                    .unwrap();
1111            let index2 = SaeCandidateIndex::build(&sketch2, cfg).unwrap();
1112            for (direction, _) in rows.iter().take(20) {
1113                let a = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
1114                let b = index2.propose(&sketch2, direction.view(), budget, cfg.multiprobe);
1115                assert_eq!(
1116                    a.proposed, b.proposed,
1117                    "K={k}: rebuild must propose identically"
1118                );
1119            }
1120
1121            // Proposal size is the budget, never the dictionary: the per-row
1122            // local block stays near the planted/active scale.
1123            for (direction, _) in rows.iter().take(20) {
1124                let prop = index.propose(&sketch, direction.view(), budget, cfg.multiprobe);
1125                assert!(prop.proposed.len() <= budget);
1126            }
1127
1128            ladder_ratios.push((k, report.sublinearity_ratio()));
1129        }
1130        // Relative sparsity must improve up the ladder: the gathered fraction
1131        // of the dictionary shrinks as K grows (sublinear gather), and at the
1132        // frontier-shaped rung it must be a small slice outright.
1133        let (_, ratio_small) = ladder_ratios[0];
1134        let (k_big, ratio_big) = ladder_ratios[1];
1135        assert!(
1136            ratio_big < ratio_small,
1137            "sublinearity must improve along the ladder: {ladder_ratios:?}"
1138        );
1139        assert!(
1140            ratio_big < 0.25,
1141            "K={k_big}: gather touched {:.1}% of the dictionary",
1142            ratio_big * 100.0
1143        );
1144    }
1145
1146    /// Counting wrapper: delegates everything, counts `project_direction`
1147    /// calls. The #994 acceptance gate: with the exact probe, building the
1148    /// query sketch touches NO atom, so a whole `propose` makes zero
1149    /// `project_direction` calls (scoring goes through `alignment`).
1150    struct CountingSketch<'a> {
1151        inner: &'a RandomProjectionFrameSketch,
1152        project_calls: std::cell::Cell<usize>,
1153    }
1154
1155    impl AtomFrameSketch for CountingSketch<'_> {
1156        fn sketch_dim(&self) -> usize {
1157            self.inner.sketch_dim()
1158        }
1159        fn output_dim(&self) -> usize {
1160            self.inner.output_dim()
1161        }
1162        fn num_atoms(&self) -> usize {
1163            self.inner.num_atoms()
1164        }
1165        fn atom_sketch(&self, atom_id: usize) -> Array1<f64> {
1166            self.inner.atom_sketch(atom_id)
1167        }
1168        fn project_direction(&self, atom_id: usize, direction: ArrayView1<f64>) -> Array1<f64> {
1169            self.project_calls.set(self.project_calls.get() + 1);
1170            self.inner.project_direction(atom_id, direction)
1171        }
1172        fn alignment(&self, atom_id: usize, direction: ArrayView1<f64>) -> f64 {
1173            self.inner.alignment(atom_id, direction)
1174        }
1175        fn query_sketch(&self, direction: ArrayView1<f64>) -> Array1<f64> {
1176            self.inner.query_sketch(direction)
1177        }
1178    }
1179
1180    #[test]
1181    fn query_probe_touches_no_atom_before_the_gather() {
1182        let k = 512usize;
1183        let p = 32usize;
1184        let (blocks, dirs) = synthetic_dictionary(k, p, 77);
1185        let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 16, 13).unwrap();
1186        let cfg = IndexConfig::auto(16, k, 13);
1187        let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1188        let counting = CountingSketch {
1189            inner: &sketch,
1190            project_calls: std::cell::Cell::new(0),
1191        };
1192        drop(index.propose(&counting, dirs[5].view(), 32, cfg.multiprobe));
1193        assert_eq!(
1194            counting.project_calls.get(),
1195            0,
1196            "the exact query probe must be independent of K: no per-atom \
1197             projection before the gather (#994)"
1198        );
1199    }
1200
1201    /// Build a coherent-cluster dictionary: `n_clusters` random unit centers,
1202    /// each with `cluster_size` atoms drawn as small perturbations of the
1203    /// center (renormalized). Exactly the non-isotropic regime where the old
1204    /// masked-average probe degraded (#994).
1205    fn coherent_cluster_dictionary(
1206        n_clusters: usize,
1207        cluster_size: usize,
1208        p: usize,
1209        spread: f64,
1210        seed: u64,
1211    ) -> (Vec<Array2<f64>>, Vec<Array1<f64>>) {
1212        let mut rng = StdRng::seed_from_u64(seed);
1213        let mut blocks = Vec::with_capacity(n_clusters * cluster_size);
1214        let mut dirs = Vec::with_capacity(n_clusters * cluster_size);
1215        for _ in 0..n_clusters {
1216            let center = unit_vec(&mut rng, p);
1217            for _ in 0..cluster_size {
1218                let noise = unit_vec(&mut rng, p);
1219                let mut c = center.clone();
1220                for (ci, &ni) in c.iter_mut().zip(noise.iter()) {
1221                    *ci += spread * ni;
1222                }
1223                let n = vec_norm(c.view());
1224                for ci in c.iter_mut() {
1225                    *ci /= n;
1226                }
1227                let mut block = Array2::<f64>::zeros((p, 1));
1228                block.column_mut(0).assign(&c);
1229                blocks.push(block);
1230                dirs.push(c);
1231            }
1232        }
1233        (blocks, dirs)
1234    }
1235
1236    #[test]
1237    fn coherent_clusters_are_recalled_with_the_exact_probe() {
1238        // 32 clusters × 32 near-parallel atoms = 1024 atoms. Rows are
1239        // dominated by one specific cluster member; the proposal must recover
1240        // that exact member (not merely its cluster) at the auto budget —
1241        // exact alignment scoring separates siblings once the probe lands the
1242        // gather in the right bucket neighborhood.
1243        let n_clusters = 32usize;
1244        let cluster_size = 32usize;
1245        let k = n_clusters * cluster_size;
1246        let p = 48usize;
1247        let (blocks, dirs) = coherent_cluster_dictionary(n_clusters, cluster_size, p, 0.25, 4242);
1248        let sketch_dim = 24usize;
1249        let sketch =
1250            RandomProjectionFrameSketch::from_decoder_blocks(&blocks, sketch_dim, 99).unwrap();
1251        let cfg = IndexConfig::auto(sketch_dim, k, 99);
1252        let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1253
1254        let rows = planted_rows(&dirs, 150, 31337);
1255        let budget = auto_candidate_budget(k);
1256        let report = index.recall_report(&sketch, &rows, budget, cfg.multiprobe);
1257        let floor = 0.80;
1258        assert!(
1259            report.recall >= floor,
1260            "coherent-cluster recall {:.3} below floor {floor}; {} misses (first: {:?})",
1261            report.recall,
1262            report.misses.len(),
1263            report
1264                .misses
1265                .iter()
1266                .take(3)
1267                .map(|m| (m.row, m.atom, m.reason, m.alignment))
1268                .collect::<Vec<_>>()
1269        );
1270        // Still a sublinear slice of the dictionary, clusters or not.
1271        assert!(
1272            report.sublinearity_ratio() < 0.5,
1273            "cluster gather touched {:.1}% of the dictionary",
1274            report.sublinearity_ratio() * 100.0
1275        );
1276    }
1277
1278    #[test]
1279    fn exact_probe_matches_shared_projection_of_the_direction() {
1280        // The override is literally normalize(R·d): verify against a manual
1281        // computation through the public surface (atom_sketch of a rank-1 atom
1282        // whose only column IS the direction gives normalize(R·d) too).
1283        let p = 16usize;
1284        let mut rng = StdRng::seed_from_u64(5);
1285        let d = unit_vec(&mut rng, p);
1286        let mut block = Array2::<f64>::zeros((p, 1));
1287        block.column_mut(0).assign(&d);
1288        let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&[block], 8, 21).unwrap();
1289        let via_probe = sketch.query_sketch(d.view());
1290        let via_atom = sketch.atom_sketch(0);
1291        let diff = vec_norm((&via_probe - &via_atom).view());
1292        assert!(
1293            diff < 1e-10,
1294            "query_sketch(d) must equal the rank-1 atom representative of d: diff {diff:e}"
1295        );
1296    }
1297
1298    #[test]
1299    fn empty_planted_rows_report_perfect_recall() {
1300        let (blocks, dirs) = synthetic_dictionary(32, 16, 1);
1301        let sketch = RandomProjectionFrameSketch::from_decoder_blocks(&blocks, 12, 3).unwrap();
1302        let cfg = IndexConfig::auto(12, 32, 3);
1303        let index = SaeCandidateIndex::build(&sketch, cfg).unwrap();
1304        let rows = vec![(dirs[0].clone(), Vec::<usize>::new())];
1305        let report = index.recall_report(&sketch, &rows, 8, true);
1306        assert_eq!(report.recall, 1.0);
1307        assert!(report.misses.is_empty());
1308    }
1309}