rai_nn/
layer_norm.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
use rai_core::{AsDevice, Tensor, Type};
use rai_derive::Module;

#[derive(Clone, Debug, Module)]
#[module(crate = rai_core)]
pub struct LayerNorm {
    weight: Option<Tensor>,
    bias: Option<Tensor>,
    #[param(skip)]
    eps: f64,
}

impl LayerNorm {
    pub fn new(
        dims: usize,
        eps: f64,
        affine: bool,
        dtype: impl Type,
        device: impl AsDevice,
    ) -> Self {
        let device = device.device();
        let (weight, bias) = if affine {
            let weight = Some(Tensor::ones([dims], dtype, device));
            let bias = Some(Tensor::zeros([dims], dtype, device));
            (weight, bias)
        } else {
            (None, None)
        };
        Self { weight, bias, eps }
    }

    pub fn fwd(&self, x: &Tensor) -> Tensor {
        let mean = x.mean((-1, true));
        let var = x.var((-1, true));
        let x = (x - mean) * (var + self.eps).rsqrt();
        if let Some(weight) = &self.weight {
            let bias = self.bias.as_ref().unwrap();
            weight * x + bias
        } else {
            x
        }
    }
}