Struct dfdx::nn::modules::BatchNorm2D
source · pub struct BatchNorm2D<const C: usize, E: Dtype, D: Storage<E>> {
pub scale: Tensor<Rank1<C>, E, D>,
pub bias: Tensor<Rank1<C>, E, D>,
pub running_mean: Tensor<Rank1<C>, E, D>,
pub running_var: Tensor<Rank1<C>, E, D>,
pub epsilon: f64,
pub momentum: f64,
}
Expand description
Batch normalization for images as described in Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Generics:
C
the size of the spatial dimension to reduce. For 3d tensors this is the 0th dimension. For 4d tensors, this is the 1st dimension.
Training vs Inference
BatchNorm2D supports the following cases (see sections below for more details):
- Training: ModuleMut and OwnedTape on the input tensor
- Inference: Module and NoneTape on the input tensor.
NOTE: ModuleMut/NoneTape, and Module/OwnedTape will fail to compile.
Examples:
type Model = BatchNorm2D<3>;
let bn = dev.build_module::<Model, f32>();
let _ = bn.forward(dev.zeros::<Rank3<3, 2, 2>>());
let _ = bn.forward(dev.zeros::<Rank4<4, 3, 2, 2>>());
Training
- Running statistics: updated with momentum
- Normalization: calculated using batch stats
Inference
- Running statistics: not updated
- Normalization: calculated using running stats
Fields§
§scale: Tensor<Rank1<C>, E, D>
Scale for affine transform. Defaults to 1.0
bias: Tensor<Rank1<C>, E, D>
Bias for affine transform. Defaults to 0.0
running_mean: Tensor<Rank1<C>, E, D>
Spatial mean that is updated during training. Defaults to 0.0
running_var: Tensor<Rank1<C>, E, D>
Spatial variance that is updated during training. Defaults to 1.0
epsilon: f64
Added to variance before taking sqrt for numerical stability. Defaults to 1e-5
momentum: f64
Controls exponential moving average of running stats.Defaults to 0.1
running_stat * (1.0 - momentum) + stat * momentum
.
Trait Implementations§
source§impl<const C: usize, E: Clone + Dtype, D: Clone + Storage<E>> Clone for BatchNorm2D<C, E, D>
impl<const C: usize, E: Clone + Dtype, D: Clone + Storage<E>> Clone for BatchNorm2D<C, E, D>
source§fn clone(&self) -> BatchNorm2D<C, E, D>
fn clone(&self) -> BatchNorm2D<C, E, D>
1.0.0 · source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source
. Read moresource§impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> Module<Tensor<(B, Const<C>, H, W), E, D, NoneTape>> for BatchNorm2D<C, E, D>
impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> Module<Tensor<(B, Const<C>, H, W), E, D, NoneTape>> for BatchNorm2D<C, E, D>
source§fn try_forward(
&self,
x: Tensor<(B, Const<C>, H, W), E, D, NoneTape>
) -> Result<Self::Output, D::Err>
fn try_forward( &self, x: Tensor<(B, Const<C>, H, W), E, D, NoneTape> ) -> Result<Self::Output, D::Err>
Inference 4d forward - does not update Self::running_mean and Self::running_var
§type Output = Tensor<(B, Const<C>, H, W), E, D, NoneTape>
type Output = Tensor<(B, Const<C>, H, W), E, D, NoneTape>
Input
.type Error = <D as HasErr>::Err
source§impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> Module<Tensor<(Const<C>, H, W), E, D, NoneTape>> for BatchNorm2D<C, E, D>
impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> Module<Tensor<(Const<C>, H, W), E, D, NoneTape>> for BatchNorm2D<C, E, D>
source§fn try_forward(
&self,
x: Tensor<(Const<C>, H, W), E, D, NoneTape>
) -> Result<Self::Output, D::Err>
fn try_forward( &self, x: Tensor<(Const<C>, H, W), E, D, NoneTape> ) -> Result<Self::Output, D::Err>
Inference 3d forward - does not update Self::running_mean and Self::running_var
§type Output = Tensor<(Const<C>, H, W), E, D, NoneTape>
type Output = Tensor<(Const<C>, H, W), E, D, NoneTape>
Input
.type Error = <D as HasErr>::Err
source§impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> ModuleMut<Tensor<(B, Const<C>, H, W), E, D, OwnedTape<E, D>>> for BatchNorm2D<C, E, D>
impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> ModuleMut<Tensor<(B, Const<C>, H, W), E, D, OwnedTape<E, D>>> for BatchNorm2D<C, E, D>
source§fn try_forward_mut(
&mut self,
x: Tensor<(B, Const<C>, H, W), E, D, OwnedTape<E, D>>
) -> Result<Self::Output, D::Err>
fn try_forward_mut( &mut self, x: Tensor<(B, Const<C>, H, W), E, D, OwnedTape<E, D>> ) -> Result<Self::Output, D::Err>
Training 4d forward - updates Self::running_mean and Self::running_var
§type Output = Tensor<(B, Const<C>, H, W), E, D, OwnedTape<E, D>>
type Output = Tensor<(B, Const<C>, H, W), E, D, OwnedTape<E, D>>
Input
.type Error = <D as HasErr>::Err
source§fn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
source§impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> ModuleMut<Tensor<(Const<C>, H, W), E, D, OwnedTape<E, D>>> for BatchNorm2D<C, E, D>
impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> ModuleMut<Tensor<(Const<C>, H, W), E, D, OwnedTape<E, D>>> for BatchNorm2D<C, E, D>
source§fn try_forward_mut(
&mut self,
x: Tensor<(Const<C>, H, W), E, D, OwnedTape<E, D>>
) -> Result<Self::Output, D::Err>
fn try_forward_mut( &mut self, x: Tensor<(Const<C>, H, W), E, D, OwnedTape<E, D>> ) -> Result<Self::Output, D::Err>
Training 3d forward - updates Self::running_mean and Self::running_var
§type Output = Tensor<(Const<C>, H, W), E, D, OwnedTape<E, D>>
type Output = Tensor<(Const<C>, H, W), E, D, OwnedTape<E, D>>
Input
.type Error = <D as HasErr>::Err
source§fn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
source§impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for BatchNorm2D<C, E, D>
impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for BatchNorm2D<C, E, D>
§type To<E2: Dtype, D2: Device<E2>> = BatchNorm2D<C, E2, D2>
type To<E2: Dtype, D2: Device<E2>> = BatchNorm2D<C, E2, D2>
source§fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>
fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>
Err(_)
to indicate an error,
Ok(None)
to indicate that there is no error and a module has not been built, and
Ok(Some(_))
contains Self::Output<E2, D2>
source§fn module<F1, F2, Field>(
name: &str,
get_ref: F1,
get_mut: F2
) -> ModuleField<'_, F1, F2, Self, Field>where
F1: FnMut(&Self) -> &Field,
F2: FnMut(&mut Self) -> &mut Field,
Field: TensorCollection<E, D>,
fn module<F1, F2, Field>( name: &str, get_ref: F1, get_mut: F2 ) -> ModuleField<'_, F1, F2, Self, Field>where F1: FnMut(&Self) -> &Field, F2: FnMut(&mut Self) -> &mut Field, Field: TensorCollection<E, D>,
source§fn tensor<F1, F2, S>(
name: &str,
get_ref: F1,
get_mut: F2,
options: TensorOptions<S, E, D>
) -> TensorField<'_, F1, F2, Self, S, E, D>where
F1: FnMut(&Self) -> &Tensor<S, E, D>,
F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>,
S: Shape,
fn tensor<F1, F2, S>( name: &str, get_ref: F1, get_mut: F2, options: TensorOptions<S, E, D> ) -> TensorField<'_, F1, F2, Self, S, E, D>where F1: FnMut(&Self) -> &Tensor<S, E, D>, F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>, S: Shape,
source§fn scalar<F1, F2, N>(
name: &str,
get_ref: F1,
get_mut: F2,
options: ScalarOptions<N>
) -> ScalarField<'_, F1, F2, Self, N>where
F1: FnMut(&Self) -> &N,
F2: FnMut(&mut Self) -> &mut N,
N: NumCast,
fn scalar<F1, F2, N>( name: &str, get_ref: F1, get_mut: F2, options: ScalarOptions<N> ) -> ScalarField<'_, F1, F2, Self, N>where F1: FnMut(&Self) -> &N, F2: FnMut(&mut Self) -> &mut N, N: NumCast,
Auto Trait Implementations§
impl<const C: usize, E, D> RefUnwindSafe for BatchNorm2D<C, E, D>where D: RefUnwindSafe, <D as Storage<E>>::Vec: RefUnwindSafe,
impl<const C: usize, E, D> Send for BatchNorm2D<C, E, D>where D: Send,
impl<const C: usize, E, D> Sync for BatchNorm2D<C, E, D>where D: Sync,
impl<const C: usize, E, D> Unpin for BatchNorm2D<C, E, D>where D: Unpin,
impl<const C: usize, E, D> UnwindSafe for BatchNorm2D<C, E, D>where D: UnwindSafe, <D as Storage<E>>::Vec: RefUnwindSafe,
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
source§impl<D, E, M> BuildModule<D, E> for Mwhere
D: Device<E>,
E: Dtype,
M: TensorCollection<E, D, To<E, D> = M>,
impl<D, E, M> BuildModule<D, E> for Mwhere D: Device<E>, E: Dtype, M: TensorCollection<E, D, To<E, D> = M>,
source§impl<E, D, T> LoadFromNpz<E, D> for Twhere
E: Dtype + NumpyDtype,
D: Device<E>,
T: TensorCollection<E, D>,
impl<E, D, T> LoadFromNpz<E, D> for Twhere E: Dtype + NumpyDtype, D: Device<E>, T: TensorCollection<E, D>,
source§impl<E, D, T> LoadFromSafetensors<E, D> for Twhere
E: Dtype + SafeDtype,
D: Device<E>,
T: TensorCollection<E, D>,
impl<E, D, T> LoadFromSafetensors<E, D> for Twhere E: Dtype + SafeDtype, D: Device<E>, T: TensorCollection<E, D>,
source§impl<E, D, M> NumParams<E, D> for Mwhere
E: Dtype,
D: Device<E>,
M: TensorCollection<E, D>,
impl<E, D, M> NumParams<E, D> for Mwhere E: Dtype, D: Device<E>, M: TensorCollection<E, D>,
source§fn num_trainable_params(&self) -> usize
fn num_trainable_params(&self) -> usize
§impl<T> Pointable for T
impl<T> Pointable for T
source§impl<E, D, M> ResetParams<E, D> for Mwhere
E: Dtype,
D: Device<E>,
M: TensorCollection<E, D>,
impl<E, D, M> ResetParams<E, D> for Mwhere E: Dtype, D: Device<E>, M: TensorCollection<E, D>,
source§fn reset_params(&mut self)
fn reset_params(&mut self)
source§impl<E, D, T> SaveToNpz<E, D> for Twhere
E: Dtype + NumpyDtype,
D: Device<E>,
T: TensorCollection<E, D>,
impl<E, D, T> SaveToNpz<E, D> for Twhere E: Dtype + NumpyDtype, D: Device<E>, T: TensorCollection<E, D>,
source§impl<E, D, T> SaveToSafetensors<E, D> for Twhere
E: Dtype + SafeDtype,
D: Device<E>,
T: TensorCollection<E, D>,
impl<E, D, T> SaveToSafetensors<E, D> for Twhere E: Dtype + SafeDtype, D: Device<E>, T: TensorCollection<E, D>,
source§fn save_safetensors<P: AsRef<Path>>(
&self,
path: P
) -> Result<(), SafeTensorError>
fn save_safetensors<P: AsRef<Path>>( &self, path: P ) -> Result<(), SafeTensorError>
source§impl<E, D1, D2, T> ToDevice<E, D1, D2> for Twhere
E: Dtype,
D1: Device<E>,
D2: Device<E>,
T: TensorCollection<E, D1>,
impl<E, D1, D2, T> ToDevice<E, D1, D2> for Twhere E: Dtype, D1: Device<E>, D2: Device<E>, T: TensorCollection<E, D1>,
source§impl<E1, D, T> ToDtype<E1, D> for Twhere
E1: Dtype,
D: Device<E1>,
T: TensorCollection<E1, D>,
impl<E1, D, T> ToDtype<E1, D> for Twhere E1: Dtype, D: Device<E1>, T: TensorCollection<E1, D>,
source§impl<E, D, M> ZeroGrads<E, D> for Mwhere
E: Dtype,
D: Device<E>,
M: TensorCollection<E, D>,
impl<E, D, M> ZeroGrads<E, D> for Mwhere E: Dtype, D: Device<E>, M: TensorCollection<E, D>,
source§fn alloc_grads(&self) -> Gradients<E, D>
fn alloc_grads(&self) -> Gradients<E, D>
source§fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err>
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err>
source§fn zero_grads(&self, gradients: &mut Gradients<E, D>)
fn zero_grads(&self, gradients: &mut Gradients<E, D>)
self
.