use super::{
Addition, AdditionBackward, AdditionBackwardUnary, Backward, Cat, Chunk, ChunkBackward,
Concatenate, ConcatenateBackward, ConcatenateBackwardLeft, Data, DifferentiableVariable,
Division, DivisionBackward, DivisionBackwardLeft, DivisionBackwardRight, Dropout,
DropoutBackward, Exp, ExpBackward, Forward, Gradient, GradientOverwrite, Input, LeakyReLU,
LeakyReLUBackward, LogSoftmax, LogSoftmaxBackward, Logn, LognBackward, MatMatMul, MatMatMulT,
MatVecMul, MatrixMatrixMul, MatrixMatrixMulBackward, MatrixMatrixMulBackwardLeft,
MatrixMatrixMulT, MatrixMatrixMulTBackward, MatrixMatrixMulTBackwardLeft, MatrixVectorMul,
MatrixVectorMulBackward, MatrixVectorMulBackwardLeft, Mean, MeanBackward, MultiConcatenate,
MultiConcatenateBackward, MultiStack, MultiStackBackward, Multiplication,
MultiplicationBackward, MultiplicationBackwardUnary, Negation, NegationBackward, Overwrite,
Param, Power, PowerBackward, RawParam, ReLU, ReLUBackward, Sigmoid, SigmoidBackward, SoftPlus,
SoftPlusBackward, Softmax, SoftmaxBackward, Sqrt, SqrtBackward, Stack, StackBackward,
StackBackwardLeft, Subtraction, SubtractionBackward, SubtractionBackwardLeft,
SubtractionBackwardRight, Sum, SumBackward, TanH, TanHBackward, Tensor, Transpose,
TransposeBackward, Unsqueeze, UnsqueezeBackward, Var, VarDiffHistory, Variable, VecMatMul,
VecVecMul, VectorMatrixMul, VectorMatrixMulBackward, VectorMatrixMulBackwardLeft,
VectorVectorMul, VectorVectorMulBackward, VectorVectorMulBackwardUnary, OPERATIONS_COUNTER,
};
use crate::nn::Register;
use ndarray::{DimMax, Dimension, IntoDimension, Ix1, Ix2, RemoveAxis};
#[cfg(feature = "serialize")]
use serde::{
de::{Deserialize, Deserializer},
ser::{Serialize, Serializer},
};
use std::{
cell::{Cell, Ref, RefMut},
fmt::{Debug, Display},
ops::{Add, Div, Mul, Neg, Sub},
rc::Rc,
};
pub struct VarDiff<T, U>
where
T: Data + 'static,
U: Gradient + Overwrite + 'static,
{
pub(crate) var: Var<T>,
pub(crate) node: Rc<U>,
pub(crate) past: VarDiffHistory,
}
impl<T, U> Clone for VarDiff<T, U>
where
T: Data + 'static,
U: Gradient + Overwrite + 'static,
{
fn clone(&self) -> Self {
Self {
var: self.var.clone(),
node: self.node.clone(),
past: self.past.clone(),
}
}
}
impl<T, U> VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient + Overwrite + Backward + 'static,
{
pub(crate) fn from(node: U, mut past: VarDiffHistory, var: Var<T>) -> VarDiff<T, U> {
let node = Rc::new(node);
past.append_backward(unsafe { OPERATIONS_COUNTER.next() }, node.clone());
VarDiff { var, node, past }
}
}
impl<T, U> VarDiff<T, U>
where
T: Data + 'static,
U: Gradient + Overwrite + 'static,
{
pub fn data(&self) -> Ref<Tensor<T::Dim>> {
self.var.node.data()
}
pub fn data_mut(&self) -> RefMut<Tensor<T::Dim>> {
self.var.node.data_mut()
}
pub fn grad(&self) -> Ref<Tensor<U::Dim>> {
self.node.gradient()
}
pub fn grad_mut(&self) -> RefMut<Tensor<U::Dim>> {
self.node.gradient_mut()
}
}
impl<T, U> VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
{
pub fn forward(&self) {
self.var.forward();
debug_assert!(self.past.buffer().is_empty() || self.past.len() == self.past.buffer().len());
self.past.prepare_buffer();
let buffer = self.past.buffer();
let mut res = buffer.binary_search_by(|n| {
if n.can_overwrite() {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
});
if let Err(i) = res {
if buffer.get(i).is_some() {
res = Ok(i);
}
};
if let Ok(pos) = res {
for node in &buffer[pos..] {
node.set_overwrite(true);
}
}
}
}
impl<T, U> VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + Backward + 'static,
{
pub fn backward(&self, seed: f32) {
debug_assert!(!self.past.is_empty());
self.node.gradient_mut().fill(seed);
self.past.prepare_buffer();
let buffer = self.past.buffer();
for node in buffer.iter().rev() {
node.backward();
}
debug_assert_eq!(self.var.past.len(), self.var.past.buffer().len());
self.var.past.prepare_buffer();
let buffer = self.var.past.buffer();
let mut res = buffer.binary_search_by(|n| {
if n.was_computed() {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
});
if let Err(i) = res {
if buffer.get(i).is_some() {
res = Ok(i);
}
};
if let Ok(pos) = res {
for node in &buffer[pos..] {
node.reset_computation();
}
}
}
pub fn no_grad(&self) {
self.past.prepare_buffer();
for node in self.past.buffer.borrow().iter() {
node.no_grad();
}
}
pub fn with_grad(&self) {
self.past.prepare_buffer();
for node in self.past.buffer.borrow().iter() {
node.with_grad();
}
}
pub fn train(&self) {
self.var.train();
}
pub fn eval(&self) {
self.var.eval();
}
}
impl<T, U> VarDiff<T, U>
where
T: Data<Dim = Ix1> + 'static,
U: Gradient<Dim = Ix1> + Overwrite + 'static,
{
pub fn vm<Rhs>(self, rhs: Rhs) -> <Self as VecMatMul<Rhs>>::Output
where
Self: VecMatMul<Rhs>,
{
VecMatMul::vm(self, rhs)
}
pub fn vv<Rhs>(self, rhs: Rhs) -> <Self as VecVecMul<Rhs>>::Output
where
Self: VecVecMul<Rhs>,
{
VecVecMul::vv(self, rhs)
}
}
impl<T, U> VarDiff<T, U>
where
T: Data<Dim = Ix2> + 'static,
U: Gradient<Dim = Ix2> + Overwrite + 'static,
{
pub fn mm<Rhs>(self, rhs: Rhs) -> <Self as MatMatMul<Rhs>>::Output
where
Self: MatMatMul<Rhs>,
{
MatMatMul::mm(self, rhs)
}
pub fn mm_t<Rhs>(self, rhs: Rhs) -> <Self as MatMatMulT<Rhs>>::Output
where
Self: MatMatMulT<Rhs>,
{
MatMatMulT::mm_t(self, rhs)
}
pub fn mv<Rhs>(self, rhs: Rhs) -> <Self as MatVecMul<Rhs>>::Output
where
Self: MatVecMul<Rhs>,
{
MatVecMul::mv(self, rhs)
}
}
impl<T, U> VarDiff<T, U>
where
T: Data + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
{
pub fn parameters(&self) -> Vec<Param<'_>> {
self.past
.parameters
.iter()
.cloned()
.map(RawParam::into_param)
.collect()
}
pub fn sum(self) -> VarDiff<Sum<T>, SumBackward<U>> {
let node = SumBackward::new(self.node);
VarDiff::from(node, self.past, self.var.sum())
}
pub fn mean(self) -> VarDiff<Mean<T>, MeanBackward<U>> {
let node = MeanBackward::new(self.node);
VarDiff::from(node, self.past, self.var.mean())
}
pub fn pow(self, exp: i32) -> VarDiff<Power<T>, PowerBackward<U, T>> {
let node = PowerBackward::new(self.node, self.var.node.clone(), exp);
VarDiff::from(node, self.past, self.var.pow(exp))
}
pub fn sqrt(self) -> VarDiff<Sqrt<T>, SqrtBackward<U, Sqrt<T>>> {
let var = self.var.sqrt();
let node = SqrtBackward::new(self.node, var.node.clone());
VarDiff::from(node, self.past, var)
}
pub fn relu(self) -> VarDiff<ReLU<T>, ReLUBackward<U, T>> {
let node = ReLUBackward::new(self.node, self.var.node.clone());
VarDiff::from(node, self.past, self.var.relu())
}
pub fn leaky_relu(self) -> VarDiff<LeakyReLU<T>, LeakyReLUBackward<U, T>> {
let node = LeakyReLUBackward::new(self.node, self.var.node.clone());
VarDiff::from(node, self.past, self.var.leaky_relu())
}
pub fn softplus(self) -> VarDiff<SoftPlus<T>, SoftPlusBackward<U, T>> {
let node = SoftPlusBackward::new(self.node, self.var.node.clone());
VarDiff::from(node, self.past, self.var.softplus())
}
pub fn sigmoid(self) -> VarDiff<Sigmoid<T>, SigmoidBackward<U, Sigmoid<T>>> {
let var = self.var.sigmoid();
let node = SigmoidBackward::new(self.node, var.node.clone());
VarDiff::from(node, self.past, var)
}
pub fn tanh(self) -> VarDiff<TanH<T>, TanHBackward<U, TanH<T>>> {
let var = self.var.tanh();
let node = TanHBackward::new(self.node, var.node.clone());
VarDiff::from(node, self.past, var)
}
pub fn ln(self) -> VarDiff<Logn<T>, LognBackward<U, T>> {
let node = LognBackward::new(self.node, self.var.node.clone());
VarDiff::from(node, self.past, self.var.ln())
}
pub fn exp(self) -> VarDiff<Exp<T>, ExpBackward<U, Exp<T>>> {
let var = self.var.exp();
let node = ExpBackward::new(self.node, var.node.clone());
VarDiff::from(node, self.past, var)
}
pub fn softmax(self, axis: usize) -> VarDiff<Softmax<T>, SoftmaxBackward<U, Softmax<T>>> {
let var = self.var.softmax(axis);
let node = SoftmaxBackward::new(self.node, var.node.clone(), axis);
VarDiff::from(node, self.past, var)
}
pub fn log_softmax(
self,
axis: usize,
) -> VarDiff<LogSoftmax<T>, LogSoftmaxBackward<U, LogSoftmax<T>>> {
let var = self.var.log_softmax(axis);
let node = LogSoftmaxBackward::new(self.node, var.node.clone(), axis);
VarDiff::from(node, self.past, var)
}
pub fn t(self) -> VarDiff<Transpose<T>, TransposeBackward<U>> {
let node = TransposeBackward::new(self.node);
VarDiff::from(node, self.past, self.var.t())
}
pub fn dropout(self, p: f64) -> VarDiff<Dropout<T>, DropoutBackward<U, T>> {
self.dropout_with_status(p, Rc::new(Cell::new(true)))
}
pub(crate) fn dropout_with_status(
self,
p: f64,
status: Rc<Cell<bool>>,
) -> VarDiff<Dropout<T>, DropoutBackward<U, T>> {
let var = self.var.dropout_with_status(p, status);
let node = DropoutBackward::new(self.node, var.node.clone(), p, var.node.status());
VarDiff::from(node, self.past, var)
}
pub fn chunks<E>(self, chunk_size: E) -> Vec<VarDiff<Chunk<T>, ChunkBackward<U>>>
where
E: IntoDimension<Dim = T::Dim>,
{
self.var
.node
.data()
.exact_chunks(chunk_size)
.into_iter()
.enumerate()
.map(|(i, chunk)| {
let var = Var::from(
Chunk::new(self.var.node.clone(), chunk.to_owned(), i),
self.var.past.clone(),
);
VarDiff::from(
ChunkBackward::new(self.node.clone(), chunk.map(|_| 0.), i),
self.past.clone(),
var,
)
})
.collect()
}
pub fn unsqueeze(self, axis: usize) -> VarDiff<Unsqueeze<T>, UnsqueezeBackward<U>> {
VarDiff::from(
UnsqueezeBackward::new(self.node, axis),
self.past,
self.var.unsqueeze(axis),
)
}
}
impl<T, U> VarDiff<T, U>
where
T: Data + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
T::Dim: RemoveAxis,
{
pub fn cat(
mut self,
variables: &[Box<dyn DifferentiableVariable<T::Dim>>],
axis: usize,
) -> VarDiff<MultiConcatenate<T::Dim>, MultiConcatenateBackward<T::Dim>> {
let vars: Vec<Box<dyn Variable<T::Dim>>> =
variables.iter().map(|el| el.get_var()).collect();
let var = self.var.cat(&vars, axis);
let shape = var.data().raw_dim();
let mut operands: Vec<Rc<dyn GradientOverwrite<T::Dim>>> =
Vec::with_capacity(variables.len() + 1);
operands.push(self.node);
for variable in variables {
self.past.merge(variable.get_past());
operands.push(variable.get_node());
}
VarDiff::from(
MultiConcatenateBackward::new(operands, axis, shape),
self.past,
var,
)
}
pub fn stack(
mut self,
variables: &[Box<dyn DifferentiableVariable<T::Dim>>],
axis: usize,
) -> VarDiff<MultiStack<T::Dim>, MultiStackBackward<T::Dim>> {
let vars: Vec<Box<dyn Variable<T::Dim>>> =
variables.iter().map(|el| el.get_var()).collect();
let var = self.var.stack(&vars, axis);
let shape = var.data().raw_dim();
let mut operands: Vec<Rc<dyn GradientOverwrite<T::Dim>>> =
Vec::with_capacity(variables.len() + 1);
operands.push(self.node);
for variable in variables {
self.past.merge(variable.get_past());
operands.push(variable.get_node());
}
VarDiff::from(
MultiStackBackward::new(operands, axis, shape),
self.past,
var,
)
}
}
impl<T, U> Add<f32> for VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
T::Dim: DimMax<Ix1>,
{
type Output = VarDiff<Addition<T, Input<Ix1>>, AdditionBackwardUnary<U, Input<Ix1>>>;
fn add(self, rhs: f32) -> Self::Output {
self + crate::full(1, rhs)
}
}
impl<T, U> Sub<f32> for VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
T::Dim: DimMax<Ix1>,
{
type Output = VarDiff<Subtraction<T, Input<Ix1>>, SubtractionBackwardLeft<U, Input<Ix1>>>;
fn sub(self, rhs: f32) -> Self::Output {
self - crate::full(1, rhs)
}
}
impl<T, U> Mul<f32> for VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
T::Dim: DimMax<Ix1>,
{
type Output =
VarDiff<Multiplication<T, Input<Ix1>>, MultiplicationBackwardUnary<U, Input<Ix1>>>;
fn mul(self, rhs: f32) -> Self::Output {
self * crate::full(1, rhs)
}
}
impl<T, U> Div<f32> for VarDiff<T, U>
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
T::Dim: DimMax<Ix1>,
{
type Output = VarDiff<Division<T, Input<Ix1>>, DivisionBackwardLeft<U, Input<Ix1>>>;
fn div(self, rhs: f32) -> Self::Output {
self / crate::full(1, rhs)
}
}
impl<T, U> Add<VarDiff<T, U>> for f32
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
Ix1: DimMax<T::Dim>,
T::Dim: DimMax<Ix1>,
{
type Output = VarDiff<Addition<Input<Ix1>, T>, AdditionBackwardUnary<U, Input<Ix1>>>;
fn add(self, rhs: VarDiff<T, U>) -> Self::Output {
crate::full(1, self) + rhs
}
}
impl<T, U> Sub<VarDiff<T, U>> for f32
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
Ix1: DimMax<T::Dim>,
T::Dim: DimMax<Ix1>,
{
type Output = VarDiff<Subtraction<Input<Ix1>, T>, SubtractionBackwardRight<U, Input<Ix1>>>;
fn sub(self, rhs: VarDiff<T, U>) -> Self::Output {
crate::full(1, self) - rhs
}
}
impl<T, U> Mul<VarDiff<T, U>> for f32
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
Ix1: DimMax<T::Dim>,
T::Dim: DimMax<Ix1>,
{
type Output =
VarDiff<Multiplication<Input<Ix1>, T>, MultiplicationBackwardUnary<U, Input<Ix1>>>;
fn mul(self, rhs: VarDiff<T, U>) -> Self::Output {
crate::full(1, self) * rhs
}
}
impl<T, U> Div<VarDiff<T, U>> for f32
where
T: Data + Forward + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
Ix1: DimMax<T::Dim>,
T::Dim: DimMax<Ix1>,
{
type Output = VarDiff<Division<Input<Ix1>, T>, DivisionBackwardRight<Input<Ix1>, T, U>>;
fn div(self, rhs: VarDiff<T, U>) -> Self::Output {
crate::full(1, self) / rhs
}
}
impl<T, U> Neg for VarDiff<T, U>
where
T: Data + 'static,
U: Gradient<Dim = T::Dim> + Overwrite + 'static,
{
type Output = VarDiff<Negation<T>, NegationBackward<U>>;
fn neg(self) -> Self::Output {
VarDiff::from(NegationBackward::new(self.node), self.past, self.var.neg())
}
}
impl<F1, B1, F2> Add<Var<F2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<F2::Dim>,
{
type Output = VarDiff<Addition<F1, F2>, AdditionBackwardUnary<B1, F2>>;
fn add(self, rhs: Var<F2>) -> Self::Output {
let node = AdditionBackwardUnary::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.add(rhs))
}
}
impl<F1, B1, F2, B2> Add<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
B2: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<B2::Dim>,
{
type Output = VarDiff<Addition<F1, F2>, AdditionBackward<B1, B2>>;
fn add(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = AdditionBackward::new(self.node, rhs.node);
VarDiff::from(node, self.past, self.var.add(rhs.var))
}
}
impl<F1, B1, F2> Sub<Var<F2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<F2::Dim>,
{
type Output = VarDiff<Subtraction<F1, F2>, SubtractionBackwardLeft<B1, F2>>;
fn sub(self, rhs: Var<F2>) -> Self::Output {
let node = SubtractionBackwardLeft::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.sub(rhs))
}
}
impl<F1, B1, F2, B2> Sub<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
B2: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<B2::Dim>,
{
type Output = VarDiff<Subtraction<F1, F2>, SubtractionBackward<B1, B2>>;
fn sub(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = SubtractionBackward::new(self.node, rhs.node);
VarDiff::from(node, self.past, self.var.sub(rhs.var))
}
}
impl<F1, B1, F2> Mul<Var<F2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<F2::Dim>,
{
type Output = VarDiff<Multiplication<F1, F2>, MultiplicationBackwardUnary<B1, F2>>;
fn mul(self, rhs: Var<F2>) -> Self::Output {
let node = MultiplicationBackwardUnary::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.mul(rhs))
}
}
impl<F1, B1, F2, B2> Mul<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
B2: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<B2::Dim>,
{
type Output = VarDiff<Multiplication<F1, F2>, MultiplicationBackward<F1, B1, F2, B2>>;
fn mul(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = MultiplicationBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.mul(rhs.var))
}
}
impl<F1, B1, F2> Div<Var<F2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<F2::Dim>,
{
type Output = VarDiff<Division<F1, F2>, DivisionBackwardLeft<B1, F2>>;
fn div(self, rhs: Var<F2>) -> Self::Output {
let node = DivisionBackwardLeft::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.div(rhs))
}
}
impl<F1, B1, F2, B2> Div<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
F2: Data + 'static,
B1: Gradient + Overwrite + 'static,
B2: Gradient + Overwrite + 'static,
F1::Dim: Dimension + DimMax<F2::Dim>,
B1::Dim: Dimension + DimMax<B2::Dim>,
{
type Output = VarDiff<Division<F1, F2>, DivisionBackward<F1, B1, F2, B2>>;
fn div(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = DivisionBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.div(rhs.var))
}
}
impl<F1, B1, F2> MatMatMul<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix2> + 'static,
B1: Gradient<Dim = Ix2> + Overwrite + 'static,
F2: Data<Dim = Ix2> + 'static,
{
type Output = VarDiff<MatrixMatrixMul<F1, F2>, MatrixMatrixMulBackwardLeft<B1, F2>>;
fn mm(self, rhs: Var<F2>) -> Self::Output {
let node = MatrixMatrixMulBackwardLeft::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.mm(rhs))
}
}
impl<F1, B1, F2, B2> MatMatMul<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix2> + 'static,
B1: Gradient<Dim = Ix2> + Overwrite + 'static,
F2: Data<Dim = Ix2> + 'static,
B2: Gradient<Dim = Ix2> + Overwrite + 'static,
{
type Output = VarDiff<MatrixMatrixMul<F1, F2>, MatrixMatrixMulBackward<F1, B1, F2, B2>>;
fn mm(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = MatrixMatrixMulBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.mm(rhs.var))
}
}
impl<F1, B1, F2> MatMatMulT<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix2> + 'static,
B1: Gradient<Dim = Ix2> + Overwrite + 'static,
F2: Data<Dim = Ix2> + 'static,
{
type Output = VarDiff<MatrixMatrixMulT<F1, F2>, MatrixMatrixMulTBackwardLeft<B1, F2>>;
fn mm_t(self, rhs: Var<F2>) -> Self::Output {
let node = MatrixMatrixMulTBackwardLeft::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.mm_t(rhs))
}
}
impl<F1, B1, F2, B2> MatMatMulT<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix2> + 'static,
B1: Gradient<Dim = Ix2> + Overwrite + 'static,
F2: Data<Dim = Ix2> + 'static,
B2: Gradient<Dim = Ix2> + Overwrite + 'static,
{
type Output = VarDiff<MatrixMatrixMulT<F1, F2>, MatrixMatrixMulTBackward<F1, B1, F2, B2>>;
fn mm_t(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = MatrixMatrixMulTBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.mm_t(rhs.var))
}
}
impl<F1, B1, F2> MatVecMul<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix2> + 'static,
B1: Gradient<Dim = Ix2> + Overwrite + 'static,
F2: Data<Dim = Ix1> + 'static,
{
type Output = VarDiff<MatrixVectorMul<F1, F2>, MatrixVectorMulBackwardLeft<B1, F2>>;
fn mv(self, rhs: Var<F2>) -> Self::Output {
let node = MatrixVectorMulBackwardLeft::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.mv(rhs))
}
}
impl<F1, B1, F2, B2> MatVecMul<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix2> + 'static,
B1: Gradient<Dim = Ix2> + Overwrite + 'static,
F2: Data<Dim = Ix1> + 'static,
B2: Gradient<Dim = Ix1> + Overwrite + 'static,
{
type Output = VarDiff<MatrixVectorMul<F1, F2>, MatrixVectorMulBackward<F1, B1, F2, B2>>;
fn mv(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = MatrixVectorMulBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.mv(rhs.var))
}
}
impl<F1, B1, F2> VecMatMul<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix1> + 'static,
B1: Gradient<Dim = Ix1> + Overwrite + 'static,
F2: Data<Dim = Ix2> + 'static,
{
type Output = VarDiff<VectorMatrixMul<F1, F2>, VectorMatrixMulBackwardLeft<B1, F2>>;
fn vm(self, rhs: Var<F2>) -> Self::Output {
let node = VectorMatrixMulBackwardLeft::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.vm(rhs))
}
}
impl<F1, B1, F2, B2> VecMatMul<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix1> + 'static,
B1: Gradient<Dim = Ix1> + Overwrite + 'static,
F2: Data<Dim = Ix2> + 'static,
B2: Gradient<Dim = Ix2> + Overwrite + 'static,
{
type Output = VarDiff<VectorMatrixMul<F1, F2>, VectorMatrixMulBackward<F1, B1, F2, B2>>;
fn vm(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = VectorMatrixMulBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.vm(rhs.var))
}
}
impl<F1, B1, F2> VecVecMul<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix1> + 'static,
B1: Gradient<Dim = Ix1> + Overwrite + 'static,
F2: Data<Dim = Ix1> + 'static,
{
type Output = VarDiff<VectorVectorMul<F1, F2>, VectorVectorMulBackwardUnary<B1, F2>>;
fn vv(self, rhs: Var<F2>) -> Self::Output {
let node = VectorVectorMulBackwardUnary::new(self.node, rhs.node.clone());
VarDiff::from(node, self.past, self.var.vv(rhs))
}
}
impl<F1, B1, F2, B2> VecVecMul<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data<Dim = Ix1> + 'static,
B1: Gradient<Dim = Ix1> + Overwrite + 'static,
F2: Data<Dim = Ix1> + 'static,
B2: Gradient<Dim = Ix1> + Overwrite + 'static,
{
type Output = VarDiff<VectorVectorMul<F1, F2>, VectorVectorMulBackward<F1, B1, F2, B2>>;
fn vv(mut self, rhs: VarDiff<F2, B2>) -> Self::Output {
self.past.merge(rhs.past);
let node = VectorVectorMulBackward::new(
self.var.node.clone(),
self.node,
rhs.var.node.clone(),
rhs.node,
);
VarDiff::from(node, self.past, self.var.vv(rhs.var))
}
}
impl<F1, B1, F2> Cat<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = B1::Dim> + 'static,
F2: Data<Dim = F1::Dim> + 'static,
B1: Gradient + Overwrite + 'static,
F1::Dim: RemoveAxis,
B1::Dim: RemoveAxis,
{
type Output = VarDiff<Concatenate<F1, F2>, ConcatenateBackwardLeft<B1>>;
fn cat(self, rhs: Var<F2>, axis: usize) -> Self::Output {
let node = ConcatenateBackwardLeft::new(self.node, rhs.node.clone(), axis);
VarDiff::from(node, self.past, Cat::cat(self.var, rhs, axis))
}
}
impl<F1, B1, F2, B2> Cat<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
B1: Gradient + Overwrite + 'static,
F2: Data<Dim = F1::Dim> + 'static,
B2: Gradient<Dim = B1::Dim> + Overwrite + 'static,
F1::Dim: RemoveAxis,
B1::Dim: RemoveAxis,
{
type Output = VarDiff<Concatenate<F1, F2>, ConcatenateBackward<B1, B2>>;
fn cat(mut self, rhs: VarDiff<F2, B2>, axis: usize) -> Self::Output {
self.past.merge(rhs.past);
let node = ConcatenateBackward::new(self.node, rhs.node, axis);
VarDiff::from(node, self.past, Cat::cat(self.var, rhs.var, axis))
}
}
impl<F1, B1, F2> Stack<Var<F2>> for VarDiff<F1, B1>
where
F1: Data<Dim = B1::Dim> + 'static,
F2: Data<Dim = F1::Dim> + 'static,
B1: Gradient + Overwrite + 'static,
F1::Dim: RemoveAxis,
B1::Dim: RemoveAxis,
{
type Output = VarDiff<super::node::Stack<F1, F2>, StackBackwardLeft<B1>>;
fn stack(self, rhs: Var<F2>, axis: usize) -> Self::Output {
let node = StackBackwardLeft::new(self.node, rhs.node.clone(), axis);
VarDiff::from(node, self.past, Stack::stack(self.var, rhs, axis))
}
}
impl<F1, B1, F2, B2> Stack<VarDiff<F2, B2>> for VarDiff<F1, B1>
where
F1: Data + 'static,
B1: Gradient + Overwrite + 'static,
F2: Data<Dim = F1::Dim> + 'static,
B2: Gradient<Dim = B1::Dim> + Overwrite + 'static,
F1::Dim: RemoveAxis,
B1::Dim: RemoveAxis,
{
type Output = VarDiff<super::node::Stack<F1, F2>, StackBackward<B1, B2>>;
fn stack(mut self, rhs: VarDiff<F2, B2>, axis: usize) -> Self::Output {
self.past.merge(rhs.past);
let node = StackBackward::new(self.node, rhs.node, axis);
VarDiff::from(node, self.past, Stack::stack(self.var, rhs.var, axis))
}
}
impl<T, U> Register for VarDiff<T, U>
where
T: Data + 'static,
U: Gradient + Overwrite + 'static,
{
fn register_params(&self, params: &mut Vec<RawParam>) {
params.extend(self.past.parameters.iter().cloned())
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
impl<T, U> Debug for VarDiff<T, U>
where
T: Data + Debug,
U: Gradient<Dim = T::Dim> + Overwrite + Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VarDiff")
.field("var", &self.var)
.field("node", &self.node)
.field("past", &self.past.len())
.field("parameters", &self.parameters().len())
.finish()
}
}
impl<T: Data + Display, U: Gradient + Overwrite + Display> Display for VarDiff<T, U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.var)
}
}
#[cfg(feature = "serialize")]
impl<D> Serialize for VarDiff<Input<D>, super::InputBackward<D>>
where
D: Dimension + Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.data().serialize(serializer)
}
}
#[cfg(feature = "serialize")]
impl<'d, D> Deserialize<'d> for VarDiff<Input<D>, super::InputBackward<D>>
where
D: Dimension + Deserialize<'d>,
{
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: Deserializer<'d>,
{
let data = ndarray::Array::<f32, D>::deserialize(deserializer).unwrap();
Ok(Input::new(data).requires_grad())
}
}