use rlx_ir::hir::FusionPolicy;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct CompileProfile {
pub fusion: FusionProfile,
pub passes: PassProfile,
pub precision: PrecisionProfile,
#[serde(default)]
pub backend: BackendOverrides,
}
impl Default for CompileProfile {
fn default() -> Self {
Self::llama32_prefill()
}
}
impl CompileProfile {
pub fn llama32_prefill() -> Self {
Self {
fusion: FusionProfile {
policy: FusionPolicyKind::Direct,
target: FusionTargetKind::Auto,
assert_clean: false,
skip: false,
},
passes: PassProfile::default(),
precision: PrecisionProfile::default(),
backend: BackendOverrides::default(),
}
}
pub fn llama32_decode() -> Self {
Self {
fusion: FusionProfile {
policy: FusionPolicyKind::Fusable,
..FusionProfile::default()
},
..Self::llama32_prefill()
}
}
pub fn qwen35_prefill() -> Self {
Self::llama32_prefill()
}
pub fn qwen35_decode() -> Self {
Self::llama32_decode()
}
pub fn qwen3_prefill() -> Self {
Self::llama32_prefill()
}
pub fn qwen3_decode() -> Self {
Self::llama32_decode()
}
pub fn gemma_prefill() -> Self {
Self::llama32_prefill()
}
pub fn gemma_decode() -> Self {
Self::llama32_decode()
}
pub fn flux2() -> Self {
Self::encoder()
}
pub fn sam_encoder() -> Self {
Self::encoder()
}
pub fn sam3() -> Self {
Self::sam_encoder()
}
pub fn sam2() -> Self {
Self::sam_encoder()
}
pub fn sam2_memory_attention() -> Self {
Self {
fusion: FusionProfile {
skip: true,
..FusionProfile::default()
},
..Self::encoder()
}
}
pub fn llada2_diffusion() -> Self {
Self {
fusion: FusionProfile {
skip: true,
..FusionProfile::default()
},
..Self::encoder()
}
}
pub fn encoder() -> Self {
Self {
fusion: FusionProfile {
policy: FusionPolicyKind::Direct,
..FusionProfile::default()
},
passes: PassProfile {
dce: true,
constant_folding: true,
verbose: false,
},
precision: PrecisionProfile::default(),
backend: BackendOverrides::default(),
}
}
pub fn fusion_policy(&self) -> FusionPolicy {
self.fusion.policy.into()
}
pub fn from_toml_str(s: &str) -> anyhow::Result<Self> {
Ok(toml::from_str(s)?)
}
pub fn from_toml_path(path: &std::path::Path) -> anyhow::Result<Self> {
let data = std::fs::read_to_string(path)?;
Self::from_toml_str(&data)
}
pub fn near_weights(weights: &std::path::Path, family: &str, mode: ProfileMode) -> Self {
let default = Self::default_for(family, mode);
let dir = weights
.parent()
.unwrap_or_else(|| std::path::Path::new("."));
let sidecar = dir.join(format!("{family}.rlx.toml"));
Self::from_toml_path(&sidecar).unwrap_or(default)
}
pub fn default_for(family: &str, mode: ProfileMode) -> Self {
match (family, mode) {
("llama32", ProfileMode::Prefill) => Self::llama32_prefill(),
("llama32", ProfileMode::Decode) => Self::llama32_decode(),
("qwen3", ProfileMode::Prefill) => Self::qwen3_prefill(),
("qwen3", ProfileMode::Decode) => Self::qwen3_decode(),
("qwen35", ProfileMode::Prefill) => Self::qwen35_prefill(),
("qwen35", ProfileMode::Decode) => Self::qwen35_decode(),
("gemma", ProfileMode::Prefill) => Self::gemma_prefill(),
("gemma", ProfileMode::Decode) => Self::gemma_decode(),
(_, ProfileMode::Prefill) => Self::llama32_prefill(),
(_, ProfileMode::Decode) => Self::llama32_decode(),
(_, ProfileMode::Encoder) => Self::encoder(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProfileMode {
Prefill,
Decode,
Encoder,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct FusionProfile {
pub policy: FusionPolicyKind,
pub target: FusionTargetKind,
pub assert_clean: bool,
pub skip: bool,
}
impl Default for FusionProfile {
fn default() -> Self {
Self {
policy: FusionPolicyKind::Direct,
target: FusionTargetKind::Auto,
assert_clean: false,
skip: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum FusionPolicyKind {
#[default]
Direct,
Fusable,
}
impl From<FusionPolicyKind> for FusionPolicy {
fn from(k: FusionPolicyKind) -> Self {
match k {
FusionPolicyKind::Direct => FusionPolicy::Direct,
FusionPolicyKind::Fusable => FusionPolicy::Fusable,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum FusionTargetKind {
#[default]
Auto,
Cpu,
Metal,
Mlx,
Cuda,
Rocm,
Wgpu,
Tpu,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct PassProfile {
pub dce: bool,
pub constant_folding: bool,
pub verbose: bool,
}
impl Default for PassProfile {
fn default() -> Self {
Self {
dce: true,
constant_folding: true,
verbose: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct PrecisionProfile {
pub compute: PrecisionKind,
pub mixed: MixedPrecisionKind,
}
impl Default for PrecisionProfile {
fn default() -> Self {
Self {
compute: PrecisionKind::F32,
mixed: MixedPrecisionKind::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PrecisionKind {
#[default]
F32,
F16,
Bf16,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MixedPrecisionKind {
#[default]
None,
Auto,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct BackendOverrides {
#[serde(default)]
pub metal: MetalBackendProfile,
#[serde(default)]
pub cpu: CpuBackendProfile,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct MetalBackendProfile {
pub skip_fusion: bool,
pub unfuse_regions: bool,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct CpuBackendProfile {
pub unfuse_regions: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_profile_toml() {
let toml = r#"
[fusion]
policy = "direct"
target = "metal"
assert_clean = true
[passes]
dce = true
constant_folding = false
[precision]
compute = "f16"
mixed = "auto"
"#;
let p = CompileProfile::from_toml_str(toml).unwrap();
assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
assert_eq!(p.fusion.target, FusionTargetKind::Metal);
assert!(p.fusion.assert_clean);
assert!(!p.passes.constant_folding);
assert_eq!(p.precision.compute, PrecisionKind::F16);
assert_eq!(p.precision.mixed, MixedPrecisionKind::Auto);
}
}