Skip to main content

rlx_runtime/
cost.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cross-backend cost interface.
17//!
18//! Each backend implements `BackendCostModel` to expose its execution
19//! characteristics (kernel throughput, dispatch overhead, memory bw).
20//! The runtime can then estimate the cost of running a graph on each
21//! available backend and pick the fastest.
22//!
23//! This is what enables "auto device" — given a graph, pick CPU or
24//! Metal automatically based on which is faster for THIS workload on
25//! THIS hardware.
26
27use crate::Device;
28use rlx_ir::{Graph, Node, Op};
29
30/// Hardware-aware cost characteristics for a backend on the current machine.
31pub trait BackendCostModel: Send + Sync {
32    /// Identify which device this model is for.
33    fn device(&self) -> Device;
34
35    /// Effective f32 sgemm throughput in GFLOP/s for the most-used kernel
36    /// path at the given dimensions. Backends should return their best
37    /// sustained rate (not peak).
38    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
39
40    /// Cost to dispatch one kernel (function call, BLAS setup, etc.) in ns.
41    fn dispatch_overhead_ns(&self) -> f64;
42
43    /// Cost to commit + wait for a command buffer / forward pass in ns.
44    /// Roughly amortized per-forward overhead independent of kernel count.
45    fn roundtrip_overhead_ns(&self) -> f64;
46
47    /// Memory bandwidth in bytes/ns (== GB/s).
48    fn memory_bw(&self) -> f64;
49
50    /// Number of compute threads available.
51    fn num_threads(&self) -> usize;
52}
53
54/// Estimate forward-pass time (ns) for a graph on the given backend.
55/// Uses node-level cost contributions; conservative — actual time may
56/// be lower due to hardware parallelism we don't model.
57pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
58    let mut total = model.roundtrip_overhead_ns();
59    for node in graph.nodes() {
60        total += node_cost(node, graph, model);
61    }
62    total
63}
64
65fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
66    let dispatch = model.dispatch_overhead_ns();
67    match &node.op {
68        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
69        Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
70            let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
71            let total = node.shape.num_elements().unwrap_or(0);
72            let m = total / n.max(1);
73            let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
74            let k = a_total / m.max(1);
75            let flops = 2.0 * m as f64 * k as f64 * n as f64;
76            flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
77        }
78        Op::Attention {
79            num_heads,
80            head_dim,
81            ..
82        } => {
83            let q_shape = &graph.node(node.inputs[0]).shape;
84            let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
85            let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
86            let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
87            flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
88        }
89        // Element-wise + small ops: bounded by memory bandwidth.
90        _ => {
91            let bytes = node.shape.num_elements().unwrap_or(0) * 4;
92            (bytes as f64) / model.memory_bw().max(1.0) + dispatch
93        }
94    }
95}
96
97/// Pick the device with the lowest predicted cost for this graph.
98pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
99    let mut best = (Device::Cpu, f64::INFINITY);
100    for &m in models {
101        let cost = estimate_graph_cost(graph, m);
102        if cost < best.1 {
103            best = (m.device(), cost);
104        }
105    }
106    best.0
107}
108
109/// Pick the fastest backend for `graph` on this host.
110pub fn fastest_device_for(graph: &Graph) -> Device {
111    fastest_device_for_with_policy(graph, &crate::device_policy::DevicePolicy::default())
112}
113
114/// Like [`fastest_device_for`] but respects a [`crate::DevicePolicy`] allow-list.
115pub fn fastest_device_for_with_policy(
116    graph: &Graph,
117    policy: &crate::device_policy::DevicePolicy,
118) -> Device {
119    let candidates = crate::device_policy::devices_for_with_policy(graph, policy);
120    if candidates.is_empty() {
121        return crate::device_ext::fastest_among(&policy.apply(crate::available_devices()));
122    }
123
124    #[cfg(feature = "cpu")]
125    let cpu = CpuCostModel::new();
126    #[cfg(feature = "metal")]
127    let metal = MetalCostModel::new();
128    #[cfg(all(feature = "mlx", rlx_mlx_host))]
129    let mlx = MlxCostModel::new();
130    #[cfg(feature = "cuda")]
131    let cuda = CudaCostModel::new();
132    #[cfg(feature = "rocm")]
133    let rocm = RocmCostModel::new();
134    #[cfg(feature = "gpu")]
135    let wgpu = WgpuCostModel::new();
136
137    let mut models: Vec<&dyn BackendCostModel> = Vec::new();
138    #[cfg(feature = "cpu")]
139    if candidates.contains(&Device::Cpu) {
140        models.push(&cpu);
141    }
142    #[cfg(feature = "metal")]
143    if candidates.contains(&Device::Metal) {
144        models.push(&metal);
145    }
146    #[cfg(all(feature = "mlx", rlx_mlx_host))]
147    if candidates.contains(&Device::Mlx) {
148        models.push(&mlx);
149    }
150    #[cfg(feature = "cuda")]
151    if candidates.contains(&Device::Cuda) {
152        models.push(&cuda);
153    }
154    #[cfg(feature = "rocm")]
155    if candidates.contains(&Device::Rocm) {
156        models.push(&rocm);
157    }
158    #[cfg(feature = "gpu")]
159    if candidates.contains(&Device::Gpu) {
160        models.push(&wgpu);
161    }
162
163    if models.len() >= 2 {
164        pick_best_device(graph, &models)
165    } else if let Some(m) = models.first() {
166        m.device()
167    } else {
168        crate::device_ext::fastest_among(&candidates)
169    }
170}
171
172// ── Backend adapters (plan #29) ─────────────────────────────────
173//
174// The CPU and Metal crates own their own internal cost models for
175// kernel-selection decisions. These thin adapters wrap them in
176// `BackendCostModel` so `pick_best_device` can compare both with a
177// single uniform interface.
178
179/// `BackendCostModel` impl backed by `rlx_cpu::cost::HwModel`.
180#[cfg(feature = "cpu")]
181pub struct CpuCostModel(rlx_cpu::cost::HwModel);
182
183#[cfg(feature = "cpu")]
184impl CpuCostModel {
185    pub fn new() -> Self {
186        let cfg = rlx_cpu::config::RuntimeConfig::global();
187        Self(rlx_cpu::cost::HwModel::from_config(cfg))
188    }
189}
190
191#[cfg(feature = "cpu")]
192impl Default for CpuCostModel {
193    fn default() -> Self {
194        Self::new()
195    }
196}
197
198#[cfg(feature = "cpu")]
199impl BackendCostModel for CpuCostModel {
200    fn device(&self) -> Device {
201        Device::Cpu
202    }
203    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
204        // Take the better of NEON / BLAS at this shape.
205        let flops = 2.0 * m as f64 * k as f64 * n as f64;
206        let neon_time = flops / self.0.neon_flops.max(1.0);
207        let blas_time = flops / self.0.blas_flops.max(1.0);
208        let pick = neon_time.min(blas_time);
209        if pick > 0.0 {
210            flops / (pick * 1e9)
211        } else {
212            0.0
213        }
214    }
215    fn dispatch_overhead_ns(&self) -> f64 {
216        self.0.blas_overhead_ns
217    }
218    fn roundtrip_overhead_ns(&self) -> f64 {
219        self.0.par_for_overhead_ns
220    }
221    fn memory_bw(&self) -> f64 {
222        self.0.mem_bw
223    }
224    fn num_threads(&self) -> usize {
225        self.0.num_threads
226    }
227}
228
229/// `BackendCostModel` impl backed by `rlx_metal::cost`. Reads from
230/// the on-disk calibration cache so the numbers reflect what this
231/// machine actually measured.
232#[cfg(feature = "metal")]
233pub struct MetalCostModel {
234    sgemm_gflops_avg: f64,
235    roundtrip_ns: f64,
236    memory_bw: f64,
237}
238
239#[cfg(feature = "metal")]
240impl MetalCostModel {
241    pub fn new() -> Self {
242        let cal = rlx_metal::calibrate::Calibration::load_or_measure();
243        // Effective single-shape sgemm: best of the calibrated paths.
244        let best = cal
245            .sgemm_simd_4x4_flops
246            .max(cal.sgemm_simd_flops)
247            .max(cal.sgemm_padded_flops);
248        Self {
249            sgemm_gflops_avg: best,
250            roundtrip_ns: cal.roundtrip_overhead_ns,
251            // Apple Silicon unified memory bandwidth (rough): ~200 GB/s
252            // on M-series base, much higher on Pro/Max. The calibrator
253            // doesn't measure pure mem-bw yet, so we hard-code a
254            // floor that makes mem-bound ops not look free.
255            memory_bw: 200.0,
256        }
257    }
258}
259
260#[cfg(feature = "metal")]
261impl Default for MetalCostModel {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267#[cfg(feature = "metal")]
268impl BackendCostModel for MetalCostModel {
269    fn device(&self) -> Device {
270        Device::Metal
271    }
272    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
273        self.sgemm_gflops_avg
274    }
275    fn dispatch_overhead_ns(&self) -> f64 {
276        // Per-kernel encode cost — small relative to the round-trip.
277        2_000.0
278    }
279    fn roundtrip_overhead_ns(&self) -> f64 {
280        self.roundtrip_ns
281    }
282    fn memory_bw(&self) -> f64 {
283        self.memory_bw
284    }
285    fn num_threads(&self) -> usize {
286        1
287    } // single command queue
288}
289
290/// `BackendCostModel` impl backed by `rlx_mlx::calibrate`. Reads from
291/// the on-disk MLX calibration cache. The first construction on a
292/// fresh machine pays a one-time measurement cost (tens of ms);
293/// subsequent constructions read the cache.
294#[cfg(all(feature = "mlx", rlx_mlx_host))]
295pub struct MlxCostModel {
296    sgemm_large_flops: f64,
297    sgemm_small_flops: f64,
298    roundtrip_ns: f64,
299    memory_bw: f64,
300}
301
302#[cfg(all(feature = "mlx", rlx_mlx_host))]
303impl MlxCostModel {
304    pub fn new() -> Self {
305        let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
306        // Use measured memory bandwidth when available (post-PR16
307        // calibrators record it); fall back to the Apple-Silicon
308        // unified-memory floor otherwise so old caches still produce
309        // sane numbers.
310        let memory_bw = if cal.memory_bw_gbps > 0.0 {
311            cal.memory_bw_gbps
312        } else {
313            200.0
314        };
315        Self {
316            sgemm_large_flops: cal.sgemm_large_flops,
317            sgemm_small_flops: cal.sgemm_small_flops,
318            roundtrip_ns: cal.roundtrip_overhead_ns,
319            memory_bw,
320        }
321    }
322}
323
324#[cfg(all(feature = "mlx", rlx_mlx_host))]
325impl Default for MlxCostModel {
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331#[cfg(all(feature = "mlx", rlx_mlx_host))]
332impl BackendCostModel for MlxCostModel {
333    fn device(&self) -> Device {
334        Device::Mlx
335    }
336    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
337        // Crossover heuristic: small shapes pay the per-op overhead;
338        // large shapes hit the optimized path. The cutoff is rough —
339        // matches the calibrator's "small" / "large" probe sizes.
340        let total = m as f64 * k as f64 * n as f64;
341        if total < 32_768.0 {
342            self.sgemm_small_flops
343        } else {
344            self.sgemm_large_flops
345        }
346    }
347    fn dispatch_overhead_ns(&self) -> f64 {
348        // MLX's lazy-eval keeps per-op encode cost low; trace
349        // construction in Rust is the dominant per-op cost.
350        2_000.0
351    }
352    fn roundtrip_overhead_ns(&self) -> f64 {
353        self.roundtrip_ns
354    }
355    fn memory_bw(&self) -> f64 {
356        self.memory_bw
357    }
358    fn num_threads(&self) -> usize {
359        1
360    }
361}
362
363/// Heuristic CUDA cost model until a dedicated calibrator lands.
364#[cfg(feature = "cuda")]
365pub struct CudaCostModel {
366    sgemm_gflops: f64,
367    roundtrip_ns: f64,
368    memory_bw: f64,
369}
370
371#[cfg(feature = "cuda")]
372impl CudaCostModel {
373    pub fn new() -> Self {
374        if crate::is_available(crate::Device::Cuda) {
375            let cal = rlx_cuda::calibrate::Calibration::load_or_measure();
376            return Self {
377                sgemm_gflops: cal.sgemm_gflops,
378                roundtrip_ns: cal.roundtrip_overhead_ns,
379                memory_bw: cal.memory_bw_gbps,
380            };
381        }
382        Self {
383            sgemm_gflops: 12_000.0,
384            roundtrip_ns: 35_000.0,
385            memory_bw: 900.0,
386        }
387    }
388}
389
390#[cfg(feature = "cuda")]
391impl Default for CudaCostModel {
392    fn default() -> Self {
393        Self::new()
394    }
395}
396
397#[cfg(feature = "cuda")]
398impl BackendCostModel for CudaCostModel {
399    fn device(&self) -> Device {
400        Device::Cuda
401    }
402    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
403        self.sgemm_gflops
404    }
405    fn dispatch_overhead_ns(&self) -> f64 {
406        3_000.0
407    }
408    fn roundtrip_overhead_ns(&self) -> f64 {
409        self.roundtrip_ns
410    }
411    fn memory_bw(&self) -> f64 {
412        self.memory_bw
413    }
414    fn num_threads(&self) -> usize {
415        1
416    }
417}
418
419/// Heuristic ROCm cost model (same class as CUDA until calibrated).
420#[cfg(feature = "rocm")]
421pub struct RocmCostModel {
422    sgemm_gflops: f64,
423    roundtrip_ns: f64,
424    memory_bw: f64,
425}
426
427#[cfg(feature = "rocm")]
428impl RocmCostModel {
429    pub fn new() -> Self {
430        if crate::is_available(crate::Device::Rocm) {
431            let cal = rlx_rocm::calibrate::Calibration::load_or_measure();
432            return Self {
433                sgemm_gflops: cal.sgemm_gflops,
434                roundtrip_ns: cal.roundtrip_overhead_ns,
435                memory_bw: cal.memory_bw_gbps,
436            };
437        }
438        Self {
439            sgemm_gflops: 10_000.0,
440            roundtrip_ns: 40_000.0,
441            memory_bw: 800.0,
442        }
443    }
444}
445
446#[cfg(feature = "rocm")]
447impl Default for RocmCostModel {
448    fn default() -> Self {
449        Self::new()
450    }
451}
452
453#[cfg(feature = "rocm")]
454impl BackendCostModel for RocmCostModel {
455    fn device(&self) -> Device {
456        Device::Rocm
457    }
458    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
459        self.sgemm_gflops
460    }
461    fn dispatch_overhead_ns(&self) -> f64 {
462        3_000.0
463    }
464    fn roundtrip_overhead_ns(&self) -> f64 {
465        self.roundtrip_ns
466    }
467    fn memory_bw(&self) -> f64 {
468        self.memory_bw
469    }
470    fn num_threads(&self) -> usize {
471        1
472    }
473}
474
475/// Heuristic wgpu (`Device::Gpu`) cost model.
476#[cfg(feature = "gpu")]
477pub struct WgpuCostModel {
478    sgemm_gflops: f64,
479    roundtrip_ns: f64,
480    memory_bw: f64,
481}
482
483#[cfg(feature = "gpu")]
484impl WgpuCostModel {
485    pub fn new() -> Self {
486        if rlx_wgpu::is_available() {
487            let cal = rlx_wgpu::calibrate::Calibration::load_or_measure();
488            return Self {
489                sgemm_gflops: cal.sgemm_gflops,
490                roundtrip_ns: cal.roundtrip_overhead_ns,
491                memory_bw: cal.memory_bw_gbps,
492            };
493        }
494        Self {
495            sgemm_gflops: 2_500.0,
496            roundtrip_ns: 80_000.0,
497            memory_bw: 120.0,
498        }
499    }
500}
501
502#[cfg(feature = "gpu")]
503impl Default for WgpuCostModel {
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509#[cfg(feature = "gpu")]
510impl BackendCostModel for WgpuCostModel {
511    fn device(&self) -> Device {
512        Device::Gpu
513    }
514    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
515        self.sgemm_gflops
516    }
517    fn dispatch_overhead_ns(&self) -> f64 {
518        5_000.0
519    }
520    fn roundtrip_overhead_ns(&self) -> f64 {
521        self.roundtrip_ns
522    }
523    fn memory_bw(&self) -> f64 {
524        self.memory_bw
525    }
526    fn num_threads(&self) -> usize {
527        1
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use rlx_ir::{DType, Graph, Shape};
535
536    #[test]
537    fn fastest_device_for_falls_back_to_cpu_for_simple_graph() {
538        let mut g = Graph::new("mm");
539        let x = g.input("x", Shape::new(&[4, 4], DType::F32));
540        let w = g.param("w", Shape::new(&[4, 4], DType::F32));
541        let y = g.matmul(x, w, Shape::new(&[4, 4], DType::F32));
542        g.set_outputs(vec![y]);
543        let pick = fastest_device_for(&g);
544        assert!(crate::is_available(pick));
545        assert!(crate::devices_for(&g).contains(&pick));
546    }
547}