Skip to main content

rlx_cpu/
kernel_config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compile-time kernel-config tables (plan #14).
17//!
18//! Borrowed from MAX's `internal_utils/nvidia_configs.mojo` /
19//! `amd_configs.mojo` pattern: tile sizes, kernel-selection
20//! thresholds, etc. as compile-time data structures kernels query
21//! instead of scattered match-arms.
22//!
23//! Today the values live as `const`s here and are surfaced through
24//! [`kernel_config_for`]. The goal is one source of truth — when
25//! we want to tune for a new arch (M5 Apple Silicon, x86 Zen5,
26//! etc.) we add a row to the table, not a new match arm in 12
27//! files.
28
29/// Coarse target classification — refined as new SoCs emerge.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum CpuArch {
32    AppleSilicon, // M-series; AMX + NEON
33    AarchGeneric, // Other ARM (RPi, AWS Graviton)
34    X86_64,
35    Other,
36}
37
38impl CpuArch {
39    /// Pick the best label for the running target.
40    pub const fn current() -> Self {
41        #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
42        {
43            Self::AppleSilicon
44        }
45        #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
46        {
47            Self::AarchGeneric
48        }
49        #[cfg(target_arch = "x86_64")]
50        {
51            Self::X86_64
52        }
53        #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
54        {
55            Self::Other
56        }
57    }
58}
59
60/// Op category that a kernel config is keyed against. Coarse —
61/// refines as we learn which thresholds actually want per-shape
62/// tuning.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum OpClass {
65    /// Matmul / linear-algebra hot path.
66    Matmul,
67    /// SDPA attention.
68    Attention,
69    /// Element-wise / activation / norm.
70    Elementwise,
71    /// View / shape ops (reshape, narrow, transpose).
72    Shape,
73}
74
75/// Settings the dispatch logic asks for. All in elements unless
76/// noted otherwise.
77#[derive(Debug, Clone, Copy)]
78pub struct KernelConfig {
79    /// Below this batch*seq, prefer NEON over BLAS for matmul / SDPA.
80    pub neon_seq_threshold: usize,
81    /// par_for granularity for elementwise.
82    pub par_grain: usize,
83    /// Below this total element count, run sequentially (par_for
84    /// has positive overhead even at 0 work).
85    pub par_threshold: usize,
86    /// FusedAttnBlock fires when batch*seq <= this.
87    pub fuse_attn_threshold: usize,
88}
89
90const APPLE_SILICON: KernelConfig = KernelConfig {
91    neon_seq_threshold: 32,
92    par_grain: 64,
93    par_threshold: 30_000,
94    fuse_attn_threshold: 64,
95};
96
97const AARCH_GENERIC: KernelConfig = KernelConfig {
98    neon_seq_threshold: 24,
99    par_grain: 32,
100    par_threshold: 20_000,
101    fuse_attn_threshold: 48,
102};
103
104const X86_DEFAULT: KernelConfig = KernelConfig {
105    neon_seq_threshold: 16, // AVX2 path; lower threshold reflects bigger vector unit
106    par_grain: 32,
107    par_threshold: 20_000,
108    fuse_attn_threshold: 32,
109};
110
111const FALLBACK: KernelConfig = KernelConfig {
112    neon_seq_threshold: 16,
113    par_grain: 16,
114    par_threshold: 10_000,
115    fuse_attn_threshold: 16,
116};
117
118/// Look up the canonical kernel config for `(arch, op_class)`. The
119/// table is `const`-evaluated so callers pay no lookup cost.
120pub const fn kernel_config_for(arch: CpuArch, op: OpClass) -> KernelConfig {
121    // Today the per-op variation is small — we return one row per
122    // arch and let callers read the field they care about. As more
123    // shape-specific tuning data lands, the OpClass dimension
124    // becomes load-bearing.
125    let _ = op;
126    match arch {
127        CpuArch::AppleSilicon => APPLE_SILICON,
128        CpuArch::AarchGeneric => AARCH_GENERIC,
129        CpuArch::X86_64 => X86_DEFAULT,
130        CpuArch::Other => FALLBACK,
131    }
132}
133
134/// Convenience: defaults for the running target.
135pub const fn current_config(op: OpClass) -> KernelConfig {
136    kernel_config_for(CpuArch::current(), op)
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn current_resolves() {
145        let cfg = current_config(OpClass::Matmul);
146        // All targets at minimum produce a non-zero threshold.
147        assert!(cfg.neon_seq_threshold > 0);
148        assert!(cfg.par_threshold > 0);
149    }
150
151    #[test]
152    fn apple_silicon_picks_higher_thresholds() {
153        let m = kernel_config_for(CpuArch::AppleSilicon, OpClass::Matmul);
154        let f = kernel_config_for(CpuArch::Other, OpClass::Matmul);
155        assert!(m.neon_seq_threshold >= f.neon_seq_threshold);
156        assert!(m.par_threshold >= f.par_threshold);
157    }
158}