1use std::path::Path;
19
20use rlx_flow::CompileProfile;
21use rlx_flow::{
22 BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
23};
24use rlx_ir::logical_kernel::KernelDispatchConfig;
25use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
26use rlx_runtime::Device;
27use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
28
29use crate::weight_loader::WeightLoader;
30
31pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
33
34impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
35 fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
36 if transpose {
37 self.0.take_transposed(key)
38 } else {
39 self.0.take(key)
40 }
41 }
42}
43
44pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
46 CompileProfile::from_toml_path(path).unwrap_or(default)
47}
48
49pub fn profile_near_weights(
51 weights: &Path,
52 profile_file: &str,
53 default: CompileProfile,
54) -> CompileProfile {
55 let dir = weights.parent().unwrap_or_else(|| Path::new("."));
56 load_compile_profile(&dir.join(profile_file), default)
57}
58
59pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
61 opts.dce = profile.passes.dce;
62 opts.constant_folding = profile.passes.constant_folding;
63 opts.verbose = profile.passes.verbose;
64 opts.assert_fusion_clean = profile.fusion.assert_clean;
65 opts.fusion_opts = FusionOptions {
66 skip_fusion: profile.fusion.skip,
67 unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
68 || profile.backend.cpu.unfuse_regions,
69 ..FusionOptions::default()
70 };
71 if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
72 opts.fusion_target = Some(target);
73 }
74 opts.precision = match profile.precision.compute {
75 PrecisionKind::F32 => Precision::F32,
76 PrecisionKind::F16 => Precision::F16,
77 PrecisionKind::Bf16 => Precision::F16, };
79 opts.policy = match profile.precision.mixed {
80 MixedPrecisionKind::None => None,
81 MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
82 };
83}
84
85pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
87 CompileOptions::new().dim_binding(binding)
88}
89
90pub fn compile_options_from_profile(
92 profile: &CompileProfile,
93 device: Device,
94 kernel_dispatch: KernelDispatchConfig,
95) -> CompileOptions {
96 let mut opts = CompileOptions::new();
97 apply_compile_profile(profile, &mut opts);
98 opts.kernel_dispatch = kernel_dispatch;
99 if opts.fusion_target.is_none() {
100 opts.fusion_target = Some(stages::fusion_target_for(device));
101 }
102 opts
103}
104
105pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
107 compile_options_from_profile(profile, device, KernelDispatchConfig::default())
108}
109
110pub fn compile_options_for_packed_gguf_prefill_with_profile(
116 profile: &CompileProfile,
117 device: Device,
118) -> CompileOptions {
119 let mut profile = profile.clone();
120 if matches!(
121 device,
122 Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan
123 ) {
124 profile.fusion.skip = true;
125 }
126 compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
127}
128
129pub fn compile_options_for_packed_gguf_prefill(device: Device) -> CompileOptions {
131 compile_options_for_packed_gguf_prefill_with_profile(
132 &CompileProfile::llama32_prefill(),
133 device,
134 )
135}
136
137pub fn packed_gguf_compile_guard<R, F>(device: Device, f: F) -> R
145where
146 F: FnOnce() -> R,
147{
148 with_packed_gguf_backend_env(device, f)
149}
150
151fn with_packed_gguf_backend_env<R, F>(device: Device, f: F) -> R
152where
153 F: FnOnce() -> R,
154{
155 let mlx_prev = if device == Device::Mlx {
156 let prev = rlx_ir::env::var("RLX_MLX_MODE");
157 rlx_ir::env::set("RLX_MLX_MODE", "eager");
158 prev
159 } else {
160 None
161 };
162 let metal = device == Device::Metal;
163 if metal {
164 rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
165 }
166 let out = f();
167 if metal {
168 rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
169 }
170 if device == Device::Mlx {
171 match mlx_prev {
172 Some(ref v) => rlx_ir::env::set("RLX_MLX_MODE", v),
173 None => rlx_ir::env::unset("RLX_MLX_MODE"),
174 }
175 }
176 out
177}
178
179pub fn packed_gguf_execution_device(device: Device) -> Device {
182 match device {
183 Device::Cpu | Device::Metal => device,
184 Device::Mlx | Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan => Device::Cpu,
185 _ => device,
186 }
187}
188
189pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
191 compile_options_for_profile(&CompileProfile::sam_encoder(), device)
192}
193
194pub fn compile_options_sam3(device: Device) -> CompileOptions {
196 compile_options_for_profile(&CompileProfile::sam3(), device)
197}
198
199pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
201 compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
202}
203
204pub fn compile_graph_with_profile(
206 device: Device,
207 graph: rlx_ir::Graph,
208 profile: &CompileProfile,
209) -> anyhow::Result<rlx_runtime::CompiledGraph> {
210 use rlx_runtime::Session;
211 let opts = compile_options_for_profile(profile, device);
212 Ok(Session::new(device).compile_with(graph, &opts))
213}
214
215pub fn compile_graph_sam(
217 device: Device,
218 graph: rlx_ir::Graph,
219) -> anyhow::Result<rlx_runtime::CompiledGraph> {
220 compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
221}
222
223pub fn compile_graph_encoder(
225 device: Device,
226 graph: rlx_ir::Graph,
227) -> anyhow::Result<rlx_runtime::CompiledGraph> {
228 compile_graph_with_profile(device, graph, &CompileProfile::encoder())
229}
230
231pub fn compile_graph_qwen3_prefill(
233 device: Device,
234 graph: rlx_ir::Graph,
235) -> anyhow::Result<rlx_runtime::CompiledGraph> {
236 compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
237}
238
239pub fn compile_graph_qwen3_decode(
241 device: Device,
242 graph: rlx_ir::Graph,
243) -> anyhow::Result<rlx_runtime::CompiledGraph> {
244 compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
245}
246
247pub fn compile_graph_qwen35_prefill(
249 device: Device,
250 graph: rlx_ir::Graph,
251) -> anyhow::Result<rlx_runtime::CompiledGraph> {
252 compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
253}
254
255pub fn compile_graph_qwen35_decode(
257 device: Device,
258 graph: rlx_ir::Graph,
259) -> anyhow::Result<rlx_runtime::CompiledGraph> {
260 compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
261}
262
263pub fn compile_graph_gemma_prefill(
265 device: Device,
266 graph: rlx_ir::Graph,
267) -> anyhow::Result<rlx_runtime::CompiledGraph> {
268 compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
269}
270
271pub fn compile_graph_gemma_decode(
273 device: Device,
274 graph: rlx_ir::Graph,
275) -> anyhow::Result<rlx_runtime::CompiledGraph> {
276 compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
277}
278
279pub fn compile_graph_llama32_prefill(
281 device: Device,
282 graph: rlx_ir::Graph,
283) -> anyhow::Result<rlx_runtime::CompiledGraph> {
284 compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
285}
286
287pub fn compile_graph_llama32_decode(
289 device: Device,
290 graph: rlx_ir::Graph,
291) -> anyhow::Result<rlx_runtime::CompiledGraph> {
292 compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
293}
294
295pub fn compile_graph_legacy(
297 device: Device,
298 graph: rlx_ir::Graph,
299) -> anyhow::Result<rlx_runtime::CompiledGraph> {
300 use rlx_runtime::{CompileOptions, Session};
301 Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
302}
303
304pub fn compile_hir_sam(
306 device: Device,
307 hir: rlx_ir::hir::HirModule,
308) -> anyhow::Result<rlx_runtime::CompiledGraph> {
309 compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
310}
311
312pub fn compile_hir_sam3(
314 device: Device,
315 hir: rlx_ir::hir::HirModule,
316) -> anyhow::Result<rlx_runtime::CompiledGraph> {
317 compile_hir_with_profile(device, hir, &CompileProfile::sam3())
318}
319
320pub fn compile_hir_with_profile(
322 device: Device,
323 hir: rlx_ir::hir::HirModule,
324 profile: &CompileProfile,
325) -> anyhow::Result<rlx_runtime::CompiledGraph> {
326 use rlx_runtime::Session;
327 let opts = compile_options_for_profile(profile, device);
328 Ok(Session::new(device).compile_hir_with(hir, &opts)?)
329}
330
331pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
333 compile_options_from_profile(
334 &config.compile_profile(),
335 Device::Cpu,
336 config.component().kernel_dispatch,
337 )
338 .dim_binding(config.dim_binding())
339}
340
341pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
343 compile_options_from_profile(
344 &config.compile_profile(),
345 device,
346 config.component().kernel_dispatch,
347 )
348 .dim_binding(config.dim_binding())
349}
350
351pub fn compile_built_with_config(
353 pipeline: &mut ModelCompilePipeline,
354 built: BuiltModel,
355 config: &ModelExecutionConfig,
356 options: &CompileOptions,
357) -> anyhow::Result<rlx_runtime::CompiledGraph> {
358 let key = config.cache_key();
359 let binding = config.dim_binding();
360 let device = pipeline.device();
361 let (hir, params) = built.into_parts()?;
362 if !pipeline.contains(key) {
365 pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
366 }
367 let mut compiled = if device == Device::Cpu {
368 pipeline
369 .get_or_compile(key, &binding, || hir.clone(), options)?
370 .clone()
371 } else {
372 Session::new(device).compile_hir_with(hir, options)?
373 };
374 for (name, data) in params {
375 compiled.set_param(&name, &data);
376 }
377 Ok(compiled)
378}
379
380fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
381 match kind {
382 FusionTargetKind::Auto => None,
383 FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
384 FusionTargetKind::Metal => Some(FusionTarget::Metal),
385 FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
386 FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
387 FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
388 FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
389 FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
390 }
391}