1use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, Ordering};
10
11use axonml_autograd::Variable;
12use axonml_tensor::Tensor;
13use parking_lot::RwLock;
14
15use crate::init::{ones, zeros};
16use crate::module::Module;
17use crate::parameter::Parameter;
18
19pub struct BatchNorm1d {
31 pub weight: Parameter,
33 pub bias: Parameter,
35 running_mean: RwLock<Tensor<f32>>,
37 running_var: RwLock<Tensor<f32>>,
39 num_features: usize,
41 eps: f32,
43 momentum: f32,
45 track_running_stats: bool,
47 training: AtomicBool,
49}
50
51impl BatchNorm1d {
52 pub fn new(num_features: usize) -> Self {
54 Self::with_options(num_features, 1e-5, 0.1, true)
55 }
56
57 pub fn with_options(
59 num_features: usize,
60 eps: f32,
61 momentum: f32,
62 track_running_stats: bool,
63 ) -> Self {
64 Self {
65 weight: Parameter::named("weight", ones(&[num_features]), true),
66 bias: Parameter::named("bias", zeros(&[num_features]), true),
67 running_mean: RwLock::new(zeros(&[num_features])),
68 running_var: RwLock::new(ones(&[num_features])),
69 num_features,
70 eps,
71 momentum,
72 track_running_stats,
73 training: AtomicBool::new(true),
74 }
75 }
76
77 pub fn num_features(&self) -> usize {
79 self.num_features
80 }
81}
82
83impl Module for BatchNorm1d {
84 fn forward(&self, input: &Variable) -> Variable {
85 let input_data = input.data();
86 let shape = input_data.shape().to_vec();
87 let batch_size = shape[0];
88 let num_features = shape[1];
89
90 assert_eq!(
92 num_features, self.num_features,
93 "BatchNorm1d: expected {} features, got {}",
94 self.num_features, num_features
95 );
96
97 let input_vec = input_data.to_vec();
98 let weight_vec = self.weight.data().to_vec();
99 let bias_vec = self.bias.data().to_vec();
100
101 let is_training = self.training.load(Ordering::Relaxed);
102 let spatial_size: usize = if shape.len() > 2 {
103 shape[2..].iter().product()
104 } else {
105 1
106 };
107
108 let mut means = vec![0.0f32; num_features];
109 let mut vars = vec![0.0f32; num_features];
110
111 if is_training {
112 for c in 0..num_features {
114 let mut sum = 0.0f32;
115 for b in 0..batch_size {
116 for s in 0..spatial_size {
117 let idx = b * num_features * spatial_size + c * spatial_size + s;
118 sum += input_vec[idx];
119 }
120 }
121 means[c] = sum / (batch_size * spatial_size) as f32;
122
123 let mut var_sum = 0.0f32;
124 for b in 0..batch_size {
125 for s in 0..spatial_size {
126 let idx = b * num_features * spatial_size + c * spatial_size + s;
127 let diff = input_vec[idx] - means[c];
128 var_sum += diff * diff;
129 }
130 }
131 vars[c] = var_sum / (batch_size * spatial_size) as f32;
132 }
133
134 if self.track_running_stats {
136 let mut running_mean = self.running_mean.write();
137 let mut running_var = self.running_var.write();
138 let running_mean_vec = running_mean.to_vec();
139 let running_var_vec = running_var.to_vec();
140
141 let new_mean: Vec<f32> = running_mean_vec
142 .iter()
143 .zip(means.iter())
144 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
145 .collect();
146 let new_var: Vec<f32> = running_var_vec
147 .iter()
148 .zip(vars.iter())
149 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
150 .collect();
151
152 *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
153 *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
154 }
155 } else {
156 means = self.running_mean.read().to_vec();
158 vars = self.running_var.read().to_vec();
159 }
160
161 let mut output_vec = vec![0.0f32; input_vec.len()];
163 for b in 0..batch_size {
164 for c in 0..num_features {
165 for s in 0..spatial_size {
166 let idx = b * num_features * spatial_size + c * spatial_size + s;
167 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
168 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
169 }
170 }
171 }
172
173 let output = Tensor::from_vec(output_vec, &shape).unwrap();
174 Variable::new(output, input.requires_grad())
175 }
176
177 fn parameters(&self) -> Vec<Parameter> {
178 vec![self.weight.clone(), self.bias.clone()]
179 }
180
181 fn named_parameters(&self) -> HashMap<String, Parameter> {
182 let mut params = HashMap::new();
183 params.insert("weight".to_string(), self.weight.clone());
184 params.insert("bias".to_string(), self.bias.clone());
185 params
186 }
187
188 fn set_training(&mut self, training: bool) {
189 self.training.store(training, Ordering::Relaxed);
190 }
191
192 fn is_training(&self) -> bool {
193 self.training.load(Ordering::Relaxed)
194 }
195
196 fn name(&self) -> &'static str {
197 "BatchNorm1d"
198 }
199}
200
201pub struct BatchNorm2d {
211 pub weight: Parameter,
213 pub bias: Parameter,
215 running_mean: RwLock<Tensor<f32>>,
217 running_var: RwLock<Tensor<f32>>,
219 num_features: usize,
221 eps: f32,
223 momentum: f32,
225 training: AtomicBool,
227}
228
229impl BatchNorm2d {
230 pub fn new(num_features: usize) -> Self {
232 Self::with_options(num_features, 1e-5, 0.1)
233 }
234
235 pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
237 Self {
238 weight: Parameter::named("weight", ones(&[num_features]), true),
239 bias: Parameter::named("bias", zeros(&[num_features]), true),
240 running_mean: RwLock::new(zeros(&[num_features])),
241 running_var: RwLock::new(ones(&[num_features])),
242 num_features,
243 eps,
244 momentum,
245 training: AtomicBool::new(true),
246 }
247 }
248
249 pub fn num_features(&self) -> usize {
251 self.num_features
252 }
253}
254
255impl Module for BatchNorm2d {
256 fn forward(&self, input: &Variable) -> Variable {
257 let input_data = input.data();
258 let shape = input_data.shape().to_vec();
259 let batch_size = shape[0];
260 let channels = shape[1];
261 let height = shape[2];
262 let width = shape[3];
263 let spatial_size = height * width;
264
265 assert_eq!(
267 channels, self.num_features,
268 "BatchNorm2d: expected {} channels, got {}",
269 self.num_features, channels
270 );
271
272 let input_vec = input_data.to_vec();
273 let weight_vec = self.weight.data().to_vec();
274 let bias_vec = self.bias.data().to_vec();
275
276 let is_training = self.training.load(Ordering::Relaxed);
277
278 let mut means = vec![0.0f32; channels];
279 let mut vars = vec![0.0f32; channels];
280
281 if is_training {
282 for c in 0..channels {
283 let mut sum = 0.0f32;
284 for b in 0..batch_size {
285 for h in 0..height {
286 for w in 0..width {
287 let idx =
288 b * channels * spatial_size + c * spatial_size + h * width + w;
289 sum += input_vec[idx];
290 }
291 }
292 }
293 means[c] = sum / (batch_size * spatial_size) as f32;
294
295 let mut var_sum = 0.0f32;
296 for b in 0..batch_size {
297 for h in 0..height {
298 for w in 0..width {
299 let idx =
300 b * channels * spatial_size + c * spatial_size + h * width + w;
301 let diff = input_vec[idx] - means[c];
302 var_sum += diff * diff;
303 }
304 }
305 }
306 vars[c] = var_sum / (batch_size * spatial_size) as f32;
307 }
308
309 let mut running_mean = self.running_mean.write();
311 let mut running_var = self.running_var.write();
312 let running_mean_vec = running_mean.to_vec();
313 let running_var_vec = running_var.to_vec();
314
315 let new_mean: Vec<f32> = running_mean_vec
316 .iter()
317 .zip(means.iter())
318 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
319 .collect();
320 let new_var: Vec<f32> = running_var_vec
321 .iter()
322 .zip(vars.iter())
323 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
324 .collect();
325
326 *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
327 *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
328 } else {
329 means = self.running_mean.read().to_vec();
330 vars = self.running_var.read().to_vec();
331 }
332
333 let mut output_vec = vec![0.0f32; input_vec.len()];
334 for b in 0..batch_size {
335 for c in 0..channels {
336 for h in 0..height {
337 for w in 0..width {
338 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
339 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
340 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
341 }
342 }
343 }
344 }
345
346 let output = Tensor::from_vec(output_vec, &shape).unwrap();
347 Variable::new(output, input.requires_grad())
348 }
349
350 fn parameters(&self) -> Vec<Parameter> {
351 vec![self.weight.clone(), self.bias.clone()]
352 }
353
354 fn named_parameters(&self) -> HashMap<String, Parameter> {
355 let mut params = HashMap::new();
356 params.insert("weight".to_string(), self.weight.clone());
357 params.insert("bias".to_string(), self.bias.clone());
358 params
359 }
360
361 fn set_training(&mut self, training: bool) {
362 self.training.store(training, Ordering::Relaxed);
363 }
364
365 fn is_training(&self) -> bool {
366 self.training.load(Ordering::Relaxed)
367 }
368
369 fn name(&self) -> &'static str {
370 "BatchNorm2d"
371 }
372}
373
374pub struct LayerNorm {
384 pub weight: Parameter,
386 pub bias: Parameter,
388 normalized_shape: Vec<usize>,
390 eps: f32,
392}
393
394impl LayerNorm {
395 pub fn new(normalized_shape: Vec<usize>) -> Self {
397 Self::with_eps(normalized_shape, 1e-5)
398 }
399
400 pub fn single(size: usize) -> Self {
402 Self::new(vec![size])
403 }
404
405 pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
407 let numel: usize = normalized_shape.iter().product();
408 Self {
409 weight: Parameter::named("weight", ones(&[numel]), true),
410 bias: Parameter::named("bias", zeros(&[numel]), true),
411 normalized_shape,
412 eps,
413 }
414 }
415}
416
417impl Module for LayerNorm {
418 fn forward(&self, input: &Variable) -> Variable {
419 let input_data = input.data();
420 let shape = input_data.shape().to_vec();
421 let input_vec = input_data.to_vec();
422
423 let weight_vec = self.weight.data().to_vec();
424 let bias_vec = self.bias.data().to_vec();
425
426 let norm_size: usize = self.normalized_shape.iter().product();
428 let batch_size = input_vec.len() / norm_size;
429
430 let mut output_vec = vec![0.0f32; input_vec.len()];
431
432 for b in 0..batch_size {
433 let start = b * norm_size;
434 let end = start + norm_size;
435 let slice = &input_vec[start..end];
436
437 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
439
440 let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
442
443 for i in 0..norm_size {
445 let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
446 output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
447 }
448 }
449
450 let output = Tensor::from_vec(output_vec, &shape).unwrap();
451 Variable::new(output, input.requires_grad())
452 }
453
454 fn parameters(&self) -> Vec<Parameter> {
455 vec![self.weight.clone(), self.bias.clone()]
456 }
457
458 fn named_parameters(&self) -> HashMap<String, Parameter> {
459 let mut params = HashMap::new();
460 params.insert("weight".to_string(), self.weight.clone());
461 params.insert("bias".to_string(), self.bias.clone());
462 params
463 }
464
465 fn name(&self) -> &'static str {
466 "LayerNorm"
467 }
468}
469
470#[cfg(test)]
475mod tests {
476 use super::*;
477
478 #[test]
479 fn test_batchnorm1d() {
480 let bn = BatchNorm1d::new(3);
481 let input = Variable::new(
482 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
483 false,
484 );
485 let output = bn.forward(&input);
486 assert_eq!(output.shape(), vec![2, 3]);
487 }
488
489 #[test]
490 fn test_batchnorm2d() {
491 let bn = BatchNorm2d::new(2);
492 let input = Variable::new(
493 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
494 false,
495 );
496 let output = bn.forward(&input);
497 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
498 }
499
500 #[test]
501 fn test_layernorm() {
502 let ln = LayerNorm::single(4);
503 let input = Variable::new(
504 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
505 false,
506 );
507 let output = ln.forward(&input);
508 assert_eq!(output.shape(), vec![2, 4]);
509 }
510
511 #[test]
512 fn test_batchnorm_parameters() {
513 let bn = BatchNorm1d::new(10);
514 assert_eq!(bn.parameters().len(), 2);
515 assert_eq!(bn.num_parameters(), 20); }
517}