rlx_cpu/cost.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cost model — estimates execution time for kernel dispatch decisions.
17//!
18//! Instead of hardcoded thresholds scattered across thunk.rs, all dispatch
19//! decisions are routed through this cost model. The model considers:
20//! - Hardware: cache sizes, core count, AMX availability, NEON throughput
21//! - Workload: matrix dimensions, batch size, sequence length
22//! - Strategy: BLAS vs NEON, parallel vs sequential, fused vs individual
23//!
24//! The model is calibrated at compile time (platform-specific constants)
25//! and refined at runtime (detected hardware + optional autotune).
26
27use crate::config::RuntimeConfig;
28
29/// Estimated cost in nanoseconds for a kernel execution strategy.
30#[derive(Debug, Clone, Copy)]
31pub struct Cost(pub f64);
32
33impl Cost {
34 pub fn ns(self) -> f64 {
35 self.0
36 }
37}
38
39/// Hardware model — derived from RuntimeConfig + platform detection.
40pub struct HwModel {
41 /// NEON throughput: FLOP/s for element-wise (FMA chains)
42 pub neon_flops: f64,
43 /// BLAS throughput: FLOP/s for sgemm (AMX or optimized NEON)
44 pub blas_flops: f64,
45 /// BLAS call overhead in nanoseconds (function call + AMX sync)
46 pub blas_overhead_ns: f64,
47 /// par_for dispatch overhead in nanoseconds
48 pub par_for_overhead_ns: f64,
49 /// L1 data cache size in bytes
50 pub l1_bytes: usize,
51 /// L2 cache size in bytes
52 pub l2_bytes: usize,
53 /// Memory bandwidth (L2 → registers) in bytes/ns
54 pub mem_bw: f64,
55 /// Number of worker threads
56 pub num_threads: usize,
57}
58
59impl HwModel {
60 /// Build from runtime config and platform defaults.
61 pub fn from_config(cfg: &RuntimeConfig) -> Self {
62 // Platform-calibrated constants
63 #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
64 let model = HwModel {
65 neon_flops: 72e9, // ~72 GFLOP/s NEON FMA throughput (M4 Pro P-core)
66 blas_flops: 1000e9, // ~1 TFLOP/s AMX peak (effective varies with tile fill)
67 blas_overhead_ns: 500.0, // ~0.5µs per cblas_sgemm call
68 par_for_overhead_ns: 5000.0, // ~5µs spin-wait dispatch
69 l1_bytes: 65536, // 64KB L1d (refined by sysctl at runtime)
70 l2_bytes: 4 * 1024 * 1024, // 4MB L2 per core
71 mem_bw: 50.0, // ~50 GB/s = 50 B/ns
72 num_threads: cfg.pool_workers + 1,
73 };
74
75 #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
76 let model = HwModel {
77 neon_flops: 32e9,
78 blas_flops: 200e9,
79 blas_overhead_ns: 300.0,
80 par_for_overhead_ns: 3000.0,
81 l1_bytes: 32768,
82 l2_bytes: 1024 * 1024,
83 mem_bw: 30.0,
84 num_threads: cfg.pool_workers + 1,
85 };
86
87 model
88 }
89
90 // ── Dispatch decisions ──────────────────────────────────────────
91
92 /// Should we use NEON sgemm instead of BLAS for this matrix multiply?
93 /// Returns true when BLAS overhead dominates the compute.
94 pub fn prefer_neon_sgemm(&self, m: usize, k: usize, n: usize) -> bool {
95 let flops = 2.0 * m as f64 * k as f64 * n as f64;
96 let blas_time = flops / self.blas_flops + self.blas_overhead_ns * 1e-9;
97 let neon_time = flops / self.neon_flops;
98 neon_time < blas_time
99 }
100
101 /// Should we use par_for for this element-wise operation?
102 /// Returns true when parallelism benefit exceeds dispatch overhead.
103 pub fn prefer_parallel(&self, total_elements: usize, cost_per_element_ns: f64) -> bool {
104 let seq_time = total_elements as f64 * cost_per_element_ns;
105 let par_time = seq_time / self.num_threads as f64 + self.par_for_overhead_ns;
106 par_time < seq_time
107 }
108
109 /// Should we use strided BLAS for SDPA, or sequential NEON dots?
110 pub fn prefer_blas_sdpa(
111 &self,
112 batch: usize,
113 seq: usize,
114 num_heads: usize,
115 head_dim: usize,
116 ) -> bool {
117 let total_heads = batch * num_heads;
118 // Two sgemm per head (Q@K^T + scores@V)
119 let per_head_flops = 2.0 * seq as f64 * seq as f64 * head_dim as f64 * 2.0;
120 let blas_per_head = per_head_flops / self.blas_flops + 2.0 * self.blas_overhead_ns * 1e-9;
121 let neon_per_head = per_head_flops / self.neon_flops;
122
123 // With par_for, BLAS heads run in parallel
124 let blas_total = blas_per_head * total_heads as f64 / self.num_threads as f64
125 + self.par_for_overhead_ns * 1e-9;
126 let neon_total = neon_per_head * total_heads as f64; // sequential
127
128 blas_total < neon_total
129 }
130
131 /// Should we fuse the entire transformer layer into one thunk?
132 /// True when intermediates fit in L1 and per-thunk overhead dominates.
133 pub fn prefer_fused_layer(
134 &self,
135 batch: usize,
136 seq: usize,
137 hidden: usize,
138 intermediate: usize,
139 ) -> bool {
140 let m = batch * seq;
141 // Estimate intermediate buffer sizes
142 let qkv_bytes = m * 3 * hidden * 4;
143 let attn_bytes = m * hidden * 4;
144 let ffn_bytes = m * intermediate * 4;
145 let total_bytes = qkv_bytes + 2 * attn_bytes + ffn_bytes;
146 // Fuse if total intermediates fit in L2 (L1 would be ideal but tight)
147 total_bytes <= self.l2_bytes / 2
148 }
149}
150
151/// Global hardware model singleton.
152pub fn hw_model() -> &'static HwModel {
153 use std::sync::OnceLock;
154 static MODEL: OnceLock<HwModel> = OnceLock::new();
155 MODEL.get_or_init(|| HwModel::from_config(RuntimeConfig::global()))
156}