oxidized_transformers/layers/
layer_norm.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use candle_core::ModuleT;
use candle_nn::{layer_norm, rms_norm, LayerNormConfig as CandleLayerNormConfig, VarBuilder};

use crate::error::BoxedError;
use crate::layers::build_module::BuildModule;

/// Layer norm configuration.
#[derive(Clone, Debug)]
pub struct LayerNormConfig {
    pub affine: bool,
    pub eps: f64,
    pub remove_mean: bool,
    pub size: usize,
}

impl LayerNormConfig {
    /// Whether to use an affine transformation.
    ///
    /// Default: `true`
    pub fn affine(mut self, affine: bool) -> Self {
        self.affine = affine;
        self
    }

    /// Epsilon value.
    ///
    /// Default: `1e-12`
    pub fn eps(mut self, eps: f64) -> Self {
        self.eps = eps;
        self
    }

    /// Whether to remove the mean.
    ///
    /// If the mean is not removed, this layer is equivalent to `RMSNorm`.
    ///
    /// Default: `true`
    pub fn remove_mean(mut self, remove_mean: bool) -> Self {
        self.remove_mean = remove_mean;
        self
    }

    /// Dimensionality of the layer.
    ///
    /// Default: `768`
    pub fn size(mut self, size: usize) -> Self {
        self.size = size;
        self
    }
}

impl Default for LayerNormConfig {
    fn default() -> Self {
        Self {
            affine: true,
            eps: 1e-12,
            remove_mean: true,
            size: 768,
        }
    }
}

impl BuildModule for LayerNormConfig {
    fn build(&self, vb: VarBuilder) -> Result<Box<dyn ModuleT>, BoxedError> {
        Ok(Box::new(layer_norm(
            self.size,
            CandleLayerNormConfig {
                affine: self.affine,
                eps: self.eps,
                remove_mean: self.remove_mean,
            },
            vb,
        )?))
    }
}

/// RMS norm configuration.
#[derive(Clone, Debug)]
pub struct RMSNormConfig {
    pub eps: f64,
    pub size: usize,
}

impl RMSNormConfig {
    /// Epsilon value.
    ///
    /// Default: `1e-12`
    pub fn eps(mut self, eps: f64) -> Self {
        self.eps = eps;
        self
    }

    /// Dimensionality of the layer.
    ///
    /// Default: `768`
    pub fn size(mut self, size: usize) -> Self {
        self.size = size;
        self
    }
}

impl Default for RMSNormConfig {
    fn default() -> Self {
        Self {
            eps: 1e-12,
            size: 768,
        }
    }
}

impl BuildModule for RMSNormConfig {
    fn build(&self, vb: VarBuilder) -> Result<Box<dyn ModuleT>, BoxedError> {
        Ok(Box::new(rms_norm(self.size, self.eps, vb)?))
    }
}