use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use crate::logical_kernel::KernelDispatchConfig;
use crate::quant::QuantScheme;
use crate::variant::ModelVariant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum CompilationMode {
#[default]
Eager,
Lazy,
Aot,
}
#[derive(Debug, Clone)]
pub struct ModelComponent {
pub variant: ModelVariant,
pub kernel_dispatch: KernelDispatchConfig,
pub compilation_mode: CompilationMode,
pub profile_key: u64,
pub quant: Option<QuantScheme>,
pub layer_composition_key: u64,
}
impl ModelComponent {
pub fn new(variant: ModelVariant) -> Self {
Self {
variant,
kernel_dispatch: KernelDispatchConfig::default(),
compilation_mode: CompilationMode::Eager,
profile_key: 0,
quant: None,
layer_composition_key: 0,
}
}
pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
self.kernel_dispatch = config;
self
}
pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
self.compilation_mode = mode;
self
}
pub fn with_profile_key(mut self, key: u64) -> Self {
self.profile_key = key;
self
}
pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
self.quant = Some(scheme);
self
}
pub fn with_layer_composition_key(mut self, key: u64) -> Self {
self.layer_composition_key = key;
self
}
pub fn cache_key(&self) -> u64 {
let mut h = DefaultHasher::new();
self.variant.cache_key().hash(&mut h);
(self.kernel_dispatch.policy as u8).hash(&mut h);
for k in self.kernel_dispatch.force_common_kinds.iter() {
k.hash(&mut h);
}
for k in self.kernel_dispatch.force_native_kinds.iter() {
k.hash(&mut h);
}
self.compilation_mode.hash(&mut h);
self.profile_key.hash(&mut h);
if let Some(q) = &self.quant {
format!("{q:?}").hash(&mut h);
}
self.layer_composition_key.hash(&mut h);
h.finish()
}
pub fn dim_binding(&self) -> crate::DimBinding {
self.variant.dim_binding()
}
pub fn aot_disk_base(&self) -> String {
format!("rlx_{:016x}", self.cache_key())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ModelVariant;
use crate::logical_kernel::KernelDispatchPolicy;
#[test]
fn cache_key_changes_with_mode_and_profile() {
let v = ModelVariant::prefill(1, 8);
let a = ModelComponent::new(v.clone()).cache_key();
let b = ModelComponent::new(v.clone())
.with_compilation_mode(CompilationMode::Lazy)
.cache_key();
let c = ModelComponent::new(v)
.with_profile_key(42)
.with_kernel_dispatch(KernelDispatchConfig::new(KernelDispatchPolicy::ForceCommon))
.cache_key();
assert_ne!(a, b);
assert_ne!(a, c);
}
}