Skip to main content

rlx_flow/
profile.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compile profile — tier-1 config for fusion, passes, precision, backends.
17
18use rlx_ir::hir::FusionPolicy;
19use serde::{Deserialize, Serialize};
20
21/// Tier-1 compile configuration. Load from `*.rlx.toml` or use Rust presets.
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(default)]
24pub struct CompileProfile {
25    pub fusion: FusionProfile,
26    pub passes: PassProfile,
27    pub precision: PrecisionProfile,
28    #[serde(default)]
29    pub backend: BackendOverrides,
30}
31
32impl Default for CompileProfile {
33    fn default() -> Self {
34        Self::llama32_prefill()
35    }
36}
37
38impl CompileProfile {
39    /// Fusion-first prefill defaults (Direct lowering, fusion passes on).
40    pub fn llama32_prefill() -> Self {
41        Self {
42            fusion: FusionProfile {
43                policy: FusionPolicyKind::Direct,
44                target: FusionTargetKind::Auto,
45                assert_clean: false,
46                skip: false,
47            },
48            passes: PassProfile::default(),
49            precision: PrecisionProfile::default(),
50            backend: BackendOverrides::default(),
51        }
52    }
53
54    /// Decode graphs: Fusable lowering so KV-cache concat patterns fuse cleanly.
55    pub fn llama32_decode() -> Self {
56        Self {
57            fusion: FusionProfile {
58                policy: FusionPolicyKind::Fusable,
59                ..FusionProfile::default()
60            },
61            ..Self::llama32_prefill()
62        }
63    }
64
65    /// Qwen3.5 prefill — same fusion-first defaults as LLaMA prefill.
66    pub fn qwen35_prefill() -> Self {
67        Self::llama32_prefill()
68    }
69
70    /// Qwen3.5 decode — fusable policy for GDN / full-attn KV patterns.
71    pub fn qwen35_decode() -> Self {
72        Self::llama32_decode()
73    }
74
75    /// Qwen3 dense LM prefill (GQA + SwiGLU).
76    pub fn qwen3_prefill() -> Self {
77        Self::llama32_prefill()
78    }
79
80    /// Qwen3 decode — fusable policy for bucketed KV-cache graphs.
81    pub fn qwen3_decode() -> Self {
82        Self::llama32_decode()
83    }
84
85    /// Gemma 2 / Gemma 3 causal LM prefill (GQA + RMSNorm + softcap).
86    pub fn gemma_prefill() -> Self {
87        Self::llama32_prefill()
88    }
89
90    /// Gemma decode — fusable policy for bucketed KV-cache graphs.
91    pub fn gemma_decode() -> Self {
92        Self::llama32_decode()
93    }
94
95    /// FLUX.2 diffusion transformer + VAE/text-encoder graphs.
96    pub fn flux2() -> Self {
97        Self::encoder()
98    }
99
100    /// SAM / SAM2 image encoder and mask-decoder subgraphs (ConvNeXt-style stacks).
101    pub fn sam_encoder() -> Self {
102        Self::encoder()
103    }
104
105    /// SAM3 detector encoder/decoder layers (ViT + deformable-style decoder).
106    pub fn sam3() -> Self {
107        Self::sam_encoder()
108    }
109
110    /// SAM2 image + mask-decoder + memory subgraphs (Hiera encoder uses same tier-1 knobs).
111    pub fn sam2() -> Self {
112        Self::sam_encoder()
113    }
114
115    /// SAM2 memory-attention layers — fusion off (host RoPE between subgraphs).
116    pub fn sam2_memory_attention() -> Self {
117        Self {
118            fusion: FusionProfile {
119                skip: true,
120                ..FusionProfile::default()
121            },
122            ..Self::encoder()
123        }
124    }
125
126    /// LLaDA2 / TIDE block-diffusion MoE (bidirectional attention + grouped MoE).
127    ///
128    /// Fusion is off so graphs legalize on wgpu/CUDA without unfused
129    /// `FusedResidualRmsNorm` lowerings.
130    pub fn llada2_diffusion() -> Self {
131        Self {
132            fusion: FusionProfile {
133                skip: true,
134                ..FusionProfile::default()
135            },
136            ..Self::encoder()
137        }
138    }
139
140    /// Bidirectional encoder defaults (BERT, NomicBERT, vision encoders).
141    pub fn encoder() -> Self {
142        Self {
143            fusion: FusionProfile {
144                policy: FusionPolicyKind::Direct,
145                ..FusionProfile::default()
146            },
147            passes: PassProfile {
148                dce: true,
149                constant_folding: true,
150                verbose: false,
151            },
152            precision: PrecisionProfile::default(),
153            backend: BackendOverrides::default(),
154        }
155    }
156
157    pub fn fusion_policy(&self) -> FusionPolicy {
158        self.fusion.policy.into()
159    }
160
161    pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
162        Ok(toml::from_str(s)?)
163    }
164
165    pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
166        let data = std::fs::read_to_string(path)?;
167        Self::from_toml_str(&data)
168    }
169
170    /// Load `<family>.rlx.toml` from the directory containing
171    /// `weights`. Falls back to the built-in preset for `(family, mode)`
172    /// when the sidecar is missing or unreadable.
173    ///
174    /// Replaces the per-crate `*_profile_near_weights` helpers (one per
175    /// family today: `llama32_profile_near_weights`, `qwen3_profile_*`,
176    /// `gemma_profile_*`). New families only need to register their
177    /// built-in presets via [`Self::default_for`].
178    pub fn near_weights(weights: &std::path::Path, family: &str, mode: ProfileMode) -> Self {
179        let default = Self::default_for(family, mode);
180        let dir = weights
181            .parent()
182            .unwrap_or_else(|| std::path::Path::new("."));
183        let sidecar = dir.join(format!("{family}.rlx.toml"));
184        Self::from_toml_path(&sidecar).unwrap_or(default)
185    }
186
187    /// Built-in preset for `(family, mode)`. Unknown families fall back
188    /// to the Llama 3.2 presets — same behavior the per-crate helpers
189    /// used before this method existed.
190    pub fn default_for(family: &str, mode: ProfileMode) -> Self {
191        match (family, mode) {
192            ("llama32", ProfileMode::Prefill) => Self::llama32_prefill(),
193            ("llama32", ProfileMode::Decode) => Self::llama32_decode(),
194            ("qwen3", ProfileMode::Prefill) => Self::qwen3_prefill(),
195            ("qwen3", ProfileMode::Decode) => Self::qwen3_decode(),
196            ("qwen35", ProfileMode::Prefill) => Self::qwen35_prefill(),
197            ("qwen35", ProfileMode::Decode) => Self::qwen35_decode(),
198            ("gemma", ProfileMode::Prefill) => Self::gemma_prefill(),
199            ("gemma", ProfileMode::Decode) => Self::gemma_decode(),
200            (_, ProfileMode::Prefill) => Self::llama32_prefill(),
201            (_, ProfileMode::Decode) => Self::llama32_decode(),
202            (_, ProfileMode::Encoder) => Self::encoder(),
203        }
204    }
205}
206
207/// Whether the graph being compiled is a prefill, decode, or
208/// encoder-style pass. Selects the right built-in preset in
209/// [`CompileProfile::near_weights`].
210#[derive(Debug, Clone, Copy, PartialEq, Eq)]
211pub enum ProfileMode {
212    Prefill,
213    Decode,
214    Encoder,
215}
216
217#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
218#[serde(default)]
219pub struct FusionProfile {
220    pub policy: FusionPolicyKind,
221    pub target: FusionTargetKind,
222    pub assert_clean: bool,
223    pub skip: bool,
224}
225
226impl Default for FusionProfile {
227    fn default() -> Self {
228        Self {
229            policy: FusionPolicyKind::Direct,
230            target: FusionTargetKind::Auto,
231            assert_clean: false,
232            skip: false,
233        }
234    }
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
238#[serde(rename_all = "lowercase")]
239pub enum FusionPolicyKind {
240    #[default]
241    Direct,
242    Fusable,
243}
244
245impl From<FusionPolicyKind> for FusionPolicy {
246    fn from(k: FusionPolicyKind) -> Self {
247        match k {
248            FusionPolicyKind::Direct => FusionPolicy::Direct,
249            FusionPolicyKind::Fusable => FusionPolicy::Fusable,
250        }
251    }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
255#[serde(rename_all = "lowercase")]
256pub enum FusionTargetKind {
257    #[default]
258    Auto,
259    Cpu,
260    Metal,
261    Mlx,
262    Cuda,
263    Rocm,
264    Wgpu,
265    Tpu,
266}
267
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
269#[serde(default)]
270pub struct PassProfile {
271    pub dce: bool,
272    pub constant_folding: bool,
273    pub verbose: bool,
274}
275
276impl Default for PassProfile {
277    fn default() -> Self {
278        Self {
279            dce: true,
280            constant_folding: true,
281            verbose: false,
282        }
283    }
284}
285
286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287#[serde(default)]
288pub struct PrecisionProfile {
289    pub compute: PrecisionKind,
290    pub mixed: MixedPrecisionKind,
291}
292
293impl Default for PrecisionProfile {
294    fn default() -> Self {
295        Self {
296            compute: PrecisionKind::F32,
297            mixed: MixedPrecisionKind::None,
298        }
299    }
300}
301
302#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
303#[serde(rename_all = "lowercase")]
304pub enum PrecisionKind {
305    #[default]
306    F32,
307    F16,
308    Bf16,
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
312#[serde(rename_all = "snake_case")]
313pub enum MixedPrecisionKind {
314    #[default]
315    None,
316    Auto,
317}
318
319/// Per-backend hint table (env-style toggles without touching IR).
320#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
321pub struct BackendOverrides {
322    #[serde(default)]
323    pub metal: MetalBackendProfile,
324    #[serde(default)]
325    pub cpu: CpuBackendProfile,
326}
327
328#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
329pub struct MetalBackendProfile {
330    pub skip_fusion: bool,
331    pub unfuse_regions: bool,
332}
333
334#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
335pub struct CpuBackendProfile {
336    pub unfuse_regions: bool,
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn parse_profile_toml() {
345        let toml = r#"
346[fusion]
347policy = "direct"
348target = "metal"
349assert_clean = true
350
351[passes]
352dce = true
353constant_folding = false
354
355[precision]
356compute = "f16"
357mixed = "auto"
358"#;
359        let p = CompileProfile::from_toml_str(toml).unwrap();
360        assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
361        assert_eq!(p.fusion.target, FusionTargetKind::Metal);
362        assert!(p.fusion.assert_clean);
363        assert!(!p.passes.constant_folding);
364        assert_eq!(p.precision.compute, PrecisionKind::F16);
365        assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
366    }
367}