use crate::Device;
use rlx_ir::{Graph, Node, Op};
pub trait BackendCostModel: Send + Sync {
fn device(&self) -> Device;
fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
fn dispatch_overhead_ns(&self) -> f64;
fn roundtrip_overhead_ns(&self) -> f64;
fn memory_bw(&self) -> f64;
fn num_threads(&self) -> usize;
}
pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
let mut total = model.roundtrip_overhead_ns();
for node in graph.nodes() {
total += node_cost(node, graph, model);
}
total
}
fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
let dispatch = model.dispatch_overhead_ns();
match &node.op {
Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
let total = node.shape.num_elements().unwrap_or(0);
let m = total / n.max(1);
let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
let k = a_total / m.max(1);
let flops = 2.0 * m as f64 * k as f64 * n as f64;
flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
}
Op::Attention {
num_heads,
head_dim,
..
} => {
let q_shape = &graph.node(node.inputs[0]).shape;
let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
}
_ => {
let bytes = node.shape.num_elements().unwrap_or(0) * 4;
(bytes as f64) / model.memory_bw().max(1.0) + dispatch
}
}
}
pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
let mut best = (Device::Cpu, f64::INFINITY);
for &m in models {
let cost = estimate_graph_cost(graph, m);
if cost < best.1 {
best = (m.device(), cost);
}
}
best.0
}
#[cfg(feature = "cpu")]
pub struct CpuCostModel(rlx_cpu::cost::HwModel);
#[cfg(feature = "cpu")]
impl CpuCostModel {
pub fn new() -> Self {
let cfg = rlx_cpu::config::RuntimeConfig::global();
Self(rlx_cpu::cost::HwModel::from_config(cfg))
}
}
#[cfg(feature = "cpu")]
impl Default for CpuCostModel {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "cpu")]
impl BackendCostModel for CpuCostModel {
fn device(&self) -> Device {
Device::Cpu
}
fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
let flops = 2.0 * m as f64 * k as f64 * n as f64;
let neon_time = flops / self.0.neon_flops.max(1.0);
let blas_time = flops / self.0.blas_flops.max(1.0);
let pick = neon_time.min(blas_time);
if pick > 0.0 {
flops / (pick * 1e9)
} else {
0.0
}
}
fn dispatch_overhead_ns(&self) -> f64 {
self.0.blas_overhead_ns
}
fn roundtrip_overhead_ns(&self) -> f64 {
self.0.par_for_overhead_ns
}
fn memory_bw(&self) -> f64 {
self.0.mem_bw
}
fn num_threads(&self) -> usize {
self.0.num_threads
}
}
#[cfg(feature = "metal")]
pub struct MetalCostModel {
sgemm_gflops_avg: f64,
roundtrip_ns: f64,
memory_bw: f64,
}
#[cfg(feature = "metal")]
impl MetalCostModel {
pub fn new() -> Self {
let cal = rlx_metal::calibrate::Calibration::load_or_measure();
let best = cal
.sgemm_simd_4x4_flops
.max(cal.sgemm_simd_flops)
.max(cal.sgemm_padded_flops);
Self {
sgemm_gflops_avg: best,
roundtrip_ns: cal.roundtrip_overhead_ns,
memory_bw: 200.0,
}
}
}
#[cfg(feature = "metal")]
impl Default for MetalCostModel {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "metal")]
impl BackendCostModel for MetalCostModel {
fn device(&self) -> Device {
Device::Metal
}
fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
self.sgemm_gflops_avg
}
fn dispatch_overhead_ns(&self) -> f64 {
2_000.0
}
fn roundtrip_overhead_ns(&self) -> f64 {
self.roundtrip_ns
}
fn memory_bw(&self) -> f64 {
self.memory_bw
}
fn num_threads(&self) -> usize {
1
} }
#[cfg(all(feature = "mlx", target_os = "macos"))]
pub struct MlxCostModel {
sgemm_large_flops: f64,
sgemm_small_flops: f64,
roundtrip_ns: f64,
memory_bw: f64,
}
#[cfg(all(feature = "mlx", target_os = "macos"))]
impl MlxCostModel {
pub fn new() -> Self {
let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
let memory_bw = if cal.memory_bw_gbps > 0.0 {
cal.memory_bw_gbps
} else {
200.0
};
Self {
sgemm_large_flops: cal.sgemm_large_flops,
sgemm_small_flops: cal.sgemm_small_flops,
roundtrip_ns: cal.roundtrip_overhead_ns,
memory_bw,
}
}
}
#[cfg(all(feature = "mlx", target_os = "macos"))]
impl Default for MlxCostModel {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(feature = "mlx", target_os = "macos"))]
impl BackendCostModel for MlxCostModel {
fn device(&self) -> Device {
Device::Mlx
}
fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
let total = m as f64 * k as f64 * n as f64;
if total < 32_768.0 {
self.sgemm_small_flops
} else {
self.sgemm_large_flops
}
}
fn dispatch_overhead_ns(&self) -> f64 {
2_000.0
}
fn roundtrip_overhead_ns(&self) -> f64 {
self.roundtrip_ns
}
fn memory_bw(&self) -> f64 {
self.memory_bw
}
fn num_threads(&self) -> usize {
1
}
}