use std::path::Path;
use rlx_flow::CompileProfile;
use rlx_flow::{
BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
};
use rlx_ir::logical_kernel::KernelDispatchConfig;
use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
use rlx_runtime::Device;
use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
use crate::weight_loader::WeightLoader;
pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
if transpose {
self.0.take_transposed(key)
} else {
self.0.take(key)
}
}
}
pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
CompileProfile::from_toml_path(path).unwrap_or(default)
}
pub fn profile_near_weights(
weights: &Path,
profile_file: &str,
default: CompileProfile,
) -> CompileProfile {
let dir = weights.parent().unwrap_or_else(|| Path::new("."));
load_compile_profile(&dir.join(profile_file), default)
}
pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
opts.dce = profile.passes.dce;
opts.constant_folding = profile.passes.constant_folding;
opts.verbose = profile.passes.verbose;
opts.assert_fusion_clean = profile.fusion.assert_clean;
opts.fusion_opts = FusionOptions {
skip_fusion: profile.fusion.skip,
unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
|| profile.backend.cpu.unfuse_regions,
..FusionOptions::default()
};
if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
opts.fusion_target = Some(target);
}
opts.precision = match profile.precision.compute {
PrecisionKind::F32 => Precision::F32,
PrecisionKind::F16 => Precision::F16,
PrecisionKind::Bf16 => Precision::F16, };
opts.policy = match profile.precision.mixed {
MixedPrecisionKind::None => None,
MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
};
}
pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
CompileOptions::new().dim_binding(binding)
}
pub fn compile_options_from_profile(
profile: &CompileProfile,
device: Device,
kernel_dispatch: KernelDispatchConfig,
) -> CompileOptions {
let mut opts = CompileOptions::new();
apply_compile_profile(profile, &mut opts);
opts.kernel_dispatch = kernel_dispatch;
if opts.fusion_target.is_none() {
opts.fusion_target = Some(stages::fusion_target_for(device));
}
opts
}
pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
compile_options_from_profile(profile, device, KernelDispatchConfig::default())
}
pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
compile_options_for_profile(&CompileProfile::sam_encoder(), device)
}
pub fn compile_options_sam3(device: Device) -> CompileOptions {
compile_options_for_profile(&CompileProfile::sam3(), device)
}
pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
}
pub fn compile_graph_with_profile(
device: Device,
graph: rlx_ir::Graph,
profile: &CompileProfile,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
use rlx_runtime::Session;
let opts = compile_options_for_profile(profile, device);
Ok(Session::new(device).compile_with(graph, &opts))
}
pub fn compile_graph_sam(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
}
pub fn compile_graph_encoder(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::encoder())
}
pub fn compile_graph_qwen3_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
}
pub fn compile_graph_qwen3_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
}
pub fn compile_graph_qwen35_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
}
pub fn compile_graph_qwen35_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
}
pub fn compile_graph_gemma_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
}
pub fn compile_graph_gemma_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
}
pub fn compile_graph_llama32_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
}
pub fn compile_graph_llama32_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
}
pub fn compile_graph_legacy(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
use rlx_runtime::{CompileOptions, Session};
Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
}
pub fn compile_hir_sam(
device: Device,
hir: rlx_ir::hir::HirModule,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
}
pub fn compile_hir_sam3(
device: Device,
hir: rlx_ir::hir::HirModule,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_hir_with_profile(device, hir, &CompileProfile::sam3())
}
pub fn compile_hir_with_profile(
device: Device,
hir: rlx_ir::hir::HirModule,
profile: &CompileProfile,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
use rlx_runtime::Session;
let opts = compile_options_for_profile(profile, device);
Ok(Session::new(device).compile_hir_with(hir, &opts)?)
}
pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
compile_options_from_profile(
&config.compile_profile(),
Device::Cpu,
config.component().kernel_dispatch,
)
.dim_binding(config.dim_binding())
}
pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
compile_options_from_profile(
&config.compile_profile(),
device,
config.component().kernel_dispatch,
)
.dim_binding(config.dim_binding())
}
pub fn compile_built_with_config(
pipeline: &mut ModelCompilePipeline,
built: BuiltModel,
config: &ModelExecutionConfig,
options: &CompileOptions,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
let key = config.cache_key();
let binding = config.dim_binding();
let device = pipeline.device();
let (hir, params) = built.into_parts()?;
if !pipeline.contains(key) {
pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
}
let mut compiled = if device == Device::Cpu {
pipeline
.get_or_compile(key, &binding, || hir.clone(), options)?
.clone()
} else {
Session::new(device).compile_hir_with(hir, options)?
};
for (name, data) in params {
compiled.set_param(&name, &data);
}
Ok(compiled)
}
fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
match kind {
FusionTargetKind::Auto => None,
FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
FusionTargetKind::Metal => Some(FusionTarget::Metal),
FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
}
}