use ndarray::{
linalg::{general_mat_mul, general_mat_vec_mul},
Array, ArrayBase, ArrayD, ArrayView, Axis, DimMax, Dimension, IntoNdProducer, Ix1, Ix2, Zip,
};
use std::{
cell::{Ref, RefCell, RefMut},
rc::Rc,
};
pub(crate) use binary::*;
pub use binary::{
Constant, Convolve, ConvolveWithGroups, PaddingMode, Reflective, Replicative, Zero,
};
pub use input::{Input, InputBackward};
pub(crate) use nary::*;
pub(crate) use unary::*;
mod binary;
mod input;
mod nary;
mod unary;
pub(crate) type Broadcasted<Lhs, Rhs> = <Lhs as DimMax<Rhs>>::Output;
pub(crate) type BroadTensor<Lhs, Rhs> = Tensor<Broadcasted<Lhs, Rhs>>;
pub(crate) type DynTensor = ArrayD<f32>;
pub(crate) type Tensor<D> = Array<f32, D>;
pub trait Data {
type Dim: Dimension;
fn data(&self) -> Ref<Tensor<Self::Dim>>;
fn data_mut(&self) -> RefMut<Tensor<Self::Dim>>;
}
pub trait Forward {
fn forward(&self);
fn was_computed(&self) -> bool;
fn reset_computation(&self);
}
pub trait Gradient {
type Dim: Dimension;
fn gradient(&self) -> Ref<Tensor<Self::Dim>>;
fn gradient_mut(&self) -> RefMut<Tensor<Self::Dim>>;
}
pub trait Overwrite {
fn can_overwrite(&self) -> bool;
fn set_overwrite(&self, state: bool);
}
pub trait GradientOverwrite<D>: Gradient<Dim = D> + Overwrite {}
impl<T> Overwrite for Rc<T>
where
T: Overwrite,
{
fn can_overwrite(&self) -> bool {
self.as_ref().can_overwrite()
}
fn set_overwrite(&self, state: bool) {
self.as_ref().set_overwrite(state)
}
}
impl<T, D> Gradient for Rc<T>
where
T: Gradient<Dim = D>,
D: Dimension,
{
type Dim = D;
fn gradient(&self) -> Ref<Tensor<Self::Dim>> {
self.as_ref().gradient()
}
fn gradient_mut(&self) -> RefMut<Tensor<Self::Dim>> {
self.as_ref().gradient_mut()
}
}
impl<D: Dimension, T> GradientOverwrite<D> for T where T: Gradient<Dim = D> + Overwrite {}
pub trait Backward: Overwrite {
fn backward(&self);
fn no_grad(&self);
fn with_grad(&self);
}
pub trait Eval {
fn train(&self);
fn eval(&self);
}
trait DotDim<Rhs>
where
Self: Dimension,
Rhs: Dimension,
{
type Output: Dimension;
fn shape(lhs: Self, rhs: Rhs) -> <Self as DotDim<Rhs>>::Output;
}
impl DotDim<Ix2> for Ix1 {
type Output = Ix1;
fn shape(_: Self, rhs: Ix2) -> <Self as DotDim<Ix2>>::Output {
let mut res_shape = Ix1::zeros(1);
res_shape[0] = rhs.last_elem();
res_shape
}
}
impl DotDim<Ix1> for Ix2 {
type Output = Ix1;
fn shape(lhs: Self, _: Ix1) -> <Self as DotDim<Ix1>>::Output {
let mut res_shape = Ix1::zeros(1);
res_shape[0] = lhs[0];
res_shape
}
}
impl DotDim<Ix2> for Ix2 {
type Output = Ix2;
fn shape(lhs: Self, rhs: Ix2) -> <Self as DotDim<Ix2>>::Output {
let mut res_shape = Ix2::zeros(2);
res_shape[0] = lhs[0];
res_shape[1] = rhs[1];
res_shape
}
}
fn sum_axis_inplace(array: &mut DynTensor, axis: Axis) {
let (first, rest) = array.view_mut().split_at(axis, 1);
Zip::from(first.remove_axis(axis))
.and(rest.lanes(axis))
.for_each(|dst, src| *dst += src.sum());
array.index_axis_inplace(axis, 0);
}
pub fn reduce<D: Dimension, E: Dimension>(dim: D, src: &Tensor<E>) -> Tensor<D> {
let mut src = src.clone().into_dyn();
while src.ndim() > dim.ndim() {
sum_axis_inplace(&mut src, Axis(0));
}
for (axis, size) in dim.slice().iter().enumerate() {
if *size == 1 {
sum_axis_inplace(&mut src, Axis(axis));
src.insert_axis_inplace(Axis(axis));
}
}
debug_assert_eq!(src.raw_dim(), dim.into_dyn());
debug_assert!(src.is_standard_layout());
src.into_dimensionality::<D>().unwrap()
}
pub fn push_gradient<'a, T, P, D>(destination_node: &T, gradient: P)
where
T: Gradient + Overwrite + ?Sized,
P: IntoNdProducer<Dim = D, Output = ArrayView<'a, f32, D>, Item = &'a f32>,
D: Dimension,
{
let mut destination_gradient = destination_node.gradient_mut();
let zip = Zip::from(&mut *destination_gradient).and_broadcast(gradient);
if destination_node.can_overwrite() {
zip.for_each(|d, s| *d = *s);
destination_node.set_overwrite(false);
} else {
zip.for_each(|d, s| *d += *s);
}
}
pub fn push_mat_mat_gradient<T, S1, S2>(
destination_node: &T,
first: &ArrayBase<S1, Ix2>,
second: &ArrayBase<S2, Ix2>,
) where
T: Gradient<Dim = Ix2> + Overwrite,
S1: ndarray::Data<Elem = f32>,
S2: ndarray::Data<Elem = f32>,
{
if destination_node.can_overwrite() {
general_mat_mul(1., first, second, 0., &mut destination_node.gradient_mut());
destination_node.set_overwrite(false);
} else {
general_mat_mul(1., first, second, 1., &mut destination_node.gradient_mut());
}
}
pub fn push_mat_vec_gradient<T, S1, S2>(
destination_node: &T,
first: &ArrayBase<S1, Ix2>,
second: &ArrayBase<S2, Ix1>,
) where
T: Gradient<Dim = Ix2> + Overwrite,
S1: ndarray::Data<Elem = f32>,
S2: ndarray::Data<Elem = f32>,
{
let mut destination_gradient = destination_node.gradient_mut();
let zip = Zip::from(&mut *destination_gradient)
.and_broadcast(first)
.and_broadcast(second);
if destination_node.can_overwrite() {
zip.for_each(|d, f, s| *d = f * s);
destination_node.set_overwrite(false);
} else {
zip.for_each(|d, f, s| *d += f * s);
}
}
pub fn push_vec_mat_gradient<T, S1, S2>(
destination_node: &T,
first: &ArrayBase<S1, Ix2>,
second: &ArrayBase<S2, Ix1>,
) where
T: Gradient<Dim = Ix1> + Overwrite,
S1: ndarray::Data<Elem = f32>,
S2: ndarray::Data<Elem = f32>,
{
if destination_node.can_overwrite() {
general_mat_vec_mul(1., first, second, 0., &mut destination_node.gradient_mut());
destination_node.set_overwrite(false);
} else {
general_mat_vec_mul(1., first, second, 1., &mut destination_node.gradient_mut());
}
}
pub fn push_vec_vec_gradient<T, S>(destination_node: &T, first: &ArrayBase<S, Ix1>, second: &f32)
where
T: Gradient<Dim = Ix1> + Overwrite,
S: ndarray::Data<Elem = f32>,
{
let mut destination_gradient = destination_node.gradient_mut();
let zip = Zip::from(&mut *destination_gradient).and_broadcast(first);
if destination_node.can_overwrite() {
zip.for_each(|d, f| *d = f * second);
destination_node.set_overwrite(false);
} else {
zip.for_each(|d, f| *d += f * second);
}
}
pub(crate) fn broadcasted_zeros<Lhs, Rhs>(
left: &Tensor<Lhs>,
right: &Tensor<Rhs>,
) -> BroadTensor<Lhs, Rhs>
where
Lhs: Dimension + DimMax<Rhs>,
Rhs: Dimension,
{
let (bigger, smaller) = if left.ndim() >= right.ndim() {
(left.shape(), right.shape())
} else {
(right.shape(), left.shape())
};
let mut broadcasted_dim = <Lhs as DimMax<Rhs>>::Output::zeros(bigger.len());
broadcasted_dim
.slice_mut()
.iter_mut()
.zip(bigger.iter())
.for_each(|(l, r)| *l = *r);
broadcasted_dim
.slice_mut()
.iter_mut()
.rev()
.zip(smaller.iter().rev())
.for_each(|(l, r)| *l = std::cmp::max(*l, *r));
Tensor::zeros(broadcasted_dim)
}
pub(crate) fn expect_tensor<D: Dimension>(tensor: &RefCell<Option<Tensor<D>>>) -> Ref<Tensor<D>> {
Ref::map(tensor.borrow(), |b| {
b.as_ref().expect(
"error: trying to get a de-allocated gradient.
Switch on the gradients first by using with_grad().",
)
})
}
pub(crate) fn expect_tensor_mut<D: Dimension>(
tensor: &RefCell<Option<Tensor<D>>>,
) -> RefMut<Tensor<D>> {
RefMut::map(tensor.borrow_mut(), |b| {
b.as_mut().expect(
"error: trying to get a de-allocated gradient.
Switch on the gradients first by using with_grad().",
)
})
}
#[cfg(test)]
const F16_EPSILON: f32 = 9.77e-04;
#[cfg(test)]
fn assert_almost_equals<D: Dimension>(array: &Tensor<D>, target: &Tensor<D>) {
assert!(
Zip::from(array).and(target).all(|l, r| {
(*l == 0. && *r == 0.)
|| (!l.is_finite() && !r.is_finite())
|| ((1. - r / l).abs() <= F16_EPSILON)
}),
"\nLeft:\n{}\nRight:\n{}",
array,
target
);
}
#[cfg(test)]
fn new_input<D, Sh>(shape: Sh, elements: Vec<f32>) -> Rc<Input<D>>
where
D: Dimension + 'static,
Sh: Into<ndarray::StrideShape<D>>,
{
Input::new(new_tensor(shape, elements)).node
}
#[cfg(test)]
fn new_backward_input<D, Sh>(shape: Sh, elements: Vec<f32>) -> Rc<InputBackward<D>>
where
D: Dimension + 'static,
Sh: Into<ndarray::StrideShape<D>>,
{
Rc::new(
Input::new(new_tensor(shape, elements))
.node
.differentiable(),
)
}
#[cfg(test)]
fn new_tensor<D, Sh>(shape: Sh, elements: Vec<f32>) -> Tensor<D>
where
D: Dimension + 'static,
Sh: Into<ndarray::StrideShape<D>>,
{
Tensor::from_shape_vec(shape, elements).unwrap()
}