gam_sae/sparse_dict/scoring.rs
1//! Tiled scoring and online top-`s` selection.
2//!
3//! The router must rank every row against all `K` atoms without ever holding an
4//! `N×K` score matrix. We do this by GEMM-ing the minibatch against the
5//! dictionary one **column tile** at a time (`atoms_tile` of shape `tile × P`),
6//! producing a `rows × tile` score block, folding that block into a per-row
7//! online top-`s` selector, and discarding it. Peak score memory is therefore
8//! `rows × tile`, independent of `K`.
9
10use ndarray::{ArrayView1, ArrayView2, Axis};
11
12/// Online "keep the `s` largest-magnitude scores seen so far" selector for a
13/// single row. Selection is by `|score|` (the dictionary atoms are unit-norm,
14/// so `|xᵀd|` is the magnitude of the optimal 1-atom projection); ties break by
15/// smaller atom index for determinism.
16#[derive(Clone, Debug)]
17pub struct TopSSelector {
18 /// `(atom_index, score, |score|)`, length ≤ `s`, kept unsorted.
19 heap: Vec<(u32, f32, f32)>,
20 capacity: usize,
21}
22
23impl TopSSelector {
24 pub fn new(capacity: usize) -> Self {
25 Self {
26 heap: Vec::with_capacity(capacity.max(1)),
27 capacity: capacity.max(1),
28 }
29 }
30
31 /// Offer one `(atom, score)` candidate.
32 #[inline]
33 pub fn offer(&mut self, atom: u32, score: f32) {
34 let mag = score.abs();
35 if self.heap.len() < self.capacity {
36 self.heap.push((atom, score, mag));
37 return;
38 }
39 // Find the current weakest slot.
40 let mut worst = 0usize;
41 for k in 1..self.heap.len() {
42 if self.heap[k].2 < self.heap[worst].2
43 || (self.heap[k].2 == self.heap[worst].2 && self.heap[k].0 > self.heap[worst].0)
44 {
45 worst = k;
46 }
47 }
48 let (w_atom, _, w_mag) = self.heap[worst];
49 if mag > w_mag || (mag == w_mag && atom < w_atom) {
50 self.heap[worst] = (atom, score, mag);
51 }
52 }
53
54 /// Finalise, returning `(atom, score)` pairs sorted by descending `|score|`
55 /// (ties by ascending atom index).
56 pub fn finish(mut self) -> Vec<(u32, f32)> {
57 self.heap.sort_by(|a, b| {
58 b.2.partial_cmp(&a.2)
59 .unwrap_or(std::cmp::Ordering::Equal)
60 .then_with(|| a.0.cmp(&b.0))
61 });
62 self.heap.into_iter().map(|(a, s, _)| (a, s)).collect()
63 }
64}
65
66/// Score a row against a tile of atoms (`atoms_tile`, `tile × P`, rows are
67/// atoms) and fold every score into `sel`. `atom_offset` is the global index of
68/// the tile's first atom.
69#[inline]
70pub fn score_row_tile(
71 row: ArrayView1<'_, f32>,
72 atoms_tile: ArrayView2<'_, f32>,
73 atom_offset: usize,
74 sel: &mut TopSSelector,
75) {
76 let p = row.len();
77 for (local, atom) in atoms_tile.outer_iter().enumerate() {
78 let mut acc = 0.0f32;
79 for c in 0..p {
80 acc += row[c] * atom[c];
81 }
82 sel.offer((atom_offset + local) as u32, acc);
83 }
84}
85
86/// Convenience: full top-`s` selection of one row against the entire decoder,
87/// tiled internally. Returns `(atom, score)` pairs, ≤ `s` of them, sorted by
88/// descending `|score|`. Used by tests and by the router's per-row path.
89pub fn top_s_online(
90 row: ArrayView1<'_, f32>,
91 decoder: ArrayView2<'_, f32>,
92 s: usize,
93 tile: usize,
94) -> Vec<(u32, f32)> {
95 let k = decoder.nrows();
96 let tile = tile.max(1);
97 let mut sel = TopSSelector::new(s);
98 let mut start = 0usize;
99 while start < k {
100 let end = (start + tile).min(k);
101 let block = decoder.slice(ndarray::s![start..end, ..]);
102 score_row_tile(row, block, start, &mut sel);
103 start = end;
104 }
105 sel.finish()
106}
107
108/// A reusable tiled scorer over a fixed decoder. Holds the tile width so the
109/// router can score a whole minibatch with one object, never materialising an
110/// `N×K` block.
111#[derive(Clone, Copy, Debug)]
112pub struct TileScorer {
113 pub tile: usize,
114 pub active: usize,
115}
116
117impl TileScorer {
118 pub fn new(active: usize, tile: usize) -> Self {
119 Self {
120 tile: tile.max(1),
121 active: active.max(1),
122 }
123 }
124
125 /// Top-`active` atoms for `row` against `decoder`.
126 pub fn route_row(
127 &self,
128 row: ArrayView1<'_, f32>,
129 decoder: ArrayView2<'_, f32>,
130 ) -> Vec<(u32, f32)> {
131 top_s_online(row, decoder, self.active, self.tile)
132 }
133
134 /// Top-`active` atoms for every row of a minibatch `rows` (`B × P`) against
135 /// `decoder` (`K × P`), scored a column tile at a time via a batched GEMM.
136 ///
137 /// This is the implementation that delivers the module's promise: the score
138 /// block formed at any instant is `B × tile` (peak `rows × tile`,
139 /// independent of `K`), and each tile is a single `(B × P)·(P × tile)`
140 /// matrix multiply rather than `B × tile` scalar dot loops. The online
141 /// top-`s` selector sees the atoms in the same global order as
142 /// [`Self::route_row`] (tile 0 first, ascending atom index). The GEMM
143 /// contracts the same `P` terms but `matrixmultiply` may accumulate them in
144 /// a blocked order, so the per-atom scores agree with the row-at-a-time path
145 /// only to f32 rounding; where two atoms tie within that rounding the two
146 /// paths may select different members of the tie (interchangeable for the
147 /// reconstruction, which is why the fit stays minibatch-invariant rather
148 /// than bit-identical). Returns one `(atom, score)` shortlist per row, in row
149 /// order.
150 pub fn route_minibatch(
151 &self,
152 rows: ArrayView2<'_, f32>,
153 decoder: ArrayView2<'_, f32>,
154 ) -> Vec<Vec<(u32, f32)>> {
155 let b = rows.nrows();
156 let k = decoder.nrows();
157 let mut selectors: Vec<TopSSelector> =
158 (0..b).map(|_| TopSSelector::new(self.active)).collect();
159
160 let mut start = 0usize;
161 while start < k {
162 let end = (start + self.tile).min(k);
163 // `decoder` tile is `tile × P`; transpose to `P × tile` so the GEMM
164 // produces the `B × tile` score block directly (rows × atoms).
165 let tile_block = decoder.slice(ndarray::s![start..end, ..]);
166 let scores = rows.dot(&tile_block.t()); // (B × P)·(P × tile) = B × tile
167 for (local, score_col) in scores.axis_iter(Axis(1)).enumerate() {
168 let atom = (start + local) as u32;
169 for (row_idx, &sc) in score_col.iter().enumerate() {
170 selectors[row_idx].offer(atom, sc);
171 }
172 }
173 start = end;
174 }
175 selectors.into_iter().map(TopSSelector::finish).collect()
176 }
177}