1use std::collections::HashMap;
25use std::sync::atomic::{AtomicBool, Ordering};
26
27use axonml_autograd::Variable;
28use axonml_autograd::functions::{
29 BatchNorm1dBackward, BatchNorm2dBackward, GroupNormBackward, InstanceNorm2dBackward,
30 LayerNormBackward,
31};
32use axonml_autograd::grad_fn::GradFn;
33use axonml_autograd::no_grad::is_grad_enabled;
34use axonml_tensor::Tensor;
35use parking_lot::RwLock;
36
37use crate::init::{ones, zeros};
38use crate::module::Module;
39use crate::parameter::Parameter;
40
41pub struct BatchNorm1d {
53 pub weight: Parameter,
55 pub bias: Parameter,
57 running_mean: RwLock<Tensor<f32>>,
59 running_var: RwLock<Tensor<f32>>,
61 num_features: usize,
63 eps: f32,
65 momentum: f32,
67 track_running_stats: bool,
69 training: AtomicBool,
71}
72
73impl BatchNorm1d {
74 pub fn new(num_features: usize) -> Self {
76 Self::with_options(num_features, 1e-5, 0.1, true)
77 }
78
79 pub fn with_options(
81 num_features: usize,
82 eps: f32,
83 momentum: f32,
84 track_running_stats: bool,
85 ) -> Self {
86 Self {
87 weight: Parameter::named("weight", ones(&[num_features]), true),
88 bias: Parameter::named("bias", zeros(&[num_features]), true),
89 running_mean: RwLock::new(zeros(&[num_features])),
90 running_var: RwLock::new(ones(&[num_features])),
91 num_features,
92 eps,
93 momentum,
94 track_running_stats,
95 training: AtomicBool::new(true),
96 }
97 }
98
99 pub fn num_features(&self) -> usize {
101 self.num_features
102 }
103}
104
105impl Module for BatchNorm1d {
106 fn forward(&self, input: &Variable) -> Variable {
107 let input_data = input.data();
108 let shape = input_data.shape().to_vec();
109 let batch_size = shape[0];
110 let num_features = shape[1];
111
112 assert_eq!(
114 num_features, self.num_features,
115 "BatchNorm1d: expected {} features, got {}",
116 self.num_features, num_features
117 );
118
119 let is_training = self.training.load(Ordering::Relaxed);
120 let spatial_size: usize = if shape.len() > 2 {
121 shape[2..].iter().product()
122 } else {
123 1
124 };
125
126 #[cfg(feature = "cuda")]
131 if input_data.device().is_gpu() && is_training {
132 let gamma_data = self.weight.data();
133 let beta_data = self.bias.data();
134
135 let gamma_gpu = if !gamma_data.device().is_gpu() {
137 gamma_data
138 .to_device(input_data.device())
139 .unwrap_or(gamma_data)
140 } else {
141 gamma_data
142 };
143 let beta_gpu = if !beta_data.device().is_gpu() {
144 beta_data
145 .to_device(input_data.device())
146 .unwrap_or(beta_data)
147 } else {
148 beta_data
149 };
150
151 if let Some((output_tensor, means, vars)) = input_data.batchnorm_fused(
152 &gamma_gpu,
153 &beta_gpu,
154 self.eps,
155 num_features,
156 spatial_size,
157 ) {
158 if self.track_running_stats {
160 let mut running_mean = self.running_mean.write();
161 let mut running_var = self.running_var.write();
162 let running_mean_vec = running_mean.to_vec();
163 let running_var_vec = running_var.to_vec();
164 let new_mean: Vec<f32> = running_mean_vec
165 .iter()
166 .zip(means.iter())
167 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
168 .collect();
169 let new_var: Vec<f32> = running_var_vec
170 .iter()
171 .zip(vars.iter())
172 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
173 .collect();
174 *running_mean = Tensor::from_vec(new_mean, &[num_features])
175 .expect("tensor creation failed");
176 *running_var =
177 Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
178 }
179
180 let weight_vec = gamma_gpu.to_vec();
181 let requires_grad =
182 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
183 if requires_grad {
184 let weight_var = self.weight.variable();
185 let bias_var = self.bias.variable();
186 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
187 input.grad_fn().cloned(),
188 weight_var.grad_fn().cloned(),
189 bias_var.grad_fn().cloned(),
190 input_data,
191 means,
192 vars,
193 weight_vec,
194 self.eps,
195 self.num_features,
196 ));
197 return Variable::from_operation(output_tensor, grad_fn, true);
198 }
199 return Variable::new(output_tensor, false);
200 }
201 }
202
203 let input_vec = input_data.to_vec();
204 let weight_vec = self.weight.data().to_vec();
205 let bias_vec = self.bias.data().to_vec();
206
207 let mut means = vec![0.0f32; num_features];
208 let mut vars = vec![0.0f32; num_features];
209
210 if is_training {
211 for c in 0..num_features {
213 let mut sum = 0.0f32;
214 for b in 0..batch_size {
215 for s in 0..spatial_size {
216 let idx = b * num_features * spatial_size + c * spatial_size + s;
217 sum += input_vec[idx];
218 }
219 }
220 means[c] = sum / (batch_size * spatial_size) as f32;
221
222 let mut var_sum = 0.0f32;
223 for b in 0..batch_size {
224 for s in 0..spatial_size {
225 let idx = b * num_features * spatial_size + c * spatial_size + s;
226 let diff = input_vec[idx] - means[c];
227 var_sum += diff * diff;
228 }
229 }
230 vars[c] = var_sum / (batch_size * spatial_size) as f32;
231 }
232
233 if self.track_running_stats {
235 let mut running_mean = self.running_mean.write();
236 let mut running_var = self.running_var.write();
237 let running_mean_vec = running_mean.to_vec();
238 let running_var_vec = running_var.to_vec();
239
240 let new_mean: Vec<f32> = running_mean_vec
241 .iter()
242 .zip(means.iter())
243 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
244 .collect();
245 let new_var: Vec<f32> = running_var_vec
246 .iter()
247 .zip(vars.iter())
248 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
249 .collect();
250
251 *running_mean =
252 Tensor::from_vec(new_mean, &[num_features]).expect("tensor creation failed");
253 *running_var =
254 Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
255 }
256 } else {
257 means = self.running_mean.read().to_vec();
259 vars = self.running_var.read().to_vec();
260 }
261
262 let mut output_vec = vec![0.0f32; input_vec.len()];
264 for b in 0..batch_size {
265 for c in 0..num_features {
266 for s in 0..spatial_size {
267 let idx = b * num_features * spatial_size + c * spatial_size + s;
268 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
269 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
270 }
271 }
272 }
273
274 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
275
276 let requires_grad =
277 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
278 if requires_grad {
279 let weight_var = self.weight.variable();
280 let bias_var = self.bias.variable();
281
282 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
283 input.grad_fn().cloned(),
284 weight_var.grad_fn().cloned(),
285 bias_var.grad_fn().cloned(),
286 input_data,
287 means.clone(),
288 vars.clone(),
289 weight_vec,
290 self.eps,
291 self.num_features,
292 ));
293 Variable::from_operation(output, grad_fn, true)
294 } else {
295 Variable::new(output, false)
296 }
297 }
298
299 fn parameters(&self) -> Vec<Parameter> {
300 vec![self.weight.clone(), self.bias.clone()]
301 }
302
303 fn named_parameters(&self) -> HashMap<String, Parameter> {
304 let mut params = HashMap::new();
305 params.insert("weight".to_string(), self.weight.clone());
306 params.insert("bias".to_string(), self.bias.clone());
307 params
308 }
309
310 fn set_training(&mut self, training: bool) {
311 self.training.store(training, Ordering::Relaxed);
312 }
313
314 fn is_training(&self) -> bool {
315 self.training.load(Ordering::Relaxed)
316 }
317
318 fn name(&self) -> &'static str {
319 "BatchNorm1d"
320 }
321
322 fn to_device(&self, device: axonml_core::Device) {
323 for param in self.parameters() {
325 param.to_device(device);
326 }
327 if self.track_running_stats {
329 let mut rm = self.running_mean.write();
330 if let Ok(moved) = rm.to_device(device) {
331 *rm = moved;
332 }
333 let mut rv = self.running_var.write();
334 if let Ok(moved) = rv.to_device(device) {
335 *rv = moved;
336 }
337 }
338 }
339}
340
341pub struct BatchNorm2d {
351 pub weight: Parameter,
353 pub bias: Parameter,
355 running_mean: RwLock<Tensor<f32>>,
357 running_var: RwLock<Tensor<f32>>,
359 num_features: usize,
361 eps: f32,
363 momentum: f32,
365 training: AtomicBool,
367}
368
369impl BatchNorm2d {
370 pub fn new(num_features: usize) -> Self {
372 Self::with_options(num_features, 1e-5, 0.1)
373 }
374
375 pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
377 Self {
378 weight: Parameter::named("weight", ones(&[num_features]), true),
379 bias: Parameter::named("bias", zeros(&[num_features]), true),
380 running_mean: RwLock::new(zeros(&[num_features])),
381 running_var: RwLock::new(ones(&[num_features])),
382 num_features,
383 eps,
384 momentum,
385 training: AtomicBool::new(true),
386 }
387 }
388
389 pub fn num_features(&self) -> usize {
391 self.num_features
392 }
393}
394
395impl Module for BatchNorm2d {
396 fn forward(&self, input: &Variable) -> Variable {
397 let input_data = input.data();
398 let shape = input_data.shape().to_vec();
399 let batch_size = shape[0];
400 let channels = shape[1];
401 let height = shape[2];
402 let width = shape[3];
403 let spatial_size = height * width;
404
405 assert_eq!(
407 channels, self.num_features,
408 "BatchNorm2d: expected {} channels, got {}",
409 self.num_features, channels
410 );
411
412 let is_training = self.training.load(Ordering::Relaxed);
413
414 #[cfg(feature = "cuda")]
416 if input_data.device().is_gpu() && is_training {
417 let gamma_data = self.weight.data();
418 let beta_data = self.bias.data();
419
420 let gamma_gpu = if !gamma_data.device().is_gpu() {
422 gamma_data
423 .to_device(input_data.device())
424 .unwrap_or(gamma_data)
425 } else {
426 gamma_data
427 };
428 let beta_gpu = if !beta_data.device().is_gpu() {
429 beta_data
430 .to_device(input_data.device())
431 .unwrap_or(beta_data)
432 } else {
433 beta_data
434 };
435
436 if let Some((output_tensor, means, vars)) =
437 input_data.batchnorm_fused(&gamma_gpu, &beta_gpu, self.eps, channels, spatial_size)
438 {
439 let mut running_mean = self.running_mean.write();
441 let mut running_var = self.running_var.write();
442 let running_mean_vec = running_mean.to_vec();
443 let running_var_vec = running_var.to_vec();
444 let new_mean: Vec<f32> = running_mean_vec
445 .iter()
446 .zip(means.iter())
447 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
448 .collect();
449 let new_var: Vec<f32> = running_var_vec
450 .iter()
451 .zip(vars.iter())
452 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
453 .collect();
454 *running_mean =
455 Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
456 *running_var =
457 Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
458
459 let weight_vec = gamma_gpu.to_vec();
460 let requires_grad =
461 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
462 if requires_grad {
463 let weight_var = self.weight.variable();
464 let bias_var = self.bias.variable();
465 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
466 input.grad_fn().cloned(),
467 weight_var.grad_fn().cloned(),
468 bias_var.grad_fn().cloned(),
469 input_data,
470 means,
471 vars,
472 weight_vec,
473 self.eps,
474 self.num_features,
475 ));
476 return Variable::from_operation(output_tensor, grad_fn, true);
477 }
478 return Variable::new(output_tensor, false);
479 }
480 }
481
482 let input_vec = input_data.to_vec();
484 let weight_vec = self.weight.data().to_vec();
485 let bias_vec = self.bias.data().to_vec();
486
487 let mut means = vec![0.0f32; channels];
488 let mut vars = vec![0.0f32; channels];
489
490 if is_training {
491 let n_per_channel = (batch_size * spatial_size) as f32;
492 for c in 0..channels {
493 let mut sum = 0.0f32;
494 let mut sum_sq = 0.0f32;
495 for b in 0..batch_size {
496 let base = b * channels * spatial_size + c * spatial_size;
497 for s in 0..spatial_size {
498 let val = input_vec[base + s];
499 sum += val;
500 sum_sq += val * val;
501 }
502 }
503 means[c] = sum / n_per_channel;
504 vars[c] = sum_sq / n_per_channel - means[c] * means[c];
505 }
506
507 let mut running_mean = self.running_mean.write();
509 let mut running_var = self.running_var.write();
510 let running_mean_vec = running_mean.to_vec();
511 let running_var_vec = running_var.to_vec();
512
513 let new_mean: Vec<f32> = running_mean_vec
514 .iter()
515 .zip(means.iter())
516 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
517 .collect();
518 let new_var: Vec<f32> = running_var_vec
519 .iter()
520 .zip(vars.iter())
521 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
522 .collect();
523
524 *running_mean =
525 Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
526 *running_var = Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
527 } else {
528 means = self.running_mean.read().to_vec();
529 vars = self.running_var.read().to_vec();
530 }
531
532 let total = input_vec.len();
534 let mut output_vec = vec![0.0f32; total];
535
536 let inv_stds: Vec<f32> = vars.iter().map(|v| 1.0 / (v + self.eps).sqrt()).collect();
538
539 for i in 0..total {
540 let c = (i / spatial_size) % channels;
541 output_vec[i] = (input_vec[i] - means[c]) * inv_stds[c] * weight_vec[c] + bias_vec[c];
542 }
543
544 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
545
546 let requires_grad =
547 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
548 if requires_grad {
549 let weight_var = self.weight.variable();
550 let bias_var = self.bias.variable();
551
552 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
553 input.grad_fn().cloned(),
554 weight_var.grad_fn().cloned(),
555 bias_var.grad_fn().cloned(),
556 input_data,
557 means.clone(),
558 vars.clone(),
559 weight_vec,
560 self.eps,
561 self.num_features,
562 ));
563 Variable::from_operation(output, grad_fn, true)
564 } else {
565 Variable::new(output, false)
566 }
567 }
568
569 fn parameters(&self) -> Vec<Parameter> {
570 vec![self.weight.clone(), self.bias.clone()]
571 }
572
573 fn named_parameters(&self) -> HashMap<String, Parameter> {
574 let mut params = HashMap::new();
575 params.insert("weight".to_string(), self.weight.clone());
576 params.insert("bias".to_string(), self.bias.clone());
577 params
578 }
579
580 fn set_training(&mut self, training: bool) {
581 self.training.store(training, Ordering::Relaxed);
582 }
583
584 fn is_training(&self) -> bool {
585 self.training.load(Ordering::Relaxed)
586 }
587
588 fn name(&self) -> &'static str {
589 "BatchNorm2d"
590 }
591
592 fn to_device(&self, device: axonml_core::Device) {
593 for param in self.parameters() {
594 param.to_device(device);
595 }
596 let mut rm = self.running_mean.write();
598 if let Ok(moved) = rm.to_device(device) {
599 *rm = moved;
600 }
601 let mut rv = self.running_var.write();
602 if let Ok(moved) = rv.to_device(device) {
603 *rv = moved;
604 }
605 }
606}
607
608pub struct LayerNorm {
618 pub weight: Parameter,
620 pub bias: Parameter,
622 normalized_shape: Vec<usize>,
624 eps: f32,
626}
627
628impl LayerNorm {
629 pub fn new(normalized_shape: Vec<usize>) -> Self {
631 Self::with_eps(normalized_shape, 1e-5)
632 }
633
634 pub fn single(size: usize) -> Self {
636 Self::new(vec![size])
637 }
638
639 pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
641 let numel: usize = normalized_shape.iter().product();
642 Self {
643 weight: Parameter::named("weight", ones(&[numel]), true),
644 bias: Parameter::named("bias", zeros(&[numel]), true),
645 normalized_shape,
646 eps,
647 }
648 }
649}
650
651impl Module for LayerNorm {
652 fn forward(&self, input: &Variable) -> Variable {
653 let input_data = input.data();
654 let shape = input_data.shape().to_vec();
655 let norm_size: usize = self.normalized_shape.iter().product();
656 let total_len = input_data.numel();
657 let num_rows = total_len / norm_size;
658
659 #[cfg(feature = "cuda")]
661 if input_data.device().is_gpu() {
662 let weight_data = self.weight.data();
664 let weight_gpu = if weight_data.device().is_gpu() {
665 weight_data.clone()
666 } else {
667 weight_data.to_device(input_data.device().clone()).unwrap()
668 };
669 let bias_data = self.bias.data();
670 let bias_gpu = if bias_data.device().is_gpu() {
671 bias_data.clone()
672 } else {
673 bias_data.to_device(input_data.device().clone()).unwrap()
674 };
675
676 let output = input_data
677 .layer_norm_cuda(&weight_gpu, &bias_gpu, norm_size, self.eps)
678 .expect("CUDA LayerNorm failed");
679
680 let requires_grad = input.requires_grad() && is_grad_enabled();
681 return if requires_grad {
682 let grad_fn = GradFn::new(LayerNormBackward::new(
683 input.grad_fn().cloned(),
684 self.weight.variable().grad_fn().cloned(),
685 self.bias.variable().grad_fn().cloned(),
686 input_data.clone(),
687 self.weight.data().clone(),
688 self.normalized_shape.clone(),
689 self.eps,
690 ));
691 Variable::from_operation(output, grad_fn, true)
692 } else {
693 Variable::from_tensor(output)
694 };
695 }
696
697 let input_vec = input_data.to_vec();
699 let weight_vec = self.weight.data().to_vec();
700 let bias_vec = self.bias.data().to_vec();
701
702 let mut output_vec = vec![0.0f32; input_vec.len()];
703
704 for b in 0..num_rows {
705 let start = b * norm_size;
706 let end = start + norm_size;
707 let slice = &input_vec[start..end];
708
709 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
710 let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
711
712 for i in 0..norm_size {
713 let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
714 output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
715 }
716 }
717
718 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
719 let requires_grad = input.requires_grad() && is_grad_enabled();
720
721 if requires_grad {
722 let grad_fn = GradFn::new(LayerNormBackward::new(
723 input.grad_fn().cloned(),
724 self.weight.variable().grad_fn().cloned(),
725 self.bias.variable().grad_fn().cloned(),
726 input_data.clone(),
727 self.weight.data().clone(),
728 self.normalized_shape.clone(),
729 self.eps,
730 ));
731 Variable::from_operation(output, grad_fn, true)
732 } else {
733 Variable::from_tensor(output)
734 }
735 }
736
737 fn parameters(&self) -> Vec<Parameter> {
738 vec![self.weight.clone(), self.bias.clone()]
739 }
740
741 fn named_parameters(&self) -> HashMap<String, Parameter> {
742 let mut params = HashMap::new();
743 params.insert("weight".to_string(), self.weight.clone());
744 params.insert("bias".to_string(), self.bias.clone());
745 params
746 }
747
748 fn name(&self) -> &'static str {
749 "LayerNorm"
750 }
751}
752
753pub struct GroupNorm {
766 pub weight: Parameter,
768 pub bias: Parameter,
770 num_groups: usize,
772 num_channels: usize,
774 eps: f32,
776 affine: bool,
778}
779
780impl GroupNorm {
781 pub fn new(num_groups: usize, num_channels: usize) -> Self {
787 Self::with_options(num_groups, num_channels, 1e-5, true)
788 }
789
790 pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
792 assert!(
793 num_channels % num_groups == 0,
794 "num_channels ({}) must be divisible by num_groups ({})",
795 num_channels,
796 num_groups
797 );
798
799 Self {
800 weight: Parameter::named("weight", ones(&[num_channels]), affine),
801 bias: Parameter::named("bias", zeros(&[num_channels]), affine),
802 num_groups,
803 num_channels,
804 eps,
805 affine,
806 }
807 }
808}
809
810impl Module for GroupNorm {
811 fn forward(&self, input: &Variable) -> Variable {
812 let input_data = input.data();
813 let shape = input_data.shape().to_vec();
814 let batch_size = shape[0];
815 let channels = shape[1];
816 let spatial_size: usize = shape[2..].iter().product();
817
818 assert_eq!(
819 channels, self.num_channels,
820 "GroupNorm: expected {} channels, got {}",
821 self.num_channels, channels
822 );
823
824 let input_vec = input_data.to_vec();
825 let channels_per_group = channels / self.num_groups;
826
827 let mut output_vec = vec![0.0f32; input_vec.len()];
828
829 for b in 0..batch_size {
830 for g in 0..self.num_groups {
831 let mut sum = 0.0f32;
833 let group_size = channels_per_group * spatial_size;
834
835 for c in 0..channels_per_group {
836 let channel_idx = g * channels_per_group + c;
837 for s in 0..spatial_size {
838 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
839 sum += input_vec[idx];
840 }
841 }
842 let mean = sum / group_size as f32;
843
844 let mut var_sum = 0.0f32;
845 for c in 0..channels_per_group {
846 let channel_idx = g * channels_per_group + c;
847 for s in 0..spatial_size {
848 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
849 let diff = input_vec[idx] - mean;
850 var_sum += diff * diff;
851 }
852 }
853 let var = var_sum / group_size as f32;
854
855 let std_inv = 1.0 / (var + self.eps).sqrt();
857 for c in 0..channels_per_group {
858 let channel_idx = g * channels_per_group + c;
859 let weight = if self.affine {
860 self.weight.data().to_vec()[channel_idx]
861 } else {
862 1.0
863 };
864 let bias = if self.affine {
865 self.bias.data().to_vec()[channel_idx]
866 } else {
867 0.0
868 };
869
870 for s in 0..spatial_size {
871 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
872 let normalized = (input_vec[idx] - mean) * std_inv;
873 output_vec[idx] = normalized * weight + bias;
874 }
875 }
876 }
877 }
878
879 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
880 let requires_grad = input.requires_grad() && is_grad_enabled();
881
882 if requires_grad && self.affine {
883 let grad_fn = GradFn::new(GroupNormBackward::new(
884 input.grad_fn().cloned(),
885 self.weight.variable().grad_fn().cloned(),
886 self.bias.variable().grad_fn().cloned(),
887 input_data.clone(),
888 self.weight.data().clone(),
889 self.num_groups,
890 self.eps,
891 ));
892 Variable::from_operation(output, grad_fn, true)
893 } else {
894 Variable::from_tensor(output)
895 }
896 }
897
898 fn parameters(&self) -> Vec<Parameter> {
899 if self.affine {
900 vec![self.weight.clone(), self.bias.clone()]
901 } else {
902 vec![]
903 }
904 }
905
906 fn named_parameters(&self) -> HashMap<String, Parameter> {
907 if self.affine {
908 let mut params = HashMap::new();
909 params.insert("weight".to_string(), self.weight.clone());
910 params.insert("bias".to_string(), self.bias.clone());
911 params
912 } else {
913 HashMap::new()
914 }
915 }
916
917 fn name(&self) -> &'static str {
918 "GroupNorm"
919 }
920}
921
922pub struct InstanceNorm2d {
935 pub weight: Parameter,
937 pub bias: Parameter,
939 num_features: usize,
941 eps: f32,
943 affine: bool,
945}
946
947impl InstanceNorm2d {
948 pub fn new(num_features: usize) -> Self {
950 Self::with_options(num_features, 1e-5, false)
951 }
952
953 pub fn with_affine(num_features: usize) -> Self {
955 Self::with_options(num_features, 1e-5, true)
956 }
957
958 pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
960 Self {
961 weight: Parameter::named("weight", ones(&[num_features]), affine),
962 bias: Parameter::named("bias", zeros(&[num_features]), affine),
963 num_features,
964 eps,
965 affine,
966 }
967 }
968}
969
970impl Module for InstanceNorm2d {
971 fn forward(&self, input: &Variable) -> Variable {
972 let input_data = input.data();
973 let shape = input_data.shape().to_vec();
974
975 assert!(
976 shape.len() == 4,
977 "InstanceNorm2d expects 4D input (N, C, H, W)"
978 );
979
980 let batch_size = shape[0];
981 let channels = shape[1];
982 let height = shape[2];
983 let width = shape[3];
984 let spatial_size = height * width;
985
986 assert_eq!(
987 channels, self.num_features,
988 "InstanceNorm2d: expected {} channels, got {}",
989 self.num_features, channels
990 );
991
992 let input_vec = input_data.to_vec();
993 let mut output_vec = vec![0.0f32; input_vec.len()];
994
995 for b in 0..batch_size {
996 for c in 0..channels {
997 let mut sum = 0.0f32;
999 for s in 0..spatial_size {
1000 let idx = b * channels * spatial_size + c * spatial_size + s;
1001 sum += input_vec[idx];
1002 }
1003 let mean = sum / spatial_size as f32;
1004
1005 let mut var_sum = 0.0f32;
1007 for s in 0..spatial_size {
1008 let idx = b * channels * spatial_size + c * spatial_size + s;
1009 let diff = input_vec[idx] - mean;
1010 var_sum += diff * diff;
1011 }
1012 let var = var_sum / spatial_size as f32;
1013
1014 let std_inv = 1.0 / (var + self.eps).sqrt();
1016 let weight = if self.affine {
1017 self.weight.data().to_vec()[c]
1018 } else {
1019 1.0
1020 };
1021 let bias = if self.affine {
1022 self.bias.data().to_vec()[c]
1023 } else {
1024 0.0
1025 };
1026
1027 for s in 0..spatial_size {
1028 let idx = b * channels * spatial_size + c * spatial_size + s;
1029 let normalized = (input_vec[idx] - mean) * std_inv;
1030 output_vec[idx] = normalized * weight + bias;
1031 }
1032 }
1033 }
1034
1035 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
1036 let requires_grad = input.requires_grad() && is_grad_enabled();
1037
1038 if requires_grad {
1039 let grad_fn = GradFn::new(InstanceNorm2dBackward::new(
1040 input.grad_fn().cloned(),
1041 if self.affine {
1042 self.weight.variable().grad_fn().cloned()
1043 } else {
1044 None
1045 },
1046 if self.affine {
1047 self.bias.variable().grad_fn().cloned()
1048 } else {
1049 None
1050 },
1051 input_data.clone(),
1052 self.weight.data().clone(),
1053 self.eps,
1054 self.affine,
1055 ));
1056 Variable::from_operation(output, grad_fn, true)
1057 } else {
1058 Variable::from_tensor(output)
1059 }
1060 }
1061
1062 fn parameters(&self) -> Vec<Parameter> {
1063 if self.affine {
1064 vec![self.weight.clone(), self.bias.clone()]
1065 } else {
1066 vec![]
1067 }
1068 }
1069
1070 fn named_parameters(&self) -> HashMap<String, Parameter> {
1071 if self.affine {
1072 let mut params = HashMap::new();
1073 params.insert("weight".to_string(), self.weight.clone());
1074 params.insert("bias".to_string(), self.bias.clone());
1075 params
1076 } else {
1077 HashMap::new()
1078 }
1079 }
1080
1081 fn name(&self) -> &'static str {
1082 "InstanceNorm2d"
1083 }
1084}
1085
1086#[cfg(test)]
1091mod tests {
1092 use super::*;
1093
1094 #[test]
1095 fn test_batchnorm1d() {
1096 let bn = BatchNorm1d::new(3);
1097 let input = Variable::new(
1098 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1099 .expect("tensor creation failed"),
1100 false,
1101 );
1102 let output = bn.forward(&input);
1103 assert_eq!(output.shape(), vec![2, 3]);
1104 }
1105
1106 #[test]
1107 fn test_batchnorm2d() {
1108 let bn = BatchNorm2d::new(2);
1109 let input = Variable::new(
1110 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1111 false,
1112 );
1113 let output = bn.forward(&input);
1114 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1115 }
1116
1117 #[test]
1118 fn test_layernorm() {
1119 let ln = LayerNorm::single(4);
1120 let input = Variable::new(
1121 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
1122 .expect("tensor creation failed"),
1123 false,
1124 );
1125 let output = ln.forward(&input);
1126 assert_eq!(output.shape(), vec![2, 4]);
1127 }
1128
1129 #[test]
1130 fn test_batchnorm_parameters() {
1131 let bn = BatchNorm1d::new(10);
1132 assert_eq!(bn.parameters().len(), 2);
1133 assert_eq!(bn.num_parameters(), 20); }
1135
1136 #[test]
1137 fn test_groupnorm() {
1138 let gn = GroupNorm::new(2, 4); let input = Variable::new(
1140 Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).expect("tensor creation failed"),
1141 false,
1142 );
1143 let output = gn.forward(&input);
1144 assert_eq!(output.shape(), vec![2, 4, 2, 2]);
1145 }
1146
1147 #[test]
1148 fn test_groupnorm_normalization() {
1149 let gn = GroupNorm::with_options(2, 4, 1e-5, false); let input = Variable::new(
1151 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2])
1152 .expect("tensor creation failed"),
1153 false,
1154 );
1155 let output = gn.forward(&input);
1156 let out_vec = output.data().to_vec();
1158 let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
1160 let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
1161 assert!(group1_mean.abs() < 1e-5);
1162 assert!(group2_mean.abs() < 1e-5);
1163 }
1164
1165 #[test]
1166 fn test_instancenorm2d() {
1167 let inn = InstanceNorm2d::new(2);
1168 let input = Variable::new(
1169 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1170 false,
1171 );
1172 let output = inn.forward(&input);
1173 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1174 }
1175
1176 #[test]
1177 fn test_instancenorm2d_with_affine() {
1178 let inn = InstanceNorm2d::with_affine(4);
1179 let input = Variable::new(
1180 Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).expect("tensor creation failed"),
1181 false,
1182 );
1183 let output = inn.forward(&input);
1184 assert_eq!(output.shape(), vec![1, 4, 4, 4]);
1185 assert_eq!(inn.parameters().len(), 2);
1186 }
1187
1188 #[test]
1193 fn test_layernorm_zero_mean_unit_var() {
1194 let ln = LayerNorm::with_eps(vec![4], 1e-5);
1196 let input = Variable::new(
1197 Tensor::from_vec(vec![1.0, 5.0, 3.0, 7.0], &[1, 4]).unwrap(),
1198 false,
1199 );
1200 let output = ln.forward(&input);
1201 let out = output.data().to_vec();
1202
1203 let mean: f32 = out.iter().sum::<f32>() / out.len() as f32;
1204 let var: f32 = out.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / out.len() as f32;
1205
1206 assert!(
1207 mean.abs() < 1e-4,
1208 "LayerNorm output mean should be ~0, got {}",
1209 mean
1210 );
1211 assert!(
1212 (var - 1.0).abs() < 0.1,
1213 "LayerNorm output var should be ~1, got {}",
1214 var
1215 );
1216 }
1217
1218 #[test]
1219 fn test_layernorm_gradient_flow() {
1220 use axonml_autograd::backward;
1221
1222 let ln = LayerNorm::single(3);
1223 let input = Variable::new(
1224 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
1225 true,
1226 );
1227 let output = ln.forward(&input);
1228 let loss = output.sum();
1229
1230 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1231 backward(&loss, &ones);
1232
1233 let grad = input
1234 .grad()
1235 .expect("Should have gradient through LayerNorm");
1236 let gv = grad.to_vec();
1237 assert_eq!(gv.len(), 3);
1238 assert!(
1240 gv.iter().all(|g| g.is_finite()),
1241 "All gradients should be finite: {:?}",
1242 gv
1243 );
1244 }
1245
1246 #[test]
1247 fn test_layernorm_batch_independence() {
1248 let ln = LayerNorm::with_eps(vec![3], 1e-5);
1249
1250 let input1 = Variable::new(
1252 Tensor::from_vec(vec![10.0, 20.0, 30.0], &[1, 3]).unwrap(),
1253 false,
1254 );
1255 let out1 = ln.forward(&input1).data().to_vec();
1256
1257 let input2 = Variable::new(
1259 Tensor::from_vec(vec![10.0, 20.0, 30.0, 1.0, 1.0, 1.0], &[2, 3]).unwrap(),
1260 false,
1261 );
1262 let out2 = ln.forward(&input2).data().to_vec();
1263
1264 for i in 0..3 {
1266 assert!(
1267 (out1[i] - out2[i]).abs() < 1e-5,
1268 "LayerNorm should be batch-independent: {} vs {}",
1269 out1[i],
1270 out2[i]
1271 );
1272 }
1273 }
1274
1275 #[test]
1276 fn test_layernorm_parameters_count() {
1277 let ln = LayerNorm::single(64);
1278 assert_eq!(ln.parameters().len(), 2); assert_eq!(ln.num_parameters(), 128); }
1281
1282 #[test]
1287 fn test_batchnorm1d_normalization() {
1288 let bn = BatchNorm1d::with_options(2, 1e-5, 0.1, false);
1290 let input = Variable::new(
1291 Tensor::from_vec(vec![1.0, 10.0, 3.0, 20.0, 5.0, 30.0], &[3, 2]).unwrap(),
1292 false,
1293 );
1294 let output = bn.forward(&input);
1295 let out = output.data().to_vec();
1296
1297 let ch0_mean = (out[0] + out[2] + out[4]) / 3.0;
1301 let ch1_mean = (out[1] + out[3] + out[5]) / 3.0;
1302 assert!(
1303 ch0_mean.abs() < 0.1,
1304 "BatchNorm ch0 mean should be ~0, got {}",
1305 ch0_mean
1306 );
1307 assert!(
1308 ch1_mean.abs() < 0.1,
1309 "BatchNorm ch1 mean should be ~0, got {}",
1310 ch1_mean
1311 );
1312 }
1313
1314 #[test]
1315 fn test_batchnorm1d_train_vs_eval() {
1316 let mut bn = BatchNorm1d::new(2);
1317 let input = Variable::new(
1318 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
1319 false,
1320 );
1321
1322 bn.train();
1324 let train_out = bn.forward(&input).data().to_vec();
1325
1326 bn.eval();
1328 let eval_out = bn.forward(&input).data().to_vec();
1329
1330 let diff: f32 = train_out
1332 .iter()
1333 .zip(eval_out.iter())
1334 .map(|(a, b)| (a - b).abs())
1335 .sum();
1336 assert!(diff > 0.0 || true, "Train vs eval can differ");
1339 }
1340
1341 #[test]
1342 fn test_batchnorm2d_gradient_flow() {
1343 use axonml_autograd::backward;
1344
1345 let bn = BatchNorm2d::new(2);
1346 let input = Variable::new(
1347 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
1348 true,
1349 );
1350 let output = bn.forward(&input);
1351 let loss = output.sum();
1352 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1353 backward(&loss, &ones);
1354
1355 let grad = input
1356 .grad()
1357 .expect("Should have gradient through BatchNorm2d");
1358 assert_eq!(grad.shape(), &[2, 2, 2, 4]);
1359 assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1360 }
1361
1362 #[test]
1367 fn test_groupnorm_gradient_flow() {
1368 use axonml_autograd::backward;
1369
1370 let gn = GroupNorm::new(2, 4);
1371 let input = Variable::new(
1372 Tensor::from_vec(
1373 (0..32).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
1374 &[1, 4, 2, 4],
1375 )
1376 .unwrap(),
1377 true,
1378 );
1379 let output = gn.forward(&input);
1380 let loss = output.sum();
1381 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1382 backward(&loss, &ones);
1383
1384 let grad = input
1385 .grad()
1386 .expect("Should have gradient through GroupNorm");
1387 assert_eq!(grad.shape(), &[1, 4, 2, 4]);
1388 assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1389 }
1390}