1use rlx_ir::{
7 CompilationMode, DimBinding, KernelDispatchConfig, ModelComponent, ModelVariant, QuantScheme,
8};
9
10use crate::composite::LayerComposition;
11use crate::profile::CompileProfile;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum ExecutionPreset {
16 Llama32Prefill,
17 Llama32Decode,
18 Qwen35Prefill,
19 Qwen35Decode,
20 Encoder,
21}
22
23impl ExecutionPreset {
24 pub fn profile(&self) -> CompileProfile {
25 match self {
26 Self::Llama32Prefill => CompileProfile::llama32_prefill(),
27 Self::Llama32Decode => CompileProfile::llama32_decode(),
28 Self::Qwen35Prefill => CompileProfile::qwen35_prefill(),
29 Self::Qwen35Decode => CompileProfile::qwen35_decode(),
30 Self::Encoder => CompileProfile::encoder(),
31 }
32 }
33
34 pub fn profile_key(&self) -> u64 {
35 use std::collections::hash_map::DefaultHasher;
36 use std::hash::{Hash, Hasher};
37 let mut h = DefaultHasher::new();
38 format!("{self:?}").hash(&mut h);
39 h.finish()
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ModelExecutionConfig {
46 pub component: ModelComponent,
47 pub preset: ExecutionPreset,
48}
49
50impl ModelExecutionConfig {
51 pub fn from_component(component: ModelComponent, preset: ExecutionPreset) -> Self {
52 Self { component, preset }
53 }
54
55 pub fn prefill(batch: usize, seq: usize) -> Self {
56 Self::from_component(
57 ModelComponent::new(ModelVariant::prefill(batch, seq))
58 .with_profile_key(ExecutionPreset::Llama32Prefill.profile_key()),
59 ExecutionPreset::Llama32Prefill,
60 )
61 }
62
63 pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
64 Self::from_component(
65 ModelComponent::new(ModelVariant::decode(batch, past_seq, new_tokens))
66 .with_profile_key(ExecutionPreset::Llama32Decode.profile_key()),
67 ExecutionPreset::Llama32Decode,
68 )
69 }
70
71 pub fn qwen35_prefill(batch: usize, seq: usize) -> Self {
72 Self::from_component(
73 ModelComponent::new(ModelVariant::prefill(batch, seq))
74 .with_profile_key(ExecutionPreset::Qwen35Prefill.profile_key()),
75 ExecutionPreset::Qwen35Prefill,
76 )
77 }
78
79 pub fn qwen35_decode(batch: usize, past_seq: usize) -> Self {
80 Self::from_component(
81 ModelComponent::new(ModelVariant::decode(batch, past_seq, 1))
82 .with_profile_key(ExecutionPreset::Qwen35Decode.profile_key()),
83 ExecutionPreset::Qwen35Decode,
84 )
85 }
86
87 pub fn with_preset(mut self, preset: ExecutionPreset) -> Self {
88 self.preset = preset;
89 self.component.profile_key = preset.profile_key();
90 self
91 }
92
93 pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
94 self.component.kernel_dispatch = config;
95 self
96 }
97
98 pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
99 self.component.compilation_mode = mode;
100 self
101 }
102
103 pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
104 self.component.quant = Some(scheme);
105 self
106 }
107
108 pub fn with_layer_composition(mut self, composition: &LayerComposition) -> Self {
109 self.component.layer_composition_key = composition.cache_key();
110 self
111 }
112
113 pub fn cache_key(&self) -> u64 {
114 self.component.cache_key()
115 }
116
117 pub fn dim_binding(&self) -> DimBinding {
118 self.component.dim_binding()
119 }
120
121 pub fn compile_profile(&self) -> CompileProfile {
122 self.preset.profile()
123 }
124
125 pub fn component(&self) -> &ModelComponent {
126 &self.component
127 }
128
129 pub fn variant(&self) -> &ModelVariant {
130 &self.component.variant
131 }
132}