ferrite/autograd/
grad.rs

1use crate::tensor::*;
2
3#[macro_export]
4macro_rules! reduce_grad {
5  ($grad:expr, $shape:expr) => {{
6    let mut reduced_grad = $grad.clone();
7    for (dim, (grad_size, shape_size)) in $grad.shape().iter().zip($shape.iter()).enumerate() {
8      if shape_size == &1 && grad_size != &1 {
9        let mut sum_dims = vec![false; $grad.shape().len()];
10        sum_dims[dim] = true;
11        reduced_grad = reduced_grad.sum_dim(&sum_dims);
12      }
13    }
14    reduced_grad
15  }};
16}
17
18pub trait GradientFunction: std::fmt::Debug {
19  fn backward(&self);
20  fn prev(&self) -> Vec<&Tensor>;
21}