rlx_cpu/autotune.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Auto-tuner — finds the optimal RuntimeConfig for a model on current hardware.
17//!
18//! Runs the model with different parameter combinations, measures each,
19//! picks the fastest. Results are printed and can be saved.
20//!
21//! ```rust,ignore
22//! let best = autotune(&compiled, &sample_inputs, 20);
23//! eprintln!("Best config: {:?}", best);
24//! ```
25
26use crate::config::RuntimeConfig;
27use rlx_ir::Tick;
28
29/// Result of one tuning trial.
30#[derive(Debug, Clone)]
31pub struct TuneResult {
32 pub config: RuntimeConfig,
33 pub p50_ms: f64,
34 pub min_ms: f64,
35}
36
37/// Search space for auto-tuning.
38pub struct SearchSpace {
39 pub workers: Vec<usize>,
40 pub par_thresholds: Vec<usize>,
41 pub sdpa_thresholds: Vec<usize>,
42}
43
44impl Default for SearchSpace {
45 fn default() -> Self {
46 let cpus = std::thread::available_parallelism()
47 .map(|n| n.get())
48 .unwrap_or(4);
49 Self {
50 workers: vec![1, 2, cpus / 4, cpus / 2, cpus * 3 / 4],
51 par_thresholds: vec![10_000, 20_000, 30_000, 50_000],
52 sdpa_thresholds: vec![16, 32, 48],
53 }
54 }
55}
56
57/// Auto-tune by running the model with different configs.
58///
59/// `run_fn` is called for each trial — it should execute one forward pass.
60/// `warmup` iterations are run before timing. `trials` are timed.
61pub fn autotune<F>(mut run_fn: F, search: &SearchSpace, warmup: usize, trials: usize) -> TuneResult
62where
63 F: FnMut(),
64{
65 let mut results: Vec<TuneResult> = Vec::new();
66 let base = RuntimeConfig::auto_detect();
67
68 // Generate all combinations
69 for &w in &search.workers {
70 for &par in &search.par_thresholds {
71 for &sdpa in &search.sdpa_thresholds {
72 let cfg = RuntimeConfig {
73 pool_workers: w.clamp(1, 15),
74 par_threshold: par,
75 sdpa_seq_threshold: sdpa,
76 ..base.clone()
77 };
78
79 // Apply this config (affects global singleton for this process)
80 // Note: pool workers can't be changed after init. Skip if different.
81 // For now, only tune par_threshold and sdpa_threshold.
82 unsafe {
83 // Override the global config pointer
84 set_global_config(cfg.clone());
85 }
86
87 // Warmup
88 for _ in 0..warmup {
89 run_fn();
90 }
91
92 // Measure — direct CNTVCT_EL0 read on Apple Silicon (#66).
93 // Sub-microsecond resolution lets the search distinguish
94 // configs whose wall-clock times differ by a few ticks.
95 let mut times = Vec::with_capacity(trials);
96 for _ in 0..trials {
97 let t = Tick::now();
98 run_fn();
99 times.push(Tick::now().elapsed_ms(t));
100 }
101 times.sort_by(|a, b| a.partial_cmp(b).unwrap());
102 let p50 = times[trials / 2];
103 let min = times[0];
104
105 eprintln!(
106 " workers={w:2} par={par:5} sdpa={sdpa:2} → p50={p50:.2}ms min={min:.2}ms"
107 );
108 results.push(TuneResult {
109 config: cfg,
110 p50_ms: p50,
111 min_ms: min,
112 });
113 }
114 }
115 }
116
117 // Find best by p50
118 results.sort_by(|a, b| a.p50_ms.partial_cmp(&b.p50_ms).unwrap());
119 let best = results[0].clone();
120
121 // Apply best config
122 unsafe {
123 set_global_config(best.config.clone());
124 }
125
126 eprintln!(
127 "[rlx] best: workers={} par={} sdpa={} → {:.2}ms p50",
128 best.config.pool_workers,
129 best.config.par_threshold,
130 best.config.sdpa_seq_threshold,
131 best.p50_ms
132 );
133
134 best
135}
136
137/// Override the global RuntimeConfig.
138/// SAFETY: must only be called during auto-tuning (single-threaded phase).
139unsafe fn set_global_config(cfg: RuntimeConfig) {
140 // The OnceLock pattern doesn't allow re-setting.
141 // For auto-tuning, we use a separate mutable global.
142 TUNE_CONFIG.lock().unwrap().replace(cfg);
143}
144
145/// Get the active tuning config (if set), otherwise fall back to global.
146pub fn active_config() -> RuntimeConfig {
147 if let Some(cfg) = TUNE_CONFIG.lock().unwrap().as_ref() {
148 cfg.clone()
149 } else {
150 RuntimeConfig::global().clone()
151 }
152}
153
154static TUNE_CONFIG: std::sync::Mutex<Option<RuntimeConfig>> = std::sync::Mutex::new(None);
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn search_space_default() {
162 let ss = SearchSpace::default();
163 assert!(ss.workers.len() >= 3);
164 assert_eq!(ss.par_thresholds.len(), 4);
165 assert_eq!(ss.sdpa_thresholds.len(), 3);
166 }
167}