use super::Config;
use crate::{Biased, LinearParams, ParamMode, Unbiased};
use concision::Forward;
use nd::prelude::*;
use nd::{Data, RemoveAxis};
use num::traits::{Float, FromPrimitive, One, Zero};
pub struct LayerNorm<A = f64, K = crate::Biased, D = Ix2>
where
D: Dimension,
{
config: Config<D>,
params: LinearParams<A, K, D>,
}
macro_rules! impl_norm_builder {
($method:ident$(.$call:ident)? where $($rest:tt)*) => {
impl_norm_builder!(@impl $method$(.$call)? where $($rest)*);
};
(@impl $method:ident where $($rest:tt)*) => {
impl_norm_builder!(@impl $method.$method where $($rest)*);
};
(@impl $method:ident.$call:ident where $($rest:tt)*) => {
pub fn $method<Sh>(shape: Sh) -> Self
where
Sh: ShapeBuilder<Dim = D>,
$($rest)*
{
Self::from_params(LinearParams::<A, K, D>::$call(shape))
}
};
}
impl<A, K, D> LayerNorm<A, K, D>
where
D: RemoveAxis,
K: ParamMode,
{
pub fn from_config(config: Config<D>) -> Self
where
A: Default,
{
let params = LinearParams::<A, K, D>::new(config.dim());
Self { config, params }
}
pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
where
A: Clone,
Sh: ShapeBuilder<Dim = D>,
{
let dim = shape.into_shape().raw_dim().clone();
let config = Config::new().dim(dim.clone()).build();
let params = LinearParams::<A, K, D>::from_elem(dim, elem);
Self { config, params }
}
pub fn from_params(params: LinearParams<A, K, D>) -> Self {
let config = Config::new().dim(params.raw_dim()).build();
Self { config, params }
}
impl_norm_builder!(new where A: Default);
impl_norm_builder!(ones where A: Clone + One);
impl_norm_builder!(zeros where A: Clone + Zero);
pub const fn config(&self) -> &Config<D> {
&self.config
}
pub fn is_biased(&self) -> bool {
self.params().is_biased()
}
pub const fn params(&self) -> &LinearParams<A, K, D> {
&self.params
}
pub fn params_mut(&mut self) -> &mut LinearParams<A, K, D> {
&mut self.params
}
pub fn dim(&self) -> D::Pattern {
self.config().dim()
}
pub fn eps(&self) -> f64 {
self.config().eps()
}
pub fn ndim(&self) -> usize {
self.config().ndim()
}
pub fn raw_dim(&self) -> D {
self.config().raw_dim()
}
pub fn shape(&self) -> &[usize] {
self.config().shape()
}
}
impl<A, D> Default for LayerNorm<A, Biased, D>
where
A: Default,
D: RemoveAxis,
{
fn default() -> Self {
Self {
config: Config::default(),
params: Default::default(),
}
}
}
impl<A, D> Default for LayerNorm<A, Unbiased, D>
where
A: Default,
D: RemoveAxis,
{
fn default() -> Self {
Self {
config: Config::default(),
params: Default::default(),
}
}
}
impl<A, S, D> Forward<ArrayBase<S, D>> for LayerNorm<A, Biased, D>
where
A: Float + FromPrimitive,
D: RemoveAxis,
S: Data<Elem = A>,
{
type Output = Array<A, D>;
fn forward(&self, x: &ArrayBase<S, D>) -> Self::Output {
let norm = if let Some(axis) = self.config().axis() {
super::layer_norm_axis(x, *axis, self.eps())
} else {
super::layer_norm(x, self.eps())
};
norm * self.params().weights() + self.params().bias()
}
}
impl<A, S, D> Forward<ArrayBase<S, D>> for LayerNorm<A, Unbiased, D>
where
A: Float + FromPrimitive,
D: RemoveAxis,
S: Data<Elem = A>,
{
type Output = Array<A, D>;
fn forward(&self, x: &ArrayBase<S, D>) -> Self::Output {
let norm = if let Some(axis) = self.config().axis() {
super::layer_norm_axis(x, *axis, self.eps())
} else {
super::layer_norm(x, self.eps())
};
norm * self.params().weights()
}
}