Skip to main content

rlx_models_core/
flow_bridge.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Bridge between `rlx-models` loaders/runtime and `rlx-flow`.
17
18use std::path::Path;
19
20use rlx_flow::CompileProfile;
21use rlx_flow::{
22    BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
23};
24use rlx_ir::logical_kernel::KernelDispatchConfig;
25use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
26use rlx_runtime::Device;
27use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
28
29use crate::weight_loader::WeightLoader;
30
31/// Adapt [`WeightLoader`] to [`rlx_flow::WeightSource`].
32pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
33
34impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
35    fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
36        if transpose {
37            self.0.take_transposed(key)
38        } else {
39            self.0.take(key)
40        }
41    }
42}
43
44/// Load a tier-1 profile from disk; fall back to `default` when missing or invalid.
45pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
46    CompileProfile::from_toml_path(path).unwrap_or(default)
47}
48
49/// Load `profile_file` next to `weights` (parent directory); fall back to `default`.
50pub fn profile_near_weights(
51    weights: &Path,
52    profile_file: &str,
53    default: CompileProfile,
54) -> CompileProfile {
55    let dir = weights.parent().unwrap_or_else(|| Path::new("."));
56    load_compile_profile(&dir.join(profile_file), default)
57}
58
59/// Apply tier-1 profile options to runtime compile options.
60pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
61    opts.dce = profile.passes.dce;
62    opts.constant_folding = profile.passes.constant_folding;
63    opts.verbose = profile.passes.verbose;
64    opts.assert_fusion_clean = profile.fusion.assert_clean;
65    opts.fusion_opts = FusionOptions {
66        skip_fusion: profile.fusion.skip,
67        unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
68            || profile.backend.cpu.unfuse_regions,
69        ..FusionOptions::default()
70    };
71    if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
72        opts.fusion_target = Some(target);
73    }
74    opts.precision = match profile.precision.compute {
75        PrecisionKind::F32 => Precision::F32,
76        PrecisionKind::F16 => Precision::F16,
77        PrecisionKind::Bf16 => Precision::F16, // closest supported runtime precision today
78    };
79    opts.policy = match profile.precision.mixed {
80        MixedPrecisionKind::None => None,
81        MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
82    };
83}
84
85/// Dynamic HIR template/specialize — default passes only (matches legacy [`DynamicDimCompileCache`]).
86pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
87    CompileOptions::new().dim_binding(binding)
88}
89
90/// Build [`CompileOptions`] from a tier-1 profile + device fusion target.
91pub fn compile_options_from_profile(
92    profile: &CompileProfile,
93    device: Device,
94    kernel_dispatch: KernelDispatchConfig,
95) -> CompileOptions {
96    let mut opts = CompileOptions::new();
97    apply_compile_profile(profile, &mut opts);
98    opts.kernel_dispatch = kernel_dispatch;
99    if opts.fusion_target.is_none() {
100        opts.fusion_target = Some(stages::fusion_target_for(device));
101    }
102    opts
103}
104
105/// Tier-1 profile + device (no execution variant binding).
106pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
107    compile_options_from_profile(profile, device, KernelDispatchConfig::default())
108}
109
110/// Compile options for packed GGUF K-quant prefill (`Op::DequantMatMul`).
111///
112/// Fusion is disabled on **all** backends: hand-rolled packed graphs emit separate
113/// RMSNorm + `DequantMatMul` nodes, and fusing them into `Op::FusedResidualRmsNorm`
114/// assumes F32 matmul weights — K-quant logits skew badly on CPU (cos ~ -0.2 vs
115/// F32 dequant). GPU backends already needed this for lowering coverage; CPU/Metal/MLX
116/// need it for numerical parity too.
117pub fn compile_options_for_packed_gguf_prefill_with_profile(
118    profile: &CompileProfile,
119    device: Device,
120) -> CompileOptions {
121    let mut profile = profile.clone();
122    profile.fusion.skip = true;
123    compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
124}
125
126/// Llama-shaped LM packed GGUF prefill (MiniCPM5, Llama 3.2, …).
127pub fn compile_options_for_packed_gguf_prefill(device: Device) -> CompileOptions {
128    compile_options_for_packed_gguf_prefill_with_profile(&CompileProfile::llama32_prefill(), device)
129}
130
131/// Backend env overrides while compiling or running packed GGUF graphs.
132///
133/// - **Metal** — `RLX_DISABLE_MPSGRAPH=1` (MPSGraph mishandles GGUF `DequantMatMul`).
134/// - **MLX** — `RLX_MLX_MODE=lazy` (GGUF `DequantMatMul` host-dequant cannot use `mlx::compile`).
135///
136/// MLX mode is baked into the executable at compile time; use this guard around every
137/// `Session::compile_with` / bucketed decode compile for packed GGUF (`rlx-llama32`,
138/// `rlx-qwen3`, `rlx-gemma`, …).
139pub fn packed_gguf_compile_guard<R, F>(device: Device, f: F) -> R
140where
141    F: FnOnce() -> R,
142{
143    with_packed_gguf_backend_env(device, f)
144}
145
146fn with_packed_gguf_backend_env<R, F>(device: Device, f: F) -> R
147where
148    F: FnOnce() -> R,
149{
150    let mlx_prev = if device == Device::Mlx {
151        let prev = rlx_ir::env::var("RLX_MLX_MODE");
152        rlx_ir::env::set("RLX_MLX_MODE", "lazy");
153        prev
154    } else {
155        None
156    };
157    let metal = device == Device::Metal;
158    if metal {
159        rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
160    }
161    let out = f();
162    if metal {
163        rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
164    }
165    if device == Device::Mlx {
166        match mlx_prev {
167            Some(ref v) => rlx_ir::env::set("RLX_MLX_MODE", v),
168            None => rlx_ir::env::unset("RLX_MLX_MODE"),
169        }
170    }
171    out
172}
173
174/// Device used to compile/run packed GGUF graphs.
175///
176/// **CPU**, **Metal**, and **MLX** run natively. MLX uses host-side GGUF dequant per
177/// `DequantMatMul` with lazy eval (see [`packed_gguf_compile_guard`]). **wgpu / CUDA /
178/// ROCm** still fall back to CPU prefill until upstream GPU parity lands.
179pub fn packed_gguf_execution_device(device: Device) -> Device {
180    match device {
181        Device::Cpu | Device::Metal | Device::Mlx => device,
182        Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan => Device::Cpu,
183        _ => device,
184    }
185}
186
187/// SAM encoder / upscale / prompt-mask subgraphs.
188pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
189    compile_options_for_profile(&CompileProfile::sam_encoder(), device)
190}
191
192/// SAM3 detector encoder/decoder layers.
193pub fn compile_options_sam3(device: Device) -> CompileOptions {
194    compile_options_for_profile(&CompileProfile::sam3(), device)
195}
196
197/// SAM2 memory attention (fusion disabled — matches legacy `compile_opts_no_fusion`).
198pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
199    compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
200}
201
202/// Compile a vision subgraph with explicit tier-1 profile options.
203pub fn compile_graph_with_profile(
204    device: Device,
205    graph: rlx_ir::Graph,
206    profile: &CompileProfile,
207) -> anyhow::Result<rlx_runtime::CompiledGraph> {
208    use rlx_runtime::Session;
209    let opts = compile_options_for_profile(profile, device);
210    Ok(Session::new(device).compile_with(graph, &opts))
211}
212
213/// Compile a SAM/SAM2/SAM3 vision subgraph with tier-1 encoder profile options.
214pub fn compile_graph_sam(
215    device: Device,
216    graph: rlx_ir::Graph,
217) -> anyhow::Result<rlx_runtime::CompiledGraph> {
218    compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
219}
220
221/// Bidirectional encoder defaults (BERT, DINOv2, Wav2Vec2, vision towers).
222pub fn compile_graph_encoder(
223    device: Device,
224    graph: rlx_ir::Graph,
225) -> anyhow::Result<rlx_runtime::CompiledGraph> {
226    compile_graph_with_profile(device, graph, &CompileProfile::encoder())
227}
228
229/// Qwen3 prefill / full-sequence graphs.
230pub fn compile_graph_qwen3_prefill(
231    device: Device,
232    graph: rlx_ir::Graph,
233) -> anyhow::Result<rlx_runtime::CompiledGraph> {
234    compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
235}
236
237/// Qwen3 single-token decode graphs.
238pub fn compile_graph_qwen3_decode(
239    device: Device,
240    graph: rlx_ir::Graph,
241) -> anyhow::Result<rlx_runtime::CompiledGraph> {
242    compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
243}
244
245/// Qwen3.5 prefill-cache / predict graphs.
246pub fn compile_graph_qwen35_prefill(
247    device: Device,
248    graph: rlx_ir::Graph,
249) -> anyhow::Result<rlx_runtime::CompiledGraph> {
250    compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
251}
252
253/// Qwen3.5 decode-step graphs.
254pub fn compile_graph_qwen35_decode(
255    device: Device,
256    graph: rlx_ir::Graph,
257) -> anyhow::Result<rlx_runtime::CompiledGraph> {
258    compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
259}
260
261/// Gemma / Gemma 2 prefill graphs.
262pub fn compile_graph_gemma_prefill(
263    device: Device,
264    graph: rlx_ir::Graph,
265) -> anyhow::Result<rlx_runtime::CompiledGraph> {
266    compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
267}
268
269/// Gemma / Gemma 2 decode-step graphs.
270pub fn compile_graph_gemma_decode(
271    device: Device,
272    graph: rlx_ir::Graph,
273) -> anyhow::Result<rlx_runtime::CompiledGraph> {
274    compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
275}
276
277/// Llama 3.2 prefill graphs.
278pub fn compile_graph_llama32_prefill(
279    device: Device,
280    graph: rlx_ir::Graph,
281) -> anyhow::Result<rlx_runtime::CompiledGraph> {
282    compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
283}
284
285/// Llama 3.2 decode graphs.
286pub fn compile_graph_llama32_decode(
287    device: Device,
288    graph: rlx_ir::Graph,
289) -> anyhow::Result<rlx_runtime::CompiledGraph> {
290    compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
291}
292
293/// Unprofiled compile (parity probes / bisect tests).
294pub fn compile_graph_legacy(
295    device: Device,
296    graph: rlx_ir::Graph,
297) -> anyhow::Result<rlx_runtime::CompiledGraph> {
298    use rlx_runtime::{CompileOptions, Session};
299    Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
300}
301
302/// Compile HIR with SAM/SAM3 tier-1 profile options.
303pub fn compile_hir_sam(
304    device: Device,
305    hir: rlx_ir::hir::HirModule,
306) -> anyhow::Result<rlx_runtime::CompiledGraph> {
307    compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
308}
309
310/// Compile HIR with SAM3 tier-1 profile options.
311pub fn compile_hir_sam3(
312    device: Device,
313    hir: rlx_ir::hir::HirModule,
314) -> anyhow::Result<rlx_runtime::CompiledGraph> {
315    compile_hir_with_profile(device, hir, &CompileProfile::sam3())
316}
317
318/// Compile HIR with an explicit tier-1 profile.
319pub fn compile_hir_with_profile(
320    device: Device,
321    hir: rlx_ir::hir::HirModule,
322    profile: &CompileProfile,
323) -> anyhow::Result<rlx_runtime::CompiledGraph> {
324    use rlx_runtime::Session;
325    let opts = compile_options_for_profile(profile, device);
326    Ok(Session::new(device).compile_hir_with(hir, &opts)?)
327}
328
329/// Unified compile options from a [`ModelExecutionConfig`] (variant preset + binding).
330pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
331    compile_options_from_profile(
332        &config.compile_profile(),
333        Device::Cpu,
334        config.component().kernel_dispatch,
335    )
336    .dim_binding(config.dim_binding())
337}
338
339/// Profile from config preset + device fusion target (runner dynamic specialize path).
340pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
341    compile_options_from_profile(
342        &config.compile_profile(),
343        device,
344        config.component().kernel_dispatch,
345    )
346    .dim_binding(config.dim_binding())
347}
348
349/// Compile a built flow through [`ModelCompilePipeline`] for one execution variant.
350pub fn compile_built_with_config(
351    pipeline: &mut ModelCompilePipeline,
352    built: BuiltModel,
353    config: &ModelExecutionConfig,
354    options: &CompileOptions,
355) -> anyhow::Result<rlx_runtime::CompiledGraph> {
356    let key = config.cache_key();
357    let binding = config.dim_binding();
358    let device = pipeline.device();
359    let (hir, params) = built.into_parts()?;
360    // Pipeline caches the variant; owned graphs for GPU backends cannot use
361    // `CompiledGraph::clone` (only CPU implements `clone_box` today).
362    if !pipeline.contains(key) {
363        pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
364    }
365    let mut compiled = if device == Device::Cpu {
366        pipeline
367            .get_or_compile(key, &binding, || hir.clone(), options)?
368            .clone()
369    } else {
370        Session::new(device).compile_hir_with(hir, options)?
371    };
372    for (name, data) in params {
373        compiled.set_param(&name, &data);
374    }
375    Ok(compiled)
376}
377
378fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
379    match kind {
380        FusionTargetKind::Auto => None,
381        FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
382        FusionTargetKind::Metal => Some(FusionTarget::Metal),
383        FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
384        FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
385        FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
386        FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
387        FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
388    }
389}