Skip to main content

rlx_runtime/
cost.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cross-backend cost interface.
17//!
18//! Each backend implements `BackendCostModel` to expose its execution
19//! characteristics (kernel throughput, dispatch overhead, memory bw).
20//! The runtime can then estimate the cost of running a graph on each
21//! available backend and pick the fastest.
22//!
23//! This is what enables "auto device" — given a graph, pick CPU or
24//! Metal automatically based on which is faster for THIS workload on
25//! THIS hardware.
26
27use crate::Device;
28use rlx_ir::{Graph, Node, Op};
29
30/// Hardware-aware cost characteristics for a backend on the current machine.
31pub trait BackendCostModel: Send + Sync {
32    /// Identify which device this model is for.
33    fn device(&self) -> Device;
34
35    /// Effective f32 sgemm throughput in GFLOP/s for the most-used kernel
36    /// path at the given dimensions. Backends should return their best
37    /// sustained rate (not peak).
38    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
39
40    /// Cost to dispatch one kernel (function call, BLAS setup, etc.) in ns.
41    fn dispatch_overhead_ns(&self) -> f64;
42
43    /// Cost to commit + wait for a command buffer / forward pass in ns.
44    /// Roughly amortized per-forward overhead independent of kernel count.
45    fn roundtrip_overhead_ns(&self) -> f64;
46
47    /// Memory bandwidth in bytes/ns (== GB/s).
48    fn memory_bw(&self) -> f64;
49
50    /// Number of compute threads available.
51    fn num_threads(&self) -> usize;
52}
53
54/// Estimate forward-pass time (ns) for a graph on the given backend.
55/// Uses node-level cost contributions; conservative — actual time may
56/// be lower due to hardware parallelism we don't model.
57pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
58    let mut total = model.roundtrip_overhead_ns();
59    for node in graph.nodes() {
60        total += node_cost(node, graph, model);
61    }
62    total
63}
64
65fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
66    let dispatch = model.dispatch_overhead_ns();
67    match &node.op {
68        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
69        Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
70            let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
71            let total = node.shape.num_elements().unwrap_or(0);
72            let m = total / n.max(1);
73            let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
74            let k = a_total / m.max(1);
75            let flops = 2.0 * m as f64 * k as f64 * n as f64;
76            flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
77        }
78        Op::Attention {
79            num_heads,
80            head_dim,
81            ..
82        } => {
83            let q_shape = &graph.node(node.inputs[0]).shape;
84            let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
85            let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
86            let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
87            flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
88        }
89        // Element-wise + small ops: bounded by memory bandwidth.
90        _ => {
91            let bytes = node.shape.num_elements().unwrap_or(0) * 4;
92            (bytes as f64) / model.memory_bw().max(1.0) + dispatch
93        }
94    }
95}
96
97/// Pick the device with the lowest predicted cost for this graph.
98pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
99    let mut best = (Device::Cpu, f64::INFINITY);
100    for &m in models {
101        let cost = estimate_graph_cost(graph, m);
102        if cost < best.1 {
103            best = (m.device(), cost);
104        }
105    }
106    best.0
107}
108
109// ── Backend adapters (plan #29) ─────────────────────────────────
110//
111// The CPU and Metal crates own their own internal cost models for
112// kernel-selection decisions. These thin adapters wrap them in
113// `BackendCostModel` so `pick_best_device` can compare both with a
114// single uniform interface.
115
116/// `BackendCostModel` impl backed by `rlx_cpu::cost::HwModel`.
117#[cfg(feature = "cpu")]
118pub struct CpuCostModel(rlx_cpu::cost::HwModel);
119
120#[cfg(feature = "cpu")]
121impl CpuCostModel {
122    pub fn new() -> Self {
123        let cfg = rlx_cpu::config::RuntimeConfig::global();
124        Self(rlx_cpu::cost::HwModel::from_config(cfg))
125    }
126}
127
128#[cfg(feature = "cpu")]
129impl Default for CpuCostModel {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135#[cfg(feature = "cpu")]
136impl BackendCostModel for CpuCostModel {
137    fn device(&self) -> Device {
138        Device::Cpu
139    }
140    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
141        // Take the better of NEON / BLAS at this shape.
142        let flops = 2.0 * m as f64 * k as f64 * n as f64;
143        let neon_time = flops / self.0.neon_flops.max(1.0);
144        let blas_time = flops / self.0.blas_flops.max(1.0);
145        let pick = neon_time.min(blas_time);
146        if pick > 0.0 {
147            flops / (pick * 1e9)
148        } else {
149            0.0
150        }
151    }
152    fn dispatch_overhead_ns(&self) -> f64 {
153        self.0.blas_overhead_ns
154    }
155    fn roundtrip_overhead_ns(&self) -> f64 {
156        self.0.par_for_overhead_ns
157    }
158    fn memory_bw(&self) -> f64 {
159        self.0.mem_bw
160    }
161    fn num_threads(&self) -> usize {
162        self.0.num_threads
163    }
164}
165
166/// `BackendCostModel` impl backed by `rlx_metal::cost`. Reads from
167/// the on-disk calibration cache so the numbers reflect what this
168/// machine actually measured.
169#[cfg(feature = "metal")]
170pub struct MetalCostModel {
171    sgemm_gflops_avg: f64,
172    roundtrip_ns: f64,
173    memory_bw: f64,
174}
175
176#[cfg(feature = "metal")]
177impl MetalCostModel {
178    pub fn new() -> Self {
179        let cal = rlx_metal::calibrate::Calibration::load_or_measure();
180        // Effective single-shape sgemm: best of the calibrated paths.
181        let best = cal
182            .sgemm_simd_4x4_flops
183            .max(cal.sgemm_simd_flops)
184            .max(cal.sgemm_padded_flops);
185        Self {
186            sgemm_gflops_avg: best,
187            roundtrip_ns: cal.roundtrip_overhead_ns,
188            // Apple Silicon unified memory bandwidth (rough): ~200 GB/s
189            // on M-series base, much higher on Pro/Max. The calibrator
190            // doesn't measure pure mem-bw yet, so we hard-code a
191            // floor that makes mem-bound ops not look free.
192            memory_bw: 200.0,
193        }
194    }
195}
196
197#[cfg(feature = "metal")]
198impl Default for MetalCostModel {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204#[cfg(feature = "metal")]
205impl BackendCostModel for MetalCostModel {
206    fn device(&self) -> Device {
207        Device::Metal
208    }
209    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
210        self.sgemm_gflops_avg
211    }
212    fn dispatch_overhead_ns(&self) -> f64 {
213        // Per-kernel encode cost — small relative to the round-trip.
214        2_000.0
215    }
216    fn roundtrip_overhead_ns(&self) -> f64 {
217        self.roundtrip_ns
218    }
219    fn memory_bw(&self) -> f64 {
220        self.memory_bw
221    }
222    fn num_threads(&self) -> usize {
223        1
224    } // single command queue
225}
226
227/// `BackendCostModel` impl backed by `rlx_mlx::calibrate`. Reads from
228/// the on-disk MLX calibration cache. The first construction on a
229/// fresh machine pays a one-time measurement cost (tens of ms);
230/// subsequent constructions read the cache.
231#[cfg(all(feature = "mlx", target_os = "macos"))]
232pub struct MlxCostModel {
233    sgemm_large_flops: f64,
234    sgemm_small_flops: f64,
235    roundtrip_ns: f64,
236    memory_bw: f64,
237}
238
239#[cfg(all(feature = "mlx", target_os = "macos"))]
240impl MlxCostModel {
241    pub fn new() -> Self {
242        let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
243        // Use measured memory bandwidth when available (post-PR16
244        // calibrators record it); fall back to the Apple-Silicon
245        // unified-memory floor otherwise so old caches still produce
246        // sane numbers.
247        let memory_bw = if cal.memory_bw_gbps > 0.0 {
248            cal.memory_bw_gbps
249        } else {
250            200.0
251        };
252        Self {
253            sgemm_large_flops: cal.sgemm_large_flops,
254            sgemm_small_flops: cal.sgemm_small_flops,
255            roundtrip_ns: cal.roundtrip_overhead_ns,
256            memory_bw,
257        }
258    }
259}
260
261#[cfg(all(feature = "mlx", target_os = "macos"))]
262impl Default for MlxCostModel {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268#[cfg(all(feature = "mlx", target_os = "macos"))]
269impl BackendCostModel for MlxCostModel {
270    fn device(&self) -> Device {
271        Device::Mlx
272    }
273    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
274        // Crossover heuristic: small shapes pay the per-op overhead;
275        // large shapes hit the optimized path. The cutoff is rough —
276        // matches the calibrator's "small" / "large" probe sizes.
277        let total = m as f64 * k as f64 * n as f64;
278        if total < 32_768.0 {
279            self.sgemm_small_flops
280        } else {
281            self.sgemm_large_flops
282        }
283    }
284    fn dispatch_overhead_ns(&self) -> f64 {
285        // MLX's lazy-eval keeps per-op encode cost low; trace
286        // construction in Rust is the dominant per-op cost.
287        2_000.0
288    }
289    fn roundtrip_overhead_ns(&self) -> f64 {
290        self.roundtrip_ns
291    }
292    fn memory_bw(&self) -> f64 {
293        self.memory_bw
294    }
295    fn num_threads(&self) -> usize {
296        1
297    }
298}