use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
pub type CustomOpFn = Box<dyn Fn(&[&[f32]]) -> Vec<f32> + Send + Sync>;
struct Registry {
map: Mutex<HashMap<String, CustomOpFn>>,
}
fn registry() -> &'static Registry {
static R: OnceLock<Registry> = OnceLock::new();
R.get_or_init(|| Registry {
map: Mutex::new(HashMap::new()),
})
}
pub fn register<F>(name: impl Into<String>, f: F)
where
F: Fn(&[&[f32]]) -> Vec<f32> + Send + Sync + 'static,
{
let r = registry();
let mut m = r.map.lock().expect("custom-op registry poisoned");
m.insert(name.into(), Box::new(f));
}
pub fn execute(name: &str, inputs: &[&[f32]]) -> Option<Vec<f32>> {
let r = registry();
let m = r.map.lock().expect("custom-op registry poisoned");
m.get(name).map(|f| f(inputs))
}
pub fn registered() -> Vec<String> {
let r = registry();
let m = r.map.lock().expect("custom-op registry poisoned");
let mut v: Vec<String> = m.keys().cloned().collect();
v.sort();
v
}
#[doc(hidden)]
pub fn clear_for_tests() {
let r = registry();
r.map.lock().unwrap().clear();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_then_execute() {
clear_for_tests();
register("test.identity", |ins| ins[0].to_vec());
let out = execute("test.identity", &[&[1.0, 2.0, 3.0]]).unwrap();
assert_eq!(out, vec![1.0, 2.0, 3.0]);
}
#[test]
fn unknown_op_returns_none() {
clear_for_tests();
assert!(execute("nope", &[]).is_none());
}
#[test]
fn re_register_replaces() {
clear_for_tests();
register("test.f", |_| vec![1.0]);
register("test.f", |_| vec![2.0]);
assert_eq!(execute("test.f", &[]).unwrap(), vec![2.0]);
}
}