Skip to main content

gam_gpu/
dictionary_score.rs

1//! Shape and memory planning for high-`K` dictionary score routing.
2//!
3//! Sparse SAE dictionary routers all have the same hot loop: score a minibatch
4//! of `n_rows` residual rows against `n_items` candidate atoms/blocks, keep a
5//! tiny online top-`s`, and never materialize the full `n_rows x n_items` score
6//! matrix. This module owns the reusable admission and tile-size invariants for
7//! that pattern. Domain crates still own their kernels and selection semantics.
8
9/// Minimum `n_rows * n_items` score elements before a cold device route is worth
10/// its launch and host/device transfer cost.
11pub const DEFAULT_DICTIONARY_SCORE_MIN_ELEMS: usize = 1 << 20;
12
13/// Maximum score elements per device launch. With `f32` scores this is 8 MiB,
14/// matching the library row-chunk target and keeping peak score memory bounded
15/// independent of dictionary width.
16pub const DEFAULT_DICTIONARY_SCORE_TILE_ELEMS: usize =
17    gam_runtime::resource::ResourcePolicy::default_library().row_chunk_target_bytes
18        / std::mem::size_of::<f32>();
19
20/// Device admission and tile geometry for one minibatch-by-dictionary score
21/// route.
22#[derive(Clone, Copy, Debug, Eq, PartialEq)]
23pub struct DictionaryScoreRoutePlan {
24    /// Minibatch rows scored together.
25    pub n_rows: usize,
26    /// Candidate atoms/blocks scored for each row.
27    pub n_items: usize,
28    /// Dot-product width for one score.
29    pub feature_dim: usize,
30    /// Minimum `n_rows * n_items` elements required for device admission.
31    pub device_min_score_elems: usize,
32    /// Maximum `n_rows * tile_items` score elements held by one device launch.
33    pub max_tile_score_elems: usize,
34    /// Candidate items per launch tile.
35    pub tile_items: usize,
36    /// Number of candidate tiles covering `0..n_items`.
37    pub tile_count: usize,
38    /// True when the total route work is large enough to use the device.
39    pub device_admitted: bool,
40    /// Peak score-block bytes for a full tile.
41    pub peak_score_bytes: usize,
42    /// Lower-bound arithmetic for dispatch diagnostics: one multiply and one add
43    /// per `(row, item, feature)` score term.
44    pub dot_flops_lower_bound: u128,
45}
46
47impl DictionaryScoreRoutePlan {
48    /// Build a plan with explicit thresholds. The function is pure and
49    /// allocation-free so call sites can test routing decisions without a CUDA
50    /// runtime.
51    #[must_use]
52    pub fn with_limits(
53        n_rows: usize,
54        n_items: usize,
55        feature_dim: usize,
56        device_min_score_elems: usize,
57        max_tile_score_elems: usize,
58    ) -> Self {
59        let total_score_elems = n_rows.saturating_mul(n_items);
60        let nondegenerate = n_rows > 0 && n_items > 0 && feature_dim > 0;
61        let tile_items = if !nondegenerate {
62            0
63        } else {
64            (max_tile_score_elems / n_rows).clamp(1, n_items)
65        };
66        let tile_count = if tile_items == 0 {
67            0
68        } else {
69            n_items.div_ceil(tile_items)
70        };
71        let peak_tile_items = tile_items.min(n_items);
72        let peak_score_elems = n_rows.saturating_mul(peak_tile_items);
73        let dot_flops_lower_bound = 2u128
74            .saturating_mul(n_rows as u128)
75            .saturating_mul(n_items as u128)
76            .saturating_mul(feature_dim as u128);
77
78        Self {
79            n_rows,
80            n_items,
81            feature_dim,
82            device_min_score_elems,
83            max_tile_score_elems,
84            tile_items,
85            tile_count,
86            device_admitted: nondegenerate && total_score_elems >= device_min_score_elems,
87            peak_score_bytes: peak_score_elems.saturating_mul(std::mem::size_of::<f32>()),
88            dot_flops_lower_bound,
89        }
90    }
91
92    /// Build a plan with the library defaults used by sparse dictionary routers.
93    #[must_use]
94    pub fn default_for_shape(n_rows: usize, n_items: usize, feature_dim: usize) -> Self {
95        Self::with_limits(
96            n_rows,
97            n_items,
98            feature_dim,
99            DEFAULT_DICTIONARY_SCORE_MIN_ELEMS,
100            DEFAULT_DICTIONARY_SCORE_TILE_ELEMS,
101        )
102    }
103
104    /// True when the plan covers no route work.
105    #[must_use]
106    pub const fn is_degenerate(self) -> bool {
107        self.n_rows == 0 || self.n_items == 0 || self.feature_dim == 0
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn target_k32k_shape_is_admitted_and_memory_bounded() {
117        let plan = DictionaryScoreRoutePlan::default_for_shape(256, 32_768, 64);
118        assert!(plan.device_admitted);
119        assert_eq!(plan.tile_items, 8_192);
120        assert_eq!(plan.tile_count, 4);
121        assert_eq!(plan.peak_score_bytes, 8 * 1024 * 1024);
122        assert_eq!(
123            plan.dot_flops_lower_bound,
124            2u128 * 256u128 * 32_768u128 * 64u128
125        );
126    }
127
128    #[test]
129    fn peak_score_memory_does_not_grow_with_dictionary_width() {
130        let small = DictionaryScoreRoutePlan::default_for_shape(512, 4_096, 48);
131        let large = DictionaryScoreRoutePlan::default_for_shape(512, 131_072, 48);
132        assert_eq!(small.tile_items, large.tile_items);
133        assert_eq!(small.peak_score_bytes, large.peak_score_bytes);
134        assert!(large.tile_count > small.tile_count);
135    }
136
137    #[test]
138    fn sub_floor_and_degenerate_shapes_stay_on_host() {
139        let tiny = DictionaryScoreRoutePlan::default_for_shape(16, 1024, 64);
140        assert!(!tiny.device_admitted);
141        assert_eq!(tiny.tile_count, 1);
142
143        for plan in [
144            DictionaryScoreRoutePlan::default_for_shape(0, 1024, 64),
145            DictionaryScoreRoutePlan::default_for_shape(16, 0, 64),
146            DictionaryScoreRoutePlan::default_for_shape(16, 1024, 0),
147        ] {
148            assert!(plan.is_degenerate());
149            assert!(!plan.device_admitted);
150            assert_eq!(plan.tile_items, 0);
151            assert_eq!(plan.tile_count, 0);
152            assert_eq!(plan.peak_score_bytes, 0);
153        }
154    }
155
156    #[test]
157    fn tiny_tile_budget_still_makes_forward_progress() {
158        let plan = DictionaryScoreRoutePlan::with_limits(512, 1000, 32, 1, 7);
159        assert_eq!(plan.tile_items, 1);
160        assert_eq!(plan.tile_count, 1000);
161        assert_eq!(plan.peak_score_bytes, 512 * std::mem::size_of::<f32>());
162    }
163}