Skip to main content

rlx_models_core/
flow_bridge.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Bridge between `rlx-models` loaders/runtime and `rlx-flow`.
17
18use std::path::Path;
19
20use rlx_flow::CompileProfile;
21use rlx_flow::{
22    BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
23};
24use rlx_ir::logical_kernel::KernelDispatchConfig;
25use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
26use rlx_runtime::Device;
27use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
28
29use crate::weight_loader::WeightLoader;
30
31/// Adapt [`WeightLoader`] to [`rlx_flow::WeightSource`].
32pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
33
34impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
35    fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
36        if transpose {
37            self.0.take_transposed(key)
38        } else {
39            self.0.take(key)
40        }
41    }
42}
43
44/// Load a tier-1 profile from disk; fall back to `default` when missing or invalid.
45pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
46    CompileProfile::from_toml_path(path).unwrap_or(default)
47}
48
49/// Load `profile_file` next to `weights` (parent directory); fall back to `default`.
50pub fn profile_near_weights(
51    weights: &Path,
52    profile_file: &str,
53    default: CompileProfile,
54) -> CompileProfile {
55    let dir = weights.parent().unwrap_or_else(|| Path::new("."));
56    load_compile_profile(&dir.join(profile_file), default)
57}
58
59/// Apply tier-1 profile options to runtime compile options.
60pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
61    opts.dce = profile.passes.dce;
62    opts.constant_folding = profile.passes.constant_folding;
63    opts.verbose = profile.passes.verbose;
64    opts.assert_fusion_clean = profile.fusion.assert_clean;
65    opts.fusion_opts = FusionOptions {
66        skip_fusion: profile.fusion.skip,
67        unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
68            || profile.backend.cpu.unfuse_regions,
69        ..FusionOptions::default()
70    };
71    if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
72        opts.fusion_target = Some(target);
73    }
74    opts.precision = match profile.precision.compute {
75        PrecisionKind::F32 => Precision::F32,
76        PrecisionKind::F16 => Precision::F16,
77        PrecisionKind::Bf16 => Precision::F16, // closest supported runtime precision today
78    };
79    opts.policy = match profile.precision.mixed {
80        MixedPrecisionKind::None => None,
81        MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
82    };
83}
84
85/// Dynamic HIR template/specialize — default passes only (matches legacy [`DynamicDimCompileCache`]).
86pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
87    CompileOptions::new().dim_binding(binding)
88}
89
90/// Build [`CompileOptions`] from a tier-1 profile + device fusion target.
91pub fn compile_options_from_profile(
92    profile: &CompileProfile,
93    device: Device,
94    kernel_dispatch: KernelDispatchConfig,
95) -> CompileOptions {
96    let mut opts = CompileOptions::new();
97    apply_compile_profile(profile, &mut opts);
98    opts.kernel_dispatch = kernel_dispatch;
99    if opts.fusion_target.is_none() {
100        opts.fusion_target = Some(stages::fusion_target_for(device));
101    }
102    opts
103}
104
105/// Tier-1 profile + device (no execution variant binding).
106pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
107    compile_options_from_profile(profile, device, KernelDispatchConfig::default())
108}
109
110/// Compile options for packed GGUF K-quant prefill (`Op::DequantMatMul`).
111///
112/// On **wgpu / CUDA / ROCm / Vulkan** (crates.io `rlx-*` 0.2.1), disable fusion so
113/// graphs do not emit `Op::FusedResidualRmsNorm` — those backends only lower a MatMul +
114/// elementwise subset today (same approach as [`CompileProfile::llada2_diffusion`]).
115pub fn compile_options_for_packed_gguf_prefill_with_profile(
116    profile: &CompileProfile,
117    device: Device,
118) -> CompileOptions {
119    let mut profile = profile.clone();
120    if matches!(
121        device,
122        Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan
123    ) {
124        profile.fusion.skip = true;
125    }
126    compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
127}
128
129/// Llama-shaped LM packed GGUF prefill (MiniCPM5, Llama 3.2, …).
130pub fn compile_options_for_packed_gguf_prefill(device: Device) -> CompileOptions {
131    compile_options_for_packed_gguf_prefill_with_profile(
132        &CompileProfile::llama32_prefill(),
133        device,
134    )
135}
136
137/// Backend env overrides while compiling packed GGUF graphs.
138///
139/// - **Metal** — `RLX_DISABLE_MPSGRAPH=1` (MPSGraph mishandles GGUF `DequantMatMul`).
140/// - **MLX** — `RLX_MLX_MODE=eager` (`DequantMatMul` host-dequant must not run under `mlx::compile`).
141///
142/// Use this around `Session::compile_with` for every packed GGUF prefill (`rlx-llama32`,
143/// `rlx-qwen3`, `rlx-gemma`, …).
144pub fn packed_gguf_compile_guard<R, F>(device: Device, f: F) -> R
145where
146    F: FnOnce() -> R,
147{
148    with_packed_gguf_backend_env(device, f)
149}
150
151fn with_packed_gguf_backend_env<R, F>(device: Device, f: F) -> R
152where
153    F: FnOnce() -> R,
154{
155    let mlx_prev = if device == Device::Mlx {
156        let prev = rlx_ir::env::var("RLX_MLX_MODE");
157        rlx_ir::env::set("RLX_MLX_MODE", "eager");
158        prev
159    } else {
160        None
161    };
162    let metal = device == Device::Metal;
163    if metal {
164        rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
165    }
166    let out = f();
167    if metal {
168        rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
169    }
170    if device == Device::Mlx {
171        match mlx_prev {
172            Some(ref v) => rlx_ir::env::set("RLX_MLX_MODE", v),
173            None => rlx_ir::env::unset("RLX_MLX_MODE"),
174        }
175    }
176    out
177}
178
179/// Device used to compile/run packed GGUF prefill when the requested GPU backend
180/// is not yet parity-clean on crates.io `rlx` 0.2.1 (MLX / wgpu / CUDA / ROCm).
181pub fn packed_gguf_execution_device(device: Device) -> Device {
182    match device {
183        Device::Cpu | Device::Metal => device,
184        Device::Mlx | Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan => Device::Cpu,
185        _ => device,
186    }
187}
188
189/// SAM encoder / upscale / prompt-mask subgraphs.
190pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
191    compile_options_for_profile(&CompileProfile::sam_encoder(), device)
192}
193
194/// SAM3 detector encoder/decoder layers.
195pub fn compile_options_sam3(device: Device) -> CompileOptions {
196    compile_options_for_profile(&CompileProfile::sam3(), device)
197}
198
199/// SAM2 memory attention (fusion disabled — matches legacy `compile_opts_no_fusion`).
200pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
201    compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
202}
203
204/// Compile a vision subgraph with explicit tier-1 profile options.
205pub fn compile_graph_with_profile(
206    device: Device,
207    graph: rlx_ir::Graph,
208    profile: &CompileProfile,
209) -> anyhow::Result<rlx_runtime::CompiledGraph> {
210    use rlx_runtime::Session;
211    let opts = compile_options_for_profile(profile, device);
212    Ok(Session::new(device).compile_with(graph, &opts))
213}
214
215/// Compile a SAM/SAM2/SAM3 vision subgraph with tier-1 encoder profile options.
216pub fn compile_graph_sam(
217    device: Device,
218    graph: rlx_ir::Graph,
219) -> anyhow::Result<rlx_runtime::CompiledGraph> {
220    compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
221}
222
223/// Bidirectional encoder defaults (BERT, DINOv2, Wav2Vec2, vision towers).
224pub fn compile_graph_encoder(
225    device: Device,
226    graph: rlx_ir::Graph,
227) -> anyhow::Result<rlx_runtime::CompiledGraph> {
228    compile_graph_with_profile(device, graph, &CompileProfile::encoder())
229}
230
231/// Qwen3 prefill / full-sequence graphs.
232pub fn compile_graph_qwen3_prefill(
233    device: Device,
234    graph: rlx_ir::Graph,
235) -> anyhow::Result<rlx_runtime::CompiledGraph> {
236    compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
237}
238
239/// Qwen3 single-token decode graphs.
240pub fn compile_graph_qwen3_decode(
241    device: Device,
242    graph: rlx_ir::Graph,
243) -> anyhow::Result<rlx_runtime::CompiledGraph> {
244    compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
245}
246
247/// Qwen3.5 prefill-cache / predict graphs.
248pub fn compile_graph_qwen35_prefill(
249    device: Device,
250    graph: rlx_ir::Graph,
251) -> anyhow::Result<rlx_runtime::CompiledGraph> {
252    compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
253}
254
255/// Qwen3.5 decode-step graphs.
256pub fn compile_graph_qwen35_decode(
257    device: Device,
258    graph: rlx_ir::Graph,
259) -> anyhow::Result<rlx_runtime::CompiledGraph> {
260    compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
261}
262
263/// Gemma / Gemma 2 prefill graphs.
264pub fn compile_graph_gemma_prefill(
265    device: Device,
266    graph: rlx_ir::Graph,
267) -> anyhow::Result<rlx_runtime::CompiledGraph> {
268    compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
269}
270
271/// Gemma / Gemma 2 decode-step graphs.
272pub fn compile_graph_gemma_decode(
273    device: Device,
274    graph: rlx_ir::Graph,
275) -> anyhow::Result<rlx_runtime::CompiledGraph> {
276    compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
277}
278
279/// Llama 3.2 prefill graphs.
280pub fn compile_graph_llama32_prefill(
281    device: Device,
282    graph: rlx_ir::Graph,
283) -> anyhow::Result<rlx_runtime::CompiledGraph> {
284    compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
285}
286
287/// Llama 3.2 decode graphs.
288pub fn compile_graph_llama32_decode(
289    device: Device,
290    graph: rlx_ir::Graph,
291) -> anyhow::Result<rlx_runtime::CompiledGraph> {
292    compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
293}
294
295/// Unprofiled compile (parity probes / bisect tests).
296pub fn compile_graph_legacy(
297    device: Device,
298    graph: rlx_ir::Graph,
299) -> anyhow::Result<rlx_runtime::CompiledGraph> {
300    use rlx_runtime::{CompileOptions, Session};
301    Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
302}
303
304/// Compile HIR with SAM/SAM3 tier-1 profile options.
305pub fn compile_hir_sam(
306    device: Device,
307    hir: rlx_ir::hir::HirModule,
308) -> anyhow::Result<rlx_runtime::CompiledGraph> {
309    compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
310}
311
312/// Compile HIR with SAM3 tier-1 profile options.
313pub fn compile_hir_sam3(
314    device: Device,
315    hir: rlx_ir::hir::HirModule,
316) -> anyhow::Result<rlx_runtime::CompiledGraph> {
317    compile_hir_with_profile(device, hir, &CompileProfile::sam3())
318}
319
320/// Compile HIR with an explicit tier-1 profile.
321pub fn compile_hir_with_profile(
322    device: Device,
323    hir: rlx_ir::hir::HirModule,
324    profile: &CompileProfile,
325) -> anyhow::Result<rlx_runtime::CompiledGraph> {
326    use rlx_runtime::Session;
327    let opts = compile_options_for_profile(profile, device);
328    Ok(Session::new(device).compile_hir_with(hir, &opts)?)
329}
330
331/// Unified compile options from a [`ModelExecutionConfig`] (variant preset + binding).
332pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
333    compile_options_from_profile(
334        &config.compile_profile(),
335        Device::Cpu,
336        config.component().kernel_dispatch,
337    )
338    .dim_binding(config.dim_binding())
339}
340
341/// Profile from config preset + device fusion target (runner dynamic specialize path).
342pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
343    compile_options_from_profile(
344        &config.compile_profile(),
345        device,
346        config.component().kernel_dispatch,
347    )
348    .dim_binding(config.dim_binding())
349}
350
351/// Compile a built flow through [`ModelCompilePipeline`] for one execution variant.
352pub fn compile_built_with_config(
353    pipeline: &mut ModelCompilePipeline,
354    built: BuiltModel,
355    config: &ModelExecutionConfig,
356    options: &CompileOptions,
357) -> anyhow::Result<rlx_runtime::CompiledGraph> {
358    let key = config.cache_key();
359    let binding = config.dim_binding();
360    let device = pipeline.device();
361    let (hir, params) = built.into_parts()?;
362    // Pipeline caches the variant; owned graphs for GPU backends cannot use
363    // `CompiledGraph::clone` (only CPU implements `clone_box` today).
364    if !pipeline.contains(key) {
365        pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
366    }
367    let mut compiled = if device == Device::Cpu {
368        pipeline
369            .get_or_compile(key, &binding, || hir.clone(), options)?
370            .clone()
371    } else {
372        Session::new(device).compile_hir_with(hir, options)?
373    };
374    for (name, data) in params {
375        compiled.set_param(&name, &data);
376    }
377    Ok(compiled)
378}
379
380fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
381    match kind {
382        FusionTargetKind::Auto => None,
383        FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
384        FusionTargetKind::Metal => Some(FusionTarget::Metal),
385        FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
386        FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
387        FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
388        FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
389        FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
390    }
391}