Skip to main content

rlx_flow/
execution.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Unified execution configuration — variant + compile preset + cache key.
5
6use rlx_ir::{
7    CompilationMode, DimBinding, KernelDispatchConfig, ModelComponent, ModelVariant, QuantScheme,
8};
9
10use crate::composite::LayerComposition;
11use crate::profile::CompileProfile;
12
13/// Named compile presets (fusion policy, precision, pass toggles).
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum ExecutionPreset {
16    Llama32Prefill,
17    Llama32Decode,
18    Qwen35Prefill,
19    Qwen35Decode,
20    Encoder,
21}
22
23impl ExecutionPreset {
24    pub fn profile(&self) -> CompileProfile {
25        match self {
26            Self::Llama32Prefill => CompileProfile::llama32_prefill(),
27            Self::Llama32Decode => CompileProfile::llama32_decode(),
28            Self::Qwen35Prefill => CompileProfile::qwen35_prefill(),
29            Self::Qwen35Decode => CompileProfile::qwen35_decode(),
30            Self::Encoder => CompileProfile::encoder(),
31        }
32    }
33
34    pub fn profile_key(&self) -> u64 {
35        use std::collections::hash_map::DefaultHasher;
36        use std::hash::{Hash, Hasher};
37        let mut h = DefaultHasher::new();
38        format!("{self:?}").hash(&mut h);
39        h.finish()
40    }
41}
42
43/// Shader-component-style bundle: one object for specialize + compile + cache.
44#[derive(Debug, Clone)]
45pub struct ModelExecutionConfig {
46    pub component: ModelComponent,
47    pub preset: ExecutionPreset,
48}
49
50impl ModelExecutionConfig {
51    pub fn from_component(component: ModelComponent, preset: ExecutionPreset) -> Self {
52        Self { component, preset }
53    }
54
55    pub fn prefill(batch: usize, seq: usize) -> Self {
56        Self::from_component(
57            ModelComponent::new(ModelVariant::prefill(batch, seq))
58                .with_profile_key(ExecutionPreset::Llama32Prefill.profile_key()),
59            ExecutionPreset::Llama32Prefill,
60        )
61    }
62
63    pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
64        Self::from_component(
65            ModelComponent::new(ModelVariant::decode(batch, past_seq, new_tokens))
66                .with_profile_key(ExecutionPreset::Llama32Decode.profile_key()),
67            ExecutionPreset::Llama32Decode,
68        )
69    }
70
71    pub fn qwen35_prefill(batch: usize, seq: usize) -> Self {
72        Self::from_component(
73            ModelComponent::new(ModelVariant::prefill(batch, seq))
74                .with_profile_key(ExecutionPreset::Qwen35Prefill.profile_key()),
75            ExecutionPreset::Qwen35Prefill,
76        )
77    }
78
79    pub fn qwen35_decode(batch: usize, past_seq: usize) -> Self {
80        Self::from_component(
81            ModelComponent::new(ModelVariant::decode(batch, past_seq, 1))
82                .with_profile_key(ExecutionPreset::Qwen35Decode.profile_key()),
83            ExecutionPreset::Qwen35Decode,
84        )
85    }
86
87    pub fn with_preset(mut self, preset: ExecutionPreset) -> Self {
88        self.preset = preset;
89        self.component.profile_key = preset.profile_key();
90        self
91    }
92
93    pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
94        self.component.kernel_dispatch = config;
95        self
96    }
97
98    pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
99        self.component.compilation_mode = mode;
100        self
101    }
102
103    pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
104        self.component.quant = Some(scheme);
105        self
106    }
107
108    pub fn with_layer_composition(mut self, composition: &LayerComposition) -> Self {
109        self.component.layer_composition_key = composition.cache_key();
110        self
111    }
112
113    pub fn cache_key(&self) -> u64 {
114        self.component.cache_key()
115    }
116
117    pub fn dim_binding(&self) -> DimBinding {
118        self.component.dim_binding()
119    }
120
121    pub fn compile_profile(&self) -> CompileProfile {
122        self.preset.profile()
123    }
124
125    pub fn component(&self) -> &ModelComponent {
126        &self.component
127    }
128
129    pub fn variant(&self) -> &ModelVariant {
130        &self.component.variant
131    }
132}