use std::fmt::Debug;
use crate::arr::*;
use crate::device::*;
use crate::{Stack};
use crate::cuda::ToCuda;
use crate::error::{EvaluateError, TrainingError, TypeConvertError};
use crate::ope::UnitValue;
use crate::lossfunction::*;
pub mod input;
pub mod output;
pub mod linear;
pub mod activation;
pub mod bridge;
pub mod batchnormalization;
pub mod bias;
#[derive(Debug)]
pub enum DiffInput<T,U,const NI:usize,const NO:usize>
where U: UnitValue<U> + Clone + Copy + Debug, T: Debug {
Diff(T,Arr<U,NO>),
NotDiff(Arr<U,NI>)
}
pub trait BatchDataType {
type Type;
}
impl BatchDataType for () {
type Type = ();
}
impl<T,U,const NI:usize,const NO:usize> BatchDataType for DiffInput<T,U,NI,NO>
where U: UnitValue<U> + Clone + Copy + Debug, T: Debug {
type Type = Vec<DiffInput<T,U,NI,NO>>;
}
impl<T,U,const NI:usize,const NO:usize> ToCuda<U> for DiffInput<T,U,NI,NO>
where U: UnitValue<U> + Clone + Copy + Debug, T: Debug {
type Output = Self;
fn to_cuda(self, _: &DeviceGpu<U>) -> Result<Self::Output, TypeConvertError> {
Ok(self)
}
}
impl<T,U,const NI:usize,const NO:usize> ToCuda<U> for Vec<DiffInput<T,U,NI,NO>>
where U: UnitValue<U> + Clone + Copy + Debug, T: Debug {
type Output = Self;
fn to_cuda(self, _: &DeviceGpu<U>) -> Result<Self::Output, TypeConvertError> {
Ok(self)
}
}
pub trait BatchSize {
fn size(&self) -> usize;
}
pub trait Forward<I,O> {
fn forward(&self,input:&I) -> O;
}
pub trait ForwardAll {
type Input: Debug;
type Output: Debug + 'static;
fn forward_all(&self, input:Self::Input) -> Result<Self::Output, EvaluateError>;
}
pub trait BackwardAll<U>: PreTrain<U> + UpdateWeight<U> where U: UnitValue<U> {
type LossInput: Debug;
type LossOutput: Debug;
fn backward_all<L: LossFunction<U>>(&mut self, input:Self::LossInput, stack:Self::OutStack, lossf:&L)
-> Result<(<Self as BackwardAll<U>>::LossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError>;
fn is_canonical_link<L: LossFunction<U>>(&self,_:&L) -> bool {
false
}
}
pub trait Loss<U>: BackwardAll<U> where U: UnitValue<U> {
fn loss<L: LossFunction<U>>(&mut self, loss:Self::LossInput, _lossf:&L, stack:Self::OutStack) -> Result<(Self::OutStack, Self::LossInput), TrainingError> {
Ok((stack,loss))
}
}
pub trait Backward<U,I,O> where U: UnitValue<U> {
fn backward(&mut self, input:I) -> O;
}
pub trait PreTrain<U>: ForwardAll where U: UnitValue<U> {
type PreOutput: Debug + 'static;
type OutStack: Stack<Head=Self::PreOutput> + Debug + Sized;
fn pre_train(&self, input:Self::Input) -> Result<Self::OutStack, EvaluateError>;
}
pub trait UpdateWeight<U> where U: UnitValue<U> {
type GradientStack: Stack + Debug + Sized;
fn update_weight(&mut self, stack:Self::GradientStack) -> Result<(), TrainingError>;
}
pub trait ForwardDiff<U>: PreTrain<U> where U: UnitValue<U> {
fn forward_diff(&self, input:Self::Input) -> Result<Self::OutStack, EvaluateError>;
}
pub trait Train<U,L>: PreTrain<U>
where U: UnitValue<U> {
fn train(&mut self, expected:Self::Output, input:Self::Input, lossf:&L) -> Result<U, TrainingError>;
}
pub trait AskDiffInput<U>: PreTrain<U> where U: UnitValue<U> {
type DiffInput: Debug;
fn ask_diff_input(&self, stack: &Self::OutStack) -> Result<Self::DiffInput,TypeConvertError>;
}
pub trait BatchForwardBase: ForwardAll {
type BatchInput: Debug;
type BatchOutput: Debug;
}
pub trait BatchForward: BatchForwardBase {
fn batch_forward(&self,input:Self::BatchInput) -> Result<Self::BatchOutput, TrainingError>;
}
pub trait BatchBackward<U>: BatchPreTrainBase<U> + UpdateWeight<U> where U: UnitValue<U> {
type BatchLossInput: Debug;
type BatchLossOutput: Debug;
fn batch_backward<L: LossFunction<U>>(&mut self, input:Self::BatchLossInput, stack:Self::BatchOutStack, lossf:&L)
-> Result<(<Self as BatchBackward<U>>::BatchLossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError>;
}
pub trait BatchLoss<U>: BatchBackward<U> + Loss<U> where U: UnitValue<U> {
fn batch_loss<L: LossFunction<U>>(&self, loss:Self::BatchLossInput, _lossf:&L, stack:Self::BatchOutStack) -> Result<(Self::BatchOutStack, Self::BatchLossInput), TrainingError> {
Ok((stack,loss))
}
}
pub trait BatchPreTrainBase<U>: BatchForwardBase + PreTrain<U> where U: UnitValue<U> {
type BatchPreOutput: Debug + 'static;
type BatchOutStack: Stack<Head=Self::BatchPreOutput> + Sized + Debug;
}
pub trait BatchPreTrain<U>: BatchPreTrainBase<U> + BatchForwardBase + BatchForward + where U: UnitValue<U> {
fn batch_pre_train(&self, input:Self::BatchInput) -> Result<Self::BatchOutStack, TrainingError>;
}
pub trait BatchTrain<U,D,L>: BatchPreTrainBase<U> + BatchPreTrain<U> + BatchBackward<U> + PreTrain<U>
where U: UnitValue<U>, D: Device<U>,
L: LossFunction<U> {
fn batch_train(&mut self, expected:Self::BatchOutput, input:Self::BatchInput, lossf:&L) -> Result<U, TrainingError>;
}
pub trait AddLayer: ForwardAll where Self: Sized {
fn add_layer<C,F>(self,f:F) -> C where C: ForwardAll, F: FnOnce(Self) -> C;
}
impl<T> AddLayer for T where T: ForwardAll + Sized {
fn add_layer<C, F>(self, f: F) -> C where C: ForwardAll, F: FnOnce(Self) -> C {
f(self)
}
}
pub trait TryAddLayer: ForwardAll where Self: Sized {
fn try_add_layer<C,F,E>(self,f:F) -> Result<C,E> where C: ForwardAll, F: FnOnce(Self) -> Result<C,E>;
}
impl<T> TryAddLayer for T where T: ForwardAll + Sized {
fn try_add_layer<C,F,E>(self, f: F) -> Result<C,E> where C: ForwardAll, F: FnOnce(Self) -> Result<C,E> {
f(self)
}
}
impl<T,U> ForwardDiff<U> for T where T: PreTrain<U> + Sized, U: UnitValue<U> {
fn forward_diff(&self, input: Self::Input) -> Result<Self::OutStack, EvaluateError> {
self.pre_train(input)
}
}