1use rlx_ir::{
19 CompilationMode, DimBinding, KernelDispatchConfig, ModelComponent, ModelVariant, QuantScheme,
20};
21
22use crate::composite::LayerComposition;
23use crate::profile::CompileProfile;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum ExecutionPreset {
28 Llama32Prefill,
29 Llama32Decode,
30 Qwen35Prefill,
31 Qwen35Decode,
32 Encoder,
33}
34
35impl ExecutionPreset {
36 pub fn profile(&self) -> CompileProfile {
37 match self {
38 Self::Llama32Prefill => CompileProfile::llama32_prefill(),
39 Self::Llama32Decode => CompileProfile::llama32_decode(),
40 Self::Qwen35Prefill => CompileProfile::qwen35_prefill(),
41 Self::Qwen35Decode => CompileProfile::qwen35_decode(),
42 Self::Encoder => CompileProfile::encoder(),
43 }
44 }
45
46 pub fn profile_key(&self) -> u64 {
47 use std::collections::hash_map::DefaultHasher;
48 use std::hash::{Hash, Hasher};
49 let mut h = DefaultHasher::new();
50 format!("{self:?}").hash(&mut h);
51 h.finish()
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct ModelExecutionConfig {
58 pub component: ModelComponent,
59 pub preset: ExecutionPreset,
60}
61
62impl ModelExecutionConfig {
63 pub fn from_component(component: ModelComponent, preset: ExecutionPreset) -> Self {
64 Self { component, preset }
65 }
66
67 pub fn prefill(batch: usize, seq: usize) -> Self {
68 Self::from_component(
69 ModelComponent::new(ModelVariant::prefill(batch, seq))
70 .with_profile_key(ExecutionPreset::Llama32Prefill.profile_key()),
71 ExecutionPreset::Llama32Prefill,
72 )
73 }
74
75 pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
76 Self::from_component(
77 ModelComponent::new(ModelVariant::decode(batch, past_seq, new_tokens))
78 .with_profile_key(ExecutionPreset::Llama32Decode.profile_key()),
79 ExecutionPreset::Llama32Decode,
80 )
81 }
82
83 pub fn qwen35_prefill(batch: usize, seq: usize) -> Self {
84 Self::from_component(
85 ModelComponent::new(ModelVariant::prefill(batch, seq))
86 .with_profile_key(ExecutionPreset::Qwen35Prefill.profile_key()),
87 ExecutionPreset::Qwen35Prefill,
88 )
89 }
90
91 pub fn qwen35_decode(batch: usize, past_seq: usize) -> Self {
92 Self::from_component(
93 ModelComponent::new(ModelVariant::decode(batch, past_seq, 1))
94 .with_profile_key(ExecutionPreset::Qwen35Decode.profile_key()),
95 ExecutionPreset::Qwen35Decode,
96 )
97 }
98
99 pub fn with_preset(mut self, preset: ExecutionPreset) -> Self {
100 self.preset = preset;
101 self.component.profile_key = preset.profile_key();
102 self
103 }
104
105 pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
106 self.component.kernel_dispatch = config;
107 self
108 }
109
110 pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
111 self.component.compilation_mode = mode;
112 self
113 }
114
115 pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
116 self.component.quant = Some(scheme);
117 self
118 }
119
120 pub fn with_layer_composition(mut self, composition: &LayerComposition) -> Self {
121 self.component.layer_composition_key = composition.cache_key();
122 self
123 }
124
125 pub fn cache_key(&self) -> u64 {
126 self.component.cache_key()
127 }
128
129 pub fn dim_binding(&self) -> DimBinding {
130 self.component.dim_binding()
131 }
132
133 pub fn compile_profile(&self) -> CompileProfile {
134 self.preset.profile()
135 }
136
137 pub fn component(&self) -> &ModelComponent {
138 &self.component
139 }
140
141 pub fn variant(&self) -> &ModelVariant {
142 &self.component.variant
143 }
144}