1use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, WeightSource};
20use rlx_ir::{Graph, HirModule};
21use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache};
22use rlx_runtime::{CompileOptions, CompiledGraph, Device, Session};
23
24use crate::weight_map::WeightMap;
25
26pub struct WeightMapSource<'a>(pub &'a mut WeightMap);
28
29impl WeightSource for WeightMapSource<'_> {
30 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
31 let (data, shape) = self.0.take(key)?;
32 if !transpose {
33 return Ok((data, shape));
34 }
35 if shape.len() != 2 {
36 anyhow::bail!("transpose requires rank-2 weight: {key}");
37 }
38 let rows = shape[0];
39 let cols = shape[1];
40 let mut out = vec![0f32; rows * cols];
41 for r in 0..rows {
42 for c in 0..cols {
43 out[c * rows + r] = data[r * cols + c];
44 }
45 }
46 Ok((out, vec![cols, rows]))
47 }
48
49 fn has(&self, key: &str) -> bool {
50 self.0.has(key)
51 }
52}
53
54pub fn built_from_hir(
55 hir: HirModule,
56 params: std::collections::HashMap<String, Vec<f32>>,
57) -> Result<BuiltModel> {
58 BuiltModel::from_hir(hir, params)
59}
60
61pub fn built_from_graph(
62 graph: Graph,
63 params: std::collections::HashMap<String, Vec<f32>>,
64) -> Result<BuiltModel> {
65 BuiltModel::from_graph(graph, params)
66}
67
68pub fn built_from_hir_with_profile(
69 hir: HirModule,
70 params: std::collections::HashMap<String, Vec<f32>>,
71 profile: CompileProfile,
72) -> Result<BuiltModel> {
73 let mut built = BuiltModel::from_hir(hir, params)?;
74 built.profile = profile;
75 Ok(built)
76}
77
78pub fn graph_from_built(
80 built: BuiltModel,
81) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)> {
82 built.into_graph_parts()
83}
84
85pub fn graph_from_hir(
87 hir: HirModule,
88 params: std::collections::HashMap<String, Vec<f32>>,
89) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)> {
90 graph_from_built(built_from_hir(hir, params)?)
91}
92
93pub fn build_graph<F>(
95 mut build: F,
96 weights: &mut WeightMap,
97) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)>
98where
99 F: FnMut(&mut WeightMapSource<'_>) -> Result<BuiltModel>,
100{
101 let built = build(&mut WeightMapSource(weights))?;
102 graph_from_built(built)
103}
104
105pub fn compile_from_flow<F>(
107 mut build: F,
108 weights: &mut WeightMap,
109 configure: impl FnOnce(Session) -> Session,
110) -> Result<CompiledGraph>
111where
112 F: FnMut(&mut WeightMapSource<'_>) -> Result<BuiltModel>,
113{
114 let built = build(&mut WeightMapSource(weights))?;
115 let profile = built.profile().clone();
116 let typed = built.typed_params.clone();
117 let (graph, params) = built.into_graph_parts()?;
118 let options = crate::flow_bridge::compile_options_for_profile(&profile, Device::Cpu);
119 let session = configure(Session::new(Device::Cpu));
120 let mut compiled = session.compile_with(graph, &options);
121 attach_built_params(&mut compiled, params, &typed);
122 Ok(compiled)
123}
124
125pub fn attach_built_params(
127 compiled: &mut CompiledGraph,
128 params: std::collections::HashMap<String, Vec<f32>>,
129 typed_params: &[(String, Vec<u8>, rlx_ir::DType)],
130) {
131 for (name, data) in params {
132 compiled.set_param(&name, &data);
133 }
134 for (name, data, dtype) in typed_params {
135 compiled.set_param_typed(name, data, *dtype);
136 }
137}
138
139pub fn compile_built(built: BuiltModel, device: Device) -> Result<CompiledGraph> {
141 let profile = built.profile().clone();
142 let typed = built.typed_params.clone();
143 let (graph, params) = built.into_graph_parts()?;
144 let options = crate::flow_bridge::compile_options_for_profile(&profile, device);
145 let mut compiled = Session::new(device).compile_with(graph, &options);
146 attach_built_params(&mut compiled, params, &typed);
147 Ok(compiled)
148}
149
150pub fn compile_built_cpu(built: BuiltModel) -> Result<CompiledGraph> {
152 compile_built(built, Device::Cpu)
153}
154
155pub fn compile_graph_legacy_with_params(
157 device: Device,
158 graph: Graph,
159 params: std::collections::HashMap<String, Vec<f32>>,
160) -> Result<CompiledGraph> {
161 let mut compiled = crate::flow_bridge::compile_graph_legacy(device, graph)?;
162 for (name, data) in params {
163 compiled.set_param(&name, data.as_slice());
164 }
165 Ok(compiled)
166}
167
168pub fn compile_graph_gemma_prefill_with_params(
170 device: Device,
171 graph: Graph,
172 params: std::collections::HashMap<String, Vec<f32>>,
173) -> Result<CompiledGraph> {
174 compile_graph_profile(device, graph, params, &CompileProfile::gemma_prefill())
175}
176
177pub fn compile_graph_gemma_decode_with_params(
178 device: Device,
179 graph: Graph,
180 params: std::collections::HashMap<String, Vec<f32>>,
181) -> Result<CompiledGraph> {
182 compile_graph_profile(device, graph, params, &CompileProfile::gemma_decode())
183}
184
185pub fn compile_graph_llama32_prefill_with_params(
186 device: Device,
187 graph: Graph,
188 params: std::collections::HashMap<String, Vec<f32>>,
189) -> Result<CompiledGraph> {
190 compile_graph_profile(device, graph, params, &CompileProfile::llama32_prefill())
191}
192
193pub fn compile_graph_llama32_decode_with_params(
195 device: Device,
196 graph: Graph,
197 params: std::collections::HashMap<String, Vec<f32>>,
198) -> Result<CompiledGraph> {
199 compile_graph_profile(device, graph, params, &CompileProfile::llama32_decode())
200}
201
202pub fn compile_graph_default_with_params(
204 device: Device,
205 graph: Graph,
206 params: std::collections::HashMap<String, Vec<f32>>,
207) -> Result<CompiledGraph> {
208 compile_graph_profile(device, graph, params, &CompileProfile::default())
209}
210
211pub fn compile_graph_profile(
213 device: Device,
214 graph: Graph,
215 params: std::collections::HashMap<String, Vec<f32>>,
216 profile: &CompileProfile,
217) -> Result<CompiledGraph> {
218 let mut compiled = crate::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
219 for (name, data) in params {
220 compiled.set_param(&name, data.as_slice());
221 }
222 Ok(compiled)
223}
224
225pub fn compile_graph_encoder_with_params(
227 device: Device,
228 graph: Graph,
229 params: std::collections::HashMap<String, Vec<f32>>,
230) -> Result<CompiledGraph> {
231 compile_graph_profile(device, graph, params, &CompileProfile::encoder())
232}
233
234pub fn compile_graph_sam_with_params(
236 device: Device,
237 graph: Graph,
238 params: std::collections::HashMap<String, Vec<f32>>,
239) -> Result<CompiledGraph> {
240 compile_graph_profile(device, graph, params, &CompileProfile::sam_encoder())
241}
242
243pub fn compile_graph_qwen3_prefill_with_params(
245 device: Device,
246 graph: Graph,
247 params: std::collections::HashMap<String, Vec<f32>>,
248) -> Result<CompiledGraph> {
249 compile_graph_profile(device, graph, params, &CompileProfile::qwen3_prefill())
250}
251
252pub fn compile_graph_qwen35_prefill_with_params(
254 device: Device,
255 graph: Graph,
256 params: std::collections::HashMap<String, Vec<f32>>,
257) -> Result<CompiledGraph> {
258 compile_graph_profile(device, graph, params, &CompileProfile::qwen35_prefill())
259}
260
261pub fn compile_graph_qwen35_decode_with_params(
263 device: Device,
264 graph: Graph,
265 params: std::collections::HashMap<String, Vec<f32>>,
266) -> Result<CompiledGraph> {
267 compile_graph_profile(device, graph, params, &CompileProfile::qwen35_decode())
268}
269
270pub fn compile_graph_with_kv_export_params(
272 device: Device,
273 graph: Graph,
274 params: std::collections::HashMap<String, Vec<f32>>,
275 profile: &CompileProfile,
276) -> Result<CompiledGraph> {
277 use rlx_runtime::Session;
278 let mut compiled = Session::new(device).compile_with(
279 graph,
280 &crate::flow_bridge::compile_options_for_profile(profile, device),
281 );
282 for (name, data) in params {
283 compiled.set_param(&name, data.as_slice());
284 }
285 Ok(compiled)
286}
287
288pub fn compile_cache_ensure_built(
290 cache: &mut CompileCache,
291 key: u64,
292 built: BuiltModel,
293) -> Result<&mut CompiledGraph> {
294 if !cache.contains(key) {
295 let (graph, params) = graph_from_built(built)?;
296 let compiled = cache.get_or_compile(key, || graph);
297 attach_built_params(compiled, params, &[]);
298 }
299 Ok(cache.get_or_compile(key, || {
300 panic!("compile_cache_ensure_built: missing entry for key {key}")
301 }))
302}
303
304pub fn bucket_cache_ensure_built<'a, F>(
306 cache: &'a mut BucketedCompileCache,
307 key: u64,
308 build: F,
309 options: &CompileOptions,
310) -> Option<(u64, &'a mut CompiledGraph)>
311where
312 F: FnOnce(u64) -> Result<BuiltModel>,
313{
314 cache.ensure_graph_with_params(
315 key,
316 |upper| {
317 let built = build(upper).expect("bucket_cache_ensure_built build failed");
318 graph_from_built(built).expect("bucket_cache_ensure_built lower failed")
319 },
320 options,
321 )
322}