use crate::activation::{Activation, BatchActivation};
use crate::arr::{Arr, ArrView, IntoConverter, SerializedVec, SerializedVecView};
use crate::cuda::{CudaTensor1dPtr, CudaTensor1dPtrView, CudaVec, CudaVecView};
use crate::device::{Device, DeviceCpu, DeviceGpu};
use crate::error::{EvaluateError, TrainingError, TypeConvertError};
use crate::layer::BatchDataType;
use crate::lossfunction::LossFunction;
use crate::ope::UnitValue;
pub trait DeviceActivation<U,I,A,const N:usize>: Device<U>
where U: UnitValue<U>,
I: BatchDataType {
fn apply(&self, f:&A, input:&I) -> Result<I, EvaluateError>;
fn derive(&self, f:&A, o:&I, loss:&I, u:&I) -> Result<I, TrainingError>;
fn batch_apply(&self, f:&A, input:&<I as BatchDataType>::Type) -> Result<<I as BatchDataType>::Type, TrainingError>;
fn batch_derive(&self, f:&A, o:&<I as BatchDataType>::Type, loss:&<I as BatchDataType>::Type, u:&<I as BatchDataType>::Type)
-> Result<<I as BatchDataType>::Type, TrainingError>;
fn is_canonical_link<L: LossFunction<U>>(&self,f:&A,l:&L) -> bool;
}
impl<'a,U,I,A,const N:usize> DeviceActivation<U,I,A,N> for DeviceCpu<U>
where U: UnitValue<U>,
I: BatchDataType,
I: From<Arr<U,N>>,
SerializedVec<U,Arr<U,N>>: IntoConverter,
<I as BatchDataType>::Type: TryFrom<<SerializedVec<U,Arr<U,N>> as IntoConverter>::Converter,Error=TypeConvertError>,
for<'b> A: Activation<U,ArrView<'b,U,N>,Arr<U,N>,Self>,
for<'b> A: BatchActivation<U,SerializedVecView<'b,U,Arr<U,N>>,SerializedVec<U,Arr<U,N>>,Self>,
for<'b> ArrView<'b,U,N>: From<&'b I>,
for<'b> SerializedVecView<'b,U,Arr<U,N>>: TryFrom<&'b <I as BatchDataType>::Type,Error=TypeConvertError> {
#[inline]
fn apply(&self, f: &A, input: &I) -> Result<I, EvaluateError> {
Ok(f.apply(self, &input.into())?.into())
}
#[inline]
fn derive(&self, f: &A, o: &I, loss: &I, u: &I) -> Result<I, TrainingError> {
Ok(f.derive(self, &o.into(), &loss.into(), &u.into())?.into())
}
#[inline]
fn batch_apply(&self, f: &A, input: &<I as BatchDataType>::Type) -> Result<<I as BatchDataType>::Type, TrainingError> {
Ok(f.batch_apply(self, &input.try_into()?)?.into_converter().try_into()?)
}
#[inline]
fn batch_derive(&self, f: &A, o: &<I as BatchDataType>::Type, loss: &<I as BatchDataType>::Type, u: &<I as BatchDataType>::Type)
-> Result<<I as BatchDataType>::Type, TrainingError> {
Ok(f.batch_derive(self, &o.try_into()?, &loss.try_into()?, &u.try_into()?).unwrap().into_converter().try_into()?)
}
fn is_canonical_link<L: LossFunction<U>>(&self, f: &A, l: &L) -> bool {
f.is_canonical_link(l)
}
}
impl<'a,U,I,A,const N:usize> DeviceActivation<U,I,A,N> for DeviceGpu<U>
where U: UnitValue<U>,
I: BatchDataType,
I: From<CudaTensor1dPtr<U,N>>,
DeviceGpu<U>: Device<U>,
CudaTensor1dPtr<U,N>: From<I>,
CudaVec<U,CudaTensor1dPtr<U,N>>: IntoConverter,
<I as BatchDataType>::Type: TryFrom<<CudaVec<U,CudaTensor1dPtr<U,N>> as IntoConverter>::Converter,Error=TypeConvertError>,
for<'b> A: Activation<U,CudaTensor1dPtrView<'b,U,N>,CudaTensor1dPtr<U,N>,Self>,
for<'b> A: BatchActivation<U,CudaVecView<'b,U,CudaTensor1dPtr<U,N>>,CudaVec<U,CudaTensor1dPtr<U,N>>,Self>,
for<'b> CudaTensor1dPtrView<'b,U,N>: From<&'b I>,
for<'b> CudaVecView<'b,U,CudaTensor1dPtr<U,N>>: TryFrom<&'b <I as BatchDataType>::Type,Error=TypeConvertError> {
#[inline]
fn apply(&self, f: &A, input: &I) -> Result<I, EvaluateError> {
Ok(f.apply(self, &input.into())?.into())
}
#[inline]
fn derive(&self, f: &A, o: &I, loss: &I, u: &I) -> Result<I, TrainingError> {
Ok(f.derive(self, &o.into(), &loss.into(), &u.into())?.into())
}
#[inline]
fn batch_apply(&self, f: &A, input: &<I as BatchDataType>::Type) -> Result<<I as BatchDataType>::Type, TrainingError> {
Ok(f.batch_apply(self, &input.try_into()?)?.into_converter().try_into()?)
}
#[inline]
fn batch_derive(&self, f: &A, o: &<I as BatchDataType>::Type, loss: &<I as BatchDataType>::Type, u: &<I as BatchDataType>::Type)
-> Result<<I as BatchDataType>::Type, TrainingError> {
Ok(f.batch_derive(self, &o.try_into()?, &loss.try_into()?, &u.try_into()?).unwrap().into_converter().try_into()?)
}
fn is_canonical_link<L: LossFunction<U>>(&self, f: &A, l: &L) -> bool {
f.is_canonical_link(l)
}
}