1use rlx_ir::hir::FusionPolicy;
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(default)]
24pub struct CompileProfile {
25 pub fusion: FusionProfile,
26 pub passes: PassProfile,
27 pub precision: PrecisionProfile,
28 #[serde(default)]
29 pub backend: BackendOverrides,
30}
31
32impl Default for CompileProfile {
33 fn default() -> Self {
34 Self::llama32_prefill()
35 }
36}
37
38impl CompileProfile {
39 pub fn llama32_prefill() -> Self {
41 Self {
42 fusion: FusionProfile {
43 policy: FusionPolicyKind::Direct,
44 target: FusionTargetKind::Auto,
45 assert_clean: false,
46 skip: false,
47 },
48 passes: PassProfile::default(),
49 precision: PrecisionProfile::default(),
50 backend: BackendOverrides::default(),
51 }
52 }
53
54 pub fn llama32_decode() -> Self {
56 Self {
57 fusion: FusionProfile {
58 policy: FusionPolicyKind::Fusable,
59 ..FusionProfile::default()
60 },
61 ..Self::llama32_prefill()
62 }
63 }
64
65 pub fn qwen35_prefill() -> Self {
67 Self::llama32_prefill()
68 }
69
70 pub fn qwen35_decode() -> Self {
72 Self::llama32_decode()
73 }
74
75 pub fn qwen3_prefill() -> Self {
77 Self::llama32_prefill()
78 }
79
80 pub fn qwen3_decode() -> Self {
82 Self::llama32_decode()
83 }
84
85 pub fn gemma_prefill() -> Self {
87 Self::llama32_prefill()
88 }
89
90 pub fn gemma_decode() -> Self {
92 Self::llama32_decode()
93 }
94
95 pub fn flux2() -> Self {
97 Self::encoder()
98 }
99
100 pub fn sam_encoder() -> Self {
102 Self::encoder()
103 }
104
105 pub fn sam3() -> Self {
107 Self::sam_encoder()
108 }
109
110 pub fn sam2() -> Self {
112 Self::sam_encoder()
113 }
114
115 pub fn sam2_memory_attention() -> Self {
117 Self {
118 fusion: FusionProfile {
119 skip: true,
120 ..FusionProfile::default()
121 },
122 ..Self::encoder()
123 }
124 }
125
126 pub fn llada2_diffusion() -> Self {
131 Self {
132 fusion: FusionProfile {
133 skip: true,
134 ..FusionProfile::default()
135 },
136 ..Self::encoder()
137 }
138 }
139
140 pub fn encoder() -> Self {
142 Self {
143 fusion: FusionProfile {
144 policy: FusionPolicyKind::Direct,
145 ..FusionProfile::default()
146 },
147 passes: PassProfile {
148 dce: true,
149 constant_folding: true,
150 verbose: false,
151 },
152 precision: PrecisionProfile::default(),
153 backend: BackendOverrides::default(),
154 }
155 }
156
157 pub fn fusion_policy(&self) -> FusionPolicy {
158 self.fusion.policy.into()
159 }
160
161 pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
162 Ok(toml::from_str(s)?)
163 }
164
165 pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
166 let data = std::fs::read_to_string(path)?;
167 Self::from_toml_str(&data)
168 }
169
170 pub fn near_weights(weights: &std::path::Path, family: &str, mode: ProfileMode) -> Self {
179 let default = Self::default_for(family, mode);
180 let dir = weights
181 .parent()
182 .unwrap_or_else(|| std::path::Path::new("."));
183 let sidecar = dir.join(format!("{family}.rlx.toml"));
184 Self::from_toml_path(&sidecar).unwrap_or(default)
185 }
186
187 pub fn default_for(family: &str, mode: ProfileMode) -> Self {
191 match (family, mode) {
192 ("llama32", ProfileMode::Prefill) => Self::llama32_prefill(),
193 ("llama32", ProfileMode::Decode) => Self::llama32_decode(),
194 ("qwen3", ProfileMode::Prefill) => Self::qwen3_prefill(),
195 ("qwen3", ProfileMode::Decode) => Self::qwen3_decode(),
196 ("qwen35", ProfileMode::Prefill) => Self::qwen35_prefill(),
197 ("qwen35", ProfileMode::Decode) => Self::qwen35_decode(),
198 ("gemma", ProfileMode::Prefill) => Self::gemma_prefill(),
199 ("gemma", ProfileMode::Decode) => Self::gemma_decode(),
200 (_, ProfileMode::Prefill) => Self::llama32_prefill(),
201 (_, ProfileMode::Decode) => Self::llama32_decode(),
202 (_, ProfileMode::Encoder) => Self::encoder(),
203 }
204 }
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
211pub enum ProfileMode {
212 Prefill,
213 Decode,
214 Encoder,
215}
216
217#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
218#[serde(default)]
219pub struct FusionProfile {
220 pub policy: FusionPolicyKind,
221 pub target: FusionTargetKind,
222 pub assert_clean: bool,
223 pub skip: bool,
224}
225
226impl Default for FusionProfile {
227 fn default() -> Self {
228 Self {
229 policy: FusionPolicyKind::Direct,
230 target: FusionTargetKind::Auto,
231 assert_clean: false,
232 skip: false,
233 }
234 }
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
238#[serde(rename_all = "lowercase")]
239pub enum FusionPolicyKind {
240 #[default]
241 Direct,
242 Fusable,
243}
244
245impl From<FusionPolicyKind> for FusionPolicy {
246 fn from(k: FusionPolicyKind) -> Self {
247 match k {
248 FusionPolicyKind::Direct => FusionPolicy::Direct,
249 FusionPolicyKind::Fusable => FusionPolicy::Fusable,
250 }
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
255#[serde(rename_all = "lowercase")]
256pub enum FusionTargetKind {
257 #[default]
258 Auto,
259 Cpu,
260 Metal,
261 Mlx,
262 Cuda,
263 Rocm,
264 Wgpu,
265 Tpu,
266}
267
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
269#[serde(default)]
270pub struct PassProfile {
271 pub dce: bool,
272 pub constant_folding: bool,
273 pub verbose: bool,
274}
275
276impl Default for PassProfile {
277 fn default() -> Self {
278 Self {
279 dce: true,
280 constant_folding: true,
281 verbose: false,
282 }
283 }
284}
285
286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287#[serde(default)]
288pub struct PrecisionProfile {
289 pub compute: PrecisionKind,
290 pub mixed: MixedPrecisionKind,
291}
292
293impl Default for PrecisionProfile {
294 fn default() -> Self {
295 Self {
296 compute: PrecisionKind::F32,
297 mixed: MixedPrecisionKind::None,
298 }
299 }
300}
301
302#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
303#[serde(rename_all = "lowercase")]
304pub enum PrecisionKind {
305 #[default]
306 F32,
307 F16,
308 Bf16,
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
312#[serde(rename_all = "snake_case")]
313pub enum MixedPrecisionKind {
314 #[default]
315 None,
316 Auto,
317}
318
319#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
321pub struct BackendOverrides {
322 #[serde(default)]
323 pub metal: MetalBackendProfile,
324 #[serde(default)]
325 pub cpu: CpuBackendProfile,
326}
327
328#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
329pub struct MetalBackendProfile {
330 pub skip_fusion: bool,
331 pub unfuse_regions: bool,
332}
333
334#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
335pub struct CpuBackendProfile {
336 pub unfuse_regions: bool,
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn parse_profile_toml() {
345 let toml = r#"
346[fusion]
347policy = "direct"
348target = "metal"
349assert_clean = true
350
351[passes]
352dce = true
353constant_folding = false
354
355[precision]
356compute = "f16"
357mixed = "auto"
358"#;
359 let p = CompileProfile::from_toml_str(toml).unwrap();
360 assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
361 assert_eq!(p.fusion.target, FusionTargetKind::Metal);
362 assert!(p.fusion.assert_clean);
363 assert!(!p.passes.constant_folding);
364 assert_eq!(p.precision.compute, PrecisionKind::F16);
365 assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
366 }
367}