Skip to main content

rlx_cpu/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Runtime configuration — compile-time platform defaults + runtime hardware detection.
17//!
18//! Compile-time: target arch/OS sets optimal defaults (cache line, SIMD strategy).
19//! Runtime: sysctl/cpuid refines values (P-core count, L1/L2 sizes).
20//! Env vars: `RLX_*` overrides for manual tuning — or set the same keys in
21//! code via [`rlx_ir::env::set`] / [`RuntimeConfig::install`].
22//!
23//! ```bash
24//! RLX_WORKERS=8           # thread pool size (0 = auto)
25//! RLX_PAR_THRESHOLD=20000 # min elements for parallel dispatch
26//! RLX_SDPA_THRESHOLD=32   # seq len: NEON dots (≤) vs BLAS sgemm (>)
27//! RLX_ARENA_ALIGN=128     # arena buffer alignment in bytes
28//! RLX_VERBOSE=0           # 0=quiet, 1=fusion passes, 2=full graph dump
29//! ```
30
31use std::sync::OnceLock;
32
33// ── Compile-time platform defaults ──────────────────────────────────────
34
35/// Cache line size — known at compile time per platform.
36#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
37const PLATFORM_CACHE_LINE: usize = 128; // Apple Silicon: 128-byte L1 lines
38
39#[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
40const PLATFORM_CACHE_LINE: usize = 64; // ARM servers (Graviton, Ampere): typically 64
41
42#[cfg(not(target_arch = "aarch64"))]
43const PLATFORM_CACHE_LINE: usize = 64; // x86_64: 64-byte cache lines
44
45/// Default parallel threshold — tuned per platform.
46/// Apple Silicon AMX handles BLAS internally; our par_for is for element-wise ops.
47/// Lower threshold = more parallelism for LayerNorm/GELU on small tensors.
48#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
49const PLATFORM_PAR_THRESHOLD: usize = 16_384; // Apple Silicon: AMX does BLAS, parallelize rest earlier
50
51#[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
52const PLATFORM_PAR_THRESHOLD: usize = 30_000;
53
54// ── Runtime hardware detection ──────────────────────────────────────────
55
56/// Detect hardware properties at runtime.
57struct HwInfo {
58    total_cpus: usize,
59    perf_cores: usize, // P-cores (0 = unknown, use total_cpus)
60    l1d_cache: usize,  // L1 data cache bytes (0 = unknown)
61    l2_cache: usize,   // L2 cache bytes (0 = unknown)
62    cache_line: usize, // actual cache line from OS (0 = use compile-time default)
63}
64
65impl HwInfo {
66    fn detect() -> Self {
67        let total = std::thread::available_parallelism()
68            .map(|n| n.get())
69            .unwrap_or(2);
70
71        let mut info = HwInfo {
72            total_cpus: total,
73            perf_cores: 0,
74            l1d_cache: 0,
75            l2_cache: 0,
76            cache_line: 0,
77        };
78
79        #[cfg(target_os = "macos")]
80        {
81            info.perf_cores = sysctl_usize("hw.perflevel0.physicalcpu").unwrap_or(0);
82            info.l1d_cache = sysctl_usize("hw.l1dcachesize").unwrap_or(0);
83            info.l2_cache = sysctl_usize("hw.l2cachesize").unwrap_or(0);
84            info.cache_line = sysctl_usize("hw.cachelinesize").unwrap_or(0);
85        }
86
87        #[cfg(target_os = "linux")]
88        {
89            // /sys/devices/system/cpu/cpu0/cache/index0/coherency_line_size
90            if let Ok(v) = std::fs::read_to_string(
91                "/sys/devices/system/cpu/cpu0/cache/index0/coherency_line_size",
92            ) {
93                info.cache_line = v.trim().parse().unwrap_or(0);
94            }
95            if let Ok(v) = std::fs::read_to_string("/sys/devices/system/cpu/cpu0/cache/index0/size")
96            {
97                // Parse "32K" or "32768"
98                let s = v.trim().to_uppercase();
99                if s.ends_with('K') {
100                    info.l1d_cache = s[..s.len() - 1].parse::<usize>().unwrap_or(0) * 1024;
101                } else {
102                    info.l1d_cache = s.parse().unwrap_or(0);
103                }
104            }
105        }
106
107        info
108    }
109
110    /// Optimal worker count: P-cores/2 (avoids E-cores + AMX cache thrashing).
111    fn optimal_workers(&self) -> usize {
112        let base = if self.perf_cores > 0 {
113            self.perf_cores / 2 // Use half of P-cores
114        } else {
115            self.total_cpus / 2 // Fallback: half of all CPUs
116        };
117        base.clamp(1, 15)
118    }
119
120    /// Cache line: prefer runtime-detected, fall back to compile-time.
121    fn cache_line(&self) -> usize {
122        if self.cache_line > 0 {
123            self.cache_line
124        } else {
125            PLATFORM_CACHE_LINE
126        }
127    }
128
129    /// Fusion threshold: intermediates must fit in L1 for monolithic kernels.
130    #[allow(dead_code)]
131    fn fuse_attn_threshold(&self) -> usize {
132        if self.l1d_cache > 0 {
133            // L1 budget: ~60% for intermediates (rest for weights being streamed)
134            // Each fused layer needs: qkv(m×3h) + attn(m×h) + res(m×h) + normed(m×h) + ffn(m×int)
135            // ≈ m × 7h floats for BERT. Must fit in 60% of L1.
136            // Solve for m: m = 0.6 * L1 / (7 * 768 * 4) ≈ L1/36000
137            // For L1=64KB: m ≈ 1.8 → batch*seq ≤ ~2 → threshold ~64 is about right
138            64
139        } else {
140            64
141        }
142    }
143}
144
145#[cfg(target_os = "macos")]
146fn sysctl_usize(name: &str) -> Option<usize> {
147    use std::ffi::CString;
148    let cname = CString::new(name).ok()?;
149    let mut val: u64 = 0;
150    let mut len = std::mem::size_of::<u64>();
151    unsafe {
152        unsafe extern "C" {
153            fn sysctlbyname(
154                name: *const i8,
155                oldp: *mut u8,
156                oldlenp: *mut usize,
157                newp: *const u8,
158                newlen: usize,
159            ) -> i32;
160        }
161        let ret = sysctlbyname(
162            cname.as_ptr(),
163            &mut val as *mut u64 as *mut u8,
164            &mut len,
165            std::ptr::null(),
166            0,
167        );
168        if ret == 0 { Some(val as usize) } else { None }
169    }
170}
171
172/// Runtime configuration for the RLX CPU backend.
173#[derive(Debug, Clone)]
174pub struct RuntimeConfig {
175    // ── Thread pool ─────────────────────────────────────────
176    pub pool_workers: usize,
177
178    // ── Parallelization ─────────────────────────────────────
179    pub par_threshold: usize,
180    pub min_rows_per_thread: usize,
181
182    // ── SDPA dispatch ───────────────────────────────────────
183    pub sdpa_seq_threshold: usize,
184
185    // ── Memory planning ─────────────────────────────────────
186    pub arena_alignment: usize,
187
188    // ── Numerical constants ─────────────────────────────────
189    pub ln_eps_default: f32,
190    pub attn_mask_neg_inf: f32,
191    pub score_skip_threshold: f32,
192    pub mask_binary_threshold: f32,
193
194    // ── Diagnostics ─────────────────────────────────────────
195    pub verbose: u8,
196}
197
198impl Default for RuntimeConfig {
199    fn default() -> Self {
200        Self::auto_detect()
201    }
202}
203
204impl RuntimeConfig {
205    /// Auto-detect hardware and apply optimal defaults.
206    pub fn auto_detect() -> Self {
207        let hw = HwInfo::detect();
208
209        Self {
210            pool_workers: hw.optimal_workers(),
211            par_threshold: PLATFORM_PAR_THRESHOLD,
212            min_rows_per_thread: 4,
213            sdpa_seq_threshold: 32,
214            arena_alignment: hw.cache_line(),
215            ln_eps_default: 1e-12,
216            attn_mask_neg_inf: -1e9,
217            score_skip_threshold: 1e-8,
218            mask_binary_threshold: 0.5,
219            verbose: 0,
220        }
221    }
222
223    /// Auto-detect then override from `RLX_*` environment variables.
224    pub fn from_env() -> Self {
225        let mut cfg = Self::auto_detect();
226
227        if let Some(v) = rlx_ir::env::var("RLX_WORKERS")
228            && let Ok(n) = v.parse::<usize>()
229        {
230            cfg.pool_workers = if n == 0 { cfg.pool_workers } else { n.min(15) };
231        }
232        if let Some(v) = rlx_ir::env::var("RLX_PAR_THRESHOLD")
233            && let Ok(n) = v.parse()
234        {
235            cfg.par_threshold = n;
236        }
237        if let Some(v) = rlx_ir::env::var("RLX_SDPA_THRESHOLD")
238            && let Ok(n) = v.parse()
239        {
240            cfg.sdpa_seq_threshold = n;
241        }
242        if let Some(v) = rlx_ir::env::var("RLX_ARENA_ALIGN")
243            && let Ok(n) = v.parse()
244        {
245            cfg.arena_alignment = n;
246        }
247        if let Some(v) = rlx_ir::env::var("RLX_VERBOSE")
248            && let Ok(n) = v.parse()
249        {
250            cfg.verbose = n;
251        }
252
253        if cfg.verbose >= 1 {
254            let hw = HwInfo::detect();
255            eprintln!(
256                "[rlx] hw: {} CPUs ({} P-cores), L1={}KB, L2={}KB, cacheline={}B",
257                hw.total_cpus,
258                hw.perf_cores,
259                hw.l1d_cache / 1024,
260                hw.l2_cache / 1024,
261                hw.cache_line()
262            );
263            eprintln!(
264                "[rlx] config: workers={}, par_thr={}, sdpa_thr={}, align={}",
265                cfg.pool_workers, cfg.par_threshold, cfg.sdpa_seq_threshold, cfg.arena_alignment
266            );
267        }
268
269        cfg
270    }
271
272    /// Push this config into the global [`rlx_ir::env`] override map so all
273    /// RLX backends see the same knobs without setting process env vars.
274    pub fn install(&self) {
275        rlx_ir::env::set("RLX_WORKERS", self.pool_workers.to_string());
276        rlx_ir::env::set("RLX_PAR_THRESHOLD", self.par_threshold.to_string());
277        rlx_ir::env::set("RLX_SDPA_THRESHOLD", self.sdpa_seq_threshold.to_string());
278        rlx_ir::env::set("RLX_ARENA_ALIGN", self.arena_alignment.to_string());
279        rlx_ir::env::set("RLX_VERBOSE", self.verbose.to_string());
280    }
281
282    /// Get or initialize the global singleton config.
283    pub fn global() -> &'static RuntimeConfig {
284        static CONFIG: OnceLock<RuntimeConfig> = OnceLock::new();
285        CONFIG.get_or_init(RuntimeConfig::from_env)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn auto_detect_sane_defaults() {
295        let cfg = RuntimeConfig::auto_detect();
296        assert!(cfg.pool_workers >= 1);
297        assert!(cfg.pool_workers <= 15);
298        // Platform-appropriate cache line
299        assert!(cfg.arena_alignment >= 64);
300        assert!(cfg.verbose == 0);
301    }
302
303    #[test]
304    fn global_is_consistent() {
305        let a = RuntimeConfig::global();
306        let b = RuntimeConfig::global();
307        assert_eq!(a.pool_workers, b.pool_workers);
308    }
309
310    #[test]
311    fn hw_detection() {
312        let hw = HwInfo::detect();
313        assert!(hw.total_cpus >= 1);
314        // On macOS with sysctl, we should detect cache line
315        #[cfg(target_os = "macos")]
316        assert!(
317            hw.cache_line > 0,
318            "expected sysctl to return cache line size"
319        );
320    }
321}