rlx_cpu/kernel_config.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compile-time kernel-config tables (plan #14).
17//!
18//! Borrowed from MAX's `internal_utils/nvidia_configs.mojo` /
19//! `amd_configs.mojo` pattern: tile sizes, kernel-selection
20//! thresholds, etc. as compile-time data structures kernels query
21//! instead of scattered match-arms.
22//!
23//! Today the values live as `const`s here and are surfaced through
24//! [`kernel_config_for`]. The goal is one source of truth — when
25//! we want to tune for a new arch (M5 Apple Silicon, x86 Zen5,
26//! etc.) we add a row to the table, not a new match arm in 12
27//! files.
28
29/// Coarse target classification — refined as new SoCs emerge.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum CpuArch {
32 AppleSilicon, // M-series; AMX + NEON
33 AarchGeneric, // Other ARM (RPi, AWS Graviton)
34 X86_64,
35 Other,
36}
37
38impl CpuArch {
39 /// Pick the best label for the running target.
40 pub const fn current() -> Self {
41 #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
42 {
43 Self::AppleSilicon
44 }
45 #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
46 {
47 Self::AarchGeneric
48 }
49 #[cfg(target_arch = "x86_64")]
50 {
51 Self::X86_64
52 }
53 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
54 {
55 Self::Other
56 }
57 }
58}
59
60/// Op category that a kernel config is keyed against. Coarse —
61/// refines as we learn which thresholds actually want per-shape
62/// tuning.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum OpClass {
65 /// Matmul / linear-algebra hot path.
66 Matmul,
67 /// SDPA attention.
68 Attention,
69 /// Element-wise / activation / norm.
70 Elementwise,
71 /// View / shape ops (reshape, narrow, transpose).
72 Shape,
73}
74
75/// Settings the dispatch logic asks for. All in elements unless
76/// noted otherwise.
77#[derive(Debug, Clone, Copy)]
78pub struct KernelConfig {
79 /// Below this batch*seq, prefer NEON over BLAS for matmul / SDPA.
80 pub neon_seq_threshold: usize,
81 /// par_for granularity for elementwise.
82 pub par_grain: usize,
83 /// Below this total element count, run sequentially (par_for
84 /// has positive overhead even at 0 work).
85 pub par_threshold: usize,
86 /// FusedAttnBlock fires when batch*seq <= this.
87 pub fuse_attn_threshold: usize,
88}
89
90const APPLE_SILICON: KernelConfig = KernelConfig {
91 neon_seq_threshold: 32,
92 par_grain: 64,
93 par_threshold: 30_000,
94 fuse_attn_threshold: 64,
95};
96
97const AARCH_GENERIC: KernelConfig = KernelConfig {
98 neon_seq_threshold: 24,
99 par_grain: 32,
100 par_threshold: 20_000,
101 fuse_attn_threshold: 48,
102};
103
104const X86_DEFAULT: KernelConfig = KernelConfig {
105 neon_seq_threshold: 16, // AVX2 path; lower threshold reflects bigger vector unit
106 par_grain: 32,
107 par_threshold: 20_000,
108 fuse_attn_threshold: 32,
109};
110
111const FALLBACK: KernelConfig = KernelConfig {
112 neon_seq_threshold: 16,
113 par_grain: 16,
114 par_threshold: 10_000,
115 fuse_attn_threshold: 16,
116};
117
118/// Look up the canonical kernel config for `(arch, op_class)`. The
119/// table is `const`-evaluated so callers pay no lookup cost.
120pub const fn kernel_config_for(arch: CpuArch, op: OpClass) -> KernelConfig {
121 // Today the per-op variation is small — we return one row per
122 // arch and let callers read the field they care about. As more
123 // shape-specific tuning data lands, the OpClass dimension
124 // becomes load-bearing.
125 let _ = op;
126 match arch {
127 CpuArch::AppleSilicon => APPLE_SILICON,
128 CpuArch::AarchGeneric => AARCH_GENERIC,
129 CpuArch::X86_64 => X86_DEFAULT,
130 CpuArch::Other => FALLBACK,
131 }
132}
133
134/// Convenience: defaults for the running target.
135pub const fn current_config(op: OpClass) -> KernelConfig {
136 kernel_config_for(CpuArch::current(), op)
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn current_resolves() {
145 let cfg = current_config(OpClass::Matmul);
146 // All targets at minimum produce a non-zero threshold.
147 assert!(cfg.neon_seq_threshold > 0);
148 assert!(cfg.par_threshold > 0);
149 }
150
151 #[test]
152 fn apple_silicon_picks_higher_thresholds() {
153 let m = kernel_config_for(CpuArch::AppleSilicon, OpClass::Matmul);
154 let f = kernel_config_for(CpuArch::Other, OpClass::Matmul);
155 assert!(m.neon_seq_threshold >= f.neon_seq_threshold);
156 assert!(m.par_threshold >= f.par_threshold);
157 }
158}