rlx_runtime/custom_ops.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Custom-op extensibility (plan #25).
17//!
18//! Borrowed from MAX's `extensibility/compiler_internal/` /
19//! `extensibility/tensor/` pattern: downstream users register their
20//! own ops + executors without forking the framework.
21//!
22//! Today this is the data layer — a registry mapping a string op
23//! name to an executor closure. The matching `#[rlx_op]` proc macro
24//! is the syntactic sugar layer; adding it is straightforward when a
25//! real consumer needs less boilerplate.
26//!
27//! The IR doesn't model custom ops natively (there's no
28//! `Op::Custom("name")` variant) — they enter through the runtime
29//! via `CustomOpRegistry::execute` rather than as graph nodes.
30//! That's deliberate: the optimizer's fusion patterns can't reason
31//! about ops it doesn't know, so custom ops should be opaque
32//! "black box" sub-stages rather than first-class IR citizens.
33//! Promote them to real ops once a fusion pattern would benefit.
34
35use std::collections::HashMap;
36use std::sync::{Mutex, OnceLock};
37
38/// Boxed executor: takes (read-only inputs) → produces an owned output.
39/// `Vec<Vec<f32>>` for now since that matches the `rlx_runtime`
40/// `CompiledGraph::run` signature; revisit when the runtime moves to
41/// `Buffer` (plan #59) end-to-end.
42pub type CustomOpFn = Box<dyn Fn(&[&[f32]]) -> Vec<f32> + Send + Sync>;
43
44struct Registry {
45 map: Mutex<HashMap<String, CustomOpFn>>,
46}
47
48fn registry() -> &'static Registry {
49 static R: OnceLock<Registry> = OnceLock::new();
50 R.get_or_init(|| Registry {
51 map: Mutex::new(HashMap::new()),
52 })
53}
54
55/// Register a custom op under `name`. Idempotent — re-registering
56/// replaces. Names are arbitrary strings; convention: dotted
57/// namespacing like `"my-crate.my-op"`.
58pub fn register<F>(name: impl Into<String>, f: F)
59where
60 F: Fn(&[&[f32]]) -> Vec<f32> + Send + Sync + 'static,
61{
62 let r = registry();
63 let mut m = r.map.lock().expect("custom-op registry poisoned");
64 m.insert(name.into(), Box::new(f));
65}
66
67/// Execute a previously-registered op. Returns `None` if the op
68/// isn't registered.
69pub fn execute(name: &str, inputs: &[&[f32]]) -> Option<Vec<f32>> {
70 let r = registry();
71 let m = r.map.lock().expect("custom-op registry poisoned");
72 m.get(name).map(|f| f(inputs))
73}
74
75/// Snapshot of registered op names (sorted, deterministic).
76pub fn registered() -> Vec<String> {
77 let r = registry();
78 let m = r.map.lock().expect("custom-op registry poisoned");
79 let mut v: Vec<String> = m.keys().cloned().collect();
80 v.sort();
81 v
82}
83
84#[doc(hidden)]
85pub fn clear_for_tests() {
86 let r = registry();
87 r.map.lock().unwrap().clear();
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn register_then_execute() {
96 clear_for_tests();
97 register("test.identity", |ins| ins[0].to_vec());
98 let out = execute("test.identity", &[&[1.0, 2.0, 3.0]]).unwrap();
99 assert_eq!(out, vec![1.0, 2.0, 3.0]);
100 }
101
102 #[test]
103 fn unknown_op_returns_none() {
104 clear_for_tests();
105 assert!(execute("nope", &[]).is_none());
106 }
107
108 #[test]
109 fn re_register_replaces() {
110 clear_for_tests();
111 register("test.f", |_| vec![1.0]);
112 register("test.f", |_| vec![2.0]);
113 assert_eq!(execute("test.f", &[]).unwrap(), vec![2.0]);
114 }
115}