Skip to main content

rlx_runtime/
cost.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cross-backend cost interface.
17//!
18//! Each backend implements `BackendCostModel` to expose its execution
19//! characteristics (kernel throughput, dispatch overhead, memory bw).
20//! The runtime can then estimate the cost of running a graph on each
21//! available backend and pick the fastest.
22//!
23//! This is what enables "auto device" — given a graph, pick CPU or
24//! Metal automatically based on which is faster for THIS workload on
25//! THIS hardware.
26
27use crate::Device;
28use rlx_ir::{Graph, Node, Op};
29
30/// Hardware-aware cost characteristics for a backend on the current machine.
31pub trait BackendCostModel: Send + Sync {
32    /// Identify which device this model is for.
33    fn device(&self) -> Device;
34
35    /// Effective f32 sgemm throughput in GFLOP/s for the most-used kernel
36    /// path at the given dimensions. Backends should return their best
37    /// sustained rate (not peak).
38    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
39
40    /// Cost to dispatch one kernel (function call, BLAS setup, etc.) in ns.
41    fn dispatch_overhead_ns(&self) -> f64;
42
43    /// Cost to commit + wait for a command buffer / forward pass in ns.
44    /// Roughly amortized per-forward overhead independent of kernel count.
45    fn roundtrip_overhead_ns(&self) -> f64;
46
47    /// Memory bandwidth in bytes/ns (== GB/s).
48    fn memory_bw(&self) -> f64;
49
50    /// Effective host readback bandwidth (bytes/ns). Defaults to [`Self::memory_bw`].
51    fn host_readback_bw(&self) -> f64 {
52        self.memory_bw()
53    }
54
55    /// True when device and host share the same physical memory (Apple Silicon Metal).
56    fn unified_memory(&self) -> bool {
57        false
58    }
59
60    /// Number of compute threads available.
61    fn num_threads(&self) -> usize;
62}
63
64/// Estimate forward-pass time (ns) for a graph on the given backend.
65/// Uses node-level cost contributions; conservative — actual time may
66/// be lower due to hardware parallelism we don't model.
67pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
68    estimate_graph_cost_with_io(graph, model, &crate::graph_io::profile_graph_io(graph))
69}
70
71/// IO-aware cost: compute + device traffic + host readback + sync points.
72pub fn estimate_graph_cost_with_io(
73    graph: &Graph,
74    model: &dyn BackendCostModel,
75    io: &crate::graph_io::GraphIoProfile,
76) -> f64 {
77    let mut total = model.roundtrip_overhead_ns();
78    for node in graph.nodes() {
79        total += node_cost(node, graph, model);
80    }
81    total += io.device_traffic_bytes as f64 / model.memory_bw().max(1.0);
82    total +=
83        io.host_readback_bytes(model.unified_memory()) as f64 / model.host_readback_bw().max(1.0);
84    total += io.sync_points as f64 * model.roundtrip_overhead_ns();
85    total
86}
87
88fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
89    let dispatch = model.dispatch_overhead_ns();
90    match &node.op {
91        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
92        Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
93            let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
94            let total = node.shape.num_elements().unwrap_or(0);
95            let m = total / n.max(1);
96            let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
97            let k = a_total / m.max(1);
98            let flops = 2.0 * m as f64 * k as f64 * n as f64;
99            flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
100        }
101        Op::Attention {
102            num_heads,
103            head_dim,
104            ..
105        } => {
106            let q_shape = &graph.node(node.inputs[0]).shape;
107            let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
108            let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
109            let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
110            flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
111        }
112        // Element-wise + small ops: bounded by memory bandwidth.
113        _ => {
114            let bytes = node.shape.num_elements().unwrap_or(0) * 4;
115            (bytes as f64) / model.memory_bw().max(1.0) + dispatch
116        }
117    }
118}
119
120/// Pick the device with the lowest predicted cost for this graph.
121pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
122    let mut best = (Device::Cpu, f64::INFINITY);
123    for &m in models {
124        let cost = estimate_graph_cost(graph, m);
125        if cost < best.1 {
126            best = (m.device(), cost);
127        }
128    }
129    best.0
130}
131
132/// Pick the fastest backend for `graph` on this host.
133pub fn fastest_device_for(graph: &Graph) -> Device {
134    fastest_device_for_with_policy(graph, &crate::device_policy::DevicePolicy::default())
135}
136
137/// Like [`fastest_device_for`] but respects a [`crate::DevicePolicy`] allow-list.
138pub fn fastest_device_for_with_policy(
139    graph: &Graph,
140    policy: &crate::device_policy::DevicePolicy,
141) -> Device {
142    let candidates = crate::device_policy::devices_for_with_policy(graph, policy);
143    if candidates.is_empty() {
144        return crate::device_ext::fastest_among(&policy.apply(crate::available_devices()));
145    }
146
147    #[cfg(feature = "cpu")]
148    let cpu = CpuCostModel::new();
149    #[cfg(feature = "metal")]
150    let metal = MetalCostModel::new();
151    #[cfg(all(feature = "mlx", rlx_mlx_host))]
152    let mlx = MlxCostModel::new();
153    #[cfg(feature = "cuda")]
154    let cuda = CudaCostModel::new();
155    #[cfg(feature = "rocm")]
156    let rocm = RocmCostModel::new();
157    #[cfg(feature = "gpu")]
158    let wgpu = WgpuCostModel::new();
159
160    let mut models: Vec<&dyn BackendCostModel> = Vec::new();
161    #[cfg(feature = "cpu")]
162    if candidates.contains(&Device::Cpu) {
163        models.push(&cpu);
164    }
165    #[cfg(feature = "metal")]
166    if candidates.contains(&Device::Metal) {
167        models.push(&metal);
168    }
169    #[cfg(all(feature = "mlx", rlx_mlx_host))]
170    if candidates.contains(&Device::Mlx) {
171        models.push(&mlx);
172    }
173    #[cfg(feature = "cuda")]
174    if candidates.contains(&Device::Cuda) {
175        models.push(&cuda);
176    }
177    #[cfg(feature = "rocm")]
178    if candidates.contains(&Device::Rocm) {
179        models.push(&rocm);
180    }
181    #[cfg(feature = "gpu")]
182    if candidates.contains(&Device::Gpu) {
183        models.push(&wgpu);
184    }
185
186    if models.len() >= 2 {
187        pick_best_device(graph, &models)
188    } else if let Some(m) = models.first() {
189        m.device()
190    } else {
191        crate::device_ext::fastest_among(&candidates)
192    }
193}
194
195// ── Backend adapters (plan #29) ─────────────────────────────────
196//
197// The CPU and Metal crates own their own internal cost models for
198// kernel-selection decisions. These thin adapters wrap them in
199// `BackendCostModel` so `pick_best_device` can compare both with a
200// single uniform interface.
201
202/// `BackendCostModel` impl backed by `rlx_cpu::cost::HwModel`.
203#[cfg(feature = "cpu")]
204pub struct CpuCostModel(rlx_cpu::cost::HwModel);
205
206#[cfg(feature = "cpu")]
207impl CpuCostModel {
208    pub fn new() -> Self {
209        let cfg = rlx_cpu::config::RuntimeConfig::global();
210        Self(rlx_cpu::cost::HwModel::from_config(cfg))
211    }
212}
213
214#[cfg(feature = "cpu")]
215impl Default for CpuCostModel {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221#[cfg(feature = "cpu")]
222impl BackendCostModel for CpuCostModel {
223    fn device(&self) -> Device {
224        Device::Cpu
225    }
226    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
227        // Take the better of NEON / BLAS at this shape.
228        let flops = 2.0 * m as f64 * k as f64 * n as f64;
229        let neon_time = flops / self.0.neon_flops.max(1.0);
230        let blas_time = flops / self.0.blas_flops.max(1.0);
231        let pick = neon_time.min(blas_time);
232        if pick > 0.0 {
233            flops / (pick * 1e9)
234        } else {
235            0.0
236        }
237    }
238    fn dispatch_overhead_ns(&self) -> f64 {
239        self.0.blas_overhead_ns
240    }
241    fn roundtrip_overhead_ns(&self) -> f64 {
242        self.0.par_for_overhead_ns
243    }
244    fn memory_bw(&self) -> f64 {
245        self.0.mem_bw
246    }
247    fn num_threads(&self) -> usize {
248        self.0.num_threads
249    }
250}
251
252/// `BackendCostModel` impl backed by `rlx_metal::cost`. Reads from
253/// the on-disk calibration cache so the numbers reflect what this
254/// machine actually measured.
255#[cfg(feature = "metal")]
256pub struct MetalCostModel {
257    sgemm_gflops_avg: f64,
258    roundtrip_ns: f64,
259    memory_bw: f64,
260}
261
262#[cfg(feature = "metal")]
263impl MetalCostModel {
264    pub fn new() -> Self {
265        let cal = rlx_metal::calibrate::Calibration::load_or_measure();
266        // Effective single-shape sgemm: best of the calibrated paths.
267        let best = cal
268            .sgemm_simd_4x4_flops
269            .max(cal.sgemm_simd_flops)
270            .max(cal.sgemm_padded_flops);
271        Self {
272            sgemm_gflops_avg: best,
273            roundtrip_ns: cal.roundtrip_overhead_ns,
274            // Apple Silicon unified memory bandwidth (rough): ~200 GB/s
275            // on M-series base, much higher on Pro/Max. The calibrator
276            // doesn't measure pure mem-bw yet, so we hard-code a
277            // floor that makes mem-bound ops not look free.
278            memory_bw: 200.0,
279        }
280    }
281}
282
283#[cfg(feature = "metal")]
284impl Default for MetalCostModel {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290#[cfg(feature = "metal")]
291impl BackendCostModel for MetalCostModel {
292    fn device(&self) -> Device {
293        Device::Metal
294    }
295    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
296        self.sgemm_gflops_avg
297    }
298    fn dispatch_overhead_ns(&self) -> f64 {
299        // Per-kernel encode cost — small relative to the round-trip.
300        2_000.0
301    }
302    fn roundtrip_overhead_ns(&self) -> f64 {
303        self.roundtrip_ns
304    }
305    fn memory_bw(&self) -> f64 {
306        self.memory_bw
307    }
308    fn unified_memory(&self) -> bool {
309        true
310    }
311    fn num_threads(&self) -> usize {
312        1
313    } // single command queue
314}
315
316/// `BackendCostModel` impl backed by `rlx_mlx::calibrate`. Reads from
317/// the on-disk MLX calibration cache. The first construction on a
318/// fresh machine pays a one-time measurement cost (tens of ms);
319/// subsequent constructions read the cache.
320#[cfg(all(feature = "mlx", rlx_mlx_host))]
321pub struct MlxCostModel {
322    sgemm_large_flops: f64,
323    sgemm_small_flops: f64,
324    roundtrip_ns: f64,
325    memory_bw: f64,
326}
327
328#[cfg(all(feature = "mlx", rlx_mlx_host))]
329impl MlxCostModel {
330    pub fn new() -> Self {
331        let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
332        // Use measured memory bandwidth when available (post-PR16
333        // calibrators record it); fall back to the Apple-Silicon
334        // unified-memory floor otherwise so old caches still produce
335        // sane numbers.
336        let memory_bw = if cal.memory_bw_gbps > 0.0 {
337            cal.memory_bw_gbps
338        } else {
339            200.0
340        };
341        Self {
342            sgemm_large_flops: cal.sgemm_large_flops,
343            sgemm_small_flops: cal.sgemm_small_flops,
344            roundtrip_ns: cal.roundtrip_overhead_ns,
345            memory_bw,
346        }
347    }
348}
349
350#[cfg(all(feature = "mlx", rlx_mlx_host))]
351impl Default for MlxCostModel {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357#[cfg(all(feature = "mlx", rlx_mlx_host))]
358impl BackendCostModel for MlxCostModel {
359    fn device(&self) -> Device {
360        Device::Mlx
361    }
362    fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
363        // Crossover heuristic: small shapes pay the per-op overhead;
364        // large shapes hit the optimized path. The cutoff is rough —
365        // matches the calibrator's "small" / "large" probe sizes.
366        let total = m as f64 * k as f64 * n as f64;
367        if total < 32_768.0 {
368            self.sgemm_small_flops
369        } else {
370            self.sgemm_large_flops
371        }
372    }
373    fn dispatch_overhead_ns(&self) -> f64 {
374        // MLX's lazy-eval keeps per-op encode cost low; trace
375        // construction in Rust is the dominant per-op cost.
376        2_000.0
377    }
378    fn roundtrip_overhead_ns(&self) -> f64 {
379        self.roundtrip_ns
380    }
381    fn memory_bw(&self) -> f64 {
382        self.memory_bw
383    }
384    fn num_threads(&self) -> usize {
385        1
386    }
387}
388
389/// Heuristic CUDA cost model until a dedicated calibrator lands.
390#[cfg(feature = "cuda")]
391pub struct CudaCostModel {
392    sgemm_gflops: f64,
393    roundtrip_ns: f64,
394    memory_bw: f64,
395}
396
397#[cfg(feature = "cuda")]
398impl CudaCostModel {
399    pub fn new() -> Self {
400        if crate::is_available(crate::Device::Cuda) {
401            let cal = rlx_cuda::calibrate::Calibration::load_or_measure();
402            return Self {
403                sgemm_gflops: cal.sgemm_gflops,
404                roundtrip_ns: cal.roundtrip_overhead_ns,
405                memory_bw: cal.memory_bw_gbps,
406            };
407        }
408        Self {
409            sgemm_gflops: 12_000.0,
410            roundtrip_ns: 35_000.0,
411            memory_bw: 900.0,
412        }
413    }
414}
415
416#[cfg(feature = "cuda")]
417impl Default for CudaCostModel {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423#[cfg(feature = "cuda")]
424impl BackendCostModel for CudaCostModel {
425    fn device(&self) -> Device {
426        Device::Cuda
427    }
428    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
429        self.sgemm_gflops
430    }
431    fn dispatch_overhead_ns(&self) -> f64 {
432        3_000.0
433    }
434    fn roundtrip_overhead_ns(&self) -> f64 {
435        self.roundtrip_ns
436    }
437    fn memory_bw(&self) -> f64 {
438        self.memory_bw
439    }
440    fn num_threads(&self) -> usize {
441        1
442    }
443}
444
445/// Heuristic ROCm cost model (same class as CUDA until calibrated).
446#[cfg(feature = "rocm")]
447pub struct RocmCostModel {
448    sgemm_gflops: f64,
449    roundtrip_ns: f64,
450    memory_bw: f64,
451}
452
453#[cfg(feature = "rocm")]
454impl RocmCostModel {
455    pub fn new() -> Self {
456        if crate::is_available(crate::Device::Rocm) {
457            let cal = rlx_rocm::calibrate::Calibration::load_or_measure();
458            return Self {
459                sgemm_gflops: cal.sgemm_gflops,
460                roundtrip_ns: cal.roundtrip_overhead_ns,
461                memory_bw: cal.memory_bw_gbps,
462            };
463        }
464        Self {
465            sgemm_gflops: 10_000.0,
466            roundtrip_ns: 40_000.0,
467            memory_bw: 800.0,
468        }
469    }
470}
471
472#[cfg(feature = "rocm")]
473impl Default for RocmCostModel {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479#[cfg(feature = "rocm")]
480impl BackendCostModel for RocmCostModel {
481    fn device(&self) -> Device {
482        Device::Rocm
483    }
484    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
485        self.sgemm_gflops
486    }
487    fn dispatch_overhead_ns(&self) -> f64 {
488        3_000.0
489    }
490    fn roundtrip_overhead_ns(&self) -> f64 {
491        self.roundtrip_ns
492    }
493    fn memory_bw(&self) -> f64 {
494        self.memory_bw
495    }
496    fn num_threads(&self) -> usize {
497        1
498    }
499}
500
501/// Heuristic wgpu (`Device::Gpu`) cost model.
502#[cfg(feature = "gpu")]
503pub struct WgpuCostModel {
504    sgemm_gflops: f64,
505    roundtrip_ns: f64,
506    memory_bw: f64,
507}
508
509#[cfg(feature = "gpu")]
510impl WgpuCostModel {
511    pub fn new() -> Self {
512        if rlx_wgpu::is_available() {
513            let cal = rlx_wgpu::calibrate::Calibration::load_or_measure();
514            return Self {
515                sgemm_gflops: cal.sgemm_gflops,
516                roundtrip_ns: cal.roundtrip_overhead_ns,
517                memory_bw: cal.memory_bw_gbps,
518            };
519        }
520        Self {
521            sgemm_gflops: 2_500.0,
522            roundtrip_ns: 80_000.0,
523            memory_bw: 120.0,
524        }
525    }
526}
527
528#[cfg(feature = "gpu")]
529impl Default for WgpuCostModel {
530    fn default() -> Self {
531        Self::new()
532    }
533}
534
535#[cfg(feature = "gpu")]
536impl BackendCostModel for WgpuCostModel {
537    fn device(&self) -> Device {
538        Device::Gpu
539    }
540    fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
541        self.sgemm_gflops
542    }
543    fn dispatch_overhead_ns(&self) -> f64 {
544        5_000.0
545    }
546    fn roundtrip_overhead_ns(&self) -> f64 {
547        self.roundtrip_ns
548    }
549    fn memory_bw(&self) -> f64 {
550        self.memory_bw
551    }
552    fn num_threads(&self) -> usize {
553        1
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use rlx_ir::{DType, Graph, Shape};
561
562    #[test]
563    fn fastest_device_for_falls_back_to_cpu_for_simple_graph() {
564        let mut g = Graph::new("mm");
565        let x = g.input("x", Shape::new(&[4, 4], DType::F32));
566        let w = g.param("w", Shape::new(&[4, 4], DType::F32));
567        let y = g.matmul(x, w, Shape::new(&[4, 4], DType::F32));
568        g.set_outputs(vec![y]);
569        let pick = fastest_device_for(&g);
570        assert!(crate::is_available(pick));
571        assert!(crate::devices_for(&g).contains(&pick));
572    }
573}