mininn/layers/types/batchnorm.rs
1use ndarray::Array1;
2use serde::{Deserialize, Serialize};
3
4// use crate::{layers::Layer, NNMode, NNResult};
5
6#[derive(Debug, Serialize, Deserialize)]
7pub(crate) struct BatchNorm {
8 input: Array1<f32>,
9 gamma: Array1<f32>,
10 beta: Array1<f32>,
11 epsilon: f32,
12 momentum: f32,
13 running_mean: Array1<f32>,
14 running_var: Array1<f32>,
15 mu: f32,
16 xmu: Array1<f32>,
17 carre: Array1<f32>,
18 var: f32,
19 sqrtvar: f32,
20 invvar: f32,
21 va2: Array1<f32>,
22 va3: Array1<f32>,
23 xbar: Array1<f32>,
24 layer_type: String,
25}
26
27impl BatchNorm {
28 #[inline]
29 pub fn _new(
30 epsilon: f32,
31 momentum: f32,
32 running_mean: Option<Array1<f32>>,
33 running_var: Option<Array1<f32>>,
34 ) -> Self {
35 Self {
36 input: Array1::zeros(0),
37 gamma: Array1::ones(0),
38 beta: Array1::zeros(0),
39 epsilon,
40 momentum,
41 running_mean: running_mean.unwrap_or(Array1::zeros(0)),
42 running_var: running_var.unwrap_or(Array1::zeros(0)),
43 mu: 0.,
44 xmu: Array1::zeros(0),
45 carre: Array1::zeros(0),
46 var: 0.,
47 sqrtvar: 0.,
48 invvar: 0.,
49 va2: Array1::zeros(0),
50 va3: Array1::zeros(0),
51 xbar: Array1::zeros(0),
52 layer_type: "BatchNorm".to_string(),
53 }
54 }
55
56 #[inline]
57 pub fn _gamma(&self) -> Array1<f32> {
58 self.gamma.to_owned()
59 }
60
61 #[inline]
62 pub fn _beta(&self) -> Array1<f32> {
63 self.beta.to_owned()
64 }
65
66 #[inline]
67 pub fn _epsilon(&self) -> f32 {
68 self.epsilon
69 }
70
71 #[inline]
72 pub fn _momentum(&self) -> f32 {
73 self.momentum
74 }
75
76 #[inline]
77 pub fn _running_mean(&self) -> Array1<f32> {
78 self.running_mean.to_owned()
79 }
80
81 #[inline]
82 pub fn _running_var(&self) -> Array1<f32> {
83 self.running_var.to_owned()
84 }
85}
86
87// impl Layer for BatchNorm {
88// #[inline]
89// fn layer_type(&self) -> String {
90// self.layer_type.to_string()
91// }
92
93// #[inline]
94// fn to_json(&self) -> NNResult<String> {
95// Ok(serde_json::to_string(self)?)
96// }
97
98// #[inline]
99// fn from_json(json_path: &str) -> NNResult<Box<dyn Layer>> {
100// Ok(Box::new(serde_json::from_str::<Self>(json_path)?))
101// }
102
103// #[inline]
104// fn as_any(&self) -> &dyn std::any::Any {
105// self
106// }
107
108// fn forward(&mut self, input: &Array1<f32>, mode: &NNMode) -> NNResult<Array1<f32>> {
109// self.input = input.to_owned();
110// // let d = self.input.len(); // let (n, d) = input.dim();
111
112// match mode {
113// NNMode::Train => {
114// self.mu = 1. / 1. * self.input.sum(); // 1 / n * sum(input)
115// self.xmu = &self.input - self.mu;
116// self.carre = self.xmu.pow2();
117// self.var = 1. / 1. * self.carre.sum(); // 1 / n * sum(carre)
118// self.sqrtvar = (self.var + self.epsilon).sqrt();
119// self.invvar = 1. / self.sqrtvar;
120// self.va2 = &self.xmu * self.invvar;
121// self.va3 = &self.gamma * &self.va2;
122// let output = &self.va3 + &self.beta;
123// self.running_mean =
124// self.momentum * &self.running_mean + (1. - self.momentum) * self.mu;
125// self.running_var =
126// self.momentum * &self.running_var + (1. - self.momentum) * self.var;
127// Ok(output)
128// }
129// NNMode::Test => {
130// self.xbar =
131// (input - &self.running_mean) / (&self.running_var + self.epsilon).sqrt();
132// Ok(&self.gamma * &self.xbar + &self.beta)
133// }
134// }
135// }
136
137// fn backward(
138// &mut self,
139// output_gradient: &Array1<f32>,
140// _learning_rate: f32,
141// _optimizer: &crate::prelude::Optimizer,
142// _mode: &NNMode,
143// ) -> NNResult<Array1<f32>> {
144// // Step 9
145// let dva3 = output_gradient.to_owned();
146// // let dbeta = output_gradient.sum();
147
148// // Step 8
149// let dva2 = &self.gamma * &dva3;
150// // let dgamma = (&self.va2 * &dva3).sum();
151
152// // Step 7
153// let mut dxmu = self.invvar * &dva2;
154// let dinvvar = (&self.xmu * &dva2).sum();
155
156// // Step 6
157// let dsqrtvar = -1. / (self.sqrtvar.powi(2)) * &dinvvar;
158
159// // Step 5
160// let dvar = 0.5 * (self.var + self.epsilon).powf(-0.5) * &dsqrtvar;
161
162// // Step 4
163// let dcarre = 1. / 1. * Array1::ones(self.carre.len()) * dvar; // 1 / n * sum(carre)
164
165// // Step 3
166// dxmu = dxmu + (2. * &self.xmu * dcarre);
167
168// // Step 2
169// let mut dx = dxmu.to_owned();
170// let dmu = dxmu.sum();
171
172// // Step 1
173// dx = dx + (1. / 1. * Array1::ones(dxmu.len()) * dmu);
174
175// Ok(dx)
176// }
177// }