use crate::kernel_config::{CpuArch, KernelConfig, OpClass, kernel_config_for};
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
#[derive(Debug, Clone, Copy)]
pub enum Override {
NeonSeqThreshold(usize),
ParGrain(usize),
ParThreshold(usize),
FuseAttnThreshold(usize),
}
#[derive(Debug, Default)]
struct Table {
overrides: Mutex<HashMap<(CpuArch, OpClass), Vec<Override>>>,
}
fn table() -> &'static Table {
static T: OnceLock<Table> = OnceLock::new();
T.get_or_init(Table::default)
}
pub fn set_override(arch: CpuArch, op: OpClass, ov: Override) {
let t = table();
let mut m = t.overrides.lock().expect("dispatch table poisoned");
let entry = m.entry((arch, op)).or_default();
entry.retain(|existing| std::mem::discriminant(existing) != std::mem::discriminant(&ov));
entry.push(ov);
}
#[doc(hidden)]
pub fn clear_overrides_for_tests() {
let t = table();
let mut m = t.overrides.lock().expect("dispatch table poisoned");
m.clear();
}
pub fn resolve(arch: CpuArch, op: OpClass) -> KernelConfig {
let mut cfg = kernel_config_for(arch, op);
let t = table();
let m = t.overrides.lock().expect("dispatch table poisoned");
if let Some(list) = m.get(&(arch, op)) {
for ov in list {
match ov {
Override::NeonSeqThreshold(v) => cfg.neon_seq_threshold = *v,
Override::ParGrain(v) => cfg.par_grain = *v,
Override::ParThreshold(v) => cfg.par_threshold = *v,
Override::FuseAttnThreshold(v) => cfg.fuse_attn_threshold = *v,
}
}
}
cfg
}
pub fn resolve_current(op: OpClass) -> KernelConfig {
resolve(CpuArch::current(), op)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static DISPATCH_TEST_LOCK: Mutex<()> = Mutex::new(());
fn with_clean_table(f: impl FnOnce()) {
let _guard = DISPATCH_TEST_LOCK
.lock()
.expect("dispatch test lock poisoned");
clear_overrides_for_tests();
f();
}
#[test]
fn defaults_pass_through() {
with_clean_table(|| {
let arch = CpuArch::AppleSilicon;
let op = OpClass::Matmul;
let resolved = resolve(arch, op);
let default = kernel_config_for(arch, op);
assert_eq!(resolved.neon_seq_threshold, default.neon_seq_threshold);
assert_eq!(resolved.par_threshold, default.par_threshold);
});
}
#[test]
fn override_replaces_field() {
with_clean_table(|| {
let arch = CpuArch::AppleSilicon;
let op = OpClass::Matmul;
set_override(arch, op, Override::NeonSeqThreshold(7));
let r = resolve(arch, op);
assert_eq!(r.neon_seq_threshold, 7);
let d = kernel_config_for(arch, op);
assert_eq!(r.par_threshold, d.par_threshold);
});
}
#[test]
fn override_for_one_field_replaces_just_that() {
with_clean_table(|| {
let arch = CpuArch::X86_64;
let op = OpClass::Attention;
set_override(arch, op, Override::NeonSeqThreshold(5));
set_override(arch, op, Override::NeonSeqThreshold(9)); let r = resolve(arch, op);
assert_eq!(r.neon_seq_threshold, 9);
});
}
}