Skip to main content

rlx_models_core/
flow_bridge.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Bridge between `rlx-models` loaders/runtime and `rlx-flow`.
17
18use std::path::Path;
19
20use rlx_flow::CompileProfile;
21use rlx_flow::{
22    BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
23};
24use rlx_ir::logical_kernel::KernelDispatchConfig;
25use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
26use rlx_runtime::Device;
27use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
28
29use crate::weight_loader::WeightLoader;
30
31/// Adapt [`WeightLoader`] to [`rlx_flow::WeightSource`].
32pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
33
34impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
35    fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
36        if transpose {
37            self.0.take_transposed(key)
38        } else {
39            self.0.take(key)
40        }
41    }
42}
43
44/// Load a tier-1 profile from disk; fall back to `default` when missing or invalid.
45pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
46    CompileProfile::from_toml_path(path).unwrap_or(default)
47}
48
49/// Load `profile_file` next to `weights` (parent directory); fall back to `default`.
50pub fn profile_near_weights(
51    weights: &Path,
52    profile_file: &str,
53    default: CompileProfile,
54) -> CompileProfile {
55    let dir = weights.parent().unwrap_or_else(|| Path::new("."));
56    load_compile_profile(&dir.join(profile_file), default)
57}
58
59/// Apply tier-1 profile options to runtime compile options.
60pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
61    opts.dce = profile.passes.dce;
62    opts.constant_folding = profile.passes.constant_folding;
63    opts.verbose = profile.passes.verbose;
64    opts.assert_fusion_clean = profile.fusion.assert_clean;
65    opts.fusion_opts = FusionOptions {
66        skip_fusion: profile.fusion.skip,
67        unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
68            || profile.backend.cpu.unfuse_regions,
69        ..FusionOptions::default()
70    };
71    if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
72        opts.fusion_target = Some(target);
73    }
74    opts.precision = match profile.precision.compute {
75        PrecisionKind::F32 => Precision::F32,
76        PrecisionKind::F16 => Precision::F16,
77        PrecisionKind::Bf16 => Precision::F16, // closest supported runtime precision today
78    };
79    opts.policy = match profile.precision.mixed {
80        MixedPrecisionKind::None => None,
81        MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
82    };
83}
84
85/// Dynamic HIR template/specialize — default passes only (matches legacy [`DynamicDimCompileCache`]).
86pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
87    CompileOptions::new().dim_binding(binding)
88}
89
90/// Build [`CompileOptions`] from a tier-1 profile + device fusion target.
91pub fn compile_options_from_profile(
92    profile: &CompileProfile,
93    device: Device,
94    kernel_dispatch: KernelDispatchConfig,
95) -> CompileOptions {
96    let mut opts = CompileOptions::new();
97    apply_compile_profile(profile, &mut opts);
98    opts.kernel_dispatch = kernel_dispatch;
99    if opts.fusion_target.is_none() {
100        opts.fusion_target = Some(stages::fusion_target_for(device));
101    }
102    opts
103}
104
105/// Tier-1 profile + device (no execution variant binding).
106pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
107    compile_options_from_profile(profile, device, KernelDispatchConfig::default())
108}
109
110/// SAM encoder / upscale / prompt-mask subgraphs.
111pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
112    compile_options_for_profile(&CompileProfile::sam_encoder(), device)
113}
114
115/// SAM3 detector encoder/decoder layers.
116pub fn compile_options_sam3(device: Device) -> CompileOptions {
117    compile_options_for_profile(&CompileProfile::sam3(), device)
118}
119
120/// SAM2 memory attention (fusion disabled — matches legacy `compile_opts_no_fusion`).
121pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
122    compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
123}
124
125/// Compile a vision subgraph with explicit tier-1 profile options.
126pub fn compile_graph_with_profile(
127    device: Device,
128    graph: rlx_ir::Graph,
129    profile: &CompileProfile,
130) -> anyhow::Result<rlx_runtime::CompiledGraph> {
131    use rlx_runtime::Session;
132    let opts = compile_options_for_profile(profile, device);
133    Ok(Session::new(device).compile_with(graph, &opts))
134}
135
136/// Compile a SAM/SAM2/SAM3 vision subgraph with tier-1 encoder profile options.
137pub fn compile_graph_sam(
138    device: Device,
139    graph: rlx_ir::Graph,
140) -> anyhow::Result<rlx_runtime::CompiledGraph> {
141    compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
142}
143
144/// Bidirectional encoder defaults (BERT, DINOv2, Wav2Vec2, vision towers).
145pub fn compile_graph_encoder(
146    device: Device,
147    graph: rlx_ir::Graph,
148) -> anyhow::Result<rlx_runtime::CompiledGraph> {
149    compile_graph_with_profile(device, graph, &CompileProfile::encoder())
150}
151
152/// Qwen3 prefill / full-sequence graphs.
153pub fn compile_graph_qwen3_prefill(
154    device: Device,
155    graph: rlx_ir::Graph,
156) -> anyhow::Result<rlx_runtime::CompiledGraph> {
157    compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
158}
159
160/// Qwen3 single-token decode graphs.
161pub fn compile_graph_qwen3_decode(
162    device: Device,
163    graph: rlx_ir::Graph,
164) -> anyhow::Result<rlx_runtime::CompiledGraph> {
165    compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
166}
167
168/// Qwen3.5 prefill-cache / predict graphs.
169pub fn compile_graph_qwen35_prefill(
170    device: Device,
171    graph: rlx_ir::Graph,
172) -> anyhow::Result<rlx_runtime::CompiledGraph> {
173    compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
174}
175
176/// Qwen3.5 decode-step graphs.
177pub fn compile_graph_qwen35_decode(
178    device: Device,
179    graph: rlx_ir::Graph,
180) -> anyhow::Result<rlx_runtime::CompiledGraph> {
181    compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
182}
183
184/// Gemma / Gemma 2 prefill graphs.
185pub fn compile_graph_gemma_prefill(
186    device: Device,
187    graph: rlx_ir::Graph,
188) -> anyhow::Result<rlx_runtime::CompiledGraph> {
189    compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
190}
191
192/// Gemma / Gemma 2 decode-step graphs.
193pub fn compile_graph_gemma_decode(
194    device: Device,
195    graph: rlx_ir::Graph,
196) -> anyhow::Result<rlx_runtime::CompiledGraph> {
197    compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
198}
199
200/// Llama 3.2 prefill graphs.
201pub fn compile_graph_llama32_prefill(
202    device: Device,
203    graph: rlx_ir::Graph,
204) -> anyhow::Result<rlx_runtime::CompiledGraph> {
205    compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
206}
207
208/// Llama 3.2 decode graphs.
209pub fn compile_graph_llama32_decode(
210    device: Device,
211    graph: rlx_ir::Graph,
212) -> anyhow::Result<rlx_runtime::CompiledGraph> {
213    compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
214}
215
216/// Unprofiled compile (parity probes / bisect tests).
217pub fn compile_graph_legacy(
218    device: Device,
219    graph: rlx_ir::Graph,
220) -> anyhow::Result<rlx_runtime::CompiledGraph> {
221    use rlx_runtime::{CompileOptions, Session};
222    Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
223}
224
225/// Compile HIR with SAM/SAM3 tier-1 profile options.
226pub fn compile_hir_sam(
227    device: Device,
228    hir: rlx_ir::hir::HirModule,
229) -> anyhow::Result<rlx_runtime::CompiledGraph> {
230    compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
231}
232
233/// Compile HIR with SAM3 tier-1 profile options.
234pub fn compile_hir_sam3(
235    device: Device,
236    hir: rlx_ir::hir::HirModule,
237) -> anyhow::Result<rlx_runtime::CompiledGraph> {
238    compile_hir_with_profile(device, hir, &CompileProfile::sam3())
239}
240
241/// Compile HIR with an explicit tier-1 profile.
242pub fn compile_hir_with_profile(
243    device: Device,
244    hir: rlx_ir::hir::HirModule,
245    profile: &CompileProfile,
246) -> anyhow::Result<rlx_runtime::CompiledGraph> {
247    use rlx_runtime::Session;
248    let opts = compile_options_for_profile(profile, device);
249    Ok(Session::new(device).compile_hir_with(hir, &opts)?)
250}
251
252/// Unified compile options from a [`ModelExecutionConfig`] (variant preset + binding).
253pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
254    compile_options_from_profile(
255        &config.compile_profile(),
256        Device::Cpu,
257        config.component().kernel_dispatch,
258    )
259    .dim_binding(config.dim_binding())
260}
261
262/// Profile from config preset + device fusion target (runner dynamic specialize path).
263pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
264    compile_options_from_profile(
265        &config.compile_profile(),
266        device,
267        config.component().kernel_dispatch,
268    )
269    .dim_binding(config.dim_binding())
270}
271
272/// Compile a built flow through [`ModelCompilePipeline`] for one execution variant.
273pub fn compile_built_with_config(
274    pipeline: &mut ModelCompilePipeline,
275    built: BuiltModel,
276    config: &ModelExecutionConfig,
277    options: &CompileOptions,
278) -> anyhow::Result<rlx_runtime::CompiledGraph> {
279    let key = config.cache_key();
280    let binding = config.dim_binding();
281    let device = pipeline.device();
282    let (hir, params) = built.into_parts()?;
283    // Pipeline caches the variant; owned graphs for GPU backends cannot use
284    // `CompiledGraph::clone` (only CPU implements `clone_box` today).
285    if !pipeline.contains(key) {
286        pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
287    }
288    let mut compiled = if device == Device::Cpu {
289        pipeline
290            .get_or_compile(key, &binding, || hir.clone(), options)?
291            .clone()
292    } else {
293        Session::new(device).compile_hir_with(hir, options)?
294    };
295    for (name, data) in params {
296        compiled.set_param(&name, &data);
297    }
298    Ok(compiled)
299}
300
301fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
302    match kind {
303        FusionTargetKind::Auto => None,
304        FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
305        FusionTargetKind::Metal => Some(FusionTarget::Metal),
306        FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
307        FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
308        FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
309        FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
310        FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
311    }
312}