use std::collections::HashMap;
pub trait DiffOp: Send + Sync {
fn name(&self) -> &str;
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64>;
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>>;
}
fn elementwise_binary<F>(a: &[f64], b: &[f64], f: F) -> Vec<f64>
where
F: Fn(f64, f64) -> f64,
{
a.iter().zip(b.iter()).map(|(&ai, &bi)| f(ai, bi)).collect()
}
fn elementwise_unary<F>(a: &[f64], f: F) -> Vec<f64>
where
F: Fn(f64) -> f64,
{
a.iter().map(|&v| f(v)).collect()
}
pub struct Add;
impl DiffOp for Add {
fn name(&self) -> &str {
"add"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.len() < 2 || inputs[0].is_empty() {
return Vec::new();
}
elementwise_binary(inputs[0], inputs[1], |a, b| a + b)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
let n = if inputs.is_empty() { 0 } else { inputs[0].len() };
if n == 0 {
return vec![Vec::new(), Vec::new()];
}
vec![output_grad.to_vec(), output_grad.to_vec()]
}
}
pub struct Sub;
impl DiffOp for Sub {
fn name(&self) -> &str {
"sub"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.len() < 2 || inputs[0].is_empty() {
return Vec::new();
}
elementwise_binary(inputs[0], inputs[1], |a, b| a - b)
}
fn vjp(&self, _inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
let neg: Vec<f64> = output_grad.iter().map(|&v| -v).collect();
vec![output_grad.to_vec(), neg]
}
}
pub struct Mul;
impl DiffOp for Mul {
fn name(&self) -> &str {
"mul"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.len() < 2 || inputs[0].is_empty() {
return Vec::new();
}
elementwise_binary(inputs[0], inputs[1], |a, b| a * b)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.len() < 2 {
return vec![Vec::new(), Vec::new()];
}
let dx = elementwise_binary(output_grad, inputs[1], |g, y| g * y);
let dy = elementwise_binary(output_grad, inputs[0], |g, x| g * x);
vec![dx, dy]
}
}
pub struct Div;
impl DiffOp for Div {
fn name(&self) -> &str {
"div"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.len() < 2 || inputs[0].is_empty() {
return Vec::new();
}
elementwise_binary(inputs[0], inputs[1], |a, b| a / b)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.len() < 2 {
return vec![Vec::new(), Vec::new()];
}
let x = inputs[0];
let y = inputs[1];
let dx: Vec<f64> = output_grad
.iter()
.zip(y.iter())
.map(|(&g, &yi)| g / yi)
.collect();
let dy: Vec<f64> = output_grad
.iter()
.zip(x.iter())
.zip(y.iter())
.map(|((&g, &xi), &yi)| -g * xi / (yi * yi))
.collect();
vec![dx, dy]
}
}
pub struct Neg;
impl DiffOp for Neg {
fn name(&self) -> &str {
"neg"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], |v| -v)
}
fn vjp(&self, _inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
vec![elementwise_unary(output_grad, |g| -g)]
}
}
pub struct Exp;
impl DiffOp for Exp {
fn name(&self) -> &str {
"exp"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], f64::exp)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| g * x.exp())
.collect();
vec![dx]
}
}
pub struct Log;
impl DiffOp for Log {
fn name(&self) -> &str {
"log"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], f64::ln)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| g / x)
.collect();
vec![dx]
}
}
pub struct Sqrt;
impl DiffOp for Sqrt {
fn name(&self) -> &str {
"sqrt"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], f64::sqrt)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| g / (2.0 * x.sqrt()))
.collect();
vec![dx]
}
}
pub struct Relu;
impl DiffOp for Relu {
fn name(&self) -> &str {
"relu"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], |v| v.max(0.0))
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
.collect();
vec![dx]
}
}
pub struct Sigmoid;
impl DiffOp for Sigmoid {
fn name(&self) -> &str {
"sigmoid"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], |x| 1.0 / (1.0 + (-x).exp()))
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| {
let z = 1.0 / (1.0 + (-x).exp());
g * z * (1.0 - z)
})
.collect();
vec![dx]
}
}
pub struct Tanh;
impl DiffOp for Tanh {
fn name(&self) -> &str {
"tanh"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
elementwise_unary(inputs[0], f64::tanh)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| {
let t = x.tanh();
g * (1.0 - t * t)
})
.collect();
vec![dx]
}
}
pub struct Pow {
pub n: f64,
}
impl DiffOp for Pow {
fn name(&self) -> &str {
"pow"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
let n = self.n;
elementwise_unary(inputs[0], |x| x.powf(n))
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let n = self.n;
let dx: Vec<f64> = output_grad
.iter()
.zip(inputs[0].iter())
.map(|(&g, &x)| g * n * x.powf(n - 1.0))
.collect();
vec![dx]
}
}
pub struct Sum;
impl DiffOp for Sum {
fn name(&self) -> &str {
"sum"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return vec![0.0];
}
vec![inputs[0].iter().sum()]
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let g = output_grad.first().copied().unwrap_or(0.0);
vec![vec![g; inputs[0].len()]]
}
}
pub struct Scale {
pub alpha: f64,
}
impl DiffOp for Scale {
fn name(&self) -> &str {
"scale"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
let a = self.alpha;
elementwise_unary(inputs[0], |v| a * v)
}
fn vjp(&self, _inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
let a = self.alpha;
vec![elementwise_unary(output_grad, |g| a * g)]
}
}
pub struct Softmax;
impl DiffOp for Softmax {
fn name(&self) -> &str {
"softmax"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
let x = inputs[0];
let max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = x.iter().map(|&v| (v - max).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.is_empty() {
return vec![Vec::new()];
}
let x = inputs[0];
let max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = x.iter().map(|&v| (v - max).exp()).collect();
let sum: f64 = exps.iter().sum();
let z: Vec<f64> = exps.iter().map(|&e| e / sum).collect();
let dot: f64 = output_grad.iter().zip(z.iter()).map(|(&g, &zi)| g * zi).sum();
let dx: Vec<f64> = output_grad
.iter()
.zip(z.iter())
.map(|(&g, &zi)| zi * (g - dot))
.collect();
vec![dx]
}
}
pub struct MatMul {
pub m: usize,
pub n: usize,
pub k: usize,
}
impl MatMul {
fn matmul(a: &[f64], b: &[f64], rows_a: usize, cols_a: usize, cols_b: usize) -> Vec<f64> {
let mut out = vec![0.0_f64; rows_a * cols_b];
for i in 0..rows_a {
for l in 0..cols_a {
let a_il = a[i * cols_a + l];
for j in 0..cols_b {
out[i * cols_b + j] += a_il * b[l * cols_b + j];
}
}
}
out
}
}
impl DiffOp for MatMul {
fn name(&self) -> &str {
"matmul"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
if inputs.len() < 2 {
return Vec::new();
}
Self::matmul(inputs[0], inputs[1], self.m, self.n, self.k)
}
fn vjp(&self, inputs: &[&[f64]], output_grad: &[f64]) -> Vec<Vec<f64>> {
if inputs.len() < 2 {
return vec![Vec::new(), Vec::new()];
}
let a = inputs[0]; let b = inputs[1]; let dz = output_grad;
let m = self.m;
let n = self.n; let k = self.k;
let mut da = vec![0.0_f64; m * n];
for i in 0..m {
for l in 0..n {
let mut s = 0.0;
for j in 0..k {
s += dz[i * k + j] * b[l * k + j]; }
da[i * n + l] = s;
}
}
let mut db = vec![0.0_f64; n * k];
for l in 0..n {
for j in 0..k {
let mut s = 0.0;
for i in 0..m {
s += a[i * n + l] * dz[i * k + j];
}
db[l * k + j] = s;
}
}
vec![da, db]
}
}
pub struct DiffOpRegistry {
ops: HashMap<String, Box<dyn DiffOp>>,
}
impl Default for DiffOpRegistry {
fn default() -> Self {
Self::new()
}
}
impl DiffOpRegistry {
pub fn new() -> Self {
Self {
ops: HashMap::new(),
}
}
pub fn with_standard_ops() -> Self {
let mut reg = Self::new();
reg.register(Add);
reg.register(Sub);
reg.register(Mul);
reg.register(Div);
reg.register(Neg);
reg.register(Exp);
reg.register(Log);
reg.register(Sqrt);
reg.register(Relu);
reg.register(Sigmoid);
reg.register(Tanh);
reg.register(Softmax);
reg.register(Sum);
reg
}
pub fn register(&mut self, op: impl DiffOp + 'static) {
self.ops.insert(op.name().to_string(), Box::new(op));
}
pub fn get(&self, name: &str) -> Option<&dyn DiffOp> {
self.ops.get(name).map(|b| b.as_ref())
}
pub fn contains(&self, name: &str) -> bool {
self.ops.contains_key(name)
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.ops.keys().map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_forward() {
let op = Add;
let x = vec![1.0_f64, 2.0, 3.0];
let y = vec![4.0_f64, 5.0, 6.0];
let z = op.forward(&[&x, &y]);
assert_eq!(z, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_add_vjp() {
let op = Add;
let x = vec![1.0_f64, 2.0];
let y = vec![3.0_f64, 4.0];
let dz = vec![1.0_f64, 2.0];
let grads = op.vjp(&[&x, &y], &dz);
assert_eq!(grads[0], dz);
assert_eq!(grads[1], dz);
}
#[test]
fn test_sub_forward() {
let op = Sub;
let z = op.forward(&[&[5.0_f64, 3.0], &[2.0_f64, 1.0]]);
assert_eq!(z, vec![3.0, 2.0]);
}
#[test]
fn test_sub_vjp() {
let op = Sub;
let dz = vec![1.0_f64, 1.0];
let grads = op.vjp(&[&[0.0_f64], &[0.0_f64]], &dz);
assert_eq!(grads[0], dz);
assert_eq!(grads[1], vec![-1.0, -1.0]);
}
#[test]
fn test_mul_forward() {
let op = Mul;
let z = op.forward(&[&[2.0_f64, 3.0], &[4.0_f64, 5.0]]);
assert_eq!(z, vec![8.0, 15.0]);
}
#[test]
fn test_mul_vjp() {
let op = Mul;
let x = vec![2.0_f64, 3.0];
let y = vec![4.0_f64, 5.0];
let dz = vec![1.0_f64, 1.0];
let grads = op.vjp(&[&x, &y], &dz);
assert_eq!(grads[0], vec![4.0, 5.0]);
assert_eq!(grads[1], vec![2.0, 3.0]);
}
#[test]
fn test_div_forward() {
let op = Div;
let z = op.forward(&[&[6.0_f64, 9.0], &[2.0_f64, 3.0]]);
assert_eq!(z, vec![3.0, 3.0]);
}
#[test]
fn test_div_vjp_numerically() {
let x = vec![6.0_f64];
let y = vec![2.0_f64];
let op = Div;
let dz = vec![1.0_f64];
let grads = op.vjp(&[&x, &y], &dz);
assert!((grads[0][0] - 0.5).abs() < 1e-10);
assert!((grads[1][0] + 1.5).abs() < 1e-10);
}
#[test]
fn test_neg_forward_and_vjp() {
let op = Neg;
let x = vec![1.0_f64, -2.0, 3.0];
let z = op.forward(&[&x]);
assert_eq!(z, vec![-1.0, 2.0, -3.0]);
let grads = op.vjp(&[&x], &[1.0, 1.0, 1.0]);
assert_eq!(grads[0], vec![-1.0, -1.0, -1.0]);
}
#[test]
fn test_exp_forward() {
let op = Exp;
let z = op.forward(&[&[0.0_f64, 1.0]]);
assert!((z[0] - 1.0).abs() < 1e-10);
assert!((z[1] - 1.0_f64.exp()).abs() < 1e-10);
}
#[test]
fn test_exp_vjp() {
let op = Exp;
let x = vec![1.0_f64];
let dz = vec![1.0_f64];
let grads = op.vjp(&[&x], &dz);
assert!((grads[0][0] - 1.0_f64.exp()).abs() < 1e-10);
}
#[test]
fn test_log_forward() {
let op = Log;
let z = op.forward(&[&[1.0_f64, std::f64::consts::E]]);
assert!((z[0] - 0.0).abs() < 1e-10);
assert!((z[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_log_vjp() {
let op = Log;
let x = vec![2.0_f64];
let grads = op.vjp(&[&x], &[1.0]);
assert!((grads[0][0] - 0.5).abs() < 1e-10);
}
#[test]
fn test_sqrt_forward_and_vjp() {
let op = Sqrt;
let x = vec![4.0_f64];
let z = op.forward(&[&x]);
assert!((z[0] - 2.0).abs() < 1e-10);
let grads = op.vjp(&[&x], &[1.0]);
assert!((grads[0][0] - 0.25).abs() < 1e-10);
}
#[test]
fn test_relu_forward() {
let op = Relu;
let x = vec![-1.0_f64, 0.0, 2.0];
let z = op.forward(&[&x]);
assert_eq!(z, vec![0.0, 0.0, 2.0]);
}
#[test]
fn test_relu_vjp() {
let op = Relu;
let x = vec![-1.0_f64, 0.0, 2.0];
let dz = vec![1.0_f64, 1.0, 1.0];
let grads = op.vjp(&[&x], &dz);
assert_eq!(grads[0], vec![0.0, 0.0, 1.0]);
}
#[test]
fn test_sigmoid_forward() {
let op = Sigmoid;
let z = op.forward(&[&[0.0_f64]]);
assert!((z[0] - 0.5).abs() < 1e-10);
}
#[test]
fn test_sigmoid_vjp() {
let op = Sigmoid;
let x = vec![0.0_f64]; let grads = op.vjp(&[&x], &[1.0]);
assert!((grads[0][0] - 0.25).abs() < 1e-10);
}
#[test]
fn test_tanh_forward_and_vjp() {
let op = Tanh;
let x = vec![0.0_f64];
let z = op.forward(&[&x]);
assert!((z[0] - 0.0).abs() < 1e-10);
let grads = op.vjp(&[&x], &[1.0]);
assert!((grads[0][0] - 1.0).abs() < 1e-10);
}
#[test]
fn test_pow_forward_and_vjp() {
let op = Pow { n: 3.0 };
let x = vec![2.0_f64];
let z = op.forward(&[&x]);
assert!((z[0] - 8.0).abs() < 1e-10);
let grads = op.vjp(&[&x], &[1.0]);
assert!((grads[0][0] - 12.0).abs() < 1e-10);
}
#[test]
fn test_sum_forward_and_vjp() {
let op = Sum;
let x = vec![1.0_f64, 2.0, 3.0];
let z = op.forward(&[&x]);
assert!((z[0] - 6.0).abs() < 1e-10);
let grads = op.vjp(&[&x], &[2.0]);
assert_eq!(grads[0], vec![2.0, 2.0, 2.0]);
}
#[test]
fn test_scale_forward_and_vjp() {
let op = Scale { alpha: 3.0 };
let x = vec![1.0_f64, 2.0];
let z = op.forward(&[&x]);
assert_eq!(z, vec![3.0, 6.0]);
let grads = op.vjp(&[&x], &[1.0, 1.0]);
assert_eq!(grads[0], vec![3.0, 3.0]);
}
#[test]
fn test_softmax_forward_sums_to_one() {
let op = Softmax;
let x = vec![1.0_f64, 2.0, 3.0];
let z = op.forward(&[&x]);
let s: f64 = z.iter().sum();
assert!((s - 1.0).abs() < 1e-10, "sum={}", s);
}
#[test]
fn test_softmax_vjp_numerically() {
let op = Softmax;
let x = vec![0.5_f64, 1.0, 0.2];
let dz = vec![1.0_f64, 0.0, 0.0]; let grads = op.vjp(&[&x], &dz);
let eps = 1e-6;
for k in 0..x.len() {
let mut xp = x.clone();
let mut xm = x.clone();
xp[k] += eps;
xm[k] -= eps;
let zp = op.forward(&[&xp]);
let zm = op.forward(&[&xm]);
let fd = (zp[0] - zm[0]) / (2.0 * eps);
assert!(
(grads[0][k] - fd).abs() < 1e-5,
"k={}: vjp={} fd={}",
k,
grads[0][k],
fd
);
}
}
#[test]
fn test_matmul_forward() {
let op = MatMul { m: 2, n: 2, k: 2 };
let a = vec![1.0_f64, 2.0, 3.0, 4.0];
let b = vec![5.0_f64, 6.0, 7.0, 8.0];
let z = op.forward(&[&a, &b]);
assert!((z[0] - 19.0).abs() < 1e-10);
assert!((z[1] - 22.0).abs() < 1e-10);
assert!((z[2] - 43.0).abs() < 1e-10);
assert!((z[3] - 50.0).abs() < 1e-10);
}
#[test]
fn test_matmul_vjp_numerically() {
let op = MatMul { m: 2, n: 3, k: 2 };
let a: Vec<f64> = (0..6).map(|v| v as f64 + 1.0).collect(); let b: Vec<f64> = (0..6).map(|v| v as f64 + 7.0).collect(); let dz = vec![1.0_f64; 4];
let grads = op.vjp(&[&a, &b], &dz);
let eps = 1e-6;
for idx in 0..6 {
let mut ap = a.clone();
let mut am = a.clone();
ap[idx] += eps;
am[idx] -= eps;
let zp = op.forward(&[&ap, &b]);
let zm = op.forward(&[&am, &b]);
let fd: f64 = zp.iter().zip(zm.iter()).map(|(p, m)| (p - m) / (2.0 * eps)).sum();
assert!(
(grads[0][idx] - fd).abs() < 1e-6,
"da[{}]: vjp={} fd={}",
idx,
grads[0][idx],
fd
);
}
}
#[test]
fn test_registry_standard_ops() {
let reg = DiffOpRegistry::with_standard_ops();
for name in &[
"add", "sub", "mul", "div", "neg", "exp", "log", "sqrt", "relu", "sigmoid", "tanh",
"softmax", "sum",
] {
assert!(reg.contains(name), "missing: {}", name);
}
}
#[test]
fn test_registry_custom_op() {
let mut reg = DiffOpRegistry::with_standard_ops();
struct DoubleOp;
impl DiffOp for DoubleOp {
fn name(&self) -> &str {
"double"
}
fn forward(&self, inputs: &[&[f64]]) -> Vec<f64> {
inputs[0].iter().map(|&v| 2.0 * v).collect()
}
fn vjp(&self, _inputs: &[&[f64]], og: &[f64]) -> Vec<Vec<f64>> {
vec![og.iter().map(|&g| 2.0 * g).collect()]
}
}
reg.register(DoubleOp);
assert!(reg.contains("double"));
let op = reg.get("double").expect("registered op should be found");
let z = op.forward(&[&[3.0_f64, 4.0]]);
assert_eq!(z, vec![6.0, 8.0]);
}
#[test]
fn test_registry_nonexistent_returns_none() {
let reg = DiffOpRegistry::with_standard_ops();
assert!(reg.get("nonexistent_op_xyz").is_none());
}
}