rlx_cpu/op_registry.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-backend (CPU) kernel registry for `Op::Custom`.
17//!
18//! Companion to [`rlx_ir::op_registry`]. The IR-level [`rlx_ir::OpExtension`]
19//! covers shape inference + autodiff; this registry covers
20//! *execution* on the CPU backend. Splitting them keeps `rlx-ir`
21//! portable and lets a custom op honestly support a subset of
22//! backends — attempting to compile an `Op::Custom` for a backend
23//! whose kernel isn't registered is a hard error, not a silent no-op.
24//!
25//! ## API contract for downstream kernel authors
26//!
27//! - **One method, typed views in.** Each input arrives as a
28//! [`CpuTensorRef`] variant matching that input's declared dtype.
29//! The output is a [`CpuTensorMut`] matching the output dtype. No
30//! byte reinterpretation in user code.
31//! - **Mixed-dtype inputs work directly.** A Sparse-LU op with
32//! `(F64 values, I32 col_idx, I32 row_ptr, F64 b)` gets each input
33//! as the right typed slice — no manual byte casts.
34//! - **Contiguous, dense buffers from the arena.** Strided / broadcast
35//! inputs need to be materialized by the caller before reaching the
36//! kernel; the IR's `Op::Expand` / `Op::Transpose` already cover
37//! that.
38//! - **`attrs` is opaque** — same `Vec<u8>` as the IR variant. Decode
39//! it however the kernel likes (typical: `bincode`, `bytemuck`, a
40//! hand-rolled struct cast).
41//!
42//! ## Multiple logical outputs
43//!
44//! `Op::Custom` produces a single tensor by design — the IR's `Node`
45//! is one-shape-per-node. Ops that conceptually return multiple
46//! outputs (LU returning L+U, eigendecomp returning λ+V) write a
47//! *packed* output and the user follows the custom op with `Narrow`
48//! to extract each logical output. Use
49//! [`rlx_ir::Graph::custom_op_packed`] when registry-driven shape
50//! inference isn't sufficient.
51
52use std::collections::HashMap;
53use std::sync::{Arc, OnceLock, RwLock};
54
55use rlx_ir::{DType, Shape};
56
57// Why an enum, not generics? `CpuKernel` takes inputs of *mixed*
58// dtypes (e.g. Sparse-LU has `(F64 values, I32 col_idx, I32 row_ptr,
59// F64 b)`). A generic `CpuKernel<T>` couldn't express that — every
60// input would have to share the same `T`. The enum-of-typed-views
61// is the right shape for this contract; generics over `T: Pod`
62// would only buy us the per-input case, which is what `as_*` /
63// `expect_*` accessors already provide.
64//
65// One variant per `rlx_ir::DType`. The dispatcher in
66// `thunk.rs::dispatch_custom_op` enumerates all of them — adding a
67// dtype to `DType` requires adding a variant here and an arm
68// there. Single source of truth for "what's wired."
69
70macro_rules! dtype_variants {
71 (
72 $(
73 $variant:ident => $rust_ty:ty,
74 $as_method:ident, $as_mut_method:ident,
75 $expect_method:ident, $expect_mut_method:ident,
76 )*
77 ) => {
78 /// Read-only typed view of one input tensor handed to a [`CpuKernel`].
79 /// The variant matches the input's declared dtype on the IR side.
80 pub enum CpuTensorRef<'a> {
81 $(
82 $variant { data: &'a [$rust_ty], shape: &'a Shape },
83 )*
84 }
85
86 /// Mutable typed view of the output tensor handed to a [`CpuKernel`].
87 pub enum CpuTensorMut<'a> {
88 $(
89 $variant { data: &'a mut [$rust_ty], shape: &'a Shape },
90 )*
91 }
92
93 impl<'a> CpuTensorRef<'a> {
94 pub fn shape(&self) -> &Shape {
95 match self {
96 $( Self::$variant { shape, .. } => shape, )*
97 }
98 }
99 pub fn dtype(&self) -> DType { self.shape().dtype() }
100
101 $(
102 pub fn $as_method(&self) -> Option<&[$rust_ty]> {
103 if let Self::$variant { data, .. } = self { Some(data) } else { None }
104 }
105 pub fn $expect_method(&self, role: &str) -> Result<&[$rust_ty], String> {
106 self.$as_method().ok_or_else(|| format!(
107 "{role}: expected {:?}, got {:?}",
108 DType::$variant, self.dtype()))
109 }
110 )*
111 }
112
113 impl<'a> CpuTensorMut<'a> {
114 pub fn shape(&self) -> &Shape {
115 match self {
116 $( Self::$variant { shape, .. } => shape, )*
117 }
118 }
119 pub fn dtype(&self) -> DType { self.shape().dtype() }
120
121 $(
122 pub fn $as_mut_method(self) -> Option<&'a mut [$rust_ty]> {
123 if let Self::$variant { data, .. } = self { Some(data) } else { None }
124 }
125 pub fn $expect_mut_method(self, role: &str) -> Result<&'a mut [$rust_ty], String> {
126 let dt = self.dtype();
127 self.$as_mut_method().ok_or_else(|| format!(
128 "{role}: expected {:?}, got {dt:?}", DType::$variant))
129 }
130 )*
131 }
132 };
133}
134
135// One row per DType. Bool is stored as `u8` on the wire (one byte
136// per element, 0 = false / non-zero = true) — exposing it as a bool
137// slice directly would be UB if any byte pattern other than 0/1
138// landed there, which the IR doesn't guarantee.
139dtype_variants! {
140 F32 => f32, as_f32, as_f32_mut, expect_f32, expect_f32_mut,
141 F64 => f64, as_f64, as_f64_mut, expect_f64, expect_f64_mut,
142 F16 => half::f16, as_f16, as_f16_mut, expect_f16, expect_f16_mut,
143 BF16 => half::bf16, as_bf16, as_bf16_mut, expect_bf16, expect_bf16_mut,
144 I8 => i8, as_i8, as_i8_mut, expect_i8, expect_i8_mut,
145 I16 => i16, as_i16, as_i16_mut, expect_i16, expect_i16_mut,
146 I32 => i32, as_i32, as_i32_mut, expect_i32, expect_i32_mut,
147 I64 => i64, as_i64, as_i64_mut, expect_i64, expect_i64_mut,
148 U8 => u8, as_u8, as_u8_mut, expect_u8, expect_u8_mut,
149 U32 => u32, as_u32, as_u32_mut, expect_u32, expect_u32_mut,
150 Bool => u8, as_bool, as_bool_mut, expect_bool, expect_bool_mut,
151}
152
153/// Trait a CPU kernel implements for one custom op. Registered under
154/// the same `name` used in `Op::Custom` and `OpExtension::name`.
155///
156/// One method, typed views in. Match on the variants you support and
157/// return `Err(...)` for anything else — the executor surfaces that
158/// as a panic naming the op + dtype, so missing support fails loudly
159/// instead of silently zeroing the output.
160pub trait CpuKernel: Send + Sync {
161 fn name(&self) -> &str;
162
163 fn execute(
164 &self,
165 inputs: &[CpuTensorRef<'_>],
166 output: CpuTensorMut<'_>,
167 attrs: &[u8],
168 ) -> Result<(), String>;
169}
170
171pub struct CpuKernelRegistry {
172 kernels: RwLock<HashMap<String, Arc<dyn CpuKernel>>>,
173}
174
175impl CpuKernelRegistry {
176 pub fn new() -> Self {
177 Self {
178 kernels: RwLock::new(HashMap::new()),
179 }
180 }
181
182 /// Register a kernel. Re-registration replaces the previous entry
183 /// and prints a one-line warning to stderr — silent overwrite has
184 /// bitten us before, the warning is cheap.
185 pub fn register(&self, k: Arc<dyn CpuKernel>) {
186 let name = k.name().to_string();
187 let mut g = self.kernels.write().unwrap();
188 if g.contains_key(&name) {
189 eprintln!(
190 "rlx-cpu: CpuKernel '{name}' was already registered — \
191 replacing the previous entry"
192 );
193 }
194 g.insert(name, k);
195 }
196
197 pub fn lookup(&self, name: &str) -> Option<Arc<dyn CpuKernel>> {
198 self.kernels.read().unwrap().get(name).cloned()
199 }
200}
201
202impl Default for CpuKernelRegistry {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208pub fn global_cpu_kernels() -> &'static CpuKernelRegistry {
209 static R: OnceLock<CpuKernelRegistry> = OnceLock::new();
210 R.get_or_init(CpuKernelRegistry::new)
211}
212
213pub fn register_cpu_kernel(k: Arc<dyn CpuKernel>) {
214 global_cpu_kernels().register(k);
215}
216
217pub fn lookup_cpu_kernel(name: &str) -> Option<Arc<dyn CpuKernel>> {
218 global_cpu_kernels().lookup(name)
219}