Skip to main content

rlx_flow/
profile.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Compile profile — tier-1 config for fusion, passes, precision, backends.
5
6use rlx_ir::hir::FusionPolicy;
7use serde::{Deserialize, Serialize};
8
9/// Tier-1 compile configuration. Load from `*.rlx.toml` or use Rust presets.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(default)]
12pub struct CompileProfile {
13    pub fusion: FusionProfile,
14    pub passes: PassProfile,
15    pub precision: PrecisionProfile,
16    #[serde(default)]
17    pub backend: BackendOverrides,
18}
19
20impl Default for CompileProfile {
21    fn default() -> Self {
22        Self::llama32_prefill()
23    }
24}
25
26impl CompileProfile {
27    /// Fusion-first prefill defaults (Direct lowering, fusion passes on).
28    pub fn llama32_prefill() -> Self {
29        Self {
30            fusion: FusionProfile {
31                policy: FusionPolicyKind::Direct,
32                target: FusionTargetKind::Auto,
33                assert_clean: false,
34                skip: false,
35            },
36            passes: PassProfile::default(),
37            precision: PrecisionProfile::default(),
38            backend: BackendOverrides::default(),
39        }
40    }
41
42    /// Decode graphs: Fusable lowering so KV-cache concat patterns fuse cleanly.
43    pub fn llama32_decode() -> Self {
44        Self {
45            fusion: FusionProfile {
46                policy: FusionPolicyKind::Fusable,
47                ..FusionProfile::default()
48            },
49            ..Self::llama32_prefill()
50        }
51    }
52
53    /// Qwen3.5 prefill — same fusion-first defaults as LLaMA prefill.
54    pub fn qwen35_prefill() -> Self {
55        Self::llama32_prefill()
56    }
57
58    /// Qwen3.5 decode — fusable policy for GDN / full-attn KV patterns.
59    pub fn qwen35_decode() -> Self {
60        Self::llama32_decode()
61    }
62
63    /// Qwen3 dense LM prefill (GQA + SwiGLU).
64    pub fn qwen3_prefill() -> Self {
65        Self::llama32_prefill()
66    }
67
68    /// Qwen3 decode — fusable policy for bucketed KV-cache graphs.
69    pub fn qwen3_decode() -> Self {
70        Self::llama32_decode()
71    }
72
73    /// FLUX.2 diffusion transformer + VAE/text-encoder graphs.
74    pub fn flux2() -> Self {
75        Self::encoder()
76    }
77
78    /// SAM / SAM2 image encoder and mask-decoder subgraphs (ConvNeXt-style stacks).
79    pub fn sam_encoder() -> Self {
80        Self::encoder()
81    }
82
83    /// SAM3 detector encoder/decoder layers (ViT + deformable-style decoder).
84    pub fn sam3() -> Self {
85        Self::sam_encoder()
86    }
87
88    /// SAM2 image + mask-decoder + memory subgraphs (Hiera encoder uses same tier-1 knobs).
89    pub fn sam2() -> Self {
90        Self::sam_encoder()
91    }
92
93    /// SAM2 memory-attention layers — fusion off (host RoPE between subgraphs).
94    pub fn sam2_memory_attention() -> Self {
95        Self {
96            fusion: FusionProfile {
97                skip: true,
98                ..FusionProfile::default()
99            },
100            ..Self::encoder()
101        }
102    }
103
104    /// LLaDA2 / TIDE block-diffusion MoE (bidirectional attention + grouped MoE).
105    ///
106    /// Fusion is off so graphs legalize on wgpu/CUDA without unfused
107    /// `FusedResidualRmsNorm` lowerings.
108    pub fn llada2_diffusion() -> Self {
109        Self {
110            fusion: FusionProfile {
111                skip: true,
112                ..FusionProfile::default()
113            },
114            ..Self::encoder()
115        }
116    }
117
118    /// Bidirectional encoder defaults (BERT, NomicBERT, vision encoders).
119    pub fn encoder() -> Self {
120        Self {
121            fusion: FusionProfile {
122                policy: FusionPolicyKind::Direct,
123                ..FusionProfile::default()
124            },
125            passes: PassProfile {
126                dce: true,
127                constant_folding: true,
128                verbose: false,
129            },
130            precision: PrecisionProfile::default(),
131            backend: BackendOverrides::default(),
132        }
133    }
134
135    pub fn fusion_policy(&self) -> FusionPolicy {
136        self.fusion.policy.into()
137    }
138
139    pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
140        Ok(toml::from_str(s)?)
141    }
142
143    pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
144        let data = std::fs::read_to_string(path)?;
145        Self::from_toml_str(&data)
146    }
147}
148
149#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150#[serde(default)]
151pub struct FusionProfile {
152    pub policy: FusionPolicyKind,
153    pub target: FusionTargetKind,
154    pub assert_clean: bool,
155    pub skip: bool,
156}
157
158impl Default for FusionProfile {
159    fn default() -> Self {
160        Self {
161            policy: FusionPolicyKind::Direct,
162            target: FusionTargetKind::Auto,
163            assert_clean: false,
164            skip: false,
165        }
166    }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
170#[serde(rename_all = "lowercase")]
171pub enum FusionPolicyKind {
172    #[default]
173    Direct,
174    Fusable,
175}
176
177impl From<FusionPolicyKind> for FusionPolicy {
178    fn from(k: FusionPolicyKind) -> Self {
179        match k {
180            FusionPolicyKind::Direct => FusionPolicy::Direct,
181            FusionPolicyKind::Fusable => FusionPolicy::Fusable,
182        }
183    }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
187#[serde(rename_all = "lowercase")]
188pub enum FusionTargetKind {
189    #[default]
190    Auto,
191    Cpu,
192    Metal,
193    Mlx,
194    Cuda,
195    Rocm,
196    Wgpu,
197    Tpu,
198}
199
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
201#[serde(default)]
202pub struct PassProfile {
203    pub dce: bool,
204    pub constant_folding: bool,
205    pub verbose: bool,
206}
207
208impl Default for PassProfile {
209    fn default() -> Self {
210        Self {
211            dce: true,
212            constant_folding: true,
213            verbose: false,
214        }
215    }
216}
217
218#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
219#[serde(default)]
220pub struct PrecisionProfile {
221    pub compute: PrecisionKind,
222    pub mixed: MixedPrecisionKind,
223}
224
225impl Default for PrecisionProfile {
226    fn default() -> Self {
227        Self {
228            compute: PrecisionKind::F32,
229            mixed: MixedPrecisionKind::None,
230        }
231    }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
235#[serde(rename_all = "lowercase")]
236pub enum PrecisionKind {
237    #[default]
238    F32,
239    F16,
240    Bf16,
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
244#[serde(rename_all = "snake_case")]
245pub enum MixedPrecisionKind {
246    #[default]
247    None,
248    Auto,
249}
250
251/// Per-backend hint table (env-style toggles without touching IR).
252#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
253pub struct BackendOverrides {
254    #[serde(default)]
255    pub metal: MetalBackendProfile,
256    #[serde(default)]
257    pub cpu: CpuBackendProfile,
258}
259
260#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
261pub struct MetalBackendProfile {
262    pub skip_fusion: bool,
263    pub unfuse_regions: bool,
264}
265
266#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
267pub struct CpuBackendProfile {
268    pub unfuse_regions: bool,
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn parse_profile_toml() {
277        let toml = r#"
278[fusion]
279policy = "direct"
280target = "metal"
281assert_clean = true
282
283[passes]
284dce = true
285constant_folding = false
286
287[precision]
288compute = "f16"
289mixed = "auto"
290"#;
291        let p = CompileProfile::from_toml_str(toml).unwrap();
292        assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
293        assert_eq!(p.fusion.target, FusionTargetKind::Metal);
294        assert!(p.fusion.assert_clean);
295        assert!(!p.passes.constant_folding);
296        assert_eq!(p.precision.compute, PrecisionKind::F16);
297        assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
298    }
299}