gam_gpu/
dictionary_score.rs1pub const DEFAULT_DICTIONARY_SCORE_MIN_ELEMS: usize = 1 << 20;
12
13pub const DEFAULT_DICTIONARY_SCORE_TILE_ELEMS: usize =
17 gam_runtime::resource::ResourcePolicy::default_library().row_chunk_target_bytes
18 / std::mem::size_of::<f32>();
19
20#[derive(Clone, Copy, Debug, Eq, PartialEq)]
23pub struct DictionaryScoreRoutePlan {
24 pub n_rows: usize,
26 pub n_items: usize,
28 pub feature_dim: usize,
30 pub device_min_score_elems: usize,
32 pub max_tile_score_elems: usize,
34 pub tile_items: usize,
36 pub tile_count: usize,
38 pub device_admitted: bool,
40 pub peak_score_bytes: usize,
42 pub dot_flops_lower_bound: u128,
45}
46
47impl DictionaryScoreRoutePlan {
48 #[must_use]
52 pub fn with_limits(
53 n_rows: usize,
54 n_items: usize,
55 feature_dim: usize,
56 device_min_score_elems: usize,
57 max_tile_score_elems: usize,
58 ) -> Self {
59 let total_score_elems = n_rows.saturating_mul(n_items);
60 let nondegenerate = n_rows > 0 && n_items > 0 && feature_dim > 0;
61 let tile_items = if !nondegenerate {
62 0
63 } else {
64 (max_tile_score_elems / n_rows).clamp(1, n_items)
65 };
66 let tile_count = if tile_items == 0 {
67 0
68 } else {
69 n_items.div_ceil(tile_items)
70 };
71 let peak_tile_items = tile_items.min(n_items);
72 let peak_score_elems = n_rows.saturating_mul(peak_tile_items);
73 let dot_flops_lower_bound = 2u128
74 .saturating_mul(n_rows as u128)
75 .saturating_mul(n_items as u128)
76 .saturating_mul(feature_dim as u128);
77
78 Self {
79 n_rows,
80 n_items,
81 feature_dim,
82 device_min_score_elems,
83 max_tile_score_elems,
84 tile_items,
85 tile_count,
86 device_admitted: nondegenerate && total_score_elems >= device_min_score_elems,
87 peak_score_bytes: peak_score_elems.saturating_mul(std::mem::size_of::<f32>()),
88 dot_flops_lower_bound,
89 }
90 }
91
92 #[must_use]
94 pub fn default_for_shape(n_rows: usize, n_items: usize, feature_dim: usize) -> Self {
95 Self::with_limits(
96 n_rows,
97 n_items,
98 feature_dim,
99 DEFAULT_DICTIONARY_SCORE_MIN_ELEMS,
100 DEFAULT_DICTIONARY_SCORE_TILE_ELEMS,
101 )
102 }
103
104 #[must_use]
106 pub const fn is_degenerate(self) -> bool {
107 self.n_rows == 0 || self.n_items == 0 || self.feature_dim == 0
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn target_k32k_shape_is_admitted_and_memory_bounded() {
117 let plan = DictionaryScoreRoutePlan::default_for_shape(256, 32_768, 64);
118 assert!(plan.device_admitted);
119 assert_eq!(plan.tile_items, 8_192);
120 assert_eq!(plan.tile_count, 4);
121 assert_eq!(plan.peak_score_bytes, 8 * 1024 * 1024);
122 assert_eq!(
123 plan.dot_flops_lower_bound,
124 2u128 * 256u128 * 32_768u128 * 64u128
125 );
126 }
127
128 #[test]
129 fn peak_score_memory_does_not_grow_with_dictionary_width() {
130 let small = DictionaryScoreRoutePlan::default_for_shape(512, 4_096, 48);
131 let large = DictionaryScoreRoutePlan::default_for_shape(512, 131_072, 48);
132 assert_eq!(small.tile_items, large.tile_items);
133 assert_eq!(small.peak_score_bytes, large.peak_score_bytes);
134 assert!(large.tile_count > small.tile_count);
135 }
136
137 #[test]
138 fn sub_floor_and_degenerate_shapes_stay_on_host() {
139 let tiny = DictionaryScoreRoutePlan::default_for_shape(16, 1024, 64);
140 assert!(!tiny.device_admitted);
141 assert_eq!(tiny.tile_count, 1);
142
143 for plan in [
144 DictionaryScoreRoutePlan::default_for_shape(0, 1024, 64),
145 DictionaryScoreRoutePlan::default_for_shape(16, 0, 64),
146 DictionaryScoreRoutePlan::default_for_shape(16, 1024, 0),
147 ] {
148 assert!(plan.is_degenerate());
149 assert!(!plan.device_admitted);
150 assert_eq!(plan.tile_items, 0);
151 assert_eq!(plan.tile_count, 0);
152 assert_eq!(plan.peak_score_bytes, 0);
153 }
154 }
155
156 #[test]
157 fn tiny_tile_budget_still_makes_forward_progress() {
158 let plan = DictionaryScoreRoutePlan::with_limits(512, 1000, 32, 1, 7);
159 assert_eq!(plan.tile_items, 1);
160 assert_eq!(plan.tile_count, 1000);
161 assert_eq!(plan.peak_score_bytes, 512 * std::mem::size_of::<f32>());
162 }
163}