Skip to main content

rlx_cpu/
autotune.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Auto-tuner — finds the optimal RuntimeConfig for a model on current hardware.
17//!
18//! Runs the model with different parameter combinations, measures each,
19//! picks the fastest. Results are printed and can be saved.
20//!
21//! ```rust,ignore
22//! let best = autotune(&compiled, &sample_inputs, 20);
23//! eprintln!("Best config: {:?}", best);
24//! ```
25
26use crate::config::RuntimeConfig;
27use rlx_ir::Tick;
28
29/// Result of one tuning trial.
30#[derive(Debug, Clone)]
31pub struct TuneResult {
32    pub config: RuntimeConfig,
33    pub p50_ms: f64,
34    pub min_ms: f64,
35}
36
37/// Search space for auto-tuning.
38pub struct SearchSpace {
39    pub workers: Vec<usize>,
40    pub par_thresholds: Vec<usize>,
41    pub sdpa_thresholds: Vec<usize>,
42}
43
44impl Default for SearchSpace {
45    fn default() -> Self {
46        let cpus = std::thread::available_parallelism()
47            .map(|n| n.get())
48            .unwrap_or(4);
49        Self {
50            workers: vec![1, 2, cpus / 4, cpus / 2, cpus * 3 / 4],
51            par_thresholds: vec![10_000, 20_000, 30_000, 50_000],
52            sdpa_thresholds: vec![16, 32, 48],
53        }
54    }
55}
56
57/// Auto-tune by running the model with different configs.
58///
59/// `run_fn` is called for each trial — it should execute one forward pass.
60/// `warmup` iterations are run before timing. `trials` are timed.
61pub fn autotune<F>(mut run_fn: F, search: &SearchSpace, warmup: usize, trials: usize) -> TuneResult
62where
63    F: FnMut(),
64{
65    let mut results: Vec<TuneResult> = Vec::new();
66    let base = RuntimeConfig::auto_detect();
67
68    // Generate all combinations
69    for &w in &search.workers {
70        for &par in &search.par_thresholds {
71            for &sdpa in &search.sdpa_thresholds {
72                let cfg = RuntimeConfig {
73                    pool_workers: w.clamp(1, 15),
74                    par_threshold: par,
75                    sdpa_seq_threshold: sdpa,
76                    ..base.clone()
77                };
78
79                // Apply this config (affects global singleton for this process)
80                // Note: pool workers can't be changed after init. Skip if different.
81                // For now, only tune par_threshold and sdpa_threshold.
82                unsafe {
83                    // Override the global config pointer
84                    set_global_config(cfg.clone());
85                }
86
87                // Warmup
88                for _ in 0..warmup {
89                    run_fn();
90                }
91
92                // Measure — direct CNTVCT_EL0 read on Apple Silicon (#66).
93                // Sub-microsecond resolution lets the search distinguish
94                // configs whose wall-clock times differ by a few ticks.
95                let mut times = Vec::with_capacity(trials);
96                for _ in 0..trials {
97                    let t = Tick::now();
98                    run_fn();
99                    times.push(Tick::now().elapsed_ms(t));
100                }
101                times.sort_by(|a, b| a.partial_cmp(b).unwrap());
102                let p50 = times[trials / 2];
103                let min = times[0];
104
105                eprintln!(
106                    "  workers={w:2} par={par:5} sdpa={sdpa:2} → p50={p50:.2}ms min={min:.2}ms"
107                );
108                results.push(TuneResult {
109                    config: cfg,
110                    p50_ms: p50,
111                    min_ms: min,
112                });
113            }
114        }
115    }
116
117    // Find best by p50
118    results.sort_by(|a, b| a.p50_ms.partial_cmp(&b.p50_ms).unwrap());
119    let best = results[0].clone();
120
121    // Apply best config
122    unsafe {
123        set_global_config(best.config.clone());
124    }
125
126    eprintln!(
127        "[rlx] best: workers={} par={} sdpa={} → {:.2}ms p50",
128        best.config.pool_workers,
129        best.config.par_threshold,
130        best.config.sdpa_seq_threshold,
131        best.p50_ms
132    );
133
134    best
135}
136
137/// Override the global RuntimeConfig.
138/// SAFETY: must only be called during auto-tuning (single-threaded phase).
139unsafe fn set_global_config(cfg: RuntimeConfig) {
140    // The OnceLock pattern doesn't allow re-setting.
141    // For auto-tuning, we use a separate mutable global.
142    TUNE_CONFIG.lock().unwrap().replace(cfg);
143}
144
145/// Get the active tuning config (if set), otherwise fall back to global.
146pub fn active_config() -> RuntimeConfig {
147    if let Some(cfg) = TUNE_CONFIG.lock().unwrap().as_ref() {
148        cfg.clone()
149    } else {
150        RuntimeConfig::global().clone()
151    }
152}
153
154static TUNE_CONFIG: std::sync::Mutex<Option<RuntimeConfig>> = std::sync::Mutex::new(None);
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn search_space_default() {
162        let ss = SearchSpace::default();
163        assert!(ss.workers.len() >= 3);
164        assert_eq!(ss.par_thresholds.len(), 4);
165        assert_eq!(ss.sdpa_thresholds.len(), 3);
166    }
167}