Skip to main content

rlx_cpu/
cost.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cost model — estimates execution time for kernel dispatch decisions.
17//!
18//! Instead of hardcoded thresholds scattered across thunk.rs, all dispatch
19//! decisions are routed through this cost model. The model considers:
20//! - Hardware: cache sizes, core count, AMX availability, NEON throughput
21//! - Workload: matrix dimensions, batch size, sequence length
22//! - Strategy: BLAS vs NEON, parallel vs sequential, fused vs individual
23//!
24//! The model is calibrated at compile time (platform-specific constants)
25//! and refined at runtime (detected hardware + optional autotune).
26
27use crate::config::RuntimeConfig;
28
29/// Estimated cost in nanoseconds for a kernel execution strategy.
30#[derive(Debug, Clone, Copy)]
31pub struct Cost(pub f64);
32
33impl Cost {
34    pub fn ns(self) -> f64 {
35        self.0
36    }
37}
38
39/// Hardware model — derived from RuntimeConfig + platform detection.
40pub struct HwModel {
41    /// NEON throughput: FLOP/s for element-wise (FMA chains)
42    pub neon_flops: f64,
43    /// BLAS throughput: FLOP/s for sgemm (AMX or optimized NEON)
44    pub blas_flops: f64,
45    /// BLAS call overhead in nanoseconds (function call + AMX sync)
46    pub blas_overhead_ns: f64,
47    /// par_for dispatch overhead in nanoseconds
48    pub par_for_overhead_ns: f64,
49    /// L1 data cache size in bytes
50    pub l1_bytes: usize,
51    /// L2 cache size in bytes
52    pub l2_bytes: usize,
53    /// Memory bandwidth (L2 → registers) in bytes/ns
54    pub mem_bw: f64,
55    /// Number of worker threads
56    pub num_threads: usize,
57}
58
59impl HwModel {
60    /// Build from runtime config and platform defaults.
61    pub fn from_config(cfg: &RuntimeConfig) -> Self {
62        // Platform-calibrated constants
63        #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
64        let model = HwModel {
65            neon_flops: 72e9,            // ~72 GFLOP/s NEON FMA throughput (M4 Pro P-core)
66            blas_flops: 1000e9,          // ~1 TFLOP/s AMX peak (effective varies with tile fill)
67            blas_overhead_ns: 500.0,     // ~0.5µs per cblas_sgemm call
68            par_for_overhead_ns: 5000.0, // ~5µs spin-wait dispatch
69            l1_bytes: 65536,             // 64KB L1d (refined by sysctl at runtime)
70            l2_bytes: 4 * 1024 * 1024,   // 4MB L2 per core
71            mem_bw: 50.0,                // ~50 GB/s = 50 B/ns
72            num_threads: cfg.pool_workers + 1,
73        };
74
75        #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
76        let model = HwModel {
77            neon_flops: 32e9,
78            blas_flops: 200e9,
79            blas_overhead_ns: 300.0,
80            par_for_overhead_ns: 3000.0,
81            l1_bytes: 32768,
82            l2_bytes: 1024 * 1024,
83            mem_bw: 30.0,
84            num_threads: cfg.pool_workers + 1,
85        };
86
87        model
88    }
89
90    // ── Dispatch decisions ──────────────────────────────────────────
91
92    /// Should we use NEON sgemm instead of BLAS for this matrix multiply?
93    /// Returns true when BLAS overhead dominates the compute.
94    pub fn prefer_neon_sgemm(&self, m: usize, k: usize, n: usize) -> bool {
95        let flops = 2.0 * m as f64 * k as f64 * n as f64;
96        let blas_time = flops / self.blas_flops + self.blas_overhead_ns * 1e-9;
97        let neon_time = flops / self.neon_flops;
98        neon_time < blas_time
99    }
100
101    /// Should we use par_for for this element-wise operation?
102    /// Returns true when parallelism benefit exceeds dispatch overhead.
103    pub fn prefer_parallel(&self, total_elements: usize, cost_per_element_ns: f64) -> bool {
104        let seq_time = total_elements as f64 * cost_per_element_ns;
105        let par_time = seq_time / self.num_threads as f64 + self.par_for_overhead_ns;
106        par_time < seq_time
107    }
108
109    /// Should we use strided BLAS for SDPA, or sequential NEON dots?
110    pub fn prefer_blas_sdpa(
111        &self,
112        batch: usize,
113        seq: usize,
114        num_heads: usize,
115        head_dim: usize,
116    ) -> bool {
117        let total_heads = batch * num_heads;
118        // Two sgemm per head (Q@K^T + scores@V)
119        let per_head_flops = 2.0 * seq as f64 * seq as f64 * head_dim as f64 * 2.0;
120        let blas_per_head = per_head_flops / self.blas_flops + 2.0 * self.blas_overhead_ns * 1e-9;
121        let neon_per_head = per_head_flops / self.neon_flops;
122
123        // With par_for, BLAS heads run in parallel
124        let blas_total = blas_per_head * total_heads as f64 / self.num_threads as f64
125            + self.par_for_overhead_ns * 1e-9;
126        let neon_total = neon_per_head * total_heads as f64; // sequential
127
128        blas_total < neon_total
129    }
130
131    /// Should we fuse the entire transformer layer into one thunk?
132    /// True when intermediates fit in L1 and per-thunk overhead dominates.
133    pub fn prefer_fused_layer(
134        &self,
135        batch: usize,
136        seq: usize,
137        hidden: usize,
138        intermediate: usize,
139    ) -> bool {
140        let m = batch * seq;
141        // Estimate intermediate buffer sizes
142        let qkv_bytes = m * 3 * hidden * 4;
143        let attn_bytes = m * hidden * 4;
144        let ffn_bytes = m * intermediate * 4;
145        let total_bytes = qkv_bytes + 2 * attn_bytes + ffn_bytes;
146        // Fuse if total intermediates fit in L2 (L1 would be ideal but tight)
147        total_bytes <= self.l2_bytes / 2
148    }
149}
150
151/// Global hardware model singleton.
152pub fn hw_model() -> &'static HwModel {
153    use std::sync::OnceLock;
154    static MODEL: OnceLock<HwModel> = OnceLock::new();
155    MODEL.get_or_init(|| HwModel::from_config(RuntimeConfig::global()))
156}