ferrite-dl 0.2.0

Deep learning library written in pure Rust
Documentation
use crate::tensor::*;

#[macro_export]
macro_rules! reduce_grad {
  ($grad:expr, $shape:expr) => {{
    let mut reduced_grad = $grad.clone();
    for (dim, (grad_size, shape_size)) in $grad.shape().iter().zip($shape.iter()).enumerate() {
      if shape_size == &1 && grad_size != &1 {
        let mut sum_dims = vec![false; $grad.shape().len()];
        sum_dims[dim] = true;
        reduced_grad = reduced_grad.sum_dim(&sum_dims);
      }
    }
    reduced_grad
  }};
}

pub trait GradientFunction: std::fmt::Debug {
  fn backward(&self);
  fn prev(&self) -> Vec<&Tensor>;
}