pub trait DeviceBatchNorm<U, C, I, const N: usize>where
U: UnitValue<U>,
I: BatchDataType + Debug + 'static,
<I as BatchDataType>::Type: Debug + 'static,{
// Required methods
fn forward_batch_norm<'a>(
&self,
input: &'a I,
scale: &C,
bias: &C,
estimated_mean: &C,
estimated_variance: &C,
) -> Result<I, EvaluateError>;
fn forward_batch_norm_train<'a>(
&self,
input: &'a I,
scale: &C,
bias: &C,
estimated_mean: &C,
estimated_variance: &C,
) -> Result<(I, C, C), EvaluateError>;
fn batch_forward_batch_norm<'a>(
&self,
input: &'a <I as BatchDataType>::Type,
scale: &C,
bias: &C,
estimated_mean: &C,
estimated_variance: &C,
) -> Result<<I as BatchDataType>::Type, EvaluateError>;
fn batch_forward_batch_norm_train<'a>(
&self,
input: &'a <I as BatchDataType>::Type,
scale: &C,
bias: &C,
running_mean: &C,
running_variance: &C,
momentum: U,
) -> Result<(<I as BatchDataType>::Type, C, C, C, C), TrainingError>;
fn backward_batch_norm<'a>(
&self,
loss: &'a I,
input: &'a I,
scale: &C,
saved_mean: &C,
saved_inv_variance: &C,
) -> Result<(I, C, C), TrainingError>;
fn batch_backward_batch_norm<'a>(
&self,
loss: &'a <I as BatchDataType>::Type,
input: &'a <I as BatchDataType>::Type,
scale: &C,
saved_mean: &C,
saved_inv_variance: &C,
) -> Result<(<I as BatchDataType>::Type, C, C), TrainingError>;
}Expand description
Features defining the implementation of the various computational processes in the batch normalization layer
Required Methods§
Sourcefn forward_batch_norm<'a>(
&self,
input: &'a I,
scale: &C,
bias: &C,
estimated_mean: &C,
estimated_variance: &C,
) -> Result<I, EvaluateError>
fn forward_batch_norm<'a>( &self, input: &'a I, scale: &C, bias: &C, estimated_mean: &C, estimated_variance: &C, ) -> Result<I, EvaluateError>
Sourcefn forward_batch_norm_train<'a>(
&self,
input: &'a I,
scale: &C,
bias: &C,
estimated_mean: &C,
estimated_variance: &C,
) -> Result<(I, C, C), EvaluateError>
fn forward_batch_norm_train<'a>( &self, input: &'a I, scale: &C, bias: &C, estimated_mean: &C, estimated_variance: &C, ) -> Result<(I, C, C), EvaluateError>
Sourcefn batch_forward_batch_norm<'a>(
&self,
input: &'a <I as BatchDataType>::Type,
scale: &C,
bias: &C,
estimated_mean: &C,
estimated_variance: &C,
) -> Result<<I as BatchDataType>::Type, EvaluateError>
fn batch_forward_batch_norm<'a>( &self, input: &'a <I as BatchDataType>::Type, scale: &C, bias: &C, estimated_mean: &C, estimated_variance: &C, ) -> Result<<I as BatchDataType>::Type, EvaluateError>
Sourcefn batch_forward_batch_norm_train<'a>(
&self,
input: &'a <I as BatchDataType>::Type,
scale: &C,
bias: &C,
running_mean: &C,
running_variance: &C,
momentum: U,
) -> Result<(<I as BatchDataType>::Type, C, C, C, C), TrainingError>
fn batch_forward_batch_norm_train<'a>( &self, input: &'a <I as BatchDataType>::Type, scale: &C, bias: &C, running_mean: &C, running_variance: &C, momentum: U, ) -> Result<(<I as BatchDataType>::Type, C, C, C, C), TrainingError>
Forward propagation calculation in batch (implemented in training mode)
§Arguments
input- inputscale- γbias- βrunning_mean- μΒrunning_variance- σΒ
running_mean = running_mean * momentum + (1 - momentum) * μΒ running_variance = running_variance * momentum + (1 - momentum) * μΒ
output = (γ * ((input - μΒ) / sqrt(σ^2Β + 1e-6)) + β,,μΒ,1 / (σΒ + 1e-6),running_mean,running_variance)
§Errors
This function may return the following errors