use anyhow::Result;
use rlx_flow::{BuiltModel, CompileProfile, WeightSource};
use rlx_ir::{Graph, HirModule};
use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache};
use rlx_runtime::{CompileOptions, CompiledGraph, Device, Session};
use crate::weight_map::WeightMap;
pub struct WeightMapSource<'a>(pub &'a mut WeightMap);
impl WeightSource for WeightMapSource<'_> {
fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
let (data, shape) = self.0.take(key)?;
if !transpose {
return Ok((data, shape));
}
if shape.len() != 2 {
anyhow::bail!("transpose requires rank-2 weight: {key}");
}
let rows = shape[0];
let cols = shape[1];
let mut out = vec![0f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
out[c * rows + r] = data[r * cols + c];
}
}
Ok((out, vec![cols, rows]))
}
fn has(&self, key: &str) -> bool {
self.0.has(key)
}
}
pub fn built_from_hir(
hir: HirModule,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<BuiltModel> {
BuiltModel::from_hir(hir, params)
}
pub fn built_from_graph(
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<BuiltModel> {
BuiltModel::from_graph(graph, params)
}
pub fn built_from_hir_with_profile(
hir: HirModule,
params: std::collections::HashMap<String, Vec<f32>>,
profile: CompileProfile,
) -> Result<BuiltModel> {
let mut built = BuiltModel::from_hir(hir, params)?;
built.profile = profile;
Ok(built)
}
pub fn graph_from_built(
built: BuiltModel,
) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)> {
built.into_graph_parts()
}
pub fn graph_from_hir(
hir: HirModule,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)> {
graph_from_built(built_from_hir(hir, params)?)
}
pub fn build_graph<F>(
mut build: F,
weights: &mut WeightMap,
) -> Result<(Graph, std::collections::HashMap<String, Vec<f32>>)>
where
F: FnMut(&mut WeightMapSource<'_>) -> Result<BuiltModel>,
{
let built = build(&mut WeightMapSource(weights))?;
graph_from_built(built)
}
pub fn compile_from_flow<F>(
mut build: F,
weights: &mut WeightMap,
configure: impl FnOnce(Session) -> Session,
) -> Result<CompiledGraph>
where
F: FnMut(&mut WeightMapSource<'_>) -> Result<BuiltModel>,
{
let built = build(&mut WeightMapSource(weights))?;
let profile = built.profile().clone();
let typed = built.typed_params.clone();
let (graph, params) = built.into_graph_parts()?;
let options = crate::flow_bridge::compile_options_for_profile(&profile, Device::Cpu);
let session = configure(Session::new(Device::Cpu));
let mut compiled = session.compile_with(graph, &options);
attach_built_params(&mut compiled, params, &typed);
Ok(compiled)
}
pub fn attach_built_params(
compiled: &mut CompiledGraph,
params: std::collections::HashMap<String, Vec<f32>>,
typed_params: &[(String, Vec<u8>, rlx_ir::DType)],
) {
for (name, data) in params {
compiled.set_param(&name, &data);
}
for (name, data, dtype) in typed_params {
compiled.set_param_typed(name, data, *dtype);
}
}
pub fn compile_built(built: BuiltModel, device: Device) -> Result<CompiledGraph> {
let profile = built.profile().clone();
let typed = built.typed_params.clone();
let (graph, params) = built.into_graph_parts()?;
let options = crate::flow_bridge::compile_options_for_profile(&profile, device);
let mut compiled = Session::new(device).compile_with(graph, &options);
attach_built_params(&mut compiled, params, &typed);
Ok(compiled)
}
pub fn compile_built_cpu(built: BuiltModel) -> Result<CompiledGraph> {
compile_built(built, Device::Cpu)
}
pub fn compile_graph_legacy_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
let mut compiled = crate::flow_bridge::compile_graph_legacy(device, graph)?;
for (name, data) in params {
compiled.set_param(&name, data.as_slice());
}
Ok(compiled)
}
pub fn compile_graph_gemma_prefill_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::gemma_prefill())
}
pub fn compile_graph_gemma_decode_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::gemma_decode())
}
pub fn compile_graph_llama32_prefill_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::llama32_prefill())
}
pub fn compile_graph_llama32_decode_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::llama32_decode())
}
pub fn compile_graph_default_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::default())
}
pub fn compile_graph_profile(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
profile: &CompileProfile,
) -> Result<CompiledGraph> {
let mut compiled = crate::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
for (name, data) in params {
compiled.set_param(&name, data.as_slice());
}
Ok(compiled)
}
pub fn compile_graph_encoder_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::encoder())
}
pub fn compile_graph_sam_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::sam_encoder())
}
pub fn compile_graph_qwen3_prefill_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::qwen3_prefill())
}
pub fn compile_graph_qwen35_prefill_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::qwen35_prefill())
}
pub fn compile_graph_qwen35_decode_with_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
) -> Result<CompiledGraph> {
compile_graph_profile(device, graph, params, &CompileProfile::qwen35_decode())
}
pub fn compile_graph_with_kv_export_params(
device: Device,
graph: Graph,
params: std::collections::HashMap<String, Vec<f32>>,
profile: &CompileProfile,
) -> Result<CompiledGraph> {
use rlx_runtime::Session;
let mut compiled = Session::new(device).compile_with(
graph,
&crate::flow_bridge::compile_options_for_profile(profile, device),
);
for (name, data) in params {
compiled.set_param(&name, data.as_slice());
}
Ok(compiled)
}
pub fn compile_cache_ensure_built(
cache: &mut CompileCache,
key: u64,
built: BuiltModel,
) -> Result<&mut CompiledGraph> {
if !cache.contains(key) {
let (graph, params) = graph_from_built(built)?;
let compiled = cache.get_or_compile(key, || graph);
attach_built_params(compiled, params, &[]);
}
Ok(cache.get_or_compile(key, || {
panic!("compile_cache_ensure_built: missing entry for key {key}")
}))
}
pub fn bucket_cache_ensure_built<'a, F>(
cache: &'a mut BucketedCompileCache,
key: u64,
build: F,
options: &CompileOptions,
) -> Option<(u64, &'a mut CompiledGraph)>
where
F: FnOnce(u64) -> Result<BuiltModel>,
{
cache.ensure_graph_with_params(
key,
|upper| {
let built = build(upper).expect("bucket_cache_ensure_built build failed");
graph_from_built(built).expect("bucket_cache_ensure_built lower failed")
},
options,
)
}