1use ghostflow_core::Tensor;
4use crate::module::Module;
5
6pub struct BatchNorm1d {
8 #[allow(dead_code)]
9 num_features: usize,
10 gamma: Tensor, beta: Tensor, running_mean: Tensor,
13 running_var: Tensor,
14 eps: f32,
15 #[allow(dead_code)]
16 momentum: f32,
17 training: bool,
18}
19
20impl BatchNorm1d {
21 pub fn new(num_features: usize) -> Self {
22 Self::with_params(num_features, 1e-5, 0.1)
23 }
24
25 pub fn with_params(num_features: usize, eps: f32, momentum: f32) -> Self {
26 BatchNorm1d {
27 num_features,
28 gamma: Tensor::ones(&[num_features]),
29 beta: Tensor::zeros(&[num_features]),
30 running_mean: Tensor::zeros(&[num_features]),
31 running_var: Tensor::ones(&[num_features]),
32 eps,
33 momentum,
34 training: true,
35 }
36 }
37}
38
39impl Module for BatchNorm1d {
40 fn forward(&self, input: &Tensor) -> Tensor {
41 let dims = input.dims();
42 let data = input.data_f32();
43 let gamma = self.gamma.data_f32();
44 let beta = self.beta.data_f32();
45
46 let batch_size = dims[0];
47 let channels = dims[1];
48 let spatial_size = if dims.len() > 2 { dims[2] } else { 1 };
49
50 let (mean, var) = if self.training {
51 let mut mean = vec![0.0f32; channels];
53 let mut var = vec![0.0f32; channels];
54 let n = (batch_size * spatial_size) as f32;
55
56 for c in 0..channels {
57 let mut sum = 0.0f32;
58 for b in 0..batch_size {
59 for s in 0..spatial_size {
60 let idx = b * channels * spatial_size + c * spatial_size + s;
61 sum += data[idx];
62 }
63 }
64 mean[c] = sum / n;
65
66 let mut var_sum = 0.0f32;
67 for b in 0..batch_size {
68 for s in 0..spatial_size {
69 let idx = b * channels * spatial_size + c * spatial_size + s;
70 var_sum += (data[idx] - mean[c]).powi(2);
71 }
72 }
73 var[c] = var_sum / n;
74 }
75
76 (mean, var)
77 } else {
78 (self.running_mean.data_f32(), self.running_var.data_f32())
79 };
80
81 let mut output = vec![0.0f32; data.len()];
83
84 for b in 0..batch_size {
85 for c in 0..channels {
86 let std = (var[c] + self.eps).sqrt();
87 for s in 0..spatial_size {
88 let idx = b * channels * spatial_size + c * spatial_size + s;
89 output[idx] = gamma[c] * (data[idx] - mean[c]) / std + beta[c];
90 }
91 }
92 }
93
94 Tensor::from_slice(&output, dims).unwrap()
95 }
96
97 fn parameters(&self) -> Vec<Tensor> {
98 vec![self.gamma.clone(), self.beta.clone()]
99 }
100
101 fn train(&mut self) { self.training = true; }
102 fn eval(&mut self) { self.training = false; }
103 fn is_training(&self) -> bool { self.training }
104}
105
106pub struct BatchNorm2d {
108 #[allow(dead_code)]
109 num_features: usize,
110 gamma: Tensor,
111 beta: Tensor,
112 running_mean: Tensor,
113 running_var: Tensor,
114 eps: f32,
115 #[allow(dead_code)]
116 momentum: f32,
117 training: bool,
118}
119
120impl BatchNorm2d {
121 pub fn new(num_features: usize) -> Self {
122 Self::with_params(num_features, 1e-5, 0.1)
123 }
124
125 pub fn with_params(num_features: usize, eps: f32, momentum: f32) -> Self {
126 BatchNorm2d {
127 num_features,
128 gamma: Tensor::ones(&[num_features]),
129 beta: Tensor::zeros(&[num_features]),
130 running_mean: Tensor::zeros(&[num_features]),
131 running_var: Tensor::ones(&[num_features]),
132 eps,
133 momentum,
134 training: true,
135 }
136 }
137}
138
139impl Module for BatchNorm2d {
140 fn forward(&self, input: &Tensor) -> Tensor {
141 let dims = input.dims();
142 let data = input.data_f32();
143 let gamma = self.gamma.data_f32();
144 let beta = self.beta.data_f32();
145
146 let batch_size = dims[0];
147 let channels = dims[1];
148 let height = dims[2];
149 let width = dims[3];
150 let spatial_size = height * width;
151
152 let (mean, var) = if self.training {
153 let mut mean = vec![0.0f32; channels];
154 let mut var = vec![0.0f32; channels];
155 let n = (batch_size * spatial_size) as f32;
156
157 for c in 0..channels {
158 let mut sum = 0.0f32;
159 for b in 0..batch_size {
160 for h in 0..height {
161 for w in 0..width {
162 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
163 sum += data[idx];
164 }
165 }
166 }
167 mean[c] = sum / n;
168
169 let mut var_sum = 0.0f32;
170 for b in 0..batch_size {
171 for h in 0..height {
172 for w in 0..width {
173 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
174 var_sum += (data[idx] - mean[c]).powi(2);
175 }
176 }
177 }
178 var[c] = var_sum / n;
179 }
180
181 (mean, var)
182 } else {
183 (self.running_mean.data_f32(), self.running_var.data_f32())
184 };
185
186 let mut output = vec![0.0f32; data.len()];
187
188 for b in 0..batch_size {
189 for c in 0..channels {
190 let std = (var[c] + self.eps).sqrt();
191 for h in 0..height {
192 for w in 0..width {
193 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
194 output[idx] = gamma[c] * (data[idx] - mean[c]) / std + beta[c];
195 }
196 }
197 }
198 }
199
200 Tensor::from_slice(&output, dims).unwrap()
201 }
202
203 fn parameters(&self) -> Vec<Tensor> {
204 vec![self.gamma.clone(), self.beta.clone()]
205 }
206
207 fn train(&mut self) { self.training = true; }
208 fn eval(&mut self) { self.training = false; }
209 fn is_training(&self) -> bool { self.training }
210}
211
212pub struct LayerNorm {
214 normalized_shape: Vec<usize>,
215 gamma: Tensor,
216 beta: Tensor,
217 eps: f32,
218 training: bool,
219}
220
221impl LayerNorm {
222 pub fn new(normalized_shape: &[usize]) -> Self {
223 Self::with_eps(normalized_shape, 1e-5)
224 }
225
226 pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
227 let numel: usize = normalized_shape.iter().product();
228
229 LayerNorm {
230 normalized_shape: normalized_shape.to_vec(),
231 gamma: Tensor::ones(&[numel]),
232 beta: Tensor::zeros(&[numel]),
233 eps,
234 training: true,
235 }
236 }
237}
238
239impl Module for LayerNorm {
240 fn forward(&self, input: &Tensor) -> Tensor {
241 let dims = input.dims();
242 let data = input.data_f32();
243 let gamma = self.gamma.data_f32();
244 let beta = self.beta.data_f32();
245
246 let norm_size: usize = self.normalized_shape.iter().product();
247 let batch_size = data.len() / norm_size;
248
249 let mut output = vec![0.0f32; data.len()];
250
251 for b in 0..batch_size {
252 let start = b * norm_size;
253 let end = start + norm_size;
254 let slice = &data[start..end];
255
256 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
258
259 let var: f32 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
261 let std = (var + self.eps).sqrt();
262
263 for i in 0..norm_size {
265 output[start + i] = gamma[i] * (slice[i] - mean) / std + beta[i];
266 }
267 }
268
269 Tensor::from_slice(&output, dims).unwrap()
270 }
271
272 fn parameters(&self) -> Vec<Tensor> {
273 vec![self.gamma.clone(), self.beta.clone()]
274 }
275
276 fn train(&mut self) { self.training = true; }
277 fn eval(&mut self) { self.training = false; }
278 fn is_training(&self) -> bool { self.training }
279}
280
281pub struct GroupNorm {
284 num_groups: usize,
285 num_channels: usize,
286 gamma: Tensor,
287 beta: Tensor,
288 eps: f32,
289 training: bool,
290}
291
292impl GroupNorm {
293 pub fn new(num_groups: usize, num_channels: usize) -> Self {
294 Self::with_eps(num_groups, num_channels, 1e-5)
295 }
296
297 pub fn with_eps(num_groups: usize, num_channels: usize, eps: f32) -> Self {
298 assert!(num_channels % num_groups == 0, "num_channels must be divisible by num_groups");
299
300 GroupNorm {
301 num_groups,
302 num_channels,
303 gamma: Tensor::ones(&[num_channels]),
304 beta: Tensor::zeros(&[num_channels]),
305 eps,
306 training: true,
307 }
308 }
309}
310
311impl Module for GroupNorm {
312 fn forward(&self, input: &Tensor) -> Tensor {
313 let dims = input.dims();
314 let data = input.data_f32();
315 let gamma = self.gamma.data_f32();
316 let beta = self.beta.data_f32();
317
318 let batch_size = dims[0];
319 let channels = dims[1];
320 let spatial_size: usize = dims[2..].iter().product();
321
322 assert_eq!(channels, self.num_channels, "Input channels must match num_channels");
323
324 let channels_per_group = channels / self.num_groups;
325 let mut output = vec![0.0f32; data.len()];
326
327 for b in 0..batch_size {
328 for g in 0..self.num_groups {
329 let mut sum = 0.0f32;
331 let mut sum_sq = 0.0f32;
332 let group_size = (channels_per_group * spatial_size) as f32;
333
334 for c in 0..channels_per_group {
335 let channel_idx = g * channels_per_group + c;
336 for s in 0..spatial_size {
337 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
338 let val = data[idx];
339 sum += val;
340 sum_sq += val * val;
341 }
342 }
343
344 let mean = sum / group_size;
345 let variance = (sum_sq / group_size) - (mean * mean);
346 let std = (variance + self.eps).sqrt();
347
348 for c in 0..channels_per_group {
350 let channel_idx = g * channels_per_group + c;
351 for s in 0..spatial_size {
352 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
353 let val = data[idx];
354 let normalized = (val - mean) / std;
355 output[idx] = gamma[channel_idx] * normalized + beta[channel_idx];
356 }
357 }
358 }
359 }
360
361 Tensor::from_slice(&output, dims).unwrap()
362 }
363
364 fn parameters(&self) -> Vec<Tensor> {
365 vec![self.gamma.clone(), self.beta.clone()]
366 }
367
368 fn train(&mut self) { self.training = true; }
369 fn eval(&mut self) { self.training = false; }
370 fn is_training(&self) -> bool { self.training }
371}
372
373pub struct InstanceNorm {
376 num_channels: usize,
377 gamma: Tensor,
378 beta: Tensor,
379 eps: f32,
380 training: bool,
381}
382
383impl InstanceNorm {
384 pub fn new(num_channels: usize) -> Self {
385 Self::with_eps(num_channels, 1e-5)
386 }
387
388 pub fn with_eps(num_channels: usize, eps: f32) -> Self {
389 InstanceNorm {
390 num_channels,
391 gamma: Tensor::ones(&[num_channels]),
392 beta: Tensor::zeros(&[num_channels]),
393 eps,
394 training: true,
395 }
396 }
397}
398
399impl Module for InstanceNorm {
400 fn forward(&self, input: &Tensor) -> Tensor {
401 let dims = input.dims();
402 let data = input.data_f32();
403 let gamma = self.gamma.data_f32();
404 let beta = self.beta.data_f32();
405
406 let batch_size = dims[0];
407 let channels = dims[1];
408 let spatial_size: usize = dims[2..].iter().product();
409
410 assert_eq!(channels, self.num_channels, "Input channels must match num_channels");
411
412 let mut output = vec![0.0f32; data.len()];
413
414 for b in 0..batch_size {
415 for c in 0..channels {
416 let mut sum = 0.0f32;
418 let mut sum_sq = 0.0f32;
419
420 for s in 0..spatial_size {
421 let idx = b * channels * spatial_size + c * spatial_size + s;
422 let val = data[idx];
423 sum += val;
424 sum_sq += val * val;
425 }
426
427 let mean = sum / spatial_size as f32;
428 let variance = (sum_sq / spatial_size as f32) - (mean * mean);
429 let std = (variance + self.eps).sqrt();
430
431 for s in 0..spatial_size {
433 let idx = b * channels * spatial_size + c * spatial_size + s;
434 let val = data[idx];
435 let normalized = (val - mean) / std;
436 output[idx] = gamma[c] * normalized + beta[c];
437 }
438 }
439 }
440
441 Tensor::from_slice(&output, dims).unwrap()
442 }
443
444 fn parameters(&self) -> Vec<Tensor> {
445 vec![self.gamma.clone(), self.beta.clone()]
446 }
447
448 fn train(&mut self) { self.training = true; }
449 fn eval(&mut self) { self.training = false; }
450 fn is_training(&self) -> bool { self.training }
451}
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn test_batchnorm2d() {
458 let bn = BatchNorm2d::new(16);
459 let input = Tensor::randn(&[2, 16, 8, 8]);
460 let output = bn.forward(&input);
461
462 assert_eq!(output.dims(), input.dims());
463 }
464
465 #[test]
466 fn test_layernorm() {
467 let ln = LayerNorm::new(&[64]);
468 let input = Tensor::randn(&[2, 10, 64]);
469 let output = ln.forward(&input);
470
471 assert_eq!(output.dims(), input.dims());
472 }
473}