1use rlx_ir::hir::FusionPolicy;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(default)]
12pub struct CompileProfile {
13 pub fusion: FusionProfile,
14 pub passes: PassProfile,
15 pub precision: PrecisionProfile,
16 #[serde(default)]
17 pub backend: BackendOverrides,
18}
19
20impl Default for CompileProfile {
21 fn default() -> Self {
22 Self::llama32_prefill()
23 }
24}
25
26impl CompileProfile {
27 pub fn llama32_prefill() -> Self {
29 Self {
30 fusion: FusionProfile {
31 policy: FusionPolicyKind::Direct,
32 target: FusionTargetKind::Auto,
33 assert_clean: false,
34 skip: false,
35 },
36 passes: PassProfile::default(),
37 precision: PrecisionProfile::default(),
38 backend: BackendOverrides::default(),
39 }
40 }
41
42 pub fn llama32_decode() -> Self {
44 Self {
45 fusion: FusionProfile {
46 policy: FusionPolicyKind::Fusable,
47 ..FusionProfile::default()
48 },
49 ..Self::llama32_prefill()
50 }
51 }
52
53 pub fn qwen35_prefill() -> Self {
55 Self::llama32_prefill()
56 }
57
58 pub fn qwen35_decode() -> Self {
60 Self::llama32_decode()
61 }
62
63 pub fn qwen3_prefill() -> Self {
65 Self::llama32_prefill()
66 }
67
68 pub fn qwen3_decode() -> Self {
70 Self::llama32_decode()
71 }
72
73 pub fn flux2() -> Self {
75 Self::encoder()
76 }
77
78 pub fn sam_encoder() -> Self {
80 Self::encoder()
81 }
82
83 pub fn sam3() -> Self {
85 Self::sam_encoder()
86 }
87
88 pub fn sam2() -> Self {
90 Self::sam_encoder()
91 }
92
93 pub fn sam2_memory_attention() -> Self {
95 Self {
96 fusion: FusionProfile {
97 skip: true,
98 ..FusionProfile::default()
99 },
100 ..Self::encoder()
101 }
102 }
103
104 pub fn llada2_diffusion() -> Self {
109 Self {
110 fusion: FusionProfile {
111 skip: true,
112 ..FusionProfile::default()
113 },
114 ..Self::encoder()
115 }
116 }
117
118 pub fn encoder() -> Self {
120 Self {
121 fusion: FusionProfile {
122 policy: FusionPolicyKind::Direct,
123 ..FusionProfile::default()
124 },
125 passes: PassProfile {
126 dce: true,
127 constant_folding: true,
128 verbose: false,
129 },
130 precision: PrecisionProfile::default(),
131 backend: BackendOverrides::default(),
132 }
133 }
134
135 pub fn fusion_policy(&self) -> FusionPolicy {
136 self.fusion.policy.into()
137 }
138
139 pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
140 Ok(toml::from_str(s)?)
141 }
142
143 pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
144 let data = std::fs::read_to_string(path)?;
145 Self::from_toml_str(&data)
146 }
147}
148
149#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150#[serde(default)]
151pub struct FusionProfile {
152 pub policy: FusionPolicyKind,
153 pub target: FusionTargetKind,
154 pub assert_clean: bool,
155 pub skip: bool,
156}
157
158impl Default for FusionProfile {
159 fn default() -> Self {
160 Self {
161 policy: FusionPolicyKind::Direct,
162 target: FusionTargetKind::Auto,
163 assert_clean: false,
164 skip: false,
165 }
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
170#[serde(rename_all = "lowercase")]
171pub enum FusionPolicyKind {
172 #[default]
173 Direct,
174 Fusable,
175}
176
177impl From<FusionPolicyKind> for FusionPolicy {
178 fn from(k: FusionPolicyKind) -> Self {
179 match k {
180 FusionPolicyKind::Direct => FusionPolicy::Direct,
181 FusionPolicyKind::Fusable => FusionPolicy::Fusable,
182 }
183 }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
187#[serde(rename_all = "lowercase")]
188pub enum FusionTargetKind {
189 #[default]
190 Auto,
191 Cpu,
192 Metal,
193 Mlx,
194 Cuda,
195 Rocm,
196 Wgpu,
197 Tpu,
198}
199
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
201#[serde(default)]
202pub struct PassProfile {
203 pub dce: bool,
204 pub constant_folding: bool,
205 pub verbose: bool,
206}
207
208impl Default for PassProfile {
209 fn default() -> Self {
210 Self {
211 dce: true,
212 constant_folding: true,
213 verbose: false,
214 }
215 }
216}
217
218#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
219#[serde(default)]
220pub struct PrecisionProfile {
221 pub compute: PrecisionKind,
222 pub mixed: MixedPrecisionKind,
223}
224
225impl Default for PrecisionProfile {
226 fn default() -> Self {
227 Self {
228 compute: PrecisionKind::F32,
229 mixed: MixedPrecisionKind::None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
235#[serde(rename_all = "lowercase")]
236pub enum PrecisionKind {
237 #[default]
238 F32,
239 F16,
240 Bf16,
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
244#[serde(rename_all = "snake_case")]
245pub enum MixedPrecisionKind {
246 #[default]
247 None,
248 Auto,
249}
250
251#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
253pub struct BackendOverrides {
254 #[serde(default)]
255 pub metal: MetalBackendProfile,
256 #[serde(default)]
257 pub cpu: CpuBackendProfile,
258}
259
260#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
261pub struct MetalBackendProfile {
262 pub skip_fusion: bool,
263 pub unfuse_regions: bool,
264}
265
266#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
267pub struct CpuBackendProfile {
268 pub unfuse_regions: bool,
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn parse_profile_toml() {
277 let toml = r#"
278[fusion]
279policy = "direct"
280target = "metal"
281assert_clean = true
282
283[passes]
284dce = true
285constant_folding = false
286
287[precision]
288compute = "f16"
289mixed = "auto"
290"#;
291 let p = CompileProfile::from_toml_str(toml).unwrap();
292 assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
293 assert_eq!(p.fusion.target, FusionTargetKind::Metal);
294 assert!(p.fusion.assert_clean);
295 assert!(!p.passes.constant_folding);
296 assert_eq!(p.precision.compute, PrecisionKind::F16);
297 assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
298 }
299}