Skip to main content

gam_sae/sparse_dict/
scoring.rs

1//! Tiled scoring and online top-`s` selection.
2//!
3//! The router must rank every row against all `K` atoms without ever holding an
4//! `N×K` score matrix. We do this by GEMM-ing the minibatch against the
5//! dictionary one **column tile** at a time (`atoms_tile` of shape `tile × P`),
6//! producing a `rows × tile` score block, folding that block into a per-row
7//! online top-`s` selector, and discarding it. Peak score memory is therefore
8//! `rows × tile`, independent of `K`.
9
10use ndarray::{ArrayView1, ArrayView2, Axis};
11
12/// Online "keep the `s` largest-magnitude scores seen so far" selector for a
13/// single row. Selection is by `|score|` (the dictionary atoms are unit-norm,
14/// so `|xᵀd|` is the magnitude of the optimal 1-atom projection); ties break by
15/// smaller atom index for determinism.
16#[derive(Clone, Debug)]
17pub struct TopSSelector {
18    /// `(atom_index, score, |score|)`, length ≤ `s`, kept unsorted.
19    heap: Vec<(u32, f32, f32)>,
20    capacity: usize,
21}
22
23impl TopSSelector {
24    pub fn new(capacity: usize) -> Self {
25        Self {
26            heap: Vec::with_capacity(capacity.max(1)),
27            capacity: capacity.max(1),
28        }
29    }
30
31    /// Offer one `(atom, score)` candidate.
32    #[inline]
33    pub fn offer(&mut self, atom: u32, score: f32) {
34        let mag = score.abs();
35        if self.heap.len() < self.capacity {
36            self.heap.push((atom, score, mag));
37            return;
38        }
39        // Find the current weakest slot.
40        let mut worst = 0usize;
41        for k in 1..self.heap.len() {
42            if self.heap[k].2 < self.heap[worst].2
43                || (self.heap[k].2 == self.heap[worst].2 && self.heap[k].0 > self.heap[worst].0)
44            {
45                worst = k;
46            }
47        }
48        let (w_atom, _, w_mag) = self.heap[worst];
49        if mag > w_mag || (mag == w_mag && atom < w_atom) {
50            self.heap[worst] = (atom, score, mag);
51        }
52    }
53
54    /// Finalise, returning `(atom, score)` pairs sorted by descending `|score|`
55    /// (ties by ascending atom index).
56    pub fn finish(mut self) -> Vec<(u32, f32)> {
57        self.heap.sort_by(|a, b| {
58            b.2.partial_cmp(&a.2)
59                .unwrap_or(std::cmp::Ordering::Equal)
60                .then_with(|| a.0.cmp(&b.0))
61        });
62        self.heap.into_iter().map(|(a, s, _)| (a, s)).collect()
63    }
64}
65
66/// Score a row against a tile of atoms (`atoms_tile`, `tile × P`, rows are
67/// atoms) and fold every score into `sel`. `atom_offset` is the global index of
68/// the tile's first atom.
69#[inline]
70pub fn score_row_tile(
71    row: ArrayView1<'_, f32>,
72    atoms_tile: ArrayView2<'_, f32>,
73    atom_offset: usize,
74    sel: &mut TopSSelector,
75) {
76    let p = row.len();
77    for (local, atom) in atoms_tile.outer_iter().enumerate() {
78        let mut acc = 0.0f32;
79        for c in 0..p {
80            acc += row[c] * atom[c];
81        }
82        sel.offer((atom_offset + local) as u32, acc);
83    }
84}
85
86/// Convenience: full top-`s` selection of one row against the entire decoder,
87/// tiled internally. Returns `(atom, score)` pairs, ≤ `s` of them, sorted by
88/// descending `|score|`. Used by tests and by the router's per-row path.
89pub fn top_s_online(
90    row: ArrayView1<'_, f32>,
91    decoder: ArrayView2<'_, f32>,
92    s: usize,
93    tile: usize,
94) -> Vec<(u32, f32)> {
95    let k = decoder.nrows();
96    let tile = tile.max(1);
97    let mut sel = TopSSelector::new(s);
98    let mut start = 0usize;
99    while start < k {
100        let end = (start + tile).min(k);
101        let block = decoder.slice(ndarray::s![start..end, ..]);
102        score_row_tile(row, block, start, &mut sel);
103        start = end;
104    }
105    sel.finish()
106}
107
108/// A reusable tiled scorer over a fixed decoder. Holds the tile width so the
109/// router can score a whole minibatch with one object, never materialising an
110/// `N×K` block.
111#[derive(Clone, Copy, Debug)]
112pub struct TileScorer {
113    pub tile: usize,
114    pub active: usize,
115}
116
117impl TileScorer {
118    pub fn new(active: usize, tile: usize) -> Self {
119        Self {
120            tile: tile.max(1),
121            active: active.max(1),
122        }
123    }
124
125    /// Top-`active` atoms for `row` against `decoder`.
126    pub fn route_row(
127        &self,
128        row: ArrayView1<'_, f32>,
129        decoder: ArrayView2<'_, f32>,
130    ) -> Vec<(u32, f32)> {
131        top_s_online(row, decoder, self.active, self.tile)
132    }
133
134    /// Top-`active` atoms for every row of a minibatch `rows` (`B × P`) against
135    /// `decoder` (`K × P`), scored a column tile at a time via a batched GEMM.
136    ///
137    /// This is the implementation that delivers the module's promise: the score
138    /// block formed at any instant is `B × tile` (peak `rows × tile`,
139    /// independent of `K`), and each tile is a single `(B × P)·(P × tile)`
140    /// matrix multiply rather than `B × tile` scalar dot loops. The online
141    /// top-`s` selector sees the atoms in the same global order as
142    /// [`Self::route_row`] (tile 0 first, ascending atom index). The GEMM
143    /// contracts the same `P` terms but `matrixmultiply` may accumulate them in
144    /// a blocked order, so the per-atom scores agree with the row-at-a-time path
145    /// only to f32 rounding; where two atoms tie within that rounding the two
146    /// paths may select different members of the tie (interchangeable for the
147    /// reconstruction, which is why the fit stays minibatch-invariant rather
148    /// than bit-identical). Returns one `(atom, score)` shortlist per row, in row
149    /// order.
150    pub fn route_minibatch(
151        &self,
152        rows: ArrayView2<'_, f32>,
153        decoder: ArrayView2<'_, f32>,
154    ) -> Vec<Vec<(u32, f32)>> {
155        let b = rows.nrows();
156        let k = decoder.nrows();
157        let mut selectors: Vec<TopSSelector> =
158            (0..b).map(|_| TopSSelector::new(self.active)).collect();
159
160        let mut start = 0usize;
161        while start < k {
162            let end = (start + self.tile).min(k);
163            // `decoder` tile is `tile × P`; transpose to `P × tile` so the GEMM
164            // produces the `B × tile` score block directly (rows × atoms).
165            let tile_block = decoder.slice(ndarray::s![start..end, ..]);
166            let scores = rows.dot(&tile_block.t()); // (B × P)·(P × tile) = B × tile
167            for (local, score_col) in scores.axis_iter(Axis(1)).enumerate() {
168                let atom = (start + local) as u32;
169                for (row_idx, &sc) in score_col.iter().enumerate() {
170                    selectors[row_idx].offer(atom, sc);
171                }
172            }
173            start = end;
174        }
175        selectors.into_iter().map(TopSSelector::finish).collect()
176    }
177}