use crate::gradient_function::{GradientFuncTrait, GradientFunction};
use crate::{NdArray, Reshape, TensorDataType};
use std::cell::RefCell;
use std::rc::Rc;
pub(crate) struct AccumulateGrad<T: TensorDataType> {
gradient: NdArray<'static, T>,
}
impl<T: TensorDataType> GradientFuncTrait<T> for AccumulateGrad<T> {
fn backward(&mut self, grad: &NdArray<T>) {
self.gradient += grad;
}
fn gradient<'a>(&'a self) -> Option<NdArray<'a, T>> {
Some((&self.gradient).view())
}
}
impl<T: TensorDataType> AccumulateGrad<T> {
pub(crate) fn new(shape: Vec<usize>) -> GradientFunction<T> {
Rc::new(RefCell::new(Self {
gradient: NdArray::zeros(shape),
}))
}
}