#![no_std]
mod chain;
pub use chain::*;
pub trait Forward {
type Input;
type Internal;
type Output;
fn forward(&self, input: &Self::Input) -> (Self::Internal, Self::Output);
fn run(&self, input: &Self::Input) -> Self::Output {
self.forward(input).1
}
}
pub trait Backward: Forward {
type OutputDelta;
type InputDelta;
type TrainDelta;
fn backward(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> (Self::InputDelta, Self::TrainDelta);
fn backward_input(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::InputDelta {
self.backward(input, internal, output_delta).0
}
fn backward_train(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::TrainDelta {
self.backward(input, internal, output_delta).1
}
}
pub trait Train: Backward {
fn train(&mut self, train_delta: &Self::TrainDelta);
fn propogate(
&mut self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::InputDelta {
let (input_delta, train_delta) = self.backward(input, internal, output_delta);
self.train(&train_delta);
input_delta
}
}
pub trait Graph: Train + Sized {
fn chain<U>(self, other: U) -> Chain<Self, U> {
Chain(self, other)
}
}
impl<T> Graph for T where T: Train {}
impl<'a, T> Forward for &'a T
where
T: Forward,
{
type Input = T::Input;
type Internal = T::Internal;
type Output = T::Output;
fn forward(&self, input: &Self::Input) -> (Self::Internal, Self::Output) {
T::forward(self, input)
}
fn run(&self, input: &Self::Input) -> Self::Output {
T::run(self, input)
}
}
impl<'a, T> Forward for &'a mut T
where
T: Forward,
{
type Input = T::Input;
type Internal = T::Internal;
type Output = T::Output;
fn forward(&self, input: &Self::Input) -> (Self::Internal, Self::Output) {
T::forward(self, input)
}
fn run(&self, input: &Self::Input) -> Self::Output {
T::run(self, input)
}
}
impl<'a, T> Backward for &'a T
where
T: Backward,
{
type OutputDelta = T::OutputDelta;
type InputDelta = T::InputDelta;
type TrainDelta = T::TrainDelta;
fn backward(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> (Self::InputDelta, Self::TrainDelta) {
T::backward(self, input, internal, output_delta)
}
fn backward_input(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::InputDelta {
T::backward_input(self, input, internal, output_delta)
}
fn backward_train(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::TrainDelta {
T::backward_train(self, input, internal, output_delta)
}
}
impl<'a, T> Backward for &'a mut T
where
T: Backward,
{
type OutputDelta = T::OutputDelta;
type InputDelta = T::InputDelta;
type TrainDelta = T::TrainDelta;
fn backward(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> (Self::InputDelta, Self::TrainDelta) {
T::backward(self, input, internal, output_delta)
}
fn backward_input(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::InputDelta {
T::backward_input(self, input, internal, output_delta)
}
fn backward_train(
&self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::TrainDelta {
T::backward_train(self, input, internal, output_delta)
}
}
impl<'a, T> Train for &'a mut T
where
T: Train,
{
fn train(&mut self, train_delta: &Self::TrainDelta) {
T::train(self, train_delta)
}
fn propogate(
&mut self,
input: &Self::Input,
internal: &Self::Internal,
output_delta: &Self::OutputDelta,
) -> Self::InputDelta {
T::propogate(self, input, internal, output_delta)
}
}