use std::path::Path;
use rlx_flow::CompileProfile;
use rlx_flow::{
BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
};
use rlx_ir::logical_kernel::KernelDispatchConfig;
use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
use rlx_runtime::Device;
use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
use crate::weight_loader::WeightLoader;
pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
if transpose {
self.0.take_transposed(key)
} else {
self.0.take(key)
}
}
}
pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
CompileProfile::from_toml_path(path).unwrap_or(default)
}
pub fn profile_near_weights(
weights: &Path,
profile_file: &str,
default: CompileProfile,
) -> CompileProfile {
let dir = weights.parent().unwrap_or_else(|| Path::new("."));
load_compile_profile(&dir.join(profile_file), default)
}
pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
opts.dce = profile.passes.dce;
opts.constant_folding = profile.passes.constant_folding;
opts.verbose = profile.passes.verbose;
opts.assert_fusion_clean = profile.fusion.assert_clean;
opts.fusion_opts = FusionOptions {
skip_fusion: profile.fusion.skip,
unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
|| profile.backend.cpu.unfuse_regions,
..FusionOptions::default()
};
if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
opts.fusion_target = Some(target);
}
opts.precision = match profile.precision.compute {
PrecisionKind::F32 => Precision::F32,
PrecisionKind::F16 => Precision::F16,
PrecisionKind::Bf16 => Precision::F16, };
opts.policy = match profile.precision.mixed {
MixedPrecisionKind::None => None,
MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
};
}
pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
CompileOptions::new().dim_binding(binding)
}
pub fn compile_options_from_profile(
profile: &CompileProfile,
device: Device,
kernel_dispatch: KernelDispatchConfig,
) -> CompileOptions {
let mut opts = CompileOptions::new();
apply_compile_profile(profile, &mut opts);
opts.kernel_dispatch = kernel_dispatch;
if opts.fusion_target.is_none() {
opts.fusion_target = Some(stages::fusion_target_for(device));
}
opts
}
pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
compile_options_from_profile(profile, device, KernelDispatchConfig::default())
}
pub fn compile_options_for_packed_gguf_prefill_with_profile(
profile: &CompileProfile,
device: Device,
) -> CompileOptions {
let mut profile = profile.clone();
if matches!(
device,
Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan
) {
profile.fusion.skip = true;
}
compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
}
pub fn compile_options_for_packed_gguf_prefill(device: Device) -> CompileOptions {
compile_options_for_packed_gguf_prefill_with_profile(
&CompileProfile::llama32_prefill(),
device,
)
}
pub fn packed_gguf_compile_guard<R, F>(device: Device, f: F) -> R
where
F: FnOnce() -> R,
{
with_packed_gguf_backend_env(device, f)
}
fn with_packed_gguf_backend_env<R, F>(device: Device, f: F) -> R
where
F: FnOnce() -> R,
{
let mlx_prev = if device == Device::Mlx {
let prev = rlx_ir::env::var("RLX_MLX_MODE");
rlx_ir::env::set("RLX_MLX_MODE", "eager");
prev
} else {
None
};
let metal = device == Device::Metal;
if metal {
rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
}
let out = f();
if metal {
rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
}
if device == Device::Mlx {
match mlx_prev {
Some(ref v) => rlx_ir::env::set("RLX_MLX_MODE", v),
None => rlx_ir::env::unset("RLX_MLX_MODE"),
}
}
out
}
pub fn packed_gguf_execution_device(device: Device) -> Device {
match device {
Device::Cpu | Device::Metal => device,
Device::Mlx | Device::Gpu | Device::Cuda | Device::Rocm | Device::Vulkan => Device::Cpu,
_ => device,
}
}
pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
compile_options_for_profile(&CompileProfile::sam_encoder(), device)
}
pub fn compile_options_sam3(device: Device) -> CompileOptions {
compile_options_for_profile(&CompileProfile::sam3(), device)
}
pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
}
pub fn compile_graph_with_profile(
device: Device,
graph: rlx_ir::Graph,
profile: &CompileProfile,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
use rlx_runtime::Session;
let opts = compile_options_for_profile(profile, device);
Ok(Session::new(device).compile_with(graph, &opts))
}
pub fn compile_graph_sam(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
}
pub fn compile_graph_encoder(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::encoder())
}
pub fn compile_graph_qwen3_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
}
pub fn compile_graph_qwen3_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
}
pub fn compile_graph_qwen35_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
}
pub fn compile_graph_qwen35_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
}
pub fn compile_graph_gemma_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
}
pub fn compile_graph_gemma_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
}
pub fn compile_graph_llama32_prefill(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
}
pub fn compile_graph_llama32_decode(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
}
pub fn compile_graph_legacy(
device: Device,
graph: rlx_ir::Graph,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
use rlx_runtime::{CompileOptions, Session};
Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
}
pub fn compile_hir_sam(
device: Device,
hir: rlx_ir::hir::HirModule,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
}
pub fn compile_hir_sam3(
device: Device,
hir: rlx_ir::hir::HirModule,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
compile_hir_with_profile(device, hir, &CompileProfile::sam3())
}
pub fn compile_hir_with_profile(
device: Device,
hir: rlx_ir::hir::HirModule,
profile: &CompileProfile,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
use rlx_runtime::Session;
let opts = compile_options_for_profile(profile, device);
Ok(Session::new(device).compile_hir_with(hir, &opts)?)
}
pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
compile_options_from_profile(
&config.compile_profile(),
Device::Cpu,
config.component().kernel_dispatch,
)
.dim_binding(config.dim_binding())
}
pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
compile_options_from_profile(
&config.compile_profile(),
device,
config.component().kernel_dispatch,
)
.dim_binding(config.dim_binding())
}
pub fn compile_built_with_config(
pipeline: &mut ModelCompilePipeline,
built: BuiltModel,
config: &ModelExecutionConfig,
options: &CompileOptions,
) -> anyhow::Result<rlx_runtime::CompiledGraph> {
let key = config.cache_key();
let binding = config.dim_binding();
let device = pipeline.device();
let (hir, params) = built.into_parts()?;
if !pipeline.contains(key) {
pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
}
let mut compiled = if device == Device::Cpu {
pipeline
.get_or_compile(key, &binding, || hir.clone(), options)?
.clone()
} else {
Session::new(device).compile_hir_with(hir, options)?
};
for (name, data) in params {
compiled.set_param(&name, &data);
}
Ok(compiled)
}
fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
match kind {
FusionTargetKind::Auto => None,
FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
FusionTargetKind::Metal => Some(FusionTarget::Metal),
FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
}
}