Skip to main content

scivex_nn/
ops.rs

1//! Differentiable operations on [`Variable`]s.
2//!
3//! Each operation records its inputs and a closure (`grad_fn`) so that
4//! [`Variable::backward`] can compute gradients via reverse-mode autodiff.
5
6use std::ops;
7
8use scivex_core::{Float, Tensor};
9
10use crate::variable::Variable;
11
12// ── Element-wise binary ops ─────────────────────────────────────────
13
14/// Element-wise addition of two variables.
15///
16/// # Examples
17///
18/// ```
19/// # use scivex_core::Tensor;
20/// # use scivex_nn::variable::Variable;
21/// # use scivex_nn::ops::add;
22/// let a = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap(), false);
23/// let b = Variable::new(Tensor::from_vec(vec![3.0_f64, 4.0], vec![2]).unwrap(), false);
24/// let c = add(&a, &b);
25/// assert_eq!(c.data().as_slice(), &[4.0, 6.0]);
26/// ```
27pub fn add<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
28    let data = &a.data() + &b.data();
29    Variable::from_op(
30        data,
31        vec![a.clone(), b.clone()],
32        Box::new(|g: &Tensor<T>| vec![g.clone(), g.clone()]),
33    )
34}
35
36/// Element-wise subtraction.
37///
38/// # Examples
39///
40/// ```
41/// # use scivex_core::Tensor;
42/// # use scivex_nn::variable::Variable;
43/// # use scivex_nn::ops::sub;
44/// let a = Variable::new(Tensor::from_vec(vec![5.0_f64, 3.0], vec![2]).unwrap(), false);
45/// let b = Variable::new(Tensor::from_vec(vec![1.0_f64, 1.0], vec![2]).unwrap(), false);
46/// let c = sub(&a, &b);
47/// assert_eq!(c.data().as_slice(), &[4.0, 2.0]);
48/// ```
49pub fn sub<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
50    let data = &a.data() - &b.data();
51    Variable::from_op(
52        data,
53        vec![a.clone(), b.clone()],
54        Box::new(|g: &Tensor<T>| vec![g.clone(), -g]),
55    )
56}
57
58/// Element-wise multiplication (Hadamard product).
59///
60/// # Examples
61///
62/// ```
63/// # use scivex_core::Tensor;
64/// # use scivex_nn::variable::Variable;
65/// # use scivex_nn::ops::mul;
66/// let a = Variable::new(Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap(), true);
67/// let b = Variable::new(Tensor::from_vec(vec![4.0_f64, 5.0], vec![2]).unwrap(), true);
68/// let c = mul(&a, &b);
69/// assert_eq!(c.data().as_slice(), &[8.0, 15.0]);
70/// ```
71pub fn mul<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
72    let a_data = a.data();
73    let b_data = b.data();
74    let data = &a_data * &b_data;
75    Variable::from_op(
76        data,
77        vec![a.clone(), b.clone()],
78        Box::new(move |g: &Tensor<T>| {
79            let ga = g
80                .zip_map(&b_data, |gi, bi| gi * bi)
81                .expect("shapes match from forward pass");
82            let gb = g
83                .zip_map(&a_data, |gi, ai| gi * ai)
84                .expect("shapes match from forward pass");
85            vec![ga, gb]
86        }),
87    )
88}
89
90/// Negation.
91///
92/// # Examples
93///
94/// ```
95/// # use scivex_core::Tensor;
96/// # use scivex_nn::variable::Variable;
97/// # use scivex_nn::ops::neg;
98/// let a = Variable::new(Tensor::from_vec(vec![1.0_f64, -2.0], vec![2]).unwrap(), false);
99/// let b = neg(&a);
100/// assert_eq!(b.data().as_slice(), &[-1.0, 2.0]);
101/// ```
102pub fn neg<T: Float>(a: &Variable<T>) -> Variable<T> {
103    let data = -&a.data();
104    Variable::from_op(data, vec![a.clone()], Box::new(|g: &Tensor<T>| vec![-g]))
105}
106
107// ── Matrix operations ───────────────────────────────────────────────
108
109/// Matrix multiplication: `a @ b`.
110///
111/// `a` has shape `[m, k]`, `b` has shape `[k, n]`, result is `[m, n]`.
112///
113/// # Examples
114///
115/// ```
116/// # use scivex_core::Tensor;
117/// # use scivex_nn::variable::Variable;
118/// # use scivex_nn::ops::matmul;
119/// let a = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap(), false);
120/// let b = Variable::new(Tensor::from_vec(vec![1.0_f64, 0.0, 0.0, 1.0], vec![2, 2]).unwrap(), false);
121/// let c = matmul(&a, &b); // identity matmul
122/// assert_eq!(c.data().as_slice(), &[1.0, 2.0, 3.0, 4.0]);
123/// ```
124pub fn matmul<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
125    let a_data = a.data();
126    let b_data = b.data();
127    let data = a_data
128        .matmul(&b_data)
129        .expect("matmul shapes validated at call site");
130    Variable::from_op(
131        data,
132        vec![a.clone(), b.clone()],
133        Box::new(move |g: &Tensor<T>| {
134            // grad_a = g @ b^T
135            let bt = b_data.transpose().expect("2-D from forward pass");
136            let ga = g.matmul(&bt).expect("shapes match from forward pass");
137            // grad_b = a^T @ g
138            let at = a_data.transpose().expect("2-D from forward pass");
139            let gb = at.matmul(g).expect("shapes match from forward pass");
140            vec![ga, gb]
141        }),
142    )
143}
144
145// ── Reductions ──────────────────────────────────────────────────────
146
147/// Sum all elements to a scalar variable.
148///
149/// # Examples
150///
151/// ```
152/// # use scivex_core::Tensor;
153/// # use scivex_nn::variable::Variable;
154/// # use scivex_nn::ops::sum;
155/// let a = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap(), false);
156/// let s = sum(&a);
157/// assert_eq!(s.data().as_slice(), &[6.0]);
158/// ```
159pub fn sum<T: Float>(a: &Variable<T>) -> Variable<T> {
160    let s = a.data().sum();
161    let shape = a.shape();
162    let data = Tensor::from_vec(vec![s], vec![1]).expect("scalar tensor");
163    Variable::from_op(
164        data,
165        vec![a.clone()],
166        Box::new(move |g: &Tensor<T>| {
167            // Broadcast scalar grad to input shape.
168            let g_val = g.as_slice()[0];
169            vec![Tensor::full(shape.clone(), g_val)]
170        }),
171    )
172}
173
174/// Mean of all elements to a scalar variable.
175///
176/// # Examples
177///
178/// ```
179/// # use scivex_core::Tensor;
180/// # use scivex_nn::variable::Variable;
181/// # use scivex_nn::ops::mean;
182/// let a = Variable::new(Tensor::from_vec(vec![2.0_f64, 4.0], vec![2]).unwrap(), false);
183/// let m = mean(&a);
184/// assert_eq!(m.data().as_slice(), &[3.0]);
185/// ```
186pub fn mean<T: Float>(a: &Variable<T>) -> Variable<T> {
187    let n = a.data().numel();
188    let m = a.data().mean();
189    let shape = a.shape();
190    let data = Tensor::from_vec(vec![m], vec![1]).expect("scalar tensor");
191    Variable::from_op(
192        data,
193        vec![a.clone()],
194        Box::new(move |g: &Tensor<T>| {
195            let g_val = g.as_slice()[0];
196            let scale = g_val / T::from_usize(n);
197            vec![Tensor::full(shape.clone(), scale)]
198        }),
199    )
200}
201
202/// Element-wise power: `a^exponent`.
203///
204/// # Examples
205///
206/// ```
207/// # use scivex_core::Tensor;
208/// # use scivex_nn::variable::Variable;
209/// # use scivex_nn::ops::pow;
210/// let a = Variable::new(Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap(), false);
211/// let b = pow(&a, 2.0);
212/// assert_eq!(b.data().as_slice(), &[4.0, 9.0]);
213/// ```
214pub fn pow<T: Float>(a: &Variable<T>, exponent: T) -> Variable<T> {
215    let a_data = a.data();
216    let data = a_data.powf(exponent);
217    Variable::from_op(
218        data,
219        vec![a.clone()],
220        Box::new(move |g: &Tensor<T>| {
221            // d/da (a^n) = n * a^(n-1)
222            let n_minus_1 = exponent - T::one();
223            let deriv = a_data.powf(n_minus_1).map(|v| exponent * v);
224            let grad = g
225                .zip_map(&deriv, |gi, di| gi * di)
226                .expect("shapes match from forward pass");
227            vec![grad]
228        }),
229    )
230}
231
232/// Scalar multiplication: `variable * scalar_value`.
233///
234/// # Examples
235///
236/// ```
237/// # use scivex_core::Tensor;
238/// # use scivex_nn::variable::Variable;
239/// # use scivex_nn::ops::scalar_mul;
240/// let a = Variable::new(Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap(), false);
241/// let b = scalar_mul(&a, 5.0);
242/// assert_eq!(b.data().as_slice(), &[10.0, 15.0]);
243/// ```
244pub fn scalar_mul<T: Float>(a: &Variable<T>, scalar: T) -> Variable<T> {
245    let data = &a.data() * scalar;
246    Variable::from_op(
247        data,
248        vec![a.clone()],
249        Box::new(move |g: &Tensor<T>| vec![g.map(|v| v * scalar)]),
250    )
251}
252
253/// Scalar division: `variable / scalar_value`.
254///
255/// # Examples
256///
257/// ```
258/// # use scivex_core::Tensor;
259/// # use scivex_nn::variable::Variable;
260/// # use scivex_nn::ops::scalar_div;
261/// let a = Variable::new(Tensor::from_vec(vec![10.0_f64, 6.0], vec![2]).unwrap(), false);
262/// let b = scalar_div(&a, 2.0);
263/// assert_eq!(b.data().as_slice(), &[5.0, 3.0]);
264/// ```
265pub fn scalar_div<T: Float>(a: &Variable<T>, scalar: T) -> Variable<T> {
266    scalar_mul(a, T::one() / scalar)
267}
268
269// ── Bias-add helper (manual broadcasting) ───────────────────────────
270
271/// Add a 1-D bias `[out]` to a 2-D input `[batch, out]` (row-wise broadcast).
272///
273/// # Examples
274///
275/// ```
276/// # use scivex_core::Tensor;
277/// # use scivex_nn::variable::Variable;
278/// # use scivex_nn::ops::add_bias;
279/// let x = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap(), false);
280/// let b = Variable::new(Tensor::from_vec(vec![0.1_f64, 0.2], vec![2]).unwrap(), false);
281/// let y = add_bias(&x, &b);
282/// assert!((y.data().as_slice()[0] - 1.1).abs() < 1e-10);
283/// ```
284pub fn add_bias<T: Float>(input: &Variable<T>, bias: &Variable<T>) -> Variable<T> {
285    let x = input.data();
286    let b = bias.data();
287    let shape = x.shape().to_vec();
288    let rows = shape[0];
289    let cols = shape[1];
290
291    // Manually broadcast: each row gets bias added.
292    let mut out_data = Vec::with_capacity(rows * cols);
293    let b_slice = b.as_slice();
294    let x_slice = x.as_slice();
295    for r in 0..rows {
296        for c in 0..cols {
297            out_data.push(x_slice[r * cols + c] + b_slice[c]);
298        }
299    }
300    let data =
301        Tensor::from_vec(out_data, shape).expect("output data length matches shape from input");
302
303    let cols_copy = cols;
304    Variable::from_op(
305        data,
306        vec![input.clone(), bias.clone()],
307        Box::new(move |g: &Tensor<T>| {
308            // grad_input = g (same shape)
309            let g_input = g.clone();
310            // grad_bias = sum over rows (reduce axis 0)
311            let g_slice = g.as_slice();
312            let g_rows = g.shape()[0];
313            let mut bias_grad = vec![T::zero(); cols_copy];
314            for r in 0..g_rows {
315                for c in 0..cols_copy {
316                    bias_grad[c] += g_slice[r * cols_copy + c];
317                }
318            }
319            let g_bias = Tensor::from_vec(bias_grad, vec![cols_copy])
320                .expect("bias grad length matches feature count");
321            vec![g_input, g_bias]
322        }),
323    )
324}
325
326// ── Operator overloads ──────────────────────────────────────────────
327
328impl<T: Float> ops::Add for &Variable<T> {
329    type Output = Variable<T>;
330    fn add(self, rhs: Self) -> Variable<T> {
331        add(self, rhs)
332    }
333}
334
335impl<T: Float> ops::Add for Variable<T> {
336    type Output = Variable<T>;
337    fn add(self, rhs: Self) -> Variable<T> {
338        add(&self, &rhs)
339    }
340}
341
342impl<T: Float> ops::Sub for &Variable<T> {
343    type Output = Variable<T>;
344    fn sub(self, rhs: Self) -> Variable<T> {
345        sub(self, rhs)
346    }
347}
348
349impl<T: Float> ops::Sub for Variable<T> {
350    type Output = Variable<T>;
351    fn sub(self, rhs: Self) -> Variable<T> {
352        sub(&self, &rhs)
353    }
354}
355
356impl<T: Float> ops::Mul for &Variable<T> {
357    type Output = Variable<T>;
358    fn mul(self, rhs: Self) -> Variable<T> {
359        mul(self, rhs)
360    }
361}
362
363impl<T: Float> ops::Mul for Variable<T> {
364    type Output = Variable<T>;
365    fn mul(self, rhs: Self) -> Variable<T> {
366        mul(&self, &rhs)
367    }
368}
369
370impl<T: Float> ops::Neg for &Variable<T> {
371    type Output = Variable<T>;
372    fn neg(self) -> Variable<T> {
373        neg(self)
374    }
375}
376
377impl<T: Float> ops::Neg for Variable<T> {
378    type Output = Variable<T>;
379    fn neg(self) -> Variable<T> {
380        neg(&self)
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    fn var(vals: &[f64]) -> Variable<f64> {
389        let t = Tensor::from_vec(vals.to_vec(), vec![vals.len()]).unwrap();
390        Variable::new(t, true)
391    }
392
393    #[test]
394    fn test_add_backward() {
395        let a = var(&[2.0, 3.0]);
396        let b = var(&[4.0, 5.0]);
397        let c = add(&a, &b);
398        let s = sum(&c);
399        s.backward();
400        // dc/da = 1, dc/db = 1, ds/dc = 1 => each grad = [1,1]
401        assert_eq!(a.grad().unwrap().as_slice(), &[1.0, 1.0]);
402        assert_eq!(b.grad().unwrap().as_slice(), &[1.0, 1.0]);
403    }
404
405    #[test]
406    fn test_mul_backward() {
407        let a = var(&[2.0, 3.0]);
408        let b = var(&[4.0, 5.0]);
409        let c = mul(&a, &b);
410        let s = sum(&c);
411        s.backward();
412        // dc/da = b, dc/db = a
413        assert_eq!(a.grad().unwrap().as_slice(), &[4.0, 5.0]);
414        assert_eq!(b.grad().unwrap().as_slice(), &[2.0, 3.0]);
415    }
416
417    #[test]
418    fn test_sub_backward() {
419        let a = var(&[5.0]);
420        let b = var(&[3.0]);
421        let c = sub(&a, &b);
422        let s = sum(&c);
423        s.backward();
424        assert_eq!(a.grad().unwrap().as_slice(), &[1.0]);
425        assert_eq!(b.grad().unwrap().as_slice(), &[-1.0]);
426    }
427
428    #[test]
429    fn test_matmul_backward() {
430        // a: [2,3], b: [3,2]
431        let a = Variable::new(
432            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap(),
433            true,
434        );
435        let b = Variable::new(
436            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap(),
437            true,
438        );
439        let c = matmul(&a, &b);
440        let s = sum(&c);
441        s.backward();
442        // grad_a should be [2,3], grad_b should be [3,2]
443        assert_eq!(a.grad().unwrap().shape(), &[2, 3]);
444        assert_eq!(b.grad().unwrap().shape(), &[3, 2]);
445    }
446
447    #[test]
448    fn test_pow_backward() {
449        let a = var(&[2.0, 3.0]);
450        let c = pow(&a, 2.0);
451        let s = sum(&c);
452        s.backward();
453        // d/da (a^2) = 2a
454        assert_eq!(a.grad().unwrap().as_slice(), &[4.0, 6.0]);
455    }
456
457    #[test]
458    fn test_mean_backward() {
459        let a = var(&[2.0, 4.0]);
460        let m = mean(&a);
461        m.backward();
462        // d/da mean = 1/n = 0.5
463        assert_eq!(a.grad().unwrap().as_slice(), &[0.5, 0.5]);
464    }
465
466    #[test]
467    fn test_neg_backward() {
468        let a = var(&[3.0]);
469        let c = neg(&a);
470        let s = sum(&c);
471        s.backward();
472        assert_eq!(a.grad().unwrap().as_slice(), &[-1.0]);
473    }
474
475    #[test]
476    fn test_operator_overloads() {
477        let a = var(&[1.0, 2.0]);
478        let b = var(&[3.0, 4.0]);
479        let c = &a + &b;
480        let d = &a * &b;
481        let s = sum(&(&c + &d));
482        s.backward();
483        // c = a+b, d = a*b, s = sum(c+d) = sum(a+b+a*b)
484        // ds/da = 1 + b = [4, 5]
485        // ds/db = 1 + a = [2, 3]
486        assert_eq!(a.grad().unwrap().as_slice(), &[4.0, 5.0]);
487        assert_eq!(b.grad().unwrap().as_slice(), &[2.0, 3.0]);
488    }
489
490    #[test]
491    fn test_scalar_mul_backward() {
492        let a = var(&[2.0, 3.0]);
493        let c = scalar_mul(&a, 5.0);
494        let s = sum(&c);
495        s.backward();
496        // d/da (5*a) = 5
497        assert_eq!(a.grad().unwrap().as_slice(), &[5.0, 5.0]);
498    }
499
500    #[test]
501    fn test_scalar_div_backward() {
502        let a = var(&[4.0, 8.0]);
503        let c = scalar_div(&a, 2.0);
504        let s = sum(&c);
505        s.backward();
506        // d/da (a/2) = 0.5
507        assert_eq!(a.grad().unwrap().as_slice(), &[0.5, 0.5]);
508    }
509
510    #[test]
511    fn test_add_bias_forward_and_backward() {
512        // input: [2, 3], bias: [3]
513        let input = Variable::new(
514            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap(),
515            true,
516        );
517        let bias = Variable::new(
518            Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![3]).unwrap(),
519            true,
520        );
521        let y = add_bias(&input, &bias);
522        // Check forward values
523        let y_data = y.data();
524        let y_s = y_data.as_slice();
525        assert!((y_s[0] - 1.1).abs() < 1e-10);
526        assert!((y_s[4] - 5.2).abs() < 1e-10);
527
528        let s = sum(&y);
529        s.backward();
530        // grad_input = ones (same shape as input)
531        let g_input = input.grad().unwrap();
532        assert_eq!(g_input.shape(), &[2, 3]);
533        for &v in g_input.as_slice() {
534            assert!((v - 1.0).abs() < 1e-10);
535        }
536        // grad_bias = sum over rows = [2.0, 2.0, 2.0] (2 rows of ones)
537        let g_bias = bias.grad().unwrap();
538        assert_eq!(g_bias.shape(), &[3]);
539        for &v in g_bias.as_slice() {
540            assert!((v - 2.0).abs() < 1e-10);
541        }
542    }
543
544    #[test]
545    fn test_single_element_sum() {
546        let a = var(&[7.0]);
547        let s = sum(&a);
548        assert_eq!(s.data().as_slice(), &[7.0]);
549        s.backward();
550        assert_eq!(a.grad().unwrap().as_slice(), &[1.0]);
551    }
552
553    #[test]
554    fn test_pow_cubic_backward() {
555        let a = var(&[2.0]);
556        let c = pow(&a, 3.0);
557        let s = sum(&c);
558        s.backward();
559        // d/da (a^3) = 3*a^2 = 3*4 = 12
560        assert!((a.grad().unwrap().as_slice()[0] - 12.0).abs() < 1e-10);
561    }
562}