1use ghostflow_core::Tensor;
4use crate::module::Module;
5
6pub struct BatchNorm1d {
8 #[allow(dead_code)]
9 num_features: usize,
10 gamma: Tensor, beta: Tensor, running_mean: Tensor,
13 running_var: Tensor,
14 eps: f32,
15 #[allow(dead_code)]
16 momentum: f32,
17 training: bool,
18}
19
20impl BatchNorm1d {
21 pub fn new(num_features: usize) -> Self {
22 Self::with_params(num_features, 1e-5, 0.1)
23 }
24
25 pub fn with_params(num_features: usize, eps: f32, momentum: f32) -> Self {
26 BatchNorm1d {
27 num_features,
28 gamma: Tensor::ones(&[num_features]),
29 beta: Tensor::zeros(&[num_features]),
30 running_mean: Tensor::zeros(&[num_features]),
31 running_var: Tensor::ones(&[num_features]),
32 eps,
33 momentum,
34 training: true,
35 }
36 }
37}
38
39impl Module for BatchNorm1d {
40 fn forward(&self, input: &Tensor) -> Tensor {
41 let dims = input.dims();
42 let data = input.data_f32();
43 let gamma = self.gamma.data_f32();
44 let beta = self.beta.data_f32();
45
46 let batch_size = dims[0];
47 let channels = dims[1];
48 let spatial_size = if dims.len() > 2 { dims[2] } else { 1 };
49
50 let (mean, var) = if self.training {
51 let mut mean = vec![0.0f32; channels];
53 let mut var = vec![0.0f32; channels];
54 let n = (batch_size * spatial_size) as f32;
55
56 for c in 0..channels {
57 let mut sum = 0.0f32;
58 for b in 0..batch_size {
59 for s in 0..spatial_size {
60 let idx = b * channels * spatial_size + c * spatial_size + s;
61 sum += data[idx];
62 }
63 }
64 mean[c] = sum / n;
65
66 let mut var_sum = 0.0f32;
67 for b in 0..batch_size {
68 for s in 0..spatial_size {
69 let idx = b * channels * spatial_size + c * spatial_size + s;
70 var_sum += (data[idx] - mean[c]).powi(2);
71 }
72 }
73 var[c] = var_sum / n;
74 }
75
76 (mean, var)
77 } else {
78 (self.running_mean.data_f32(), self.running_var.data_f32())
79 };
80
81 let mut output = vec![0.0f32; data.len()];
83
84 for b in 0..batch_size {
85 for c in 0..channels {
86 let std = (var[c] + self.eps).sqrt();
87 for s in 0..spatial_size {
88 let idx = b * channels * spatial_size + c * spatial_size + s;
89 output[idx] = gamma[c] * (data[idx] - mean[c]) / std + beta[c];
90 }
91 }
92 }
93
94 Tensor::from_slice(&output, dims).unwrap()
95 }
96
97 fn parameters(&self) -> Vec<Tensor> {
98 vec![self.gamma.clone(), self.beta.clone()]
99 }
100
101 fn train(&mut self) { self.training = true; }
102 fn eval(&mut self) { self.training = false; }
103 fn is_training(&self) -> bool { self.training }
104}
105
106pub struct BatchNorm2d {
108 #[allow(dead_code)]
109 num_features: usize,
110 gamma: Tensor,
111 beta: Tensor,
112 running_mean: Tensor,
113 running_var: Tensor,
114 eps: f32,
115 #[allow(dead_code)]
116 momentum: f32,
117 training: bool,
118}
119
120impl BatchNorm2d {
121 pub fn new(num_features: usize) -> Self {
122 Self::with_params(num_features, 1e-5, 0.1)
123 }
124
125 pub fn with_params(num_features: usize, eps: f32, momentum: f32) -> Self {
126 BatchNorm2d {
127 num_features,
128 gamma: Tensor::ones(&[num_features]),
129 beta: Tensor::zeros(&[num_features]),
130 running_mean: Tensor::zeros(&[num_features]),
131 running_var: Tensor::ones(&[num_features]),
132 eps,
133 momentum,
134 training: true,
135 }
136 }
137}
138
139impl Module for BatchNorm2d {
140 fn forward(&self, input: &Tensor) -> Tensor {
141 let dims = input.dims();
142 let data = input.data_f32();
143 let gamma = self.gamma.data_f32();
144 let beta = self.beta.data_f32();
145
146 let batch_size = dims[0];
147 let channels = dims[1];
148 let height = dims[2];
149 let width = dims[3];
150 let spatial_size = height * width;
151
152 let (mean, var) = if self.training {
153 let mut mean = vec![0.0f32; channels];
154 let mut var = vec![0.0f32; channels];
155 let n = (batch_size * spatial_size) as f32;
156
157 for c in 0..channels {
158 let mut sum = 0.0f32;
159 for b in 0..batch_size {
160 for h in 0..height {
161 for w in 0..width {
162 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
163 sum += data[idx];
164 }
165 }
166 }
167 mean[c] = sum / n;
168
169 let mut var_sum = 0.0f32;
170 for b in 0..batch_size {
171 for h in 0..height {
172 for w in 0..width {
173 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
174 var_sum += (data[idx] - mean[c]).powi(2);
175 }
176 }
177 }
178 var[c] = var_sum / n;
179 }
180
181 (mean, var)
182 } else {
183 (self.running_mean.data_f32(), self.running_var.data_f32())
184 };
185
186 let mut output = vec![0.0f32; data.len()];
187
188 for b in 0..batch_size {
189 for c in 0..channels {
190 let std = (var[c] + self.eps).sqrt();
191 for h in 0..height {
192 for w in 0..width {
193 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
194 output[idx] = gamma[c] * (data[idx] - mean[c]) / std + beta[c];
195 }
196 }
197 }
198 }
199
200 Tensor::from_slice(&output, dims).unwrap()
201 }
202
203 fn parameters(&self) -> Vec<Tensor> {
204 vec![self.gamma.clone(), self.beta.clone()]
205 }
206
207 fn train(&mut self) { self.training = true; }
208 fn eval(&mut self) { self.training = false; }
209 fn is_training(&self) -> bool { self.training }
210}
211
212pub struct LayerNorm {
214 normalized_shape: Vec<usize>,
215 gamma: Tensor,
216 beta: Tensor,
217 eps: f32,
218 training: bool,
219}
220
221impl LayerNorm {
222 pub fn new(normalized_shape: &[usize]) -> Self {
223 Self::with_eps(normalized_shape, 1e-5)
224 }
225
226 pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
227 let numel: usize = normalized_shape.iter().product();
228
229 LayerNorm {
230 normalized_shape: normalized_shape.to_vec(),
231 gamma: Tensor::ones(&[numel]),
232 beta: Tensor::zeros(&[numel]),
233 eps,
234 training: true,
235 }
236 }
237}
238
239impl Module for LayerNorm {
240 fn forward(&self, input: &Tensor) -> Tensor {
241 let dims = input.dims();
242 let data = input.data_f32();
243 let gamma = self.gamma.data_f32();
244 let beta = self.beta.data_f32();
245
246 let norm_size: usize = self.normalized_shape.iter().product();
247 let batch_size = data.len() / norm_size;
248
249 let mut output = vec![0.0f32; data.len()];
250
251 for b in 0..batch_size {
252 let start = b * norm_size;
253 let end = start + norm_size;
254 let slice = &data[start..end];
255
256 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
258
259 let var: f32 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
261 let std = (var + self.eps).sqrt();
262
263 for i in 0..norm_size {
265 output[start + i] = gamma[i] * (slice[i] - mean) / std + beta[i];
266 }
267 }
268
269 Tensor::from_slice(&output, dims).unwrap()
270 }
271
272 fn parameters(&self) -> Vec<Tensor> {
273 vec![self.gamma.clone(), self.beta.clone()]
274 }
275
276 fn train(&mut self) { self.training = true; }
277 fn eval(&mut self) { self.training = false; }
278 fn is_training(&self) -> bool { self.training }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_batchnorm2d() {
287 let bn = BatchNorm2d::new(16);
288 let input = Tensor::randn(&[2, 16, 8, 8]);
289 let output = bn.forward(&input);
290
291 assert_eq!(output.dims(), input.dims());
292 }
293
294 #[test]
295 fn test_layernorm() {
296 let ln = LayerNorm::new(&[64]);
297 let input = Tensor::randn(&[2, 10, 64]);
298 let output = ln.forward(&input);
299
300 assert_eq!(output.dims(), input.dims());
301 }
302}