1use std::collections::hash_map::DefaultHasher;
23use std::hash::{Hash, Hasher};
24
25use crate::logical_kernel::KernelDispatchConfig;
26use crate::quant::QuantScheme;
27use crate::variant::ModelVariant;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
31pub enum CompilationMode {
32 #[default]
34 Eager,
35 Lazy,
37 Aot,
39}
40
41#[derive(Debug, Clone)]
43pub struct ModelComponent {
44 pub variant: ModelVariant,
45 pub kernel_dispatch: KernelDispatchConfig,
46 pub compilation_mode: CompilationMode,
47 pub profile_key: u64,
49 pub quant: Option<QuantScheme>,
51 pub layer_composition_key: u64,
53}
54
55impl ModelComponent {
56 pub fn new(variant: ModelVariant) -> Self {
57 Self {
58 variant,
59 kernel_dispatch: KernelDispatchConfig::default(),
60 compilation_mode: CompilationMode::Eager,
61 profile_key: 0,
62 quant: None,
63 layer_composition_key: 0,
64 }
65 }
66
67 pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
68 self.kernel_dispatch = config;
69 self
70 }
71
72 pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
73 self.compilation_mode = mode;
74 self
75 }
76
77 pub fn with_profile_key(mut self, key: u64) -> Self {
78 self.profile_key = key;
79 self
80 }
81
82 pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
83 self.quant = Some(scheme);
84 self
85 }
86
87 pub fn with_layer_composition_key(mut self, key: u64) -> Self {
88 self.layer_composition_key = key;
89 self
90 }
91
92 pub fn cache_key(&self) -> u64 {
94 let mut h = DefaultHasher::new();
95 self.variant.cache_key().hash(&mut h);
96 (self.kernel_dispatch.policy as u8).hash(&mut h);
97 for k in self.kernel_dispatch.force_common_kinds.iter() {
98 k.hash(&mut h);
99 }
100 for k in self.kernel_dispatch.force_native_kinds.iter() {
101 k.hash(&mut h);
102 }
103 self.compilation_mode.hash(&mut h);
104 self.profile_key.hash(&mut h);
105 if let Some(q) = &self.quant {
106 format!("{q:?}").hash(&mut h);
107 }
108 self.layer_composition_key.hash(&mut h);
109 h.finish()
110 }
111
112 pub fn dim_binding(&self) -> crate::DimBinding {
113 self.variant.dim_binding()
114 }
115
116 pub fn aot_disk_base(&self) -> String {
118 format!("rlx_{:016x}", self.cache_key())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::ModelVariant;
126 use crate::logical_kernel::KernelDispatchPolicy;
127
128 #[test]
129 fn cache_key_changes_with_mode_and_profile() {
130 let v = ModelVariant::prefill(1, 8);
131 let a = ModelComponent::new(v.clone()).cache_key();
132 let b = ModelComponent::new(v.clone())
133 .with_compilation_mode(CompilationMode::Lazy)
134 .cache_key();
135 let c = ModelComponent::new(v)
136 .with_profile_key(42)
137 .with_kernel_dispatch(KernelDispatchConfig::new(KernelDispatchPolicy::ForceCommon))
138 .cache_key();
139 assert_ne!(a, b);
140 assert_ne!(a, c);
141 }
142}