Skip to main content

rlx_models_core/
flow_util.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared helpers for tier-0 model flow migration.
17
18use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, WeightSource};
20use rlx_ir::{Graph, HirModule};
21use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache};
22use rlx_runtime::{CompileOptions, CompiledGraph, Device, Session};
23
24use crate::weight_map::WeightMap;
25
26/// Adapt in-memory [`WeightMap`] to [`WeightSource`].
27pub struct WeightMapSource<'a>(pub &'a mut WeightMap);
28
29impl WeightSource for WeightMapSource<'_> {
30    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
31        let (data, shape) = self.0.take(key)?;
32        if !transpose {
33            return Ok((data, shape));
34        }
35        if shape.len() != 2 {
36            anyhow::bail!("transpose requires rank-2 weight: {key}");
37        }
38        let rows = shape[0];
39        let cols = shape[1];
40        let mut out = vec![0f32; rows * cols];
41        for r in 0..rows {
42            for c in 0..cols {
43                out[c * rows + r] = data[r * cols + c];
44            }
45        }
46        Ok((out, vec![cols, rows]))
47    }
48
49    fn has(&self, key: &str) -> bool {
50        self.0.has(key)
51    }
52}
53
54pub fn built_from_hir(
55    hir: HirModule,
56    params: std::collections::HashMap<String, Vec<f32>>,
57) -> Result<BuiltModel> {
58    BuiltModel::from_hir(hir, params)
59}
60
61pub fn built_from_graph(
62    graph: Graph,
63    params: std::collections::HashMap<String, Vec<f32>>,
64) -> Result<BuiltModel> {
65    BuiltModel::from_graph(graph, params)
66}
67
68pub fn built_from_hir_with_profile(
69    hir: HirModule,
70    params: std::collections::HashMap<String, Vec<f32>>,
71    profile: CompileProfile,
72) -> Result<BuiltModel> {
73    let mut built = BuiltModel::from_hir(hir, params)?;
74    built.profile = profile;
75    Ok(built)
76}
77
78/// Build a flow and return `(Graph, params)` — preferred compile entry point.
79pub fn graph_from_built(
80    built: BuiltModel,
81) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)> {
82    built.into_graph_parts()
83}
84
85/// Lower an existing HIR module through [`BuiltModel`] (utility for HIR-first builders).
86pub fn graph_from_hir(
87    hir: HirModule,
88    params: std::collections::HashMap<String, Vec<f32>>,
89) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)> {
90    graph_from_built(built_from_hir(hir, params)?)
91}
92
93/// Build via flow and lower to MIR graph + params.
94pub fn build_graph<F>(
95    mut build: F,
96    weights: &mut WeightMap,
97) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)>
98where
99    F: FnMut(&mut WeightMapSource<'_>) -> Result<BuiltModel>,
100{
101    let built = build(&mut WeightMapSource(weights))?;
102    graph_from_built(built)
103}
104
105/// Compile helper — build graph + params from a flow, then compile with a configured session.
106pub fn compile_from_flow<F>(
107    mut build: F,
108    weights: &mut WeightMap,
109    configure: impl FnOnce(Session) -> Session,
110) -> Result<CompiledGraph>
111where
112    F: FnMut(&mut WeightMapSource<'_>) -> Result<BuiltModel>,
113{
114    let built = build(&mut WeightMapSource(weights))?;
115    let profile = built.profile().clone();
116    let typed = built.typed_params.clone();
117    let (graph, params) = built.into_graph_parts()?;
118    let options = crate::flow_bridge::compile_options_for_profile(&profile, Device::Cpu);
119    let session = configure(Session::new(Device::Cpu));
120    let mut compiled = session.compile_with(graph, &options);
121    attach_built_params(&mut compiled, params, &typed);
122    Ok(compiled)
123}
124
125/// Attach f32 and typed (U8 packed GGUF) params after compile.
126pub fn attach_built_params(
127    compiled: &mut CompiledGraph,
128    params: std::collections::HashMap<String, Vec<f32>>,
129    typed_params: &[(String, Vec<u8>, rlx_ir::DType)],
130) {
131    for (name, data) in params {
132        compiled.set_param(&name, &data);
133    }
134    for (name, data, dtype) in typed_params {
135        compiled.set_param_typed(name, data, *dtype);
136    }
137}
138
139/// Compile a [`BuiltModel`] on the given device using its embedded profile.
140pub fn compile_built(built: BuiltModel, device: Device) -> Result<CompiledGraph> {
141    let profile = built.profile().clone();
142    let typed = built.typed_params.clone();
143    let (graph, params) = built.into_graph_parts()?;
144    let options = crate::flow_bridge::compile_options_for_profile(&profile, device);
145    let mut compiled = Session::new(device).compile_with(graph, &options);
146    attach_built_params(&mut compiled, params, &typed);
147    Ok(compiled)
148}
149
150/// Compile a [`BuiltModel`] on CPU with default options (embedding quick-check tests).
151pub fn compile_built_cpu(built: BuiltModel) -> Result<CompiledGraph> {
152    compile_built(built, Device::Cpu)
153}
154
155/// Unprofiled compile + params (layer probes; matches historical `Session::compile`).
156pub fn compile_graph_legacy_with_params(
157    device: Device,
158    graph: Graph,
159    params: std::collections::HashMap<String, Vec<f32>>,
160) -> Result<CompiledGraph> {
161    let mut compiled = crate::flow_bridge::compile_graph_legacy(device, graph)?;
162    for (name, data) in params {
163        compiled.set_param(&name, data.as_slice());
164    }
165    Ok(compiled)
166}
167
168/// Llama 3.2 prefill + params.
169pub fn compile_graph_gemma_prefill_with_params(
170    device: Device,
171    graph: Graph,
172    params: std::collections::HashMap<String, Vec<f32>>,
173) -> Result<CompiledGraph> {
174    compile_graph_profile(device, graph, params, &CompileProfile::gemma_prefill())
175}
176
177pub fn compile_graph_gemma_decode_with_params(
178    device: Device,
179    graph: Graph,
180    params: std::collections::HashMap<String, Vec<f32>>,
181) -> Result<CompiledGraph> {
182    compile_graph_profile(device, graph, params, &CompileProfile::gemma_decode())
183}
184
185pub fn compile_graph_llama32_prefill_with_params(
186    device: Device,
187    graph: Graph,
188    params: std::collections::HashMap<String, Vec<f32>>,
189) -> Result<CompiledGraph> {
190    compile_graph_profile(device, graph, params, &CompileProfile::llama32_prefill())
191}
192
193/// Llama 3.2 decode + params.
194pub fn compile_graph_llama32_decode_with_params(
195    device: Device,
196    graph: Graph,
197    params: std::collections::HashMap<String, Vec<f32>>,
198) -> Result<CompiledGraph> {
199    compile_graph_profile(device, graph, params, &CompileProfile::llama32_decode())
200}
201
202/// Legacy default compile options (plumbing tests with hand-built graphs).
203pub fn compile_graph_default_with_params(
204    device: Device,
205    graph: Graph,
206    params: std::collections::HashMap<String, Vec<f32>>,
207) -> Result<CompiledGraph> {
208    compile_graph_profile(device, graph, params, &CompileProfile::default())
209}
210
211/// Lower a graph with a tier-1 profile and attach params (tests / examples).
212pub fn compile_graph_profile(
213    device: Device,
214    graph: Graph,
215    params: std::collections::HashMap<String, Vec<f32>>,
216    profile: &CompileProfile,
217) -> Result<CompiledGraph> {
218    let mut compiled = crate::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
219    for (name, data) in params {
220        compiled.set_param(&name, data.as_slice());
221    }
222    Ok(compiled)
223}
224
225/// [`CompileProfile::encoder`] + params.
226pub fn compile_graph_encoder_with_params(
227    device: Device,
228    graph: Graph,
229    params: std::collections::HashMap<String, Vec<f32>>,
230) -> Result<CompiledGraph> {
231    compile_graph_profile(device, graph, params, &CompileProfile::encoder())
232}
233
234/// [`CompileProfile::sam_encoder`] + params.
235pub fn compile_graph_sam_with_params(
236    device: Device,
237    graph: Graph,
238    params: std::collections::HashMap<String, Vec<f32>>,
239) -> Result<CompiledGraph> {
240    compile_graph_profile(device, graph, params, &CompileProfile::sam_encoder())
241}
242
243/// [`CompileProfile::qwen3_prefill`] + params.
244pub fn compile_graph_qwen3_prefill_with_params(
245    device: Device,
246    graph: Graph,
247    params: std::collections::HashMap<String, Vec<f32>>,
248) -> Result<CompiledGraph> {
249    compile_graph_profile(device, graph, params, &CompileProfile::qwen3_prefill())
250}
251
252/// [`CompileProfile::qwen35_prefill`] + params.
253pub fn compile_graph_qwen35_prefill_with_params(
254    device: Device,
255    graph: Graph,
256    params: std::collections::HashMap<String, Vec<f32>>,
257) -> Result<CompiledGraph> {
258    compile_graph_profile(device, graph, params, &CompileProfile::qwen35_prefill())
259}
260
261/// [`CompileProfile::qwen35_decode`] + params.
262pub fn compile_graph_qwen35_decode_with_params(
263    device: Device,
264    graph: Graph,
265    params: std::collections::HashMap<String, Vec<f32>>,
266) -> Result<CompiledGraph> {
267    compile_graph_profile(device, graph, params, &CompileProfile::qwen35_decode())
268}
269
270/// Tier-1 profile + params (including graphs that export KV side outputs).
271pub fn compile_graph_with_kv_export_params(
272    device: Device,
273    graph: Graph,
274    params: std::collections::HashMap<String, Vec<f32>>,
275    profile: &CompileProfile,
276) -> Result<CompiledGraph> {
277    use rlx_runtime::Session;
278    let mut compiled = Session::new(device).compile_with(
279        graph,
280        &crate::flow_bridge::compile_options_for_profile(profile, device),
281    );
282    for (name, data) in params {
283        compiled.set_param(&name, data.as_slice());
284    }
285    Ok(compiled)
286}
287
288/// Insert a [`BuiltModel`] into an LRU [`CompileCache`] (compile + params on first `key`).
289pub fn compile_cache_ensure_built(
290    cache: &mut CompileCache,
291    key: u64,
292    built: BuiltModel,
293) -> Result<&mut CompiledGraph> {
294    if !cache.contains(key) {
295        let (graph, params) = graph_from_built(built)?;
296        let compiled = cache.get_or_compile(key, || graph);
297        attach_built_params(compiled, params, &[]);
298    }
299    Ok(cache.get_or_compile(key, || {
300        panic!("compile_cache_ensure_built: missing entry for key {key}")
301    }))
302}
303
304/// Compile a decode bucket once and attach params (see [`BucketedCompileCache::ensure_graph_with_params`]).
305pub fn bucket_cache_ensure_built<'a, F>(
306    cache: &'a mut BucketedCompileCache,
307    key: u64,
308    build: F,
309    options: &CompileOptions,
310) -> Option<(u64, &'a mut CompiledGraph)>
311where
312    F: FnOnce(u64) -> Result<BuiltModel>,
313{
314    cache.ensure_graph_with_params(
315        key,
316        |upper| {
317            let built = build(upper).expect("bucket_cache_ensure_built build failed");
318            graph_from_built(built).expect("bucket_cache_ensure_built lower failed")
319        },
320        options,
321    )
322}