1use std::path::Path;
19
20use rlx_flow::CompileProfile;
21use rlx_flow::{
22 BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
23};
24use rlx_ir::logical_kernel::KernelDispatchConfig;
25use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
26use rlx_runtime::Device;
27use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
28
29use crate::weight_loader::WeightLoader;
30
31pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
33
34impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
35 fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
36 if transpose {
37 self.0.take_transposed(key)
38 } else {
39 self.0.take(key)
40 }
41 }
42}
43
44pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
46 CompileProfile::from_toml_path(path).unwrap_or(default)
47}
48
49pub fn profile_near_weights(
51 weights: &Path,
52 profile_file: &str,
53 default: CompileProfile,
54) -> CompileProfile {
55 let dir = weights.parent().unwrap_or_else(|| Path::new("."));
56 load_compile_profile(&dir.join(profile_file), default)
57}
58
59pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
61 opts.dce = profile.passes.dce;
62 opts.constant_folding = profile.passes.constant_folding;
63 opts.verbose = profile.passes.verbose;
64 opts.assert_fusion_clean = profile.fusion.assert_clean;
65 opts.fusion_opts = FusionOptions {
66 skip_fusion: profile.fusion.skip,
67 unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
68 || profile.backend.cpu.unfuse_regions,
69 ..FusionOptions::default()
70 };
71 if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
72 opts.fusion_target = Some(target);
73 }
74 opts.precision = match profile.precision.compute {
75 PrecisionKind::F32 => Precision::F32,
76 PrecisionKind::F16 => Precision::F16,
77 PrecisionKind::Bf16 => Precision::F16, };
79 opts.policy = match profile.precision.mixed {
80 MixedPrecisionKind::None => None,
81 MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
82 };
83}
84
85pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
87 CompileOptions::new().dim_binding(binding)
88}
89
90pub fn compile_options_from_profile(
92 profile: &CompileProfile,
93 device: Device,
94 kernel_dispatch: KernelDispatchConfig,
95) -> CompileOptions {
96 let mut opts = CompileOptions::new();
97 apply_compile_profile(profile, &mut opts);
98 opts.kernel_dispatch = kernel_dispatch;
99 if opts.fusion_target.is_none() {
100 opts.fusion_target = Some(stages::fusion_target_for(device));
101 }
102 opts
103}
104
105pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
107 compile_options_from_profile(profile, device, KernelDispatchConfig::default())
108}
109
110pub fn compile_options_for_packed_gguf_prefill_with_profile(
118 profile: &CompileProfile,
119 device: Device,
120) -> CompileOptions {
121 let mut profile = profile.clone();
122 profile.fusion.skip = true;
123 compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
124}
125
126pub fn compile_options_for_packed_gguf_prefill(device: Device) -> CompileOptions {
128 compile_options_for_packed_gguf_prefill_with_profile(&CompileProfile::llama32_prefill(), device)
129}
130
131pub fn packed_gguf_compile_guard<R, F>(device: Device, f: F) -> R
140where
141 F: FnOnce() -> R,
142{
143 with_packed_gguf_backend_env(device, f)
144}
145
146fn with_packed_gguf_backend_env<R, F>(device: Device, f: F) -> R
147where
148 F: FnOnce() -> R,
149{
150 let mlx_prev = if device == Device::Mlx {
151 let prev = rlx_ir::env::var("RLX_MLX_MODE");
152 rlx_ir::env::set("RLX_MLX_MODE", "lazy");
153 prev
154 } else {
155 None
156 };
157 let metal = device == Device::Metal;
158 if metal {
159 rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
160 }
161 let out = f();
162 if metal {
163 rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
164 }
165 if device == Device::Mlx {
166 match mlx_prev {
167 Some(ref v) => rlx_ir::env::set("RLX_MLX_MODE", v),
168 None => rlx_ir::env::unset("RLX_MLX_MODE"),
169 }
170 }
171 out
172}
173
174pub fn packed_gguf_execution_device(device: Device) -> Device {
180 match device {
181 Device::Cpu | Device::Metal | Device::Mlx => device,
182 Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan => Device::Cpu,
183 _ => device,
184 }
185}
186
187pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
189 compile_options_for_profile(&CompileProfile::sam_encoder(), device)
190}
191
192pub fn compile_options_sam3(device: Device) -> CompileOptions {
194 compile_options_for_profile(&CompileProfile::sam3(), device)
195}
196
197pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
199 compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
200}
201
202pub fn compile_graph_with_profile(
204 device: Device,
205 graph: rlx_ir::Graph,
206 profile: &CompileProfile,
207) -> anyhow::Result<rlx_runtime::CompiledGraph> {
208 use rlx_runtime::Session;
209 let opts = compile_options_for_profile(profile, device);
210 Ok(Session::new(device).compile_with(graph, &opts))
211}
212
213pub fn compile_graph_sam(
215 device: Device,
216 graph: rlx_ir::Graph,
217) -> anyhow::Result<rlx_runtime::CompiledGraph> {
218 compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
219}
220
221pub fn compile_graph_encoder(
223 device: Device,
224 graph: rlx_ir::Graph,
225) -> anyhow::Result<rlx_runtime::CompiledGraph> {
226 compile_graph_with_profile(device, graph, &CompileProfile::encoder())
227}
228
229pub fn compile_graph_qwen3_prefill(
231 device: Device,
232 graph: rlx_ir::Graph,
233) -> anyhow::Result<rlx_runtime::CompiledGraph> {
234 compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
235}
236
237pub fn compile_graph_qwen3_decode(
239 device: Device,
240 graph: rlx_ir::Graph,
241) -> anyhow::Result<rlx_runtime::CompiledGraph> {
242 compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
243}
244
245pub fn compile_graph_qwen35_prefill(
247 device: Device,
248 graph: rlx_ir::Graph,
249) -> anyhow::Result<rlx_runtime::CompiledGraph> {
250 compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
251}
252
253pub fn compile_graph_qwen35_decode(
255 device: Device,
256 graph: rlx_ir::Graph,
257) -> anyhow::Result<rlx_runtime::CompiledGraph> {
258 compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
259}
260
261pub fn compile_graph_gemma_prefill(
263 device: Device,
264 graph: rlx_ir::Graph,
265) -> anyhow::Result<rlx_runtime::CompiledGraph> {
266 compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
267}
268
269pub fn compile_graph_gemma_decode(
271 device: Device,
272 graph: rlx_ir::Graph,
273) -> anyhow::Result<rlx_runtime::CompiledGraph> {
274 compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
275}
276
277pub fn compile_graph_llama32_prefill(
279 device: Device,
280 graph: rlx_ir::Graph,
281) -> anyhow::Result<rlx_runtime::CompiledGraph> {
282 compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
283}
284
285pub fn compile_graph_llama32_decode(
287 device: Device,
288 graph: rlx_ir::Graph,
289) -> anyhow::Result<rlx_runtime::CompiledGraph> {
290 compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
291}
292
293pub fn compile_graph_legacy(
295 device: Device,
296 graph: rlx_ir::Graph,
297) -> anyhow::Result<rlx_runtime::CompiledGraph> {
298 use rlx_runtime::{CompileOptions, Session};
299 Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
300}
301
302pub fn compile_hir_sam(
304 device: Device,
305 hir: rlx_ir::hir::HirModule,
306) -> anyhow::Result<rlx_runtime::CompiledGraph> {
307 compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
308}
309
310pub fn compile_hir_sam3(
312 device: Device,
313 hir: rlx_ir::hir::HirModule,
314) -> anyhow::Result<rlx_runtime::CompiledGraph> {
315 compile_hir_with_profile(device, hir, &CompileProfile::sam3())
316}
317
318pub fn compile_hir_with_profile(
320 device: Device,
321 hir: rlx_ir::hir::HirModule,
322 profile: &CompileProfile,
323) -> anyhow::Result<rlx_runtime::CompiledGraph> {
324 use rlx_runtime::Session;
325 let opts = compile_options_for_profile(profile, device);
326 Ok(Session::new(device).compile_hir_with(hir, &opts)?)
327}
328
329pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
331 compile_options_from_profile(
332 &config.compile_profile(),
333 Device::Cpu,
334 config.component().kernel_dispatch,
335 )
336 .dim_binding(config.dim_binding())
337}
338
339pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
341 compile_options_from_profile(
342 &config.compile_profile(),
343 device,
344 config.component().kernel_dispatch,
345 )
346 .dim_binding(config.dim_binding())
347}
348
349pub fn compile_built_with_config(
351 pipeline: &mut ModelCompilePipeline,
352 built: BuiltModel,
353 config: &ModelExecutionConfig,
354 options: &CompileOptions,
355) -> anyhow::Result<rlx_runtime::CompiledGraph> {
356 let key = config.cache_key();
357 let binding = config.dim_binding();
358 let device = pipeline.device();
359 let (hir, params) = built.into_parts()?;
360 if !pipeline.contains(key) {
363 pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
364 }
365 let mut compiled = if device == Device::Cpu {
366 pipeline
367 .get_or_compile(key, &binding, || hir.clone(), options)?
368 .clone()
369 } else {
370 Session::new(device).compile_hir_with(hir, options)?
371 };
372 for (name, data) in params {
373 compiled.set_param(&name, &data);
374 }
375 Ok(compiled)
376}
377
378fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
379 match kind {
380 FusionTargetKind::Auto => None,
381 FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
382 FusionTargetKind::Metal => Some(FusionTarget::Metal),
383 FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
384 FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
385 FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
386 FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
387 FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
388 }
389}