pub const DEFAULT_DICTIONARY_SCORE_MIN_ELEMS: usize = 1 << 20;
pub const DEFAULT_DICTIONARY_SCORE_TILE_ELEMS: usize =
gam_runtime::resource::ResourcePolicy::default_library().row_chunk_target_bytes
/ std::mem::size_of::<f32>();
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DictionaryScoreRoutePlan {
pub n_rows: usize,
pub n_items: usize,
pub feature_dim: usize,
pub device_min_score_elems: usize,
pub max_tile_score_elems: usize,
pub tile_items: usize,
pub tile_count: usize,
pub device_admitted: bool,
pub peak_score_bytes: usize,
pub dot_flops_lower_bound: u128,
}
impl DictionaryScoreRoutePlan {
#[must_use]
pub fn with_limits(
n_rows: usize,
n_items: usize,
feature_dim: usize,
device_min_score_elems: usize,
max_tile_score_elems: usize,
) -> Self {
let total_score_elems = n_rows.saturating_mul(n_items);
let nondegenerate = n_rows > 0 && n_items > 0 && feature_dim > 0;
let tile_items = if !nondegenerate {
0
} else {
(max_tile_score_elems / n_rows).clamp(1, n_items)
};
let tile_count = if tile_items == 0 {
0
} else {
n_items.div_ceil(tile_items)
};
let peak_tile_items = tile_items.min(n_items);
let peak_score_elems = n_rows.saturating_mul(peak_tile_items);
let dot_flops_lower_bound = 2u128
.saturating_mul(n_rows as u128)
.saturating_mul(n_items as u128)
.saturating_mul(feature_dim as u128);
Self {
n_rows,
n_items,
feature_dim,
device_min_score_elems,
max_tile_score_elems,
tile_items,
tile_count,
device_admitted: nondegenerate && total_score_elems >= device_min_score_elems,
peak_score_bytes: peak_score_elems.saturating_mul(std::mem::size_of::<f32>()),
dot_flops_lower_bound,
}
}
#[must_use]
pub fn default_for_shape(n_rows: usize, n_items: usize, feature_dim: usize) -> Self {
Self::with_limits(
n_rows,
n_items,
feature_dim,
DEFAULT_DICTIONARY_SCORE_MIN_ELEMS,
DEFAULT_DICTIONARY_SCORE_TILE_ELEMS,
)
}
#[must_use]
pub const fn is_degenerate(self) -> bool {
self.n_rows == 0 || self.n_items == 0 || self.feature_dim == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn target_k32k_shape_is_admitted_and_memory_bounded() {
let plan = DictionaryScoreRoutePlan::default_for_shape(256, 32_768, 64);
assert!(plan.device_admitted);
assert_eq!(plan.tile_items, 8_192);
assert_eq!(plan.tile_count, 4);
assert_eq!(plan.peak_score_bytes, 8 * 1024 * 1024);
assert_eq!(
plan.dot_flops_lower_bound,
2u128 * 256u128 * 32_768u128 * 64u128
);
}
#[test]
fn peak_score_memory_does_not_grow_with_dictionary_width() {
let small = DictionaryScoreRoutePlan::default_for_shape(512, 4_096, 48);
let large = DictionaryScoreRoutePlan::default_for_shape(512, 131_072, 48);
assert_eq!(small.tile_items, large.tile_items);
assert_eq!(small.peak_score_bytes, large.peak_score_bytes);
assert!(large.tile_count > small.tile_count);
}
#[test]
fn sub_floor_and_degenerate_shapes_stay_on_host() {
let tiny = DictionaryScoreRoutePlan::default_for_shape(16, 1024, 64);
assert!(!tiny.device_admitted);
assert_eq!(tiny.tile_count, 1);
for plan in [
DictionaryScoreRoutePlan::default_for_shape(0, 1024, 64),
DictionaryScoreRoutePlan::default_for_shape(16, 0, 64),
DictionaryScoreRoutePlan::default_for_shape(16, 1024, 0),
] {
assert!(plan.is_degenerate());
assert!(!plan.device_admitted);
assert_eq!(plan.tile_items, 0);
assert_eq!(plan.tile_count, 0);
assert_eq!(plan.peak_score_bytes, 0);
}
}
#[test]
fn tiny_tile_budget_still_makes_forward_progress() {
let plan = DictionaryScoreRoutePlan::with_limits(512, 1000, 32, 1, 7);
assert_eq!(plan.tile_items, 1);
assert_eq!(plan.tile_count, 1000);
assert_eq!(plan.peak_score_bytes, 512 * std::mem::size_of::<f32>());
}
}