use crate::gradient_function::{GradientFuncTrait, GradientFunction};
use crate::{Constructors, NdArray, Reshape, StridedMemory, TensorDataType};
use std::cell::RefCell;
use std::hint::assert_unchecked;
use std::rc::Rc;
pub(crate) struct AccumulateGrad<T: TensorDataType> {
gradient: NdArray<'static, T>,
}
impl<T: TensorDataType> GradientFuncTrait<T> for AccumulateGrad<T> {
fn backward(&mut self, grad: &NdArray<T>) {
unsafe { assert_unchecked(self.gradient.shape() == grad.shape()) }
self.gradient += grad;
}
fn gradient(&self) -> Option<NdArray<T>> {
Some((&self.gradient).view())
}
fn zero_gradient(&mut self) {
self.gradient.zero();
}
}
impl<T: TensorDataType> AccumulateGrad<T> {
pub(crate) fn new(shape: Vec<usize>) -> GradientFunction<T> {
Rc::new(RefCell::new(Self {
gradient: NdArray::zeros(shape),
}))
}
}