use crate::config::RuntimeConfig;
use rlx_ir::Tick;
#[derive(Debug, Clone)]
pub struct TuneResult {
pub config: RuntimeConfig,
pub p50_ms: f64,
pub min_ms: f64,
}
pub struct SearchSpace {
pub workers: Vec<usize>,
pub par_thresholds: Vec<usize>,
pub sdpa_thresholds: Vec<usize>,
}
impl Default for SearchSpace {
fn default() -> Self {
let cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
Self {
workers: vec![1, 2, cpus / 4, cpus / 2, cpus * 3 / 4],
par_thresholds: vec![10_000, 20_000, 30_000, 50_000],
sdpa_thresholds: vec![16, 32, 48],
}
}
}
pub fn autotune<F>(mut run_fn: F, search: &SearchSpace, warmup: usize, trials: usize) -> TuneResult
where
F: FnMut(),
{
let mut results: Vec<TuneResult> = Vec::new();
let base = RuntimeConfig::auto_detect();
for &w in &search.workers {
for &par in &search.par_thresholds {
for &sdpa in &search.sdpa_thresholds {
let cfg = RuntimeConfig {
pool_workers: w.clamp(1, 15),
par_threshold: par,
sdpa_seq_threshold: sdpa,
..base.clone()
};
unsafe {
set_global_config(cfg.clone());
}
for _ in 0..warmup {
run_fn();
}
let mut times = Vec::with_capacity(trials);
for _ in 0..trials {
let t = Tick::now();
run_fn();
times.push(Tick::now().elapsed_ms(t));
}
times.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p50 = times[trials / 2];
let min = times[0];
eprintln!(
" workers={w:2} par={par:5} sdpa={sdpa:2} → p50={p50:.2}ms min={min:.2}ms"
);
results.push(TuneResult {
config: cfg,
p50_ms: p50,
min_ms: min,
});
}
}
}
results.sort_by(|a, b| a.p50_ms.partial_cmp(&b.p50_ms).unwrap());
let best = results[0].clone();
unsafe {
set_global_config(best.config.clone());
}
eprintln!(
"[rlx] best: workers={} par={} sdpa={} → {:.2}ms p50",
best.config.pool_workers,
best.config.par_threshold,
best.config.sdpa_seq_threshold,
best.p50_ms
);
best
}
unsafe fn set_global_config(cfg: RuntimeConfig) {
TUNE_CONFIG.lock().unwrap().replace(cfg);
}
pub fn active_config() -> RuntimeConfig {
if let Some(cfg) = TUNE_CONFIG.lock().unwrap().as_ref() {
cfg.clone()
} else {
RuntimeConfig::global().clone()
}
}
static TUNE_CONFIG: std::sync::Mutex<Option<RuntimeConfig>> = std::sync::Mutex::new(None);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn search_space_default() {
let ss = SearchSpace::default();
assert!(ss.workers.len() >= 3);
assert_eq!(ss.par_thresholds.len(), 4);
assert_eq!(ss.sdpa_thresholds.len(), 3);
}
}