Skip to main content

rlx_flow/
profile.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Compile profile — tier-1 config for fusion, passes, precision, backends.
5
6use rlx_ir::hir::FusionPolicy;
7use serde::{Deserialize, Serialize};
8
9/// Tier-1 compile configuration. Load from `*.rlx.toml` or use Rust presets.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(default)]
12pub struct CompileProfile {
13    pub fusion: FusionProfile,
14    pub passes: PassProfile,
15    pub precision: PrecisionProfile,
16    #[serde(default)]
17    pub backend: BackendOverrides,
18}
19
20impl Default for CompileProfile {
21    fn default() -> Self {
22        Self::llama32_prefill()
23    }
24}
25
26impl CompileProfile {
27    /// Fusion-first prefill defaults (Direct lowering, fusion passes on).
28    pub fn llama32_prefill() -> Self {
29        Self {
30            fusion: FusionProfile {
31                policy: FusionPolicyKind::Direct,
32                target: FusionTargetKind::Auto,
33                assert_clean: false,
34                skip: false,
35            },
36            passes: PassProfile::default(),
37            precision: PrecisionProfile::default(),
38            backend: BackendOverrides::default(),
39        }
40    }
41
42    /// Decode graphs: Fusable lowering so KV-cache concat patterns fuse cleanly.
43    pub fn llama32_decode() -> Self {
44        Self {
45            fusion: FusionProfile {
46                policy: FusionPolicyKind::Fusable,
47                ..FusionProfile::default()
48            },
49            ..Self::llama32_prefill()
50        }
51    }
52
53    /// Qwen3.5 prefill — same fusion-first defaults as LLaMA prefill.
54    pub fn qwen35_prefill() -> Self {
55        Self::llama32_prefill()
56    }
57
58    /// Qwen3.5 decode — fusable policy for GDN / full-attn KV patterns.
59    pub fn qwen35_decode() -> Self {
60        Self::llama32_decode()
61    }
62
63    /// Qwen3 dense LM prefill (GQA + SwiGLU).
64    pub fn qwen3_prefill() -> Self {
65        Self::llama32_prefill()
66    }
67
68    /// Qwen3 decode — fusable policy for bucketed KV-cache graphs.
69    pub fn qwen3_decode() -> Self {
70        Self::llama32_decode()
71    }
72
73    /// Gemma 2 / Gemma 3 causal LM prefill (GQA + RMSNorm + softcap).
74    pub fn gemma_prefill() -> Self {
75        Self::llama32_prefill()
76    }
77
78    /// Gemma decode — fusable policy for bucketed KV-cache graphs.
79    pub fn gemma_decode() -> Self {
80        Self::llama32_decode()
81    }
82
83    /// FLUX.2 diffusion transformer + VAE/text-encoder graphs.
84    pub fn flux2() -> Self {
85        Self::encoder()
86    }
87
88    /// SAM / SAM2 image encoder and mask-decoder subgraphs (ConvNeXt-style stacks).
89    pub fn sam_encoder() -> Self {
90        Self::encoder()
91    }
92
93    /// SAM3 detector encoder/decoder layers (ViT + deformable-style decoder).
94    pub fn sam3() -> Self {
95        Self::sam_encoder()
96    }
97
98    /// SAM2 image + mask-decoder + memory subgraphs (Hiera encoder uses same tier-1 knobs).
99    pub fn sam2() -> Self {
100        Self::sam_encoder()
101    }
102
103    /// SAM2 memory-attention layers — fusion off (host RoPE between subgraphs).
104    pub fn sam2_memory_attention() -> Self {
105        Self {
106            fusion: FusionProfile {
107                skip: true,
108                ..FusionProfile::default()
109            },
110            ..Self::encoder()
111        }
112    }
113
114    /// LLaDA2 / TIDE block-diffusion MoE (bidirectional attention + grouped MoE).
115    ///
116    /// Fusion is off so graphs legalize on wgpu/CUDA without unfused
117    /// `FusedResidualRmsNorm` lowerings.
118    pub fn llada2_diffusion() -> Self {
119        Self {
120            fusion: FusionProfile {
121                skip: true,
122                ..FusionProfile::default()
123            },
124            ..Self::encoder()
125        }
126    }
127
128    /// Bidirectional encoder defaults (BERT, NomicBERT, vision encoders).
129    pub fn encoder() -> Self {
130        Self {
131            fusion: FusionProfile {
132                policy: FusionPolicyKind::Direct,
133                ..FusionProfile::default()
134            },
135            passes: PassProfile {
136                dce: true,
137                constant_folding: true,
138                verbose: false,
139            },
140            precision: PrecisionProfile::default(),
141            backend: BackendOverrides::default(),
142        }
143    }
144
145    pub fn fusion_policy(&self) -> FusionPolicy {
146        self.fusion.policy.into()
147    }
148
149    pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
150        Ok(toml::from_str(s)?)
151    }
152
153    pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
154        let data = std::fs::read_to_string(path)?;
155        Self::from_toml_str(&data)
156    }
157}
158
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160#[serde(default)]
161pub struct FusionProfile {
162    pub policy: FusionPolicyKind,
163    pub target: FusionTargetKind,
164    pub assert_clean: bool,
165    pub skip: bool,
166}
167
168impl Default for FusionProfile {
169    fn default() -> Self {
170        Self {
171            policy: FusionPolicyKind::Direct,
172            target: FusionTargetKind::Auto,
173            assert_clean: false,
174            skip: false,
175        }
176    }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
180#[serde(rename_all = "lowercase")]
181pub enum FusionPolicyKind {
182    #[default]
183    Direct,
184    Fusable,
185}
186
187impl From<FusionPolicyKind> for FusionPolicy {
188    fn from(k: FusionPolicyKind) -> Self {
189        match k {
190            FusionPolicyKind::Direct => FusionPolicy::Direct,
191            FusionPolicyKind::Fusable => FusionPolicy::Fusable,
192        }
193    }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
197#[serde(rename_all = "lowercase")]
198pub enum FusionTargetKind {
199    #[default]
200    Auto,
201    Cpu,
202    Metal,
203    Mlx,
204    Cuda,
205    Rocm,
206    Wgpu,
207    Tpu,
208}
209
210#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211#[serde(default)]
212pub struct PassProfile {
213    pub dce: bool,
214    pub constant_folding: bool,
215    pub verbose: bool,
216}
217
218impl Default for PassProfile {
219    fn default() -> Self {
220        Self {
221            dce: true,
222            constant_folding: true,
223            verbose: false,
224        }
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
229#[serde(default)]
230pub struct PrecisionProfile {
231    pub compute: PrecisionKind,
232    pub mixed: MixedPrecisionKind,
233}
234
235impl Default for PrecisionProfile {
236    fn default() -> Self {
237        Self {
238            compute: PrecisionKind::F32,
239            mixed: MixedPrecisionKind::None,
240        }
241    }
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
245#[serde(rename_all = "lowercase")]
246pub enum PrecisionKind {
247    #[default]
248    F32,
249    F16,
250    Bf16,
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
254#[serde(rename_all = "snake_case")]
255pub enum MixedPrecisionKind {
256    #[default]
257    None,
258    Auto,
259}
260
261/// Per-backend hint table (env-style toggles without touching IR).
262#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
263pub struct BackendOverrides {
264    #[serde(default)]
265    pub metal: MetalBackendProfile,
266    #[serde(default)]
267    pub cpu: CpuBackendProfile,
268}
269
270#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
271pub struct MetalBackendProfile {
272    pub skip_fusion: bool,
273    pub unfuse_regions: bool,
274}
275
276#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
277pub struct CpuBackendProfile {
278    pub unfuse_regions: bool,
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn parse_profile_toml() {
287        let toml = r#"
288[fusion]
289policy = "direct"
290target = "metal"
291assert_clean = true
292
293[passes]
294dce = true
295constant_folding = false
296
297[precision]
298compute = "f16"
299mixed = "auto"
300"#;
301        let p = CompileProfile::from_toml_str(toml).unwrap();
302        assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
303        assert_eq!(p.fusion.target, FusionTargetKind::Metal);
304        assert!(p.fusion.assert_clean);
305        assert!(!p.passes.constant_folding);
306        assert_eq!(p.precision.compute, PrecisionKind::F16);
307        assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
308    }
309}