use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
pub type GradFn = Arc<dyn Fn(&[Vec<f64>], &[Vec<f64>]) -> Vec<Vec<f64>> + Send + Sync>;
pub struct CustomGradRegistry {
ops: HashMap<String, GradFn>,
}
impl Default for CustomGradRegistry {
fn default() -> Self {
Self::new()
}
}
impl CustomGradRegistry {
pub fn new() -> Self {
Self {
ops: HashMap::new(),
}
}
pub fn register(
&mut self,
op_name: impl Into<String>,
grad_fn: impl Fn(&[Vec<f64>], &[Vec<f64>]) -> Vec<Vec<f64>> + Send + Sync + 'static,
) {
self.ops
.insert(op_name.into(), Arc::new(grad_fn));
}
pub fn get(&self, op_name: &str) -> Option<&GradFn> {
self.ops.get(op_name)
}
pub fn contains(&self, op_name: &str) -> bool {
self.ops.contains_key(op_name)
}
pub fn remove(&mut self, op_name: &str) -> Option<GradFn> {
self.ops.remove(op_name)
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.ops.keys().map(|s| s.as_str())
}
pub fn global() -> &'static RwLock<CustomGradRegistry> {
static GLOBAL: OnceLock<RwLock<CustomGradRegistry>> = OnceLock::new();
GLOBAL.get_or_init(|| RwLock::new(CustomGradRegistry::new()))
}
}
pub struct CustomGradientFn<F, GF> {
forward: F,
gradient: GF,
}
impl<F, GF> CustomGradientFn<F, GF>
where
F: Fn(&[Vec<f64>]) -> Vec<f64>,
GF: Fn(&[Vec<f64>], &[Vec<f64>]) -> Vec<Vec<f64>>,
{
pub fn new(forward: F, gradient: GF) -> Self {
Self { forward, gradient }
}
pub fn call(&self, inputs: &[Vec<f64>]) -> Vec<f64> {
(self.forward)(inputs)
}
pub fn grad(&self, inputs: &[Vec<f64>], output_grad: &[Vec<f64>]) -> Vec<Vec<f64>> {
(self.gradient)(inputs, output_grad)
}
pub fn forward_with_backward(
&self,
inputs: &[Vec<f64>],
) -> (Vec<f64>, impl Fn(&[Vec<f64>]) -> Vec<Vec<f64>> + '_) {
let output = (self.forward)(inputs);
let inputs_clone = inputs.to_vec();
let bwd = move |og: &[Vec<f64>]| (self.gradient)(&inputs_clone, og);
(output, bwd)
}
}
#[macro_export]
macro_rules! define_custom_gradient {
(
fwd: $fwd:expr,
bwd: $bwd:expr
) => {
$crate::custom_grad::CustomGradientFn::new($fwd, $bwd)
};
}
pub fn register_global_grad(
op_name: impl Into<String>,
grad_fn: impl Fn(&[Vec<f64>], &[Vec<f64>]) -> Vec<Vec<f64>> + Send + Sync + 'static,
) -> Result<(), String> {
CustomGradRegistry::global()
.write()
.map(|mut reg| reg.register(op_name, grad_fn))
.map_err(|e| format!("CustomGradRegistry lock poisoned: {e}"))
}
pub fn lookup_global_grad(op_name: &str) -> Option<GradFn> {
CustomGradRegistry::global()
.read()
.ok()
.and_then(|reg| reg.get(op_name).cloned())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_register_and_lookup() {
let mut reg = CustomGradRegistry::new();
reg.register("sq_grad", |inputs: &[Vec<f64>], og: &[Vec<f64>]| {
vec![inputs[0]
.iter()
.zip(og[0].iter())
.map(|(&x, &g)| 2.0 * x * g)
.collect()]
});
assert!(reg.contains("sq_grad"));
assert!(!reg.contains("nonexistent"));
}
#[test]
fn test_registry_get_and_call() {
let mut reg = CustomGradRegistry::new();
reg.register("double_grad", |_inputs: &[Vec<f64>], og: &[Vec<f64>]| {
vec![og[0].iter().map(|&g| 2.0 * g).collect()]
});
let f = reg.get("double_grad").expect("should be registered");
let inputs = vec![vec![1.0_f64, 2.0]];
let og = vec![vec![1.0_f64, 1.0]];
let grads = f(&inputs, &og);
assert_eq!(grads[0], vec![2.0, 2.0]);
}
#[test]
fn test_registry_replace_existing() {
let mut reg = CustomGradRegistry::new();
reg.register("op", |_i: &[Vec<f64>], og: &[Vec<f64>]| {
vec![og[0].clone()]
});
reg.register("op", |_i: &[Vec<f64>], og: &[Vec<f64>]| {
vec![og[0].iter().map(|&g| g * 3.0).collect()]
});
let f = reg.get("op").expect("should exist");
let og = vec![vec![2.0_f64]];
let grads = f(&[], &og);
assert!((grads[0][0] - 6.0).abs() < 1e-10);
}
#[test]
fn test_registry_remove() {
let mut reg = CustomGradRegistry::new();
reg.register("removable", |_i: &[Vec<f64>], og: &[Vec<f64>]| {
vec![og[0].clone()]
});
assert!(reg.contains("removable"));
let removed = reg.remove("removable");
assert!(removed.is_some());
assert!(!reg.contains("removable"));
}
#[test]
fn test_registry_len_and_is_empty() {
let mut reg = CustomGradRegistry::new();
assert!(reg.is_empty());
reg.register("a", |_i, og| vec![og[0].clone()]);
reg.register("b", |_i, og| vec![og[0].clone()]);
assert_eq!(reg.len(), 2);
}
#[test]
fn test_decorator_forward_square() {
let sq = CustomGradientFn::new(
|inputs: &[Vec<f64>]| inputs[0].iter().map(|&v| v * v).collect::<Vec<f64>>(),
|inputs: &[Vec<f64>], og: &[Vec<f64>]| {
vec![inputs[0]
.iter()
.zip(og[0].iter())
.map(|(&x, &g)| 2.0 * x * g)
.collect()]
},
);
let out = sq.call(&[vec![3.0_f64, 4.0]]);
assert!((out[0] - 9.0).abs() < 1e-10, "out[0]={}", out[0]);
assert!((out[1] - 16.0).abs() < 1e-10, "out[1]={}", out[1]);
}
#[test]
fn test_decorator_gradient_square() {
let sq = CustomGradientFn::new(
|inputs: &[Vec<f64>]| inputs[0].iter().map(|&v| v * v).collect::<Vec<f64>>(),
|inputs: &[Vec<f64>], og: &[Vec<f64>]| {
vec![inputs[0]
.iter()
.zip(og[0].iter())
.map(|(&x, &g)| 2.0 * x * g)
.collect()]
},
);
let grads = sq.grad(&[vec![3.0_f64]], &[vec![1.0_f64]]);
assert!((grads[0][0] - 6.0).abs() < 1e-10, "grad={}", grads[0][0]);
}
#[test]
fn test_decorator_forward_with_backward() {
let sq = CustomGradientFn::new(
|inputs: &[Vec<f64>]| inputs[0].iter().map(|&v| v * v).collect::<Vec<f64>>(),
|inputs: &[Vec<f64>], og: &[Vec<f64>]| {
vec![inputs[0]
.iter()
.zip(og[0].iter())
.map(|(&x, &g)| 2.0 * x * g)
.collect()]
},
);
let (out, bwd) = sq.forward_with_backward(&[vec![5.0_f64]]);
assert!((out[0] - 25.0).abs() < 1e-10);
let grads = bwd(&[vec![1.0_f64]]);
assert!((grads[0][0] - 10.0).abs() < 1e-10);
}
#[test]
fn test_macro_cube() {
let cube = define_custom_gradient!(
fwd: |inputs: &[Vec<f64>]| -> Vec<f64> {
inputs[0].iter().map(|&v| v * v * v).collect()
},
bwd: |inputs: &[Vec<f64>], og: &[Vec<f64>]| -> Vec<Vec<f64>> {
vec![inputs[0].iter().zip(og[0].iter())
.map(|(&x, &g)| 3.0 * x * x * g).collect()]
}
);
let z = cube.call(&[vec![2.0_f64]]);
assert!((z[0] - 8.0).abs() < 1e-10, "z[0]={}", z[0]);
let g = cube.grad(&[vec![2.0_f64]], &[vec![1.0_f64]]);
assert!((g[0][0] - 12.0).abs() < 1e-10, "g[0][0]={}", g[0][0]);
}
#[test]
fn test_macro_log_grad() {
let logop = define_custom_gradient!(
fwd: |inputs: &[Vec<f64>]| -> Vec<f64> {
inputs[0].iter().map(|&v| v.ln()).collect()
},
bwd: |inputs: &[Vec<f64>], og: &[Vec<f64>]| -> Vec<Vec<f64>> {
vec![inputs[0].iter().zip(og[0].iter())
.map(|(&x, &g)| g / x).collect()]
}
);
let z = logop.call(&[vec![std::f64::consts::E]]);
assert!((z[0] - 1.0).abs() < 1e-10);
let g = logop.grad(&[vec![2.0_f64]], &[vec![1.0_f64]]);
assert!((g[0][0] - 0.5).abs() < 1e-10);
}
#[test]
fn test_global_registry_register_and_lookup() {
register_global_grad("test_global_abs_grad", |inputs: &[Vec<f64>], og: &[Vec<f64>]| {
vec![inputs[0]
.iter()
.zip(og[0].iter())
.map(|(&x, &g)| g * x.signum())
.collect()]
})
.expect("lock should not be poisoned");
let f = lookup_global_grad("test_global_abs_grad")
.expect("should be registered");
let inputs = vec![vec![-2.0_f64, 3.0]];
let og = vec![vec![1.0_f64, 1.0]];
let grads = f(&inputs, &og);
assert!((grads[0][0] + 1.0).abs() < 1e-10); assert!((grads[0][1] - 1.0).abs() < 1e-10); }
#[test]
fn test_global_registry_lookup_missing() {
let f = lookup_global_grad("definitely_not_registered_xyz_abc");
assert!(f.is_none());
}
}