1use crate::kernel_config::{CpuArch, KernelConfig, OpClass, kernel_config_for};
33use std::collections::HashMap;
34use std::sync::{Mutex, OnceLock};
35
36#[derive(Debug, Clone, Copy)]
39pub enum Override {
40 NeonSeqThreshold(usize),
41 ParGrain(usize),
42 ParThreshold(usize),
43 FuseAttnThreshold(usize),
44}
45
46#[derive(Debug, Default)]
47struct Table {
48 overrides: Mutex<HashMap<(CpuArch, OpClass), Vec<Override>>>,
49}
50
51fn table() -> &'static Table {
52 static T: OnceLock<Table> = OnceLock::new();
53 T.get_or_init(Table::default)
54}
55
56pub fn set_override(arch: CpuArch, op: OpClass, ov: Override) {
58 let t = table();
59 let mut m = t.overrides.lock().expect("dispatch table poisoned");
60 let entry = m.entry((arch, op)).or_default();
61 entry.retain(|existing| std::mem::discriminant(existing) != std::mem::discriminant(&ov));
63 entry.push(ov);
64}
65
66#[doc(hidden)]
68pub fn clear_overrides_for_tests() {
69 let t = table();
70 let mut m = t.overrides.lock().expect("dispatch table poisoned");
71 m.clear();
72}
73
74pub fn resolve(arch: CpuArch, op: OpClass) -> KernelConfig {
77 let mut cfg = kernel_config_for(arch, op);
78 let t = table();
79 let m = t.overrides.lock().expect("dispatch table poisoned");
80 if let Some(list) = m.get(&(arch, op)) {
81 for ov in list {
82 match ov {
83 Override::NeonSeqThreshold(v) => cfg.neon_seq_threshold = *v,
84 Override::ParGrain(v) => cfg.par_grain = *v,
85 Override::ParThreshold(v) => cfg.par_threshold = *v,
86 Override::FuseAttnThreshold(v) => cfg.fuse_attn_threshold = *v,
87 }
88 }
89 }
90 cfg
91}
92
93pub fn resolve_current(op: OpClass) -> KernelConfig {
95 resolve(CpuArch::current(), op)
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use std::sync::Mutex;
102
103 static DISPATCH_TEST_LOCK: Mutex<()> = Mutex::new(());
105
106 fn with_clean_table(f: impl FnOnce()) {
107 let _guard = DISPATCH_TEST_LOCK
108 .lock()
109 .expect("dispatch test lock poisoned");
110 clear_overrides_for_tests();
111 f();
112 }
113
114 #[test]
115 fn defaults_pass_through() {
116 with_clean_table(|| {
117 let arch = CpuArch::AppleSilicon;
118 let op = OpClass::Matmul;
119 let resolved = resolve(arch, op);
120 let default = kernel_config_for(arch, op);
121 assert_eq!(resolved.neon_seq_threshold, default.neon_seq_threshold);
122 assert_eq!(resolved.par_threshold, default.par_threshold);
123 });
124 }
125
126 #[test]
127 fn override_replaces_field() {
128 with_clean_table(|| {
129 let arch = CpuArch::AppleSilicon;
130 let op = OpClass::Matmul;
131 set_override(arch, op, Override::NeonSeqThreshold(7));
132 let r = resolve(arch, op);
133 assert_eq!(r.neon_seq_threshold, 7);
134 let d = kernel_config_for(arch, op);
136 assert_eq!(r.par_threshold, d.par_threshold);
137 });
138 }
139
140 #[test]
141 fn override_for_one_field_replaces_just_that() {
142 with_clean_table(|| {
143 let arch = CpuArch::X86_64;
144 let op = OpClass::Attention;
145 set_override(arch, op, Override::NeonSeqThreshold(5));
146 set_override(arch, op, Override::NeonSeqThreshold(9)); let r = resolve(arch, op);
148 assert_eq!(r.neon_seq_threshold, 9);
149 });
150 }
151}