use crate::Tensor;
use std::borrow::Borrow;
#[derive(Debug, Clone, Copy)]
pub struct LayerNormConfig {
pub cudnn_enabled: bool,
pub eps: f64,
pub elementwise_affine: bool,
pub ws_init: super::Init,
pub bs_init: super::Init,
}
impl Default for LayerNormConfig {
fn default() -> Self {
LayerNormConfig {
cudnn_enabled: true,
eps: 1e-5,
elementwise_affine: true,
ws_init: super::Init::Const(1.),
bs_init: super::Init::Const(0.),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
config: LayerNormConfig,
pub ws: Option<Tensor>,
pub bs: Option<Tensor>,
pub normalized_shape: Vec<i64>,
}
pub fn layer_norm<'a, T: Borrow<super::Path<'a>>>(
vs: T,
normalized_shape: Vec<i64>,
config: LayerNormConfig,
) -> LayerNorm {
let vs = vs.borrow();
let (ws, bs) = if config.elementwise_affine {
let ws = vs.var("weight", normalized_shape.as_slice(), config.ws_init);
let bs = vs.var("bias", normalized_shape.as_slice(), config.bs_init);
(Some(ws), Some(bs))
} else {
(None, None)
};
LayerNorm { config, ws, bs, normalized_shape }
}
impl super::module::Module for LayerNorm {
fn forward(&self, xs: &Tensor) -> Tensor {
Tensor::layer_norm(
xs,
self.normalized_shape.as_slice(),
self.ws.as_ref(),
self.bs.as_ref(),
self.config.eps,
self.config.cudnn_enabled,
)
}
}