Skip to main content

rlx_cpu/
op_registry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-backend (CPU) kernel registry for `Op::Custom`.
17//!
18//! Companion to [`rlx_ir::op_registry`]. The IR-level [`rlx_ir::OpExtension`]
19//! covers shape inference + autodiff; this registry covers
20//! *execution* on the CPU backend. Splitting them keeps `rlx-ir`
21//! portable and lets a custom op honestly support a subset of
22//! backends — attempting to compile an `Op::Custom` for a backend
23//! whose kernel isn't registered is a hard error, not a silent no-op.
24//!
25//! ## API contract for downstream kernel authors
26//!
27//! - **One method, typed views in.** Each input arrives as a
28//!   [`CpuTensorRef`] variant matching that input's declared dtype.
29//!   The output is a [`CpuTensorMut`] matching the output dtype. No
30//!   byte reinterpretation in user code.
31//! - **Mixed-dtype inputs work directly.** A Sparse-LU op with
32//!   `(F64 values, I32 col_idx, I32 row_ptr, F64 b)` gets each input
33//!   as the right typed slice — no manual byte casts.
34//! - **Contiguous, dense buffers from the arena.** Strided / broadcast
35//!   inputs need to be materialized by the caller before reaching the
36//!   kernel; the IR's `Op::Expand` / `Op::Transpose` already cover
37//!   that.
38//! - **`attrs` is opaque** — same `Vec<u8>` as the IR variant. Decode
39//!   it however the kernel likes (typical: `bincode`, `bytemuck`, a
40//!   hand-rolled struct cast).
41//!
42//! ## Multiple logical outputs
43//!
44//! `Op::Custom` produces a single tensor by design — the IR's `Node`
45//! is one-shape-per-node. Ops that conceptually return multiple
46//! outputs (LU returning L+U, eigendecomp returning λ+V) write a
47//! *packed* output and the user follows the custom op with `Narrow`
48//! to extract each logical output. Use
49//! [`rlx_ir::Graph::custom_op_packed`] when registry-driven shape
50//! inference isn't sufficient.
51
52use std::collections::HashMap;
53use std::sync::{Arc, OnceLock, RwLock};
54
55use rlx_ir::{DType, Shape};
56
57// Why an enum, not generics? `CpuKernel` takes inputs of *mixed*
58// dtypes (e.g. Sparse-LU has `(F64 values, I32 col_idx, I32 row_ptr,
59// F64 b)`). A generic `CpuKernel<T>` couldn't express that — every
60// input would have to share the same `T`. The enum-of-typed-views
61// is the right shape for this contract; generics over `T: Pod`
62// would only buy us the per-input case, which is what `as_*` /
63// `expect_*` accessors already provide.
64//
65// One variant per `rlx_ir::DType`. The dispatcher in
66// `thunk.rs::dispatch_custom_op` enumerates all of them — adding a
67// dtype to `DType` requires adding a variant here and an arm
68// there. Single source of truth for "what's wired."
69
70macro_rules! dtype_variants {
71    (
72        $(
73            $variant:ident => $rust_ty:ty,
74            $as_method:ident, $as_mut_method:ident,
75            $expect_method:ident, $expect_mut_method:ident,
76        )*
77    ) => {
78        /// Read-only typed view of one input tensor handed to a [`CpuKernel`].
79        /// The variant matches the input's declared dtype on the IR side.
80        pub enum CpuTensorRef<'a> {
81            $(
82                $variant { data: &'a [$rust_ty], shape: &'a Shape },
83            )*
84        }
85
86        /// Mutable typed view of the output tensor handed to a [`CpuKernel`].
87        pub enum CpuTensorMut<'a> {
88            $(
89                $variant { data: &'a mut [$rust_ty], shape: &'a Shape },
90            )*
91        }
92
93        impl<'a> CpuTensorRef<'a> {
94            pub fn shape(&self) -> &Shape {
95                match self {
96                    $( Self::$variant { shape, .. } => shape, )*
97                }
98            }
99            pub fn dtype(&self) -> DType { self.shape().dtype() }
100
101            $(
102                pub fn $as_method(&self) -> Option<&[$rust_ty]> {
103                    if let Self::$variant { data, .. } = self { Some(data) } else { None }
104                }
105                pub fn $expect_method(&self, role: &str) -> Result<&[$rust_ty], String> {
106                    self.$as_method().ok_or_else(|| format!(
107                        "{role}: expected {:?}, got {:?}",
108                        DType::$variant, self.dtype()))
109                }
110            )*
111        }
112
113        impl<'a> CpuTensorMut<'a> {
114            pub fn shape(&self) -> &Shape {
115                match self {
116                    $( Self::$variant { shape, .. } => shape, )*
117                }
118            }
119            pub fn dtype(&self) -> DType { self.shape().dtype() }
120
121            $(
122                pub fn $as_mut_method(self) -> Option<&'a mut [$rust_ty]> {
123                    if let Self::$variant { data, .. } = self { Some(data) } else { None }
124                }
125                pub fn $expect_mut_method(self, role: &str) -> Result<&'a mut [$rust_ty], String> {
126                    let dt = self.dtype();
127                    self.$as_mut_method().ok_or_else(|| format!(
128                        "{role}: expected {:?}, got {dt:?}", DType::$variant))
129                }
130            )*
131        }
132    };
133}
134
135// One row per DType. Bool is stored as `u8` on the wire (one byte
136// per element, 0 = false / non-zero = true) — exposing it as a bool
137// slice directly would be UB if any byte pattern other than 0/1
138// landed there, which the IR doesn't guarantee.
139dtype_variants! {
140    F32  => f32,        as_f32,  as_f32_mut,  expect_f32,  expect_f32_mut,
141    F64  => f64,        as_f64,  as_f64_mut,  expect_f64,  expect_f64_mut,
142    F16  => half::f16,  as_f16,  as_f16_mut,  expect_f16,  expect_f16_mut,
143    BF16 => half::bf16, as_bf16, as_bf16_mut, expect_bf16, expect_bf16_mut,
144    I8   => i8,         as_i8,   as_i8_mut,   expect_i8,   expect_i8_mut,
145    I16  => i16,        as_i16,  as_i16_mut,  expect_i16,  expect_i16_mut,
146    I32  => i32,        as_i32,  as_i32_mut,  expect_i32,  expect_i32_mut,
147    I64  => i64,        as_i64,  as_i64_mut,  expect_i64,  expect_i64_mut,
148    U8   => u8,         as_u8,   as_u8_mut,   expect_u8,   expect_u8_mut,
149    U32  => u32,        as_u32,  as_u32_mut,  expect_u32,  expect_u32_mut,
150    Bool => u8,         as_bool, as_bool_mut, expect_bool, expect_bool_mut,
151}
152
153/// Trait a CPU kernel implements for one custom op. Registered under
154/// the same `name` used in `Op::Custom` and `OpExtension::name`.
155///
156/// One method, typed views in. Match on the variants you support and
157/// return `Err(...)` for anything else — the executor surfaces that
158/// as a panic naming the op + dtype, so missing support fails loudly
159/// instead of silently zeroing the output.
160pub trait CpuKernel: Send + Sync {
161    fn name(&self) -> &str;
162
163    fn execute(
164        &self,
165        inputs: &[CpuTensorRef<'_>],
166        output: CpuTensorMut<'_>,
167        attrs: &[u8],
168    ) -> Result<(), String>;
169}
170
171pub struct CpuKernelRegistry {
172    kernels: RwLock<HashMap<String, Arc<dyn CpuKernel>>>,
173}
174
175impl CpuKernelRegistry {
176    pub fn new() -> Self {
177        Self {
178            kernels: RwLock::new(HashMap::new()),
179        }
180    }
181
182    /// Register a kernel. Re-registration replaces the previous entry
183    /// and prints a one-line warning to stderr — silent overwrite has
184    /// bitten us before, the warning is cheap.
185    pub fn register(&self, k: Arc<dyn CpuKernel>) {
186        let name = k.name().to_string();
187        let mut g = self.kernels.write().unwrap();
188        if g.contains_key(&name) {
189            eprintln!(
190                "rlx-cpu: CpuKernel '{name}' was already registered — \
191                 replacing the previous entry"
192            );
193        }
194        g.insert(name, k);
195    }
196
197    pub fn lookup(&self, name: &str) -> Option<Arc<dyn CpuKernel>> {
198        self.kernels.read().unwrap().get(name).cloned()
199    }
200}
201
202impl Default for CpuKernelRegistry {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208pub fn global_cpu_kernels() -> &'static CpuKernelRegistry {
209    static R: OnceLock<CpuKernelRegistry> = OnceLock::new();
210    R.get_or_init(CpuKernelRegistry::new)
211}
212
213pub fn register_cpu_kernel(k: Arc<dyn CpuKernel>) {
214    global_cpu_kernels().register(k);
215}
216
217pub fn lookup_cpu_kernel(name: &str) -> Option<Arc<dyn CpuKernel>> {
218    global_cpu_kernels().lookup(name)
219}