1use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::Variable;
21use axonml_autograd::functions::{
22 BatchNorm1dBackward, BatchNorm2dBackward, GroupNormBackward, InstanceNorm2dBackward,
23 LayerNormBackward,
24};
25use axonml_autograd::grad_fn::GradFn;
26use axonml_autograd::no_grad::is_grad_enabled;
27use axonml_tensor::Tensor;
28use parking_lot::RwLock;
29
30use crate::init::{ones, zeros};
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34pub struct BatchNorm1d {
46 pub weight: Parameter,
48 pub bias: Parameter,
50 running_mean: RwLock<Tensor<f32>>,
52 running_var: RwLock<Tensor<f32>>,
54 num_features: usize,
56 eps: f32,
58 momentum: f32,
60 track_running_stats: bool,
62 training: AtomicBool,
64}
65
66impl BatchNorm1d {
67 pub fn new(num_features: usize) -> Self {
69 Self::with_options(num_features, 1e-5, 0.1, true)
70 }
71
72 pub fn with_options(
74 num_features: usize,
75 eps: f32,
76 momentum: f32,
77 track_running_stats: bool,
78 ) -> Self {
79 Self {
80 weight: Parameter::named("weight", ones(&[num_features]), true),
81 bias: Parameter::named("bias", zeros(&[num_features]), true),
82 running_mean: RwLock::new(zeros(&[num_features])),
83 running_var: RwLock::new(ones(&[num_features])),
84 num_features,
85 eps,
86 momentum,
87 track_running_stats,
88 training: AtomicBool::new(true),
89 }
90 }
91
92 pub fn num_features(&self) -> usize {
94 self.num_features
95 }
96}
97
98impl Module for BatchNorm1d {
99 fn forward(&self, input: &Variable) -> Variable {
100 let input_data = input.data();
101 let shape = input_data.shape().to_vec();
102 let batch_size = shape[0];
103 let num_features = shape[1];
104
105 assert_eq!(
107 num_features, self.num_features,
108 "BatchNorm1d: expected {} features, got {}",
109 self.num_features, num_features
110 );
111
112 let is_training = self.training.load(Ordering::Relaxed);
113 let spatial_size: usize = if shape.len() > 2 {
114 shape[2..].iter().product()
115 } else {
116 1
117 };
118
119 #[cfg(feature = "cuda")]
124 if input_data.device().is_gpu() && is_training {
125 let gamma_data = self.weight.data();
126 let beta_data = self.bias.data();
127
128 let gamma_gpu = if !gamma_data.device().is_gpu() {
130 gamma_data
131 .to_device(input_data.device())
132 .unwrap_or(gamma_data)
133 } else {
134 gamma_data
135 };
136 let beta_gpu = if !beta_data.device().is_gpu() {
137 beta_data
138 .to_device(input_data.device())
139 .unwrap_or(beta_data)
140 } else {
141 beta_data
142 };
143
144 if let Some((output_tensor, means, vars)) = input_data.batchnorm_fused(
145 &gamma_gpu,
146 &beta_gpu,
147 self.eps,
148 num_features,
149 spatial_size,
150 ) {
151 if self.track_running_stats {
153 let mut running_mean = self.running_mean.write();
154 let mut running_var = self.running_var.write();
155 let running_mean_vec = running_mean.to_vec();
156 let running_var_vec = running_var.to_vec();
157 let new_mean: Vec<f32> = running_mean_vec
158 .iter()
159 .zip(means.iter())
160 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
161 .collect();
162 let new_var: Vec<f32> = running_var_vec
163 .iter()
164 .zip(vars.iter())
165 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
166 .collect();
167 *running_mean = Tensor::from_vec(new_mean, &[num_features])
168 .expect("tensor creation failed");
169 *running_var =
170 Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
171 }
172
173 let weight_vec = gamma_gpu.to_vec();
174 let requires_grad =
175 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
176 if requires_grad {
177 let weight_var = self.weight.variable();
178 let bias_var = self.bias.variable();
179 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
180 input.grad_fn().cloned(),
181 weight_var.grad_fn().cloned(),
182 bias_var.grad_fn().cloned(),
183 input_data,
184 means,
185 vars,
186 weight_vec,
187 self.eps,
188 self.num_features,
189 ));
190 return Variable::from_operation(output_tensor, grad_fn, true);
191 }
192 return Variable::new(output_tensor, false);
193 }
194 }
195
196 let input_vec = input_data.to_vec();
197 let weight_vec = self.weight.data().to_vec();
198 let bias_vec = self.bias.data().to_vec();
199
200 let mut means = vec![0.0f32; num_features];
201 let mut vars = vec![0.0f32; num_features];
202
203 if is_training {
204 for c in 0..num_features {
206 let mut sum = 0.0f32;
207 for b in 0..batch_size {
208 for s in 0..spatial_size {
209 let idx = b * num_features * spatial_size + c * spatial_size + s;
210 sum += input_vec[idx];
211 }
212 }
213 means[c] = sum / (batch_size * spatial_size) as f32;
214
215 let mut var_sum = 0.0f32;
216 for b in 0..batch_size {
217 for s in 0..spatial_size {
218 let idx = b * num_features * spatial_size + c * spatial_size + s;
219 let diff = input_vec[idx] - means[c];
220 var_sum += diff * diff;
221 }
222 }
223 vars[c] = var_sum / (batch_size * spatial_size) as f32;
224 }
225
226 if self.track_running_stats {
228 let mut running_mean = self.running_mean.write();
229 let mut running_var = self.running_var.write();
230 let running_mean_vec = running_mean.to_vec();
231 let running_var_vec = running_var.to_vec();
232
233 let new_mean: Vec<f32> = running_mean_vec
234 .iter()
235 .zip(means.iter())
236 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
237 .collect();
238 let new_var: Vec<f32> = running_var_vec
239 .iter()
240 .zip(vars.iter())
241 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
242 .collect();
243
244 *running_mean =
245 Tensor::from_vec(new_mean, &[num_features]).expect("tensor creation failed");
246 *running_var =
247 Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
248 }
249 } else {
250 means = self.running_mean.read().to_vec();
252 vars = self.running_var.read().to_vec();
253 }
254
255 let mut output_vec = vec![0.0f32; input_vec.len()];
257 for b in 0..batch_size {
258 for c in 0..num_features {
259 for s in 0..spatial_size {
260 let idx = b * num_features * spatial_size + c * spatial_size + s;
261 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
262 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
263 }
264 }
265 }
266
267 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
268
269 let requires_grad =
270 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
271 if requires_grad {
272 let weight_var = self.weight.variable();
273 let bias_var = self.bias.variable();
274
275 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
276 input.grad_fn().cloned(),
277 weight_var.grad_fn().cloned(),
278 bias_var.grad_fn().cloned(),
279 input_data,
280 means.clone(),
281 vars.clone(),
282 weight_vec,
283 self.eps,
284 self.num_features,
285 ));
286 Variable::from_operation(output, grad_fn, true)
287 } else {
288 Variable::new(output, false)
289 }
290 }
291
292 fn parameters(&self) -> Vec<Parameter> {
293 vec![self.weight.clone(), self.bias.clone()]
294 }
295
296 fn named_parameters(&self) -> HashMap<String, Parameter> {
297 let mut params = HashMap::new();
298 params.insert("weight".to_string(), self.weight.clone());
299 params.insert("bias".to_string(), self.bias.clone());
300 params
301 }
302
303 fn set_training(&mut self, training: bool) {
304 self.training.store(training, Ordering::Relaxed);
305 }
306
307 fn is_training(&self) -> bool {
308 self.training.load(Ordering::Relaxed)
309 }
310
311 fn name(&self) -> &'static str {
312 "BatchNorm1d"
313 }
314
315 fn to_device(&self, device: axonml_core::Device) {
316 for param in self.parameters() {
318 param.to_device(device);
319 }
320 if self.track_running_stats {
322 let mut rm = self.running_mean.write();
323 if let Ok(moved) = rm.to_device(device) {
324 *rm = moved;
325 }
326 let mut rv = self.running_var.write();
327 if let Ok(moved) = rv.to_device(device) {
328 *rv = moved;
329 }
330 }
331 }
332}
333
334pub struct BatchNorm2d {
344 pub weight: Parameter,
346 pub bias: Parameter,
348 running_mean: RwLock<Tensor<f32>>,
350 running_var: RwLock<Tensor<f32>>,
352 num_features: usize,
354 eps: f32,
356 momentum: f32,
358 training: AtomicBool,
360}
361
362impl BatchNorm2d {
363 pub fn new(num_features: usize) -> Self {
365 Self::with_options(num_features, 1e-5, 0.1)
366 }
367
368 pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
370 Self {
371 weight: Parameter::named("weight", ones(&[num_features]), true),
372 bias: Parameter::named("bias", zeros(&[num_features]), true),
373 running_mean: RwLock::new(zeros(&[num_features])),
374 running_var: RwLock::new(ones(&[num_features])),
375 num_features,
376 eps,
377 momentum,
378 training: AtomicBool::new(true),
379 }
380 }
381
382 pub fn num_features(&self) -> usize {
384 self.num_features
385 }
386}
387
388impl Module for BatchNorm2d {
389 fn forward(&self, input: &Variable) -> Variable {
390 let input_data = input.data();
391 let shape = input_data.shape().to_vec();
392 let batch_size = shape[0];
393 let channels = shape[1];
394 let height = shape[2];
395 let width = shape[3];
396 let spatial_size = height * width;
397
398 assert_eq!(
400 channels, self.num_features,
401 "BatchNorm2d: expected {} channels, got {}",
402 self.num_features, channels
403 );
404
405 let is_training = self.training.load(Ordering::Relaxed);
406
407 #[cfg(feature = "cuda")]
409 if input_data.device().is_gpu() && is_training {
410 let gamma_data = self.weight.data();
411 let beta_data = self.bias.data();
412
413 let gamma_gpu = if !gamma_data.device().is_gpu() {
415 gamma_data
416 .to_device(input_data.device())
417 .unwrap_or(gamma_data)
418 } else {
419 gamma_data
420 };
421 let beta_gpu = if !beta_data.device().is_gpu() {
422 beta_data
423 .to_device(input_data.device())
424 .unwrap_or(beta_data)
425 } else {
426 beta_data
427 };
428
429 if let Some((output_tensor, means, vars)) =
430 input_data.batchnorm_fused(&gamma_gpu, &beta_gpu, self.eps, channels, spatial_size)
431 {
432 let mut running_mean = self.running_mean.write();
434 let mut running_var = self.running_var.write();
435 let running_mean_vec = running_mean.to_vec();
436 let running_var_vec = running_var.to_vec();
437 let new_mean: Vec<f32> = running_mean_vec
438 .iter()
439 .zip(means.iter())
440 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
441 .collect();
442 let new_var: Vec<f32> = running_var_vec
443 .iter()
444 .zip(vars.iter())
445 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
446 .collect();
447 *running_mean =
448 Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
449 *running_var =
450 Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
451
452 let weight_vec = gamma_gpu.to_vec();
453 let requires_grad =
454 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
455 if requires_grad {
456 let weight_var = self.weight.variable();
457 let bias_var = self.bias.variable();
458 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
459 input.grad_fn().cloned(),
460 weight_var.grad_fn().cloned(),
461 bias_var.grad_fn().cloned(),
462 input_data,
463 means,
464 vars,
465 weight_vec,
466 self.eps,
467 self.num_features,
468 ));
469 return Variable::from_operation(output_tensor, grad_fn, true);
470 }
471 return Variable::new(output_tensor, false);
472 }
473 }
474
475 let input_vec = input_data.to_vec();
477 let weight_vec = self.weight.data().to_vec();
478 let bias_vec = self.bias.data().to_vec();
479
480 let mut means = vec![0.0f32; channels];
481 let mut vars = vec![0.0f32; channels];
482
483 if is_training {
484 let n_per_channel = (batch_size * spatial_size) as f32;
485 for c in 0..channels {
486 let mut sum = 0.0f32;
487 let mut sum_sq = 0.0f32;
488 for b in 0..batch_size {
489 let base = b * channels * spatial_size + c * spatial_size;
490 for s in 0..spatial_size {
491 let val = input_vec[base + s];
492 sum += val;
493 sum_sq += val * val;
494 }
495 }
496 means[c] = sum / n_per_channel;
497 vars[c] = sum_sq / n_per_channel - means[c] * means[c];
498 }
499
500 let mut running_mean = self.running_mean.write();
502 let mut running_var = self.running_var.write();
503 let running_mean_vec = running_mean.to_vec();
504 let running_var_vec = running_var.to_vec();
505
506 let new_mean: Vec<f32> = running_mean_vec
507 .iter()
508 .zip(means.iter())
509 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
510 .collect();
511 let new_var: Vec<f32> = running_var_vec
512 .iter()
513 .zip(vars.iter())
514 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
515 .collect();
516
517 *running_mean =
518 Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
519 *running_var = Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
520 } else {
521 means = self.running_mean.read().to_vec();
522 vars = self.running_var.read().to_vec();
523 }
524
525 let total = input_vec.len();
527 let mut output_vec = vec![0.0f32; total];
528
529 let inv_stds: Vec<f32> = vars.iter().map(|v| 1.0 / (v + self.eps).sqrt()).collect();
531
532 for i in 0..total {
533 let c = (i / spatial_size) % channels;
534 output_vec[i] = (input_vec[i] - means[c]) * inv_stds[c] * weight_vec[c] + bias_vec[c];
535 }
536
537 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
538
539 let requires_grad =
540 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
541 if requires_grad {
542 let weight_var = self.weight.variable();
543 let bias_var = self.bias.variable();
544
545 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
546 input.grad_fn().cloned(),
547 weight_var.grad_fn().cloned(),
548 bias_var.grad_fn().cloned(),
549 input_data,
550 means.clone(),
551 vars.clone(),
552 weight_vec,
553 self.eps,
554 self.num_features,
555 ));
556 Variable::from_operation(output, grad_fn, true)
557 } else {
558 Variable::new(output, false)
559 }
560 }
561
562 fn parameters(&self) -> Vec<Parameter> {
563 vec![self.weight.clone(), self.bias.clone()]
564 }
565
566 fn named_parameters(&self) -> HashMap<String, Parameter> {
567 let mut params = HashMap::new();
568 params.insert("weight".to_string(), self.weight.clone());
569 params.insert("bias".to_string(), self.bias.clone());
570 params
571 }
572
573 fn set_training(&mut self, training: bool) {
574 self.training.store(training, Ordering::Relaxed);
575 }
576
577 fn is_training(&self) -> bool {
578 self.training.load(Ordering::Relaxed)
579 }
580
581 fn name(&self) -> &'static str {
582 "BatchNorm2d"
583 }
584
585 fn to_device(&self, device: axonml_core::Device) {
586 for param in self.parameters() {
587 param.to_device(device);
588 }
589 let mut rm = self.running_mean.write();
591 if let Ok(moved) = rm.to_device(device) {
592 *rm = moved;
593 }
594 let mut rv = self.running_var.write();
595 if let Ok(moved) = rv.to_device(device) {
596 *rv = moved;
597 }
598 }
599}
600
601pub struct LayerNorm {
611 pub weight: Parameter,
613 pub bias: Parameter,
615 normalized_shape: Vec<usize>,
617 eps: f32,
619}
620
621impl LayerNorm {
622 pub fn new(normalized_shape: Vec<usize>) -> Self {
624 Self::with_eps(normalized_shape, 1e-5)
625 }
626
627 pub fn single(size: usize) -> Self {
629 Self::new(vec![size])
630 }
631
632 pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
634 let numel: usize = normalized_shape.iter().product();
635 Self {
636 weight: Parameter::named("weight", ones(&[numel]), true),
637 bias: Parameter::named("bias", zeros(&[numel]), true),
638 normalized_shape,
639 eps,
640 }
641 }
642}
643
644impl Module for LayerNorm {
645 fn forward(&self, input: &Variable) -> Variable {
646 let input_data = input.data();
647 let shape = input_data.shape().to_vec();
648 let norm_size: usize = self.normalized_shape.iter().product();
649 let total_len = input_data.numel();
650 let num_rows = total_len / norm_size;
651
652 #[cfg(feature = "cuda")]
654 if input_data.device().is_gpu() {
655 let weight_data = self.weight.data();
657 let weight_gpu = if weight_data.device().is_gpu() {
658 weight_data.clone()
659 } else {
660 weight_data.to_device(input_data.device().clone()).unwrap()
661 };
662 let bias_data = self.bias.data();
663 let bias_gpu = if bias_data.device().is_gpu() {
664 bias_data.clone()
665 } else {
666 bias_data.to_device(input_data.device().clone()).unwrap()
667 };
668
669 let output = input_data
670 .layer_norm_cuda(&weight_gpu, &bias_gpu, norm_size, self.eps)
671 .expect("CUDA LayerNorm failed");
672
673 let requires_grad = input.requires_grad() && is_grad_enabled();
674 return if requires_grad {
675 let grad_fn = GradFn::new(LayerNormBackward::new(
676 input.grad_fn().cloned(),
677 self.weight.variable().grad_fn().cloned(),
678 self.bias.variable().grad_fn().cloned(),
679 input_data.clone(),
680 self.weight.data().clone(),
681 self.normalized_shape.clone(),
682 self.eps,
683 ));
684 Variable::from_operation(output, grad_fn, true)
685 } else {
686 Variable::from_tensor(output)
687 };
688 }
689
690 let input_vec = input_data.to_vec();
692 let weight_vec = self.weight.data().to_vec();
693 let bias_vec = self.bias.data().to_vec();
694
695 let mut output_vec = vec![0.0f32; input_vec.len()];
696
697 for b in 0..num_rows {
698 let start = b * norm_size;
699 let end = start + norm_size;
700 let slice = &input_vec[start..end];
701
702 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
703 let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
704
705 for i in 0..norm_size {
706 let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
707 output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
708 }
709 }
710
711 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
712 let requires_grad = input.requires_grad() && is_grad_enabled();
713
714 if requires_grad {
715 let grad_fn = GradFn::new(LayerNormBackward::new(
716 input.grad_fn().cloned(),
717 self.weight.variable().grad_fn().cloned(),
718 self.bias.variable().grad_fn().cloned(),
719 input_data.clone(),
720 self.weight.data().clone(),
721 self.normalized_shape.clone(),
722 self.eps,
723 ));
724 Variable::from_operation(output, grad_fn, true)
725 } else {
726 Variable::from_tensor(output)
727 }
728 }
729
730 fn parameters(&self) -> Vec<Parameter> {
731 vec![self.weight.clone(), self.bias.clone()]
732 }
733
734 fn named_parameters(&self) -> HashMap<String, Parameter> {
735 let mut params = HashMap::new();
736 params.insert("weight".to_string(), self.weight.clone());
737 params.insert("bias".to_string(), self.bias.clone());
738 params
739 }
740
741 fn name(&self) -> &'static str {
742 "LayerNorm"
743 }
744}
745
746pub struct GroupNorm {
759 pub weight: Parameter,
761 pub bias: Parameter,
763 num_groups: usize,
765 num_channels: usize,
767 eps: f32,
769 affine: bool,
771}
772
773impl GroupNorm {
774 pub fn new(num_groups: usize, num_channels: usize) -> Self {
780 Self::with_options(num_groups, num_channels, 1e-5, true)
781 }
782
783 pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
785 assert!(
786 num_channels % num_groups == 0,
787 "num_channels ({}) must be divisible by num_groups ({})",
788 num_channels,
789 num_groups
790 );
791
792 Self {
793 weight: Parameter::named("weight", ones(&[num_channels]), affine),
794 bias: Parameter::named("bias", zeros(&[num_channels]), affine),
795 num_groups,
796 num_channels,
797 eps,
798 affine,
799 }
800 }
801}
802
803impl Module for GroupNorm {
804 fn forward(&self, input: &Variable) -> Variable {
805 let input_data = input.data();
806 let shape = input_data.shape().to_vec();
807 let batch_size = shape[0];
808 let channels = shape[1];
809 let spatial_size: usize = shape[2..].iter().product();
810
811 assert_eq!(
812 channels, self.num_channels,
813 "GroupNorm: expected {} channels, got {}",
814 self.num_channels, channels
815 );
816
817 let input_vec = input_data.to_vec();
818 let channels_per_group = channels / self.num_groups;
819
820 let mut output_vec = vec![0.0f32; input_vec.len()];
821
822 for b in 0..batch_size {
823 for g in 0..self.num_groups {
824 let mut sum = 0.0f32;
826 let group_size = channels_per_group * spatial_size;
827
828 for c in 0..channels_per_group {
829 let channel_idx = g * channels_per_group + c;
830 for s in 0..spatial_size {
831 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
832 sum += input_vec[idx];
833 }
834 }
835 let mean = sum / group_size as f32;
836
837 let mut var_sum = 0.0f32;
838 for c in 0..channels_per_group {
839 let channel_idx = g * channels_per_group + c;
840 for s in 0..spatial_size {
841 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
842 let diff = input_vec[idx] - mean;
843 var_sum += diff * diff;
844 }
845 }
846 let var = var_sum / group_size as f32;
847
848 let std_inv = 1.0 / (var + self.eps).sqrt();
850 for c in 0..channels_per_group {
851 let channel_idx = g * channels_per_group + c;
852 let weight = if self.affine {
853 self.weight.data().to_vec()[channel_idx]
854 } else {
855 1.0
856 };
857 let bias = if self.affine {
858 self.bias.data().to_vec()[channel_idx]
859 } else {
860 0.0
861 };
862
863 for s in 0..spatial_size {
864 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
865 let normalized = (input_vec[idx] - mean) * std_inv;
866 output_vec[idx] = normalized * weight + bias;
867 }
868 }
869 }
870 }
871
872 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
873 let requires_grad = input.requires_grad() && is_grad_enabled();
874
875 if requires_grad && self.affine {
876 let grad_fn = GradFn::new(GroupNormBackward::new(
877 input.grad_fn().cloned(),
878 self.weight.variable().grad_fn().cloned(),
879 self.bias.variable().grad_fn().cloned(),
880 input_data.clone(),
881 self.weight.data().clone(),
882 self.num_groups,
883 self.eps,
884 ));
885 Variable::from_operation(output, grad_fn, true)
886 } else {
887 Variable::from_tensor(output)
888 }
889 }
890
891 fn parameters(&self) -> Vec<Parameter> {
892 if self.affine {
893 vec![self.weight.clone(), self.bias.clone()]
894 } else {
895 vec![]
896 }
897 }
898
899 fn named_parameters(&self) -> HashMap<String, Parameter> {
900 if self.affine {
901 let mut params = HashMap::new();
902 params.insert("weight".to_string(), self.weight.clone());
903 params.insert("bias".to_string(), self.bias.clone());
904 params
905 } else {
906 HashMap::new()
907 }
908 }
909
910 fn name(&self) -> &'static str {
911 "GroupNorm"
912 }
913}
914
915pub struct InstanceNorm2d {
928 pub weight: Parameter,
930 pub bias: Parameter,
932 num_features: usize,
934 eps: f32,
936 affine: bool,
938}
939
940impl InstanceNorm2d {
941 pub fn new(num_features: usize) -> Self {
943 Self::with_options(num_features, 1e-5, false)
944 }
945
946 pub fn with_affine(num_features: usize) -> Self {
948 Self::with_options(num_features, 1e-5, true)
949 }
950
951 pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
953 Self {
954 weight: Parameter::named("weight", ones(&[num_features]), affine),
955 bias: Parameter::named("bias", zeros(&[num_features]), affine),
956 num_features,
957 eps,
958 affine,
959 }
960 }
961}
962
963impl Module for InstanceNorm2d {
964 fn forward(&self, input: &Variable) -> Variable {
965 let input_data = input.data();
966 let shape = input_data.shape().to_vec();
967
968 assert!(
969 shape.len() == 4,
970 "InstanceNorm2d expects 4D input (N, C, H, W)"
971 );
972
973 let batch_size = shape[0];
974 let channels = shape[1];
975 let height = shape[2];
976 let width = shape[3];
977 let spatial_size = height * width;
978
979 assert_eq!(
980 channels, self.num_features,
981 "InstanceNorm2d: expected {} channels, got {}",
982 self.num_features, channels
983 );
984
985 let input_vec = input_data.to_vec();
986 let mut output_vec = vec![0.0f32; input_vec.len()];
987
988 for b in 0..batch_size {
989 for c in 0..channels {
990 let mut sum = 0.0f32;
992 for s in 0..spatial_size {
993 let idx = b * channels * spatial_size + c * spatial_size + s;
994 sum += input_vec[idx];
995 }
996 let mean = sum / spatial_size as f32;
997
998 let mut var_sum = 0.0f32;
1000 for s in 0..spatial_size {
1001 let idx = b * channels * spatial_size + c * spatial_size + s;
1002 let diff = input_vec[idx] - mean;
1003 var_sum += diff * diff;
1004 }
1005 let var = var_sum / spatial_size as f32;
1006
1007 let std_inv = 1.0 / (var + self.eps).sqrt();
1009 let weight = if self.affine {
1010 self.weight.data().to_vec()[c]
1011 } else {
1012 1.0
1013 };
1014 let bias = if self.affine {
1015 self.bias.data().to_vec()[c]
1016 } else {
1017 0.0
1018 };
1019
1020 for s in 0..spatial_size {
1021 let idx = b * channels * spatial_size + c * spatial_size + s;
1022 let normalized = (input_vec[idx] - mean) * std_inv;
1023 output_vec[idx] = normalized * weight + bias;
1024 }
1025 }
1026 }
1027
1028 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
1029 let requires_grad = input.requires_grad() && is_grad_enabled();
1030
1031 if requires_grad {
1032 let grad_fn = GradFn::new(InstanceNorm2dBackward::new(
1033 input.grad_fn().cloned(),
1034 if self.affine {
1035 self.weight.variable().grad_fn().cloned()
1036 } else {
1037 None
1038 },
1039 if self.affine {
1040 self.bias.variable().grad_fn().cloned()
1041 } else {
1042 None
1043 },
1044 input_data.clone(),
1045 self.weight.data().clone(),
1046 self.eps,
1047 self.affine,
1048 ));
1049 Variable::from_operation(output, grad_fn, true)
1050 } else {
1051 Variable::from_tensor(output)
1052 }
1053 }
1054
1055 fn parameters(&self) -> Vec<Parameter> {
1056 if self.affine {
1057 vec![self.weight.clone(), self.bias.clone()]
1058 } else {
1059 vec![]
1060 }
1061 }
1062
1063 fn named_parameters(&self) -> HashMap<String, Parameter> {
1064 if self.affine {
1065 let mut params = HashMap::new();
1066 params.insert("weight".to_string(), self.weight.clone());
1067 params.insert("bias".to_string(), self.bias.clone());
1068 params
1069 } else {
1070 HashMap::new()
1071 }
1072 }
1073
1074 fn name(&self) -> &'static str {
1075 "InstanceNorm2d"
1076 }
1077}
1078
1079#[cfg(test)]
1084mod tests {
1085 use super::*;
1086
1087 #[test]
1088 fn test_batchnorm1d() {
1089 let bn = BatchNorm1d::new(3);
1090 let input = Variable::new(
1091 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1092 .expect("tensor creation failed"),
1093 false,
1094 );
1095 let output = bn.forward(&input);
1096 assert_eq!(output.shape(), vec![2, 3]);
1097 }
1098
1099 #[test]
1100 fn test_batchnorm2d() {
1101 let bn = BatchNorm2d::new(2);
1102 let input = Variable::new(
1103 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1104 false,
1105 );
1106 let output = bn.forward(&input);
1107 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1108 }
1109
1110 #[test]
1111 fn test_layernorm() {
1112 let ln = LayerNorm::single(4);
1113 let input = Variable::new(
1114 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
1115 .expect("tensor creation failed"),
1116 false,
1117 );
1118 let output = ln.forward(&input);
1119 assert_eq!(output.shape(), vec![2, 4]);
1120 }
1121
1122 #[test]
1123 fn test_batchnorm_parameters() {
1124 let bn = BatchNorm1d::new(10);
1125 assert_eq!(bn.parameters().len(), 2);
1126 assert_eq!(bn.num_parameters(), 20); }
1128
1129 #[test]
1130 fn test_groupnorm() {
1131 let gn = GroupNorm::new(2, 4); let input = Variable::new(
1133 Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).expect("tensor creation failed"),
1134 false,
1135 );
1136 let output = gn.forward(&input);
1137 assert_eq!(output.shape(), vec![2, 4, 2, 2]);
1138 }
1139
1140 #[test]
1141 fn test_groupnorm_normalization() {
1142 let gn = GroupNorm::with_options(2, 4, 1e-5, false); let input = Variable::new(
1144 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2])
1145 .expect("tensor creation failed"),
1146 false,
1147 );
1148 let output = gn.forward(&input);
1149 let out_vec = output.data().to_vec();
1151 let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
1153 let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
1154 assert!(group1_mean.abs() < 1e-5);
1155 assert!(group2_mean.abs() < 1e-5);
1156 }
1157
1158 #[test]
1159 fn test_instancenorm2d() {
1160 let inn = InstanceNorm2d::new(2);
1161 let input = Variable::new(
1162 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1163 false,
1164 );
1165 let output = inn.forward(&input);
1166 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1167 }
1168
1169 #[test]
1170 fn test_instancenorm2d_with_affine() {
1171 let inn = InstanceNorm2d::with_affine(4);
1172 let input = Variable::new(
1173 Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).expect("tensor creation failed"),
1174 false,
1175 );
1176 let output = inn.forward(&input);
1177 assert_eq!(output.shape(), vec![1, 4, 4, 4]);
1178 assert_eq!(inn.parameters().len(), 2);
1179 }
1180
1181 #[test]
1186 fn test_layernorm_zero_mean_unit_var() {
1187 let ln = LayerNorm::with_eps(vec![4], 1e-5);
1189 let input = Variable::new(
1190 Tensor::from_vec(vec![1.0, 5.0, 3.0, 7.0], &[1, 4]).unwrap(),
1191 false,
1192 );
1193 let output = ln.forward(&input);
1194 let out = output.data().to_vec();
1195
1196 let mean: f32 = out.iter().sum::<f32>() / out.len() as f32;
1197 let var: f32 = out.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / out.len() as f32;
1198
1199 assert!(
1200 mean.abs() < 1e-4,
1201 "LayerNorm output mean should be ~0, got {}",
1202 mean
1203 );
1204 assert!(
1205 (var - 1.0).abs() < 0.1,
1206 "LayerNorm output var should be ~1, got {}",
1207 var
1208 );
1209 }
1210
1211 #[test]
1212 fn test_layernorm_gradient_flow() {
1213 use axonml_autograd::backward;
1214
1215 let ln = LayerNorm::single(3);
1216 let input = Variable::new(
1217 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
1218 true,
1219 );
1220 let output = ln.forward(&input);
1221 let loss = output.sum();
1222
1223 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1224 backward(&loss, &ones);
1225
1226 let grad = input
1227 .grad()
1228 .expect("Should have gradient through LayerNorm");
1229 let gv = grad.to_vec();
1230 assert_eq!(gv.len(), 3);
1231 assert!(
1233 gv.iter().all(|g| g.is_finite()),
1234 "All gradients should be finite: {:?}",
1235 gv
1236 );
1237 }
1238
1239 #[test]
1240 fn test_layernorm_batch_independence() {
1241 let ln = LayerNorm::with_eps(vec![3], 1e-5);
1242
1243 let input1 = Variable::new(
1245 Tensor::from_vec(vec![10.0, 20.0, 30.0], &[1, 3]).unwrap(),
1246 false,
1247 );
1248 let out1 = ln.forward(&input1).data().to_vec();
1249
1250 let input2 = Variable::new(
1252 Tensor::from_vec(vec![10.0, 20.0, 30.0, 1.0, 1.0, 1.0], &[2, 3]).unwrap(),
1253 false,
1254 );
1255 let out2 = ln.forward(&input2).data().to_vec();
1256
1257 for i in 0..3 {
1259 assert!(
1260 (out1[i] - out2[i]).abs() < 1e-5,
1261 "LayerNorm should be batch-independent: {} vs {}",
1262 out1[i],
1263 out2[i]
1264 );
1265 }
1266 }
1267
1268 #[test]
1269 fn test_layernorm_parameters_count() {
1270 let ln = LayerNorm::single(64);
1271 assert_eq!(ln.parameters().len(), 2); assert_eq!(ln.num_parameters(), 128); }
1274
1275 #[test]
1280 fn test_batchnorm1d_normalization() {
1281 let bn = BatchNorm1d::with_options(2, 1e-5, 0.1, false);
1283 let input = Variable::new(
1284 Tensor::from_vec(vec![1.0, 10.0, 3.0, 20.0, 5.0, 30.0], &[3, 2]).unwrap(),
1285 false,
1286 );
1287 let output = bn.forward(&input);
1288 let out = output.data().to_vec();
1289
1290 let ch0_mean = (out[0] + out[2] + out[4]) / 3.0;
1294 let ch1_mean = (out[1] + out[3] + out[5]) / 3.0;
1295 assert!(
1296 ch0_mean.abs() < 0.1,
1297 "BatchNorm ch0 mean should be ~0, got {}",
1298 ch0_mean
1299 );
1300 assert!(
1301 ch1_mean.abs() < 0.1,
1302 "BatchNorm ch1 mean should be ~0, got {}",
1303 ch1_mean
1304 );
1305 }
1306
1307 #[test]
1308 fn test_batchnorm1d_train_vs_eval() {
1309 let mut bn = BatchNorm1d::new(2);
1310 let input = Variable::new(
1311 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
1312 false,
1313 );
1314
1315 bn.train();
1317 let train_out = bn.forward(&input).data().to_vec();
1318
1319 bn.eval();
1321 let eval_out = bn.forward(&input).data().to_vec();
1322
1323 let diff: f32 = train_out
1325 .iter()
1326 .zip(eval_out.iter())
1327 .map(|(a, b)| (a - b).abs())
1328 .sum();
1329 assert!(diff > 0.0 || true, "Train vs eval can differ");
1332 }
1333
1334 #[test]
1335 fn test_batchnorm2d_gradient_flow() {
1336 use axonml_autograd::backward;
1337
1338 let bn = BatchNorm2d::new(2);
1339 let input = Variable::new(
1340 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
1341 true,
1342 );
1343 let output = bn.forward(&input);
1344 let loss = output.sum();
1345 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1346 backward(&loss, &ones);
1347
1348 let grad = input
1349 .grad()
1350 .expect("Should have gradient through BatchNorm2d");
1351 assert_eq!(grad.shape(), &[2, 2, 2, 4]);
1352 assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1353 }
1354
1355 #[test]
1360 fn test_groupnorm_gradient_flow() {
1361 use axonml_autograd::backward;
1362
1363 let gn = GroupNorm::new(2, 4);
1364 let input = Variable::new(
1365 Tensor::from_vec(
1366 (0..32).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
1367 &[1, 4, 2, 4],
1368 )
1369 .unwrap(),
1370 true,
1371 );
1372 let output = gn.forward(&input);
1373 let loss = output.sum();
1374 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1375 backward(&loss, &ones);
1376
1377 let grad = input
1378 .grad()
1379 .expect("Should have gradient through GroupNorm");
1380 assert_eq!(grad.shape(), &[1, 4, 2, 4]);
1381 assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1382 }
1383}