Skip to main content

rlx_cpu/
dispatch.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Dispatch table — calibration-aware kernel selection (plan #2).
17//!
18//! Two layers:
19//!   - **Defaults** from `kernel_config` (compile-time, per-arch).
20//!   - **Overrides** filled at runtime by autotune / calibration
21//!     when a measured number disagrees with the default.
22//!
23//! Borrowed from MAX's `dispatch_table_a100_gpu.mojo` /
24//! `dispatch_table_amd.mojo` pattern: kernel-variant selection is a
25//! data lookup, not scattered match arms in dispatch sites.
26//!
27//! Today the table is consulted by the cost model and (when an
28//! override is set) used to override a `kernel_config` default.
29//! Future work wires fusion patterns through the same table so
30//! schedule decisions are uniformly data-driven.
31
32use crate::kernel_config::{CpuArch, KernelConfig, OpClass, kernel_config_for};
33use std::collections::HashMap;
34use std::sync::{Mutex, OnceLock};
35
36/// One-line override for a default `KernelConfig` field. Add cases
37/// as autotune learns more.
38#[derive(Debug, Clone, Copy)]
39pub enum Override {
40    NeonSeqThreshold(usize),
41    ParGrain(usize),
42    ParThreshold(usize),
43    FuseAttnThreshold(usize),
44}
45
46#[derive(Debug, Default)]
47struct Table {
48    overrides: Mutex<HashMap<(CpuArch, OpClass), Vec<Override>>>,
49}
50
51fn table() -> &'static Table {
52    static T: OnceLock<Table> = OnceLock::new();
53    T.get_or_init(Table::default)
54}
55
56/// Set an override for `(arch, op)`. Idempotent — re-set replaces.
57pub fn set_override(arch: CpuArch, op: OpClass, ov: Override) {
58    let t = table();
59    let mut m = t.overrides.lock().expect("dispatch table poisoned");
60    let entry = m.entry((arch, op)).or_default();
61    // Replace any existing override of the same field tag.
62    entry.retain(|existing| std::mem::discriminant(existing) != std::mem::discriminant(&ov));
63    entry.push(ov);
64}
65
66/// Reset all overrides (test hook).
67#[doc(hidden)]
68pub fn clear_overrides_for_tests() {
69    let t = table();
70    let mut m = t.overrides.lock().expect("dispatch table poisoned");
71    m.clear();
72}
73
74/// Resolve a `KernelConfig` for `(arch, op)`, applying any
75/// overrides on top of the const-time defaults.
76pub fn resolve(arch: CpuArch, op: OpClass) -> KernelConfig {
77    let mut cfg = kernel_config_for(arch, op);
78    let t = table();
79    let m = t.overrides.lock().expect("dispatch table poisoned");
80    if let Some(list) = m.get(&(arch, op)) {
81        for ov in list {
82            match ov {
83                Override::NeonSeqThreshold(v) => cfg.neon_seq_threshold = *v,
84                Override::ParGrain(v) => cfg.par_grain = *v,
85                Override::ParThreshold(v) => cfg.par_threshold = *v,
86                Override::FuseAttnThreshold(v) => cfg.fuse_attn_threshold = *v,
87            }
88        }
89    }
90    cfg
91}
92
93/// Convenience: resolve for the running target.
94pub fn resolve_current(op: OpClass) -> KernelConfig {
95    resolve(CpuArch::current(), op)
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use std::sync::Mutex;
102
103    /// Serialize tests — `resolve` reads a process-global override table.
104    static DISPATCH_TEST_LOCK: Mutex<()> = Mutex::new(());
105
106    fn with_clean_table(f: impl FnOnce()) {
107        let _guard = DISPATCH_TEST_LOCK
108            .lock()
109            .expect("dispatch test lock poisoned");
110        clear_overrides_for_tests();
111        f();
112    }
113
114    #[test]
115    fn defaults_pass_through() {
116        with_clean_table(|| {
117            let arch = CpuArch::AppleSilicon;
118            let op = OpClass::Matmul;
119            let resolved = resolve(arch, op);
120            let default = kernel_config_for(arch, op);
121            assert_eq!(resolved.neon_seq_threshold, default.neon_seq_threshold);
122            assert_eq!(resolved.par_threshold, default.par_threshold);
123        });
124    }
125
126    #[test]
127    fn override_replaces_field() {
128        with_clean_table(|| {
129            let arch = CpuArch::AppleSilicon;
130            let op = OpClass::Matmul;
131            set_override(arch, op, Override::NeonSeqThreshold(7));
132            let r = resolve(arch, op);
133            assert_eq!(r.neon_seq_threshold, 7);
134            // Other fields untouched.
135            let d = kernel_config_for(arch, op);
136            assert_eq!(r.par_threshold, d.par_threshold);
137        });
138    }
139
140    #[test]
141    fn override_for_one_field_replaces_just_that() {
142        with_clean_table(|| {
143            let arch = CpuArch::X86_64;
144            let op = OpClass::Attention;
145            set_override(arch, op, Override::NeonSeqThreshold(5));
146            set_override(arch, op, Override::NeonSeqThreshold(9)); // replaces
147            let r = resolve(arch, op);
148            assert_eq!(r.neon_seq_threshold, 9);
149        });
150    }
151}