1use rlx_ir::hir::FusionPolicy;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(default)]
12pub struct CompileProfile {
13 pub fusion: FusionProfile,
14 pub passes: PassProfile,
15 pub precision: PrecisionProfile,
16 #[serde(default)]
17 pub backend: BackendOverrides,
18}
19
20impl Default for CompileProfile {
21 fn default() -> Self {
22 Self::llama32_prefill()
23 }
24}
25
26impl CompileProfile {
27 pub fn llama32_prefill() -> Self {
29 Self {
30 fusion: FusionProfile {
31 policy: FusionPolicyKind::Direct,
32 target: FusionTargetKind::Auto,
33 assert_clean: false,
34 skip: false,
35 },
36 passes: PassProfile::default(),
37 precision: PrecisionProfile::default(),
38 backend: BackendOverrides::default(),
39 }
40 }
41
42 pub fn llama32_decode() -> Self {
44 Self {
45 fusion: FusionProfile {
46 policy: FusionPolicyKind::Fusable,
47 ..FusionProfile::default()
48 },
49 ..Self::llama32_prefill()
50 }
51 }
52
53 pub fn qwen35_prefill() -> Self {
55 Self::llama32_prefill()
56 }
57
58 pub fn qwen35_decode() -> Self {
60 Self::llama32_decode()
61 }
62
63 pub fn qwen3_prefill() -> Self {
65 Self::llama32_prefill()
66 }
67
68 pub fn qwen3_decode() -> Self {
70 Self::llama32_decode()
71 }
72
73 pub fn gemma_prefill() -> Self {
75 Self::llama32_prefill()
76 }
77
78 pub fn gemma_decode() -> Self {
80 Self::llama32_decode()
81 }
82
83 pub fn flux2() -> Self {
85 Self::encoder()
86 }
87
88 pub fn sam_encoder() -> Self {
90 Self::encoder()
91 }
92
93 pub fn sam3() -> Self {
95 Self::sam_encoder()
96 }
97
98 pub fn sam2() -> Self {
100 Self::sam_encoder()
101 }
102
103 pub fn sam2_memory_attention() -> Self {
105 Self {
106 fusion: FusionProfile {
107 skip: true,
108 ..FusionProfile::default()
109 },
110 ..Self::encoder()
111 }
112 }
113
114 pub fn llada2_diffusion() -> Self {
119 Self {
120 fusion: FusionProfile {
121 skip: true,
122 ..FusionProfile::default()
123 },
124 ..Self::encoder()
125 }
126 }
127
128 pub fn encoder() -> Self {
130 Self {
131 fusion: FusionProfile {
132 policy: FusionPolicyKind::Direct,
133 ..FusionProfile::default()
134 },
135 passes: PassProfile {
136 dce: true,
137 constant_folding: true,
138 verbose: false,
139 },
140 precision: PrecisionProfile::default(),
141 backend: BackendOverrides::default(),
142 }
143 }
144
145 pub fn fusion_policy(&self) -> FusionPolicy {
146 self.fusion.policy.into()
147 }
148
149 pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
150 Ok(toml::from_str(s)?)
151 }
152
153 pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
154 let data = std::fs::read_to_string(path)?;
155 Self::from_toml_str(&data)
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160#[serde(default)]
161pub struct FusionProfile {
162 pub policy: FusionPolicyKind,
163 pub target: FusionTargetKind,
164 pub assert_clean: bool,
165 pub skip: bool,
166}
167
168impl Default for FusionProfile {
169 fn default() -> Self {
170 Self {
171 policy: FusionPolicyKind::Direct,
172 target: FusionTargetKind::Auto,
173 assert_clean: false,
174 skip: false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
180#[serde(rename_all = "lowercase")]
181pub enum FusionPolicyKind {
182 #[default]
183 Direct,
184 Fusable,
185}
186
187impl From<FusionPolicyKind> for FusionPolicy {
188 fn from(k: FusionPolicyKind) -> Self {
189 match k {
190 FusionPolicyKind::Direct => FusionPolicy::Direct,
191 FusionPolicyKind::Fusable => FusionPolicy::Fusable,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
197#[serde(rename_all = "lowercase")]
198pub enum FusionTargetKind {
199 #[default]
200 Auto,
201 Cpu,
202 Metal,
203 Mlx,
204 Cuda,
205 Rocm,
206 Wgpu,
207 Tpu,
208}
209
210#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211#[serde(default)]
212pub struct PassProfile {
213 pub dce: bool,
214 pub constant_folding: bool,
215 pub verbose: bool,
216}
217
218impl Default for PassProfile {
219 fn default() -> Self {
220 Self {
221 dce: true,
222 constant_folding: true,
223 verbose: false,
224 }
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
229#[serde(default)]
230pub struct PrecisionProfile {
231 pub compute: PrecisionKind,
232 pub mixed: MixedPrecisionKind,
233}
234
235impl Default for PrecisionProfile {
236 fn default() -> Self {
237 Self {
238 compute: PrecisionKind::F32,
239 mixed: MixedPrecisionKind::None,
240 }
241 }
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
245#[serde(rename_all = "lowercase")]
246pub enum PrecisionKind {
247 #[default]
248 F32,
249 F16,
250 Bf16,
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
254#[serde(rename_all = "snake_case")]
255pub enum MixedPrecisionKind {
256 #[default]
257 None,
258 Auto,
259}
260
261#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
263pub struct BackendOverrides {
264 #[serde(default)]
265 pub metal: MetalBackendProfile,
266 #[serde(default)]
267 pub cpu: CpuBackendProfile,
268}
269
270#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
271pub struct MetalBackendProfile {
272 pub skip_fusion: bool,
273 pub unfuse_regions: bool,
274}
275
276#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
277pub struct CpuBackendProfile {
278 pub unfuse_regions: bool,
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn parse_profile_toml() {
287 let toml = r#"
288[fusion]
289policy = "direct"
290target = "metal"
291assert_clean = true
292
293[passes]
294dce = true
295constant_folding = false
296
297[precision]
298compute = "f16"
299mixed = "auto"
300"#;
301 let p = CompileProfile::from_toml_str(toml).unwrap();
302 assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
303 assert_eq!(p.fusion.target, FusionTargetKind::Metal);
304 assert!(p.fusion.assert_clean);
305 assert!(!p.passes.constant_folding);
306 assert_eq!(p.precision.compute, PrecisionKind::F16);
307 assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
308 }
309}