1use crate::tensor::*;
2
3#[macro_export]
4macro_rules! reduce_grad {
5 ($grad:expr, $shape:expr) => {{
6 let mut reduced_grad = $grad.clone();
7 for (dim, (grad_size, shape_size)) in $grad.shape().iter().zip($shape.iter()).enumerate() {
8 if shape_size == &1 && grad_size != &1 {
9 let mut sum_dims = vec![false; $grad.shape().len()];
10 sum_dims[dim] = true;
11 reduced_grad = reduced_grad.sum_dim(&sum_dims);
12 }
13 }
14 reduced_grad
15 }};
16}
17
18pub trait GradientFunction: std::fmt::Debug {
19 fn backward(&self);
20 fn prev(&self) -> Vec<&Tensor>;
21}