use super::{Input, InputBackward, Param};
use crate::variable::{
self, Backward, Convolve, ConvolveWithGroups, Data, Dropout as DropoutNode,
DropoutBackward as DropoutBackwardNode, Eval, Forward, Gradient, MatMatMulT, Overwrite,
RawParam, Tensor, Var, VarDiff,
};
pub use crate::variable::{Constant, PaddingMode, Reflective, Replicative, Zero};
use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5};
use std::{cell::Cell, rc::Rc};
pub mod init;
pub mod loss;
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
pub type Learnable<D> = VarDiff<Input<D>, InputBackward<D>>;
pub struct ModelStatus {
params: Vec<RawParam>,
train: Rc<Cell<bool>>,
}
impl ModelStatus {
pub fn parameters(&self) -> Vec<Param<'_>> {
self.params
.iter()
.cloned()
.map(RawParam::into_param)
.collect()
}
pub fn register<T: Register>(&mut self, mut component: T) -> T {
component.register_params(&mut self.params);
component.register_status(self.train.clone());
component
}
pub fn train(&self) {
<Self as Eval>::train(self)
}
pub fn eval(&self) {
<Self as Eval>::eval(self)
}
}
impl Default for ModelStatus {
fn default() -> Self {
Self {
params: Vec::new(),
train: Rc::new(Cell::new(true)),
}
}
}
impl Eval for ModelStatus {
fn train(&self) {
self.train.set(true)
}
fn eval(&self) {
self.train.set(false)
}
}
pub trait DropoutInput {
type Output;
fn dropout(self, p: f64, status: Rc<Cell<bool>>) -> Self::Output;
}
impl<T, U> DropoutInput for VarDiff<T, U>
where
T: Data + Forward,
U: Gradient<Dim = T::Dim> + Overwrite + Backward,
{
type Output = VarDiff<DropoutNode<T>, DropoutBackwardNode<U, T>>;
fn dropout(self, p: f64, status: Rc<Cell<bool>>) -> Self::Output {
self.dropout_with_status(p, status)
}
}
impl<T> DropoutInput for Var<T>
where
T: Data + Forward,
{
type Output = Var<DropoutNode<T>>;
fn dropout(self, p: f64, status: Rc<Cell<bool>>) -> Self::Output {
self.dropout_with_status(p, status)
}
}
pub trait Register {
fn register_params(&self, params: &mut Vec<RawParam>);
fn register_status(&mut self, status: Rc<Cell<bool>>);
}
pub struct Dropout {
pub status: Rc<Cell<bool>>,
pub p: f64,
}
impl Dropout {
pub fn new(p: f64) -> Self {
let status = Rc::new(Cell::new(true));
Self { status, p }
}
pub fn forward<I: DropoutInput>(&self, input: I) -> I::Output {
input.dropout(self.p, self.status.clone())
}
}
impl Eval for Dropout {
fn eval(&self) {
self.status.set(false)
}
fn train(&self) {
self.status.set(true)
}
}
impl Register for Dropout {
fn register_status(&mut self, status: Rc<Cell<bool>>) {
self.status = status;
}
fn register_params(&self, _: &mut Vec<RawParam>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct Linear {
pub weight: Learnable<Ix2>,
pub bias: Learnable<Ix1>,
}
impl Linear {
pub fn new(in_features: usize, out_features: usize) -> Self {
let weight = Input::new(Tensor::zeros((out_features, in_features))).requires_grad();
let bias = Input::new(Tensor::zeros(out_features)).requires_grad();
let k = (1. / (in_features as f32)).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self { weight, bias }
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>
where
I: MatMatMulT<Learnable<Ix2>>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix2>,
U: Gradient<Dim = Ix2> + Overwrite,
{
input.mm_t(self.weight.clone()).into() + self.bias.clone()
}
}
impl Register for Linear {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
#[allow(clippy::upper_case_acronyms)]
pub struct LSTMCell {
pub weight_ih: Learnable<Ix2>,
pub weight_hh: Learnable<Ix2>,
pub bias_ih: Learnable<Ix1>,
pub bias_hh: Learnable<Ix1>,
}
impl LSTMCell {
pub fn new(input_size: usize, hidden_size: usize) -> Self {
let (weight_ih_shape, weight_hh_shape, bias_shape) = {
let xhidden_size = 4 * hidden_size;
(
(xhidden_size, input_size),
(xhidden_size, hidden_size),
xhidden_size,
)
};
let weight_ih = Input::new(Tensor::zeros(weight_ih_shape)).requires_grad();
let weight_hh = Input::new(Tensor::zeros(weight_hh_shape)).requires_grad();
let bias_ih = Input::new(Tensor::zeros(bias_shape)).requires_grad();
let bias_hh = Input::new(Tensor::zeros(bias_shape)).requires_grad();
let k = 1. / (hidden_size as f32).sqrt();
init::uniform(&weight_ih, -k, k);
init::uniform(&weight_hh, -k, k);
init::uniform(&bias_ih, -k, k);
init::uniform(&bias_hh, -k, k);
Self {
weight_ih,
weight_hh,
bias_ih,
bias_hh,
}
}
pub fn forward<Cf, Cb, Hf, Hb, I, T, U>(
&self,
state: (VarDiff<Cf, Cb>, VarDiff<Hf, Hb>),
input: I,
) -> (
VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>,
VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>,
)
where
Cf: Data<Dim = Ix2>,
Cb: Gradient<Dim = Ix2> + Overwrite,
Hf: Data<Dim = Ix2>,
Hb: Gradient<Dim = Ix2> + Overwrite,
I: MatMatMulT<Learnable<Ix2>>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix2>,
U: Gradient<Dim = Ix2> + Overwrite,
{
let (cell_state, hidden) = state;
let gates = hidden.mm_t(self.weight_hh.clone())
+ self.bias_hh.clone()
+ input.mm_t(self.weight_ih.clone()).into()
+ self.bias_ih.clone();
let gate_shape = {
let (gates_shape_rows, gates_shape_cols) = gates.data().dim();
(gates_shape_rows, gates_shape_cols / 4)
};
let chunked_gates = gates.chunks(gate_shape);
let (input_gate, forget_gate, cell_state_gate, output_gate) = (
chunked_gates[0].clone().sigmoid(),
chunked_gates[1].clone().tanh(),
chunked_gates[2].clone().sigmoid(),
chunked_gates[3].clone().sigmoid(),
);
let new_cell_state = forget_gate * cell_state + (input_gate * cell_state_gate);
let new_hidden = output_gate * new_cell_state.clone().tanh();
(new_cell_state, new_hidden)
}
}
impl Register for LSTMCell {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight_hh.register_params(params);
self.weight_ih.register_params(params);
self.bias_hh.register_params(params);
self.bias_ih.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
#[allow(clippy::upper_case_acronyms)]
pub struct GRUCell {
pub weight_ih: Learnable<Ix2>,
pub weight_hh: Learnable<Ix2>,
pub bias_ih: Learnable<Ix1>,
pub bias_hh: Learnable<Ix1>,
}
impl GRUCell {
pub fn new(input_size: usize, hidden_size: usize) -> Self {
let (weight_ih_shape, weight_hh_shape, bias_shape) = {
let xhidden_size = 3 * hidden_size;
(
(xhidden_size, input_size),
(xhidden_size, hidden_size),
xhidden_size,
)
};
let weight_ih = Input::new(Tensor::zeros(weight_ih_shape)).requires_grad();
let weight_hh = Input::new(Tensor::zeros(weight_hh_shape)).requires_grad();
let bias_ih = Input::new(Tensor::zeros(bias_shape)).requires_grad();
let bias_hh = Input::new(Tensor::zeros(bias_shape)).requires_grad();
let k = 1. / (hidden_size as f32).sqrt();
init::uniform(&weight_ih, -k, k);
init::uniform(&weight_hh, -k, k);
init::uniform(&bias_ih, -k, k);
init::uniform(&bias_hh, -k, k);
Self {
weight_ih,
weight_hh,
bias_ih,
bias_hh,
}
}
pub fn forward<Hf, Hb, I, T, U>(
&self,
hidden: VarDiff<Hf, Hb>,
input: I,
) -> VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>
where
Hf: Data<Dim = Ix2>,
Hb: Gradient<Dim = Ix2> + Overwrite,
I: MatMatMulT<Learnable<Ix2>>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix2>,
U: Gradient<Dim = Ix2> + Overwrite,
{
let (igates, hgates) = {
(
input.mm_t(self.weight_ih.clone()).into() + self.bias_ih.clone(),
hidden.clone().mm_t(self.weight_hh.clone()) + self.bias_hh.clone(),
)
};
let gate_shape = {
let (gates_shape_rows, gates_shape_cols) = hgates.data().dim();
(gates_shape_rows, gates_shape_cols / 3)
};
let (chunked_igates, chunked_hgates) =
(igates.chunks(gate_shape), hgates.chunks(gate_shape));
let reset_gate = (chunked_hgates[0].clone() + chunked_igates[0].clone()).sigmoid();
let input_gate = (chunked_hgates[1].clone() + chunked_igates[1].clone()).sigmoid();
let new_gate =
(chunked_igates[2].clone() + (chunked_hgates[2].clone() * reset_gate)).tanh();
(hidden - new_gate.clone()) * input_gate + new_gate
}
}
impl Register for GRUCell {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight_hh.register_params(params);
self.weight_ih.register_params(params);
self.bias_hh.register_params(params);
self.bias_ih.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct Conv1d<Pad: PaddingMode> {
pub padding: usize,
pub padding_mode: Pad,
pub stride: usize,
pub dilation: usize,
pub weight: Learnable<Ix3>,
pub bias: Learnable<Ix1>,
}
impl<Pad: PaddingMode> Conv1d<Pad> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
padding: usize,
padding_mode: Pad,
stride: usize,
dilation: usize,
) -> Self {
let weight =
Input::new(Tensor::zeros((out_channels, in_channels, kernel_size))).requires_grad();
let bias = Input::new(Tensor::zeros(out_channels)).requires_grad();
let k = (1. / (in_channels * kernel_size) as f32).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self {
padding,
padding_mode,
stride,
dilation,
weight,
bias,
}
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix3> + Forward, impl Gradient<Dim = Ix3> + Overwrite + Backward>
where
I: Convolve<I, Learnable<Ix3>, Pad>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix3>,
U: Gradient<Dim = Ix3> + Overwrite,
{
I::convolve(
input,
self.weight.clone(),
&[self.stride],
&[self.dilation],
&[self.padding],
self.padding_mode.clone(),
)
.into()
+ self.bias.clone()
}
}
impl<Pad: PaddingMode> Register for Conv1d<Pad> {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct GroupedConv1d<Pad: PaddingMode> {
pub padding: usize,
pub padding_mode: Pad,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
pub weight: Learnable<Ix3>,
pub bias: Learnable<Ix1>,
}
impl<Pad: PaddingMode> GroupedConv1d<Pad> {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
padding: usize,
padding_mode: Pad,
stride: usize,
dilation: usize,
groups: usize,
) -> Self {
let weight = Input::new(Tensor::zeros((
out_channels,
in_channels / groups,
kernel_size,
)))
.requires_grad();
let bias = Input::new(Tensor::zeros(out_channels)).requires_grad();
let k = (groups as f32 / (in_channels * kernel_size) as f32).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self {
padding,
padding_mode,
stride,
dilation,
groups,
weight,
bias,
}
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix3> + Forward, impl Gradient<Dim = Ix3> + Overwrite + Backward>
where
I: ConvolveWithGroups<I, Learnable<Ix3>, Pad>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix3>,
U: Gradient<Dim = Ix3> + Overwrite,
{
I::convolve_with_groups(
input,
self.weight.clone(),
&[self.stride],
&[self.dilation],
&[self.padding],
self.padding_mode.clone(),
self.groups,
)
.into()
+ self.bias.clone()
}
}
impl<Pad: PaddingMode> Register for GroupedConv1d<Pad> {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct Conv2d<Pad: PaddingMode> {
pub padding: (usize, usize),
pub padding_mode: Pad,
pub stride: (usize, usize),
pub dilation: (usize, usize),
pub weight: Learnable<Ix4>,
pub bias: Learnable<Ix1>,
}
impl<Pad: PaddingMode> Conv2d<Pad> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
padding: (usize, usize),
padding_mode: Pad,
stride: (usize, usize),
dilation: (usize, usize),
) -> Self {
let (kernel_h, kernel_w) = kernel_size;
let weight = Input::new(Tensor::zeros((
out_channels,
in_channels,
kernel_h,
kernel_w,
)))
.requires_grad();
let bias = Input::new(Tensor::zeros(out_channels)).requires_grad();
let k = (1. / (in_channels * kernel_h * kernel_w) as f32).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self {
padding,
padding_mode,
stride,
dilation,
weight,
bias,
}
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix4> + Forward, impl Gradient<Dim = Ix4> + Overwrite + Backward>
where
I: Convolve<I, Learnable<Ix4>, Pad>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix4>,
U: Gradient<Dim = Ix4> + Overwrite,
{
let (stride_h, stride_w) = self.stride;
let (padding_h, padding_w) = self.padding;
let (dilation_h, dilation_w) = self.dilation;
I::convolve(
input,
self.weight.clone(),
&[stride_h, stride_w],
&[dilation_h, dilation_w],
&[padding_h, padding_w],
self.padding_mode.clone(),
)
.into()
+ self.bias.clone()
}
}
impl<Pad: PaddingMode> Register for Conv2d<Pad> {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct GroupedConv2d<Pad: PaddingMode> {
pub padding: (usize, usize),
pub padding_mode: Pad,
pub stride: (usize, usize),
pub dilation: (usize, usize),
pub groups: usize,
pub weight: Learnable<Ix4>,
pub bias: Learnable<Ix1>,
}
impl<Pad: PaddingMode> GroupedConv2d<Pad> {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
padding: (usize, usize),
padding_mode: Pad,
stride: (usize, usize),
dilation: (usize, usize),
groups: usize,
) -> Self {
let (kernel_h, kernel_w) = kernel_size;
let weight = Input::new(Tensor::zeros((
out_channels,
in_channels,
kernel_h,
kernel_w,
)))
.requires_grad();
let bias = Input::new(Tensor::zeros(out_channels)).requires_grad();
let k = (groups as f32 / (in_channels * kernel_h * kernel_w) as f32).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self {
padding,
padding_mode,
stride,
dilation,
groups,
weight,
bias,
}
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix4> + Forward, impl Gradient<Dim = Ix4> + Overwrite + Backward>
where
I: ConvolveWithGroups<I, Learnable<Ix4>, Pad>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix4>,
U: Gradient<Dim = Ix4> + Overwrite,
{
let (stride_h, stride_w) = self.stride;
let (padding_h, padding_w) = self.padding;
let (dilation_h, dilation_w) = self.dilation;
I::convolve_with_groups(
input,
self.weight.clone(),
&[stride_h, stride_w],
&[dilation_h, dilation_w],
&[padding_h, padding_w],
self.padding_mode.clone(),
self.groups,
)
.into()
+ self.bias.clone()
}
}
impl<Pad: PaddingMode> Register for GroupedConv2d<Pad> {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct Conv3d<Pad: PaddingMode> {
pub padding: (usize, usize, usize),
pub padding_mode: Pad,
pub stride: (usize, usize, usize),
pub dilation: (usize, usize, usize),
pub weight: Learnable<Ix5>,
pub bias: Learnable<Ix1>,
}
impl<Pad: PaddingMode> Conv3d<Pad> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
padding: (usize, usize, usize),
padding_mode: Pad,
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
) -> Self {
let (kernel_d, kernel_h, kernel_w) = kernel_size;
let weight = Input::new(Tensor::zeros((
out_channels,
in_channels,
kernel_d,
kernel_h,
kernel_w,
)))
.requires_grad();
let bias = Input::new(Tensor::zeros(out_channels)).requires_grad();
let k = (1. / (in_channels * kernel_d * kernel_h * kernel_w) as f32).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self {
padding,
padding_mode,
stride,
dilation,
weight,
bias,
}
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix5> + Forward, impl Gradient<Dim = Ix5> + Overwrite + Backward>
where
I: Convolve<I, Learnable<Ix5>, Pad>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix5>,
U: Gradient<Dim = Ix5> + Overwrite,
{
let (stride_d, stride_h, stride_w) = self.stride;
let (padding_d, padding_h, padding_w) = self.padding;
let (dilation_d, dilation_h, dilation_w) = self.dilation;
I::convolve(
input,
self.weight.clone(),
&[stride_d, stride_h, stride_w],
&[dilation_d, dilation_h, dilation_w],
&[padding_d, padding_h, padding_w],
self.padding_mode.clone(),
)
.into()
+ self.bias.clone()
}
}
impl<Pad: PaddingMode> Register for Conv3d<Pad> {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct GroupedConv3d<Pad: PaddingMode> {
pub padding: (usize, usize, usize),
pub padding_mode: Pad,
pub stride: (usize, usize, usize),
pub dilation: (usize, usize, usize),
pub groups: usize,
pub weight: Learnable<Ix5>,
pub bias: Learnable<Ix1>,
}
impl<Pad: PaddingMode> GroupedConv3d<Pad> {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
padding: (usize, usize, usize),
padding_mode: Pad,
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
groups: usize,
) -> Self {
let (kernel_d, kernel_h, kernel_w) = kernel_size;
let weight = Input::new(Tensor::zeros((
out_channels,
in_channels,
kernel_d,
kernel_h,
kernel_w,
)))
.requires_grad();
let bias = Input::new(Tensor::zeros(out_channels)).requires_grad();
let k = (1. / (in_channels * kernel_d * kernel_h * kernel_w) as f32).sqrt();
init::uniform(&weight, -k, k);
init::uniform(&bias, -k, k);
Self {
padding,
padding_mode,
stride,
dilation,
groups,
weight,
bias,
}
}
pub fn forward<I, T, U>(
&self,
input: I,
) -> VarDiff<impl Data<Dim = Ix5> + Forward, impl Gradient<Dim = Ix5> + Overwrite + Backward>
where
I: ConvolveWithGroups<I, Learnable<Ix5>, Pad>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix5>,
U: Gradient<Dim = Ix5> + Overwrite,
{
let (stride_d, stride_h, stride_w) = self.stride;
let (padding_d, padding_h, padding_w) = self.padding;
let (dilation_d, dilation_h, dilation_w) = self.dilation;
I::convolve_with_groups(
input,
self.weight.clone(),
&[stride_d, stride_h, stride_w],
&[dilation_d, dilation_h, dilation_w],
&[padding_d, padding_h, padding_w],
self.padding_mode.clone(),
self.groups,
)
.into()
+ self.bias.clone()
}
}
impl<Pad: PaddingMode> Register for GroupedConv3d<Pad> {
fn register_params(&self, params: &mut Vec<RawParam>) {
self.weight.register_params(params);
self.bias.register_params(params);
}
fn register_status(&mut self, _: Rc<Cell<bool>>) {}
}