1use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13use crate::logical_kernel::KernelDispatchConfig;
14use crate::quant::QuantScheme;
15use crate::variant::ModelVariant;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
19pub enum CompilationMode {
20 #[default]
22 Eager,
23 Lazy,
25 Aot,
27}
28
29#[derive(Debug, Clone)]
31pub struct ModelComponent {
32 pub variant: ModelVariant,
33 pub kernel_dispatch: KernelDispatchConfig,
34 pub compilation_mode: CompilationMode,
35 pub profile_key: u64,
37 pub quant: Option<QuantScheme>,
39 pub layer_composition_key: u64,
41}
42
43impl ModelComponent {
44 pub fn new(variant: ModelVariant) -> Self {
45 Self {
46 variant,
47 kernel_dispatch: KernelDispatchConfig::default(),
48 compilation_mode: CompilationMode::Eager,
49 profile_key: 0,
50 quant: None,
51 layer_composition_key: 0,
52 }
53 }
54
55 pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
56 self.kernel_dispatch = config;
57 self
58 }
59
60 pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
61 self.compilation_mode = mode;
62 self
63 }
64
65 pub fn with_profile_key(mut self, key: u64) -> Self {
66 self.profile_key = key;
67 self
68 }
69
70 pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
71 self.quant = Some(scheme);
72 self
73 }
74
75 pub fn with_layer_composition_key(mut self, key: u64) -> Self {
76 self.layer_composition_key = key;
77 self
78 }
79
80 pub fn cache_key(&self) -> u64 {
82 let mut h = DefaultHasher::new();
83 self.variant.cache_key().hash(&mut h);
84 (self.kernel_dispatch.policy as u8).hash(&mut h);
85 for k in self.kernel_dispatch.force_common_kinds.iter() {
86 k.hash(&mut h);
87 }
88 for k in self.kernel_dispatch.force_native_kinds.iter() {
89 k.hash(&mut h);
90 }
91 self.compilation_mode.hash(&mut h);
92 self.profile_key.hash(&mut h);
93 if let Some(q) = &self.quant {
94 format!("{q:?}").hash(&mut h);
95 }
96 self.layer_composition_key.hash(&mut h);
97 h.finish()
98 }
99
100 pub fn dim_binding(&self) -> crate::DimBinding {
101 self.variant.dim_binding()
102 }
103
104 pub fn aot_disk_base(&self) -> String {
106 format!("rlx_{:016x}", self.cache_key())
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::ModelVariant;
114 use crate::logical_kernel::KernelDispatchPolicy;
115
116 #[test]
117 fn cache_key_changes_with_mode_and_profile() {
118 let v = ModelVariant::prefill(1, 8);
119 let a = ModelComponent::new(v.clone()).cache_key();
120 let b = ModelComponent::new(v.clone())
121 .with_compilation_mode(CompilationMode::Lazy)
122 .cache_key();
123 let c = ModelComponent::new(v)
124 .with_profile_key(42)
125 .with_kernel_dispatch(KernelDispatchConfig::new(KernelDispatchPolicy::ForceCommon))
126 .cache_key();
127 assert_ne!(a, b);
128 assert_ne!(a, c);
129 }
130}