mininn/layers/types/
batchnorm.rs

1use ndarray::Array1;
2use serde::{Deserialize, Serialize};
3
4// use crate::{layers::Layer, NNMode, NNResult};
5
6#[derive(Debug, Serialize, Deserialize)]
7pub(crate) struct BatchNorm {
8    input: Array1<f32>,
9    gamma: Array1<f32>,
10    beta: Array1<f32>,
11    epsilon: f32,
12    momentum: f32,
13    running_mean: Array1<f32>,
14    running_var: Array1<f32>,
15    mu: f32,
16    xmu: Array1<f32>,
17    carre: Array1<f32>,
18    var: f32,
19    sqrtvar: f32,
20    invvar: f32,
21    va2: Array1<f32>,
22    va3: Array1<f32>,
23    xbar: Array1<f32>,
24    layer_type: String,
25}
26
27impl BatchNorm {
28    #[inline]
29    pub fn _new(
30        epsilon: f32,
31        momentum: f32,
32        running_mean: Option<Array1<f32>>,
33        running_var: Option<Array1<f32>>,
34    ) -> Self {
35        Self {
36            input: Array1::zeros(0),
37            gamma: Array1::ones(0),
38            beta: Array1::zeros(0),
39            epsilon,
40            momentum,
41            running_mean: running_mean.unwrap_or(Array1::zeros(0)),
42            running_var: running_var.unwrap_or(Array1::zeros(0)),
43            mu: 0.,
44            xmu: Array1::zeros(0),
45            carre: Array1::zeros(0),
46            var: 0.,
47            sqrtvar: 0.,
48            invvar: 0.,
49            va2: Array1::zeros(0),
50            va3: Array1::zeros(0),
51            xbar: Array1::zeros(0),
52            layer_type: "BatchNorm".to_string(),
53        }
54    }
55
56    #[inline]
57    pub fn _gamma(&self) -> Array1<f32> {
58        self.gamma.to_owned()
59    }
60
61    #[inline]
62    pub fn _beta(&self) -> Array1<f32> {
63        self.beta.to_owned()
64    }
65
66    #[inline]
67    pub fn _epsilon(&self) -> f32 {
68        self.epsilon
69    }
70
71    #[inline]
72    pub fn _momentum(&self) -> f32 {
73        self.momentum
74    }
75
76    #[inline]
77    pub fn _running_mean(&self) -> Array1<f32> {
78        self.running_mean.to_owned()
79    }
80
81    #[inline]
82    pub fn _running_var(&self) -> Array1<f32> {
83        self.running_var.to_owned()
84    }
85}
86
87// impl Layer for BatchNorm {
88//     #[inline]
89//     fn layer_type(&self) -> String {
90//         self.layer_type.to_string()
91//     }
92
93//     #[inline]
94//     fn to_json(&self) -> NNResult<String> {
95//         Ok(serde_json::to_string(self)?)
96//     }
97
98//     #[inline]
99//     fn from_json(json_path: &str) -> NNResult<Box<dyn Layer>> {
100//         Ok(Box::new(serde_json::from_str::<Self>(json_path)?))
101//     }
102
103//     #[inline]
104//     fn as_any(&self) -> &dyn std::any::Any {
105//         self
106//     }
107
108//     fn forward(&mut self, input: &Array1<f32>, mode: &NNMode) -> NNResult<Array1<f32>> {
109//         self.input = input.to_owned();
110//         // let d = self.input.len(); // let (n, d) = input.dim();
111
112//         match mode {
113//             NNMode::Train => {
114//                 self.mu = 1. / 1. * self.input.sum(); // 1 / n * sum(input)
115//                 self.xmu = &self.input - self.mu;
116//                 self.carre = self.xmu.pow2();
117//                 self.var = 1. / 1. * self.carre.sum(); // 1 / n * sum(carre)
118//                 self.sqrtvar = (self.var + self.epsilon).sqrt();
119//                 self.invvar = 1. / self.sqrtvar;
120//                 self.va2 = &self.xmu * self.invvar;
121//                 self.va3 = &self.gamma * &self.va2;
122//                 let output = &self.va3 + &self.beta;
123//                 self.running_mean =
124//                     self.momentum * &self.running_mean + (1. - self.momentum) * self.mu;
125//                 self.running_var =
126//                     self.momentum * &self.running_var + (1. - self.momentum) * self.var;
127//                 Ok(output)
128//             }
129//             NNMode::Test => {
130//                 self.xbar =
131//                     (input - &self.running_mean) / (&self.running_var + self.epsilon).sqrt();
132//                 Ok(&self.gamma * &self.xbar + &self.beta)
133//             }
134//         }
135//     }
136
137//     fn backward(
138//         &mut self,
139//         output_gradient: &Array1<f32>,
140//         _learning_rate: f32,
141//         _optimizer: &crate::prelude::Optimizer,
142//         _mode: &NNMode,
143//     ) -> NNResult<Array1<f32>> {
144//         // Step 9
145//         let dva3 = output_gradient.to_owned();
146//         // let dbeta = output_gradient.sum();
147
148//         // Step 8
149//         let dva2 = &self.gamma * &dva3;
150//         // let dgamma = (&self.va2 * &dva3).sum();
151
152//         // Step 7
153//         let mut dxmu = self.invvar * &dva2;
154//         let dinvvar = (&self.xmu * &dva2).sum();
155
156//         // Step 6
157//         let dsqrtvar = -1. / (self.sqrtvar.powi(2)) * &dinvvar;
158
159//         // Step 5
160//         let dvar = 0.5 * (self.var + self.epsilon).powf(-0.5) * &dsqrtvar;
161
162//         // Step 4
163//         let dcarre = 1. / 1. * Array1::ones(self.carre.len()) * dvar; // 1 / n * sum(carre)
164
165//         // Step 3
166//         dxmu = dxmu + (2. * &self.xmu * dcarre);
167
168//         // Step 2
169//         let mut dx = dxmu.to_owned();
170//         let dmu = dxmu.sum();
171
172//         // Step 1
173//         dx = dx + (1. / 1. * Array1::ones(dxmu.len()) * dmu);
174
175//         Ok(dx)
176//     }
177// }