use std::fmt::Debug;
use std::marker::PhantomData;
use std::str::FromStr;
use crate::{Stack};
use crate::arr::{Arr, SerializedVec};
use crate::cuda::ToHost;
use crate::device::{Device};
use crate::device::output::DeviceLinearOutput;
use crate::error::{ConfigReadError, EvaluateError, PersistenceError, SizeMismatchError, TrainingError, TypeConvertError};
use crate::layer::{AskDiffInput, BackwardAll, BatchBackward, BatchDataType, BatchForward, BatchForwardBase, BatchLoss, BatchPreTrain, BatchPreTrainBase, BatchSize, BatchTrain, ForwardAll, Loss, PreTrain, Train, UpdateWeight};
use crate::lossfunction::{BatchLossFunctionLinear, LossFunction, LossFunctionLinear};
use crate::ope::UnitValue;
use crate::persistence::{Linear, LinearPersistence, Persistence, Specialized, TextFilePersistence};
pub struct LinearOutputLayer<U,P,D,I,PI,const N:usize>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
U: Default + Clone + Copy + UnitValue<U>,
D: Device<U>,
PI: Debug + 'static,
I: Debug + Send + Sync {
u:PhantomData<U>,
i:PhantomData<I>,
io:PhantomData<PI>,
n:PhantomData<[();N]>,
parent:P,
device:D,
}
impl<U,P,D,I,PI,const N:usize> LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
U: Default + Clone + Copy + UnitValue<U>,
D: Device<U>,
PI: Debug + 'static,
I: Debug + Send + Sync {
pub fn new(parent:P,device:&D) -> LinearOutputLayer<U,P,D,I,PI,N> {
LinearOutputLayer {
u:PhantomData::<U>,
i:PhantomData::<I>,
io:PhantomData::<PI>,
n:PhantomData::<[();N]>,
parent:parent,
device:device.clone(),
}
}
}
impl<U,P,D,I,PI,const N:usize> Persistence<U,TextFilePersistence<U>,Specialized> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> +
PreTrain<U,PreOutput=PI> + Loss<U> + Persistence<U,TextFilePersistence<U>,Specialized>,
U: Default + Clone + Copy + UnitValue<U> + FromStr + Sized,
D: Device<U>,
PI: Debug + 'static,
I: Debug + Send + Sync {
fn load(&mut self, persistence: &mut TextFilePersistence<U>) -> Result<(),ConfigReadError> {
self.parent.load(persistence)?;
persistence.verify_eof()
}
fn save(&mut self, persistence: &mut TextFilePersistence<U>) -> Result<(), PersistenceError> {
self.parent.save(persistence)
}
}
impl<T,U,P,D,I,PI,const N:usize> Persistence<U,T,Linear> for LinearOutputLayer<U,P,D,I,PI,N>
where T: LinearPersistence<U>,
P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> +
PreTrain<U,PreOutput=PI> + Loss<U> + Persistence<U,T,Linear>,
U: Default + Clone + Copy + UnitValue<U>,
D: Device<U>,
PI: Debug + 'static,
I: Debug + Send + Sync {
fn load(&mut self, persistence: &mut T) -> Result<(),ConfigReadError> {
self.parent.load(persistence)?;
persistence.verify_eof()
}
fn save(&mut self, persistence: &mut T) -> Result<(), PersistenceError> {
self.parent.save(persistence)
}
}
impl<U,P,D,I,PI,const N:usize> ForwardAll for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
U: Default + Clone + Copy + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync,
<PI as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type Input = I;
type Output = <PI as ToHost<U>>::Output;
fn forward_all(&self, input: Self::Input) -> Result<Self::Output, EvaluateError> {
Ok(self.parent.forward_all(input)?.to_host()?)
}
}
impl<U,P,D,I,PI,const N:usize> PreTrain<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
U: Default + Clone + Copy + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync,
<PI as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type PreOutput = PI;
type OutStack = P::OutStack;
fn pre_train(&self, input: Self::Input) -> Result<Self::OutStack, EvaluateError> {
self.parent.pre_train(input)
}
}
impl<U,P,D,I,PI,const N:usize> BackwardAll<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: BackwardAll<U,LossInput=PI> +
ForwardAll<Input=I,Output=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
U: Default + Clone + Copy + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync,
<PI as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type LossInput = PI;
type LossOutput = <P as BackwardAll<U>>::LossOutput;
fn backward_all<L: LossFunction<U>>(&mut self, input: Self::LossInput, stack:Self::OutStack, lossf:&L)
-> Result<(<Self as BackwardAll<U>>::LossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError> {
self.parent.backward_all(input, stack, lossf)
}
}
impl<U,P,D,I,PI,const N:usize> UpdateWeight<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: BackwardAll<U,LossInput=PI> +
ForwardAll<Input=I,Output=PI> +
PreTrain<U,PreOutput=PI> + Loss<U> + UpdateWeight<U>,
U: Default + Clone + Copy + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync,
<PI as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type GradientStack = <P as UpdateWeight<U>>::GradientStack;
fn update_weight(&mut self, stack: Self::GradientStack) -> Result<(), TrainingError> {
Ok(self.parent.update_weight(stack)?)
}
}
impl<U,P,D,I,PI,const N:usize> AskDiffInput<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: BackwardAll<U,LossInput=PI> +
ForwardAll<Input=I,Output=PI> + PreTrain<U,PreOutput=PI> + Loss<U> + AskDiffInput<U>,
U: Default + Clone + Copy + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync,
<PI as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type DiffInput = P::DiffInput;
fn ask_diff_input(&self, stack: &Self::OutStack) -> Result<Self::DiffInput,TypeConvertError> {
self.parent.ask_diff_input(stack)
}
}
impl<U,P,D,I,PI,L,const N:usize> Train<U,L> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
U: Default + Clone + Copy + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync,
<PI as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI>,
for<'a> L: LossFunction<U> + LossFunctionLinear<'a,U,PI,D,N,Output=PI> {
fn train(&mut self, expected: Self::Output, input: Self::Input, lossf: &L) -> Result<U, TrainingError> {
let stack = self.pre_train(input)?;
let total_loss = stack.map(|l| self.device.loss_linear_total(&expected,l,lossf))?;
let (stack,loss) = if self.parent.is_canonical_link(lossf) {
let loss = stack.map(|actual| {
self.device.loss_linear_by_canonical_link(&expected, &actual)
})?;
(stack,loss)
} else {
let loss = stack.map(|actual| {
self.device.loss_linear(&expected,&actual,lossf)
})?;
self.parent.loss(loss,lossf,stack)?
};
let (_,s) = self.backward_all(loss,stack,lossf)?;
self.parent.update_weight(s)?;
Ok(total_loss)
}
}
impl<U,P,D,I,PI,const N:usize> BatchForwardBase for LinearOutputLayer<U,P,D,I,PI,N>
where P: PreTrain<U,PreOutput=PI> + ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + Loss<U> +
BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type>,
U: Default + Clone + Copy + Send + UnitValue<U>,
PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync + BatchDataType,
<PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
<PI as ToHost<U>>::Output: Debug + 'static,
<I as BatchDataType>::Type: Debug,
<<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type BatchInput = <I as BatchDataType>::Type;
type BatchOutput = <<PI as BatchDataType>::Type as ToHost<U>>::Output;
}
impl<U,P,D,I,PI,const N:usize> BatchForward for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward,
U: Default + Clone + Copy + Send + UnitValue<U>,
PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync + BatchDataType,
<PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
<PI as ToHost<U>>::Output: Debug + 'static,
<I as BatchDataType>::Type: Debug,
<<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
fn batch_forward(&self, input: Self::BatchInput) -> Result<Self::BatchOutput, TrainingError> {
Ok(self.parent.batch_forward(input)?.to_host()?)
}
}
impl<U,P,D,I,PI,const N:usize> BatchPreTrainBase<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type>,
U: Default + Clone + Copy + Send + UnitValue<U>,
PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync + BatchDataType,
<PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
<PI as ToHost<U>>::Output: Debug + 'static,
<I as BatchDataType>::Type: Debug,
<<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type BatchPreOutput = <PI as BatchDataType>::Type;
type BatchOutStack = P::BatchOutStack;
}
impl<U,P,D,I,PI,const N:usize> BatchPreTrain<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U>,
U: Default + Clone + Copy + Send + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync + BatchDataType,
<PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
<PI as ToHost<U>>::Output: Debug + 'static,
<I as BatchDataType>::Type: Debug,
<<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
fn batch_pre_train(&self, input: Self::BatchInput) -> Result<Self::BatchOutStack, TrainingError> {
self.parent.batch_pre_train(input)
}
}
impl<U,P,D,I,PI,const N:usize> BatchBackward<U> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U> +
BatchBackward<U> + UpdateWeight<U> + BatchLoss<U,BatchLossInput=<PI as BatchDataType>::Type>,
U: Default + Clone + Copy + Send + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync + BatchDataType,
<PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
<PI as ToHost<U>>::Output: Debug + 'static,
<I as BatchDataType>::Type: Debug + BatchSize,
<<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
type BatchLossInput = <PI as BatchDataType>::Type;
type BatchLossOutput = <P as BatchBackward<U>>::BatchLossOutput;
fn batch_backward<L: LossFunction<U>>(&mut self, input: Self::BatchLossInput, stack: Self::BatchOutStack, lossf: &L)
-> Result<(<Self as BatchBackward<U>>::BatchLossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError> {
self.parent.batch_backward(input,stack,lossf)
}
}
impl<U,P,D,I,PI,L,const N:usize> BatchTrain<U,D,L> for LinearOutputLayer<U,P,D,I,PI,N>
where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U> +
BatchBackward<U> + UpdateWeight<U> + BatchLoss<U,BatchLossInput=<PI as BatchDataType>::Type>,
U: Default + Clone + Copy + Send + UnitValue<U>,
PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
I: Debug + Send + Sync + BatchDataType,
<PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
<PI as ToHost<U>>::Output: Debug + 'static,
<I as BatchDataType>::Type: Debug + BatchSize,
<<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI,BatchIO=<PI as BatchDataType>::Type>,
f64: From<U>,
Self: UpdateWeight<U,GradientStack = <P as UpdateWeight<U>>::GradientStack>,
for<'a> L: LossFunction<U> + BatchLossFunctionLinear<'a,U,<PI as BatchDataType>::Type,D,N,Output=<PI as BatchDataType>::Type> {
fn batch_train(&mut self, expected:Self::BatchOutput, input:Self::BatchInput, lossf:&L) -> Result<U, TrainingError> {
if expected.len() != input.size() {
return Err(TrainingError::from(SizeMismatchError(expected.len(),input.size())));
}
let stack = self.batch_pre_train(input)?;
let total_loss = stack.map(|l| self.device.batch_loss_linear_total(&expected,l,lossf))?;
let (stack,loss) = if self.parent.is_canonical_link(lossf) {
let loss = stack.map(|actual| {
self.device.loss_linear_batch_by_canonical_link(&expected, &actual)
})?;
(stack,loss)
} else {
let loss = stack.map(|actual| {
self.device.batch_loss_linear(&expected,actual,lossf)
})?;
self.parent.batch_loss(loss,lossf,stack)?
};
let (_,s) = self.parent.batch_backward(loss,stack,lossf)?;
self.parent.update_weight(s)?;
Ok(total_loss)
}
}