1use std::sync::{Arc, Mutex};
64
65use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
66use ferrotorch_core::storage::TensorStorage;
67use ferrotorch_core::tensor::{GradFn, Tensor};
68use ferrotorch_core::{Float, is_grad_enabled};
69use ferrotorch_nn::{Module, Parameter};
70
71use crate::backend::Backend;
72use crate::collective::{ReduceOp, allreduce};
73
74#[non_exhaustive]
85pub struct SyncBatchNorm2d<T: Float> {
86 pub num_features: usize,
87 pub eps: f64,
88 pub momentum: f64,
89 pub affine: bool,
90 pub weight: Option<Parameter<T>>,
91 pub bias: Option<Parameter<T>>,
92 running_mean: Mutex<Vec<f64>>,
93 running_var: Mutex<Vec<f64>>,
94 num_batches_tracked: Mutex<usize>,
95 training: Mutex<bool>,
96 backend: Option<Arc<dyn Backend>>,
99}
100
101impl<T: Float> std::fmt::Debug for SyncBatchNorm2d<T> {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("SyncBatchNorm2d")
104 .field("num_features", &self.num_features)
105 .field("eps", &self.eps)
106 .field("momentum", &self.momentum)
107 .field("affine", &self.affine)
108 .field(
109 "world_size",
110 &self.backend.as_ref().map(|b| b.world_size()).unwrap_or(1),
111 )
112 .field("training", &self.training)
113 .finish()
114 }
115}
116
117impl<T: Float> SyncBatchNorm2d<T> {
118 pub fn new(
122 num_features: usize,
123 eps: f64,
124 momentum: f64,
125 affine: bool,
126 ) -> FerrotorchResult<Self> {
127 if num_features == 0 {
128 return Err(FerrotorchError::InvalidArgument {
129 message: "SyncBatchNorm2d: num_features must be positive".into(),
130 });
131 }
132 let weight = if affine {
133 Some(Parameter::ones(&[num_features])?)
134 } else {
135 None
136 };
137 let bias = if affine {
138 Some(Parameter::zeros(&[num_features])?)
139 } else {
140 None
141 };
142 Ok(Self {
143 num_features,
144 eps,
145 momentum,
146 affine,
147 weight,
148 bias,
149 running_mean: Mutex::new(vec![0.0; num_features]),
150 running_var: Mutex::new(vec![1.0; num_features]),
151 num_batches_tracked: Mutex::new(0),
152 training: Mutex::new(true),
153 backend: None,
154 })
155 }
156
157 pub fn with_backend(mut self, backend: Arc<dyn Backend>) -> Self {
160 self.backend = Some(backend);
161 self
162 }
163
164 pub fn running_mean(&self) -> Vec<f64> {
166 self.running_mean.lock().unwrap().clone()
167 }
168
169 pub fn running_var(&self) -> Vec<f64> {
171 self.running_var.lock().unwrap().clone()
172 }
173
174 pub fn num_batches_tracked(&self) -> usize {
176 *self.num_batches_tracked.lock().unwrap()
177 }
178}
179
180impl<T: Float> Module<T> for SyncBatchNorm2d<T> {
181 #[allow(clippy::manual_memcpy)]
184 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
185 let shape = input.shape().to_vec();
186 if shape.len() != 4 {
187 return Err(FerrotorchError::ShapeMismatch {
188 message: format!(
189 "SyncBatchNorm2d: expected 4D input [B, C, H, W], got {:?}",
190 shape
191 ),
192 });
193 }
194 let batch = shape[0];
195 let channels = shape[1];
196 let height = shape[2];
197 let width = shape[3];
198 let spatial = height * width;
199
200 if channels != self.num_features {
201 return Err(FerrotorchError::ShapeMismatch {
202 message: format!(
203 "SyncBatchNorm2d: expected {} channels, got {}",
204 self.num_features, channels
205 ),
206 });
207 }
208
209 if input.is_cuda() {
210 return Err(FerrotorchError::NotImplementedOnCuda {
211 op: "SyncBatchNorm2d::forward",
212 });
213 }
214
215 let input_data = input.data()?;
216 let eps_t = T::from(self.eps).unwrap();
217 let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
218 let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
219 let is_training = *self.training.lock().unwrap();
220
221 let mut chan_mean = vec![<T as num_traits::Zero>::zero(); channels];
222 let mut chan_var = vec![<T as num_traits::Zero>::zero(); channels];
223
224 if is_training {
225 let local_count = batch * spatial;
227 let mut sum = vec![<T as num_traits::Zero>::zero(); channels];
228 let mut sum_sq = vec![<T as num_traits::Zero>::zero(); channels];
229
230 for c in 0..channels {
231 for b in 0..batch {
232 let base = b * channels * spatial + c * spatial;
233 for s in 0..spatial {
234 let v = input_data[base + s];
235 sum[c] += v;
236 sum_sq[c] += v * v;
237 }
238 }
239 }
240
241 let global_count = if let Some(ref backend) = self.backend {
243 let world_size = backend.world_size();
244 if world_size > 1 {
245 let mut packed: Vec<T> = Vec::with_capacity(2 * channels);
248 packed.extend_from_slice(&sum);
249 packed.extend_from_slice(&sum_sq);
250 let packed_t = Tensor::from_storage(
251 TensorStorage::cpu(packed),
252 vec![2 * channels],
253 false,
254 )?;
255 let reduced = allreduce(&packed_t, backend.as_ref(), ReduceOp::Sum)?;
256 let reduced_data = reduced.data()?;
257 for c in 0..channels {
258 sum[c] = reduced_data[c];
259 sum_sq[c] = reduced_data[channels + c];
260 }
261 local_count * world_size
262 } else {
263 local_count
264 }
265 } else {
266 local_count
267 };
268
269 let global_count_t = T::from(global_count).unwrap();
270 for c in 0..channels {
271 let m = sum[c] / global_count_t;
272 chan_mean[c] = m;
273 chan_var[c] = sum_sq[c] / global_count_t - m * m;
275 }
276
277 {
279 let mut rm = self.running_mean.lock().unwrap();
280 let mut rv = self.running_var.lock().unwrap();
281 let mut nbt = self.num_batches_tracked.lock().unwrap();
282 *nbt += 1;
283 let mom = self.momentum;
284 let bessel = if global_count > 1 {
285 global_count as f64 / (global_count as f64 - 1.0)
286 } else {
287 1.0
288 };
289 for c in 0..channels {
290 let bm = chan_mean[c].to_f64().unwrap();
291 let bv = chan_var[c].to_f64().unwrap();
292 rm[c] = (1.0 - mom) * rm[c] + mom * bm;
293 rv[c] = (1.0 - mom) * rv[c] + mom * bv * bessel;
294 }
295 }
296 } else {
297 let rm = self.running_mean.lock().unwrap();
299 let rv = self.running_var.lock().unwrap();
300 for c in 0..channels {
301 chan_mean[c] = T::from(rm[c]).unwrap();
302 chan_var[c] = T::from(rv[c]).unwrap();
303 }
304 }
305
306 let mut output = vec![<T as num_traits::Zero>::zero(); input.numel()];
308 let mut x_hat_data = if is_grad_enabled() && input.requires_grad() {
309 Vec::with_capacity(input.numel())
310 } else {
311 Vec::new()
312 };
313 let need_x_hat = is_grad_enabled() && input.requires_grad();
314
315 let mut inv_std = vec![<T as num_traits::Zero>::zero(); channels];
316 for c in 0..channels {
317 inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
318 }
319
320 for b in 0..batch {
321 for c in 0..channels {
322 let base = b * channels * spatial + c * spatial;
323 for s in 0..spatial {
324 let idx = base + s;
325 let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
326 if need_x_hat {
327 x_hat_data.push(normed);
328 }
329 if self.affine {
330 let w = weight_data.as_ref().unwrap();
331 let bi = bias_data.as_ref().unwrap();
332 output[idx] = normed * w[c] + bi[c];
333 } else {
334 output[idx] = normed;
335 }
336 }
337 }
338 }
339
340 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.clone(), false)?;
341
342 if need_x_hat {
343 let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
344 let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
345 let local_count = batch * spatial;
346 let global_count = self
347 .backend
348 .as_ref()
349 .map(|b| local_count * b.world_size())
350 .unwrap_or(local_count);
351 let grad_fn = Arc::new(SyncBatchNorm2dBackward {
352 input: input.clone(),
353 x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.clone(), false)?,
354 weight: weight_tensor,
355 bias: bias_tensor,
356 chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
357 eps: self.eps,
358 affine: self.affine,
359 global_count,
360 backend: self.backend.clone(),
361 });
362 Tensor::from_operation(
363 TensorStorage::cpu(result.data()?.to_vec()),
364 result.shape().to_vec(),
365 grad_fn,
366 )
367 } else {
368 Ok(result)
369 }
370 }
371
372 fn parameters(&self) -> Vec<&Parameter<T>> {
373 match (&self.weight, &self.bias) {
374 (Some(w), Some(b)) => vec![w, b],
375 _ => vec![],
376 }
377 }
378
379 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
380 match (&mut self.weight, &mut self.bias) {
381 (Some(w), Some(b)) => vec![w, b],
382 _ => vec![],
383 }
384 }
385
386 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
387 match (&self.weight, &self.bias) {
388 (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
389 _ => vec![],
390 }
391 }
392
393 fn train(&mut self) {
394 *self.training.lock().unwrap() = true;
395 }
396
397 fn eval(&mut self) {
398 *self.training.lock().unwrap() = false;
399 }
400
401 fn is_training(&self) -> bool {
402 *self.training.lock().unwrap()
403 }
404}
405
406struct SyncBatchNorm2dBackward<T: Float> {
412 input: Tensor<T>,
413 x_hat: Tensor<T>,
414 weight: Option<Tensor<T>>,
415 bias: Option<Tensor<T>>,
416 chan_var: Vec<f64>,
417 eps: f64,
418 affine: bool,
419 global_count: usize,
420 backend: Option<Arc<dyn Backend>>,
421}
422
423impl<T: Float> std::fmt::Debug for SyncBatchNorm2dBackward<T> {
424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 f.debug_struct("SyncBatchNorm2dBackward")
426 .field("global_count", &self.global_count)
427 .finish()
428 }
429}
430
431impl<T: Float> GradFn<T> for SyncBatchNorm2dBackward<T> {
432 #[allow(clippy::manual_memcpy)]
433 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
434 let shape = self.input.shape();
435 let batch = shape[0];
436 let channels = shape[1];
437 let height = shape[2];
438 let width = shape[3];
439 let spatial = height * width;
440
441 if self.input.is_cuda() {
442 return Err(FerrotorchError::NotImplementedOnCuda {
443 op: "SyncBatchNorm2dBackward",
444 });
445 }
446
447 let go_data = grad_output.data()?;
448 let x_hat_data = self.x_hat.data()?;
449 let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
450
451 let mut grad_input = vec![<T as num_traits::Zero>::zero(); self.input.numel()];
452 let mut grad_weight = vec![<T as num_traits::Zero>::zero(); channels];
453 let mut grad_bias = vec![<T as num_traits::Zero>::zero(); channels];
454
455 let mut local_dl_dx_hat_sum = vec![<T as num_traits::Zero>::zero(); channels];
457 let mut local_dl_dx_hat_x_hat_sum = vec![<T as num_traits::Zero>::zero(); channels];
458
459 for c in 0..channels {
460 for b in 0..batch {
461 let base = b * channels * spatial + c * spatial;
462 for s in 0..spatial {
463 let idx = base + s;
464 let x_h = x_hat_data[idx];
465 let go = go_data[idx];
466 let dl_dx_hat = if self.affine {
467 go * weight_data.as_ref().unwrap()[c]
468 } else {
469 go
470 };
471 local_dl_dx_hat_sum[c] += dl_dx_hat;
472 local_dl_dx_hat_x_hat_sum[c] += dl_dx_hat * x_h;
473 if self.affine {
474 grad_weight[c] += go * x_h;
475 grad_bias[c] += go;
476 }
477 }
478 }
479 }
480
481 let mut global_dl_dx_hat_sum = local_dl_dx_hat_sum.clone();
484 let mut global_dl_dx_hat_x_hat_sum = local_dl_dx_hat_x_hat_sum.clone();
485
486 if let Some(ref backend) = self.backend {
487 if backend.world_size() > 1 {
488 let mut packed: Vec<T> = Vec::with_capacity(2 * channels);
489 packed.extend_from_slice(&local_dl_dx_hat_sum);
490 packed.extend_from_slice(&local_dl_dx_hat_x_hat_sum);
491 let packed_t =
492 Tensor::from_storage(TensorStorage::cpu(packed), vec![2 * channels], false)?;
493 let reduced = allreduce(&packed_t, backend.as_ref(), ReduceOp::Sum)?;
494 let reduced_data = reduced.data()?;
495 for c in 0..channels {
496 global_dl_dx_hat_sum[c] = reduced_data[c];
497 global_dl_dx_hat_x_hat_sum[c] = reduced_data[channels + c];
498 }
499 }
500 }
501
502 let global_count_t = T::from(self.global_count).unwrap();
503
504 for c in 0..channels {
506 let var_f64 = self.chan_var[c];
507 let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
508
509 let dl_dx_hat_mean = global_dl_dx_hat_sum[c] / global_count_t;
510 let dl_dx_hat_x_hat_mean = global_dl_dx_hat_x_hat_sum[c] / global_count_t;
511
512 for b in 0..batch {
513 let base = b * channels * spatial + c * spatial;
514 for s in 0..spatial {
515 let idx = base + s;
516 let x_h = x_hat_data[idx];
517 let go = go_data[idx];
518 let dl_dx_hat = if self.affine {
519 go * weight_data.as_ref().unwrap()[c]
520 } else {
521 go
522 };
523 grad_input[idx] =
524 inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
525 }
526 }
527 }
528
529 let grad_input_tensor = Tensor::from_storage(
530 TensorStorage::cpu(grad_input),
531 self.input.shape().to_vec(),
532 false,
533 )?;
534 let grad_weight_out = if self.affine {
535 self.weight.as_ref().and_then(|w| {
536 if w.requires_grad() {
537 Some(
538 Tensor::from_storage(
539 TensorStorage::cpu(grad_weight),
540 vec![channels],
541 false,
542 )
543 .unwrap(),
544 )
545 } else {
546 None
547 }
548 })
549 } else {
550 None
551 };
552 let grad_bias_out = if self.affine {
553 self.bias.as_ref().and_then(|b| {
554 if b.requires_grad() {
555 Some(
556 Tensor::from_storage(TensorStorage::cpu(grad_bias), vec![channels], false)
557 .unwrap(),
558 )
559 } else {
560 None
561 }
562 })
563 } else {
564 None
565 };
566
567 Ok(vec![
568 Some(grad_input_tensor),
569 grad_weight_out,
570 grad_bias_out,
571 ])
572 }
573
574 fn inputs(&self) -> Vec<&Tensor<T>> {
575 let mut v: Vec<&Tensor<T>> = vec![&self.input];
576 if let Some(ref w) = self.weight {
577 v.push(w);
578 }
579 if let Some(ref b) = self.bias {
580 v.push(b);
581 }
582 v
583 }
584
585 fn name(&self) -> &'static str {
586 "SyncBatchNorm2dBackward"
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use crate::backend::SimulatedBackend;
594 use ferrotorch_core::Tensor;
595 use ferrotorch_nn::BatchNorm2d;
596 use std::thread;
597
598 fn cpu_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
599 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
600 }
601
602 #[test]
603 fn test_sync_bn_world_size_1_matches_batch_norm() {
604 let input_data: Vec<f32> = (0..24).map(|i| i as f32 / 10.0).collect();
607 let input = cpu_tensor(&input_data, &[2, 3, 2, 2]);
608
609 let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
610 let mut plain = BatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
611 sync.train();
612 plain.train();
613
614 let out_sync = sync.forward(&input).unwrap();
615 let out_plain = plain.forward(&input).unwrap();
616
617 let s = out_sync.data().unwrap();
618 let p = out_plain.data().unwrap();
619 for (i, (a, b)) in s.iter().zip(p.iter()).enumerate() {
620 assert!((a - b).abs() < 1e-5, "out[{i}]: sync={a}, plain={b}");
621 }
622 }
623
624 #[test]
625 fn test_sync_bn_two_ranks_match_full_batch() {
626 let full_data: Vec<f32> = (0..48).map(|i| (i as f32 - 24.0) / 10.0).collect();
630 let full = cpu_tensor(&full_data, &[4, 3, 2, 2]);
631
632 let mut plain = BatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
633 plain.train();
634 let plain_out = plain.forward(&full).unwrap();
635 let plain_data = plain_out.data().unwrap().to_vec();
636 let plain_running_mean = plain.running_mean();
637 let plain_running_var = plain.running_var();
638
639 let r0 = full_data[0..24].to_vec();
642 let r1 = full_data[24..48].to_vec();
643 let r0_t = cpu_tensor(&r0, &[2, 3, 2, 2]);
644 let r1_t = cpu_tensor(&r1, &[2, 3, 2, 2]);
645
646 let group = SimulatedBackend::create_group(2).unwrap();
649 let mut iter = group.into_iter();
650 let b0 = Arc::new(iter.next().unwrap());
651 let b1 = Arc::new(iter.next().unwrap());
652
653 let r0_clone = r0_t.clone();
654 let r1_clone = r1_t.clone();
655 let b0_clone: Arc<dyn Backend> = b0.clone();
656 let b1_clone: Arc<dyn Backend> = b1.clone();
657
658 let h0 = thread::spawn(move || {
659 let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true)
660 .unwrap()
661 .with_backend(b0_clone);
662 sync.train();
663 let out = sync.forward(&r0_clone).unwrap();
664 (
665 out.data().unwrap().to_vec(),
666 sync.running_mean(),
667 sync.running_var(),
668 )
669 });
670 let h1 = thread::spawn(move || {
671 let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true)
672 .unwrap()
673 .with_backend(b1_clone);
674 sync.train();
675 let out = sync.forward(&r1_clone).unwrap();
676 (
677 out.data().unwrap().to_vec(),
678 sync.running_mean(),
679 sync.running_var(),
680 )
681 });
682
683 let (out0, rm0, rv0) = h0.join().unwrap();
684 let (out1, rm1, rv1) = h1.join().unwrap();
685
686 let mut concat = out0.clone();
689 concat.extend_from_slice(&out1);
690 for (i, (a, b)) in concat.iter().zip(plain_data.iter()).enumerate() {
691 assert!((a - b).abs() < 1e-4, "out[{i}]: sync={a}, plain={b}");
692 }
693
694 for c in 0..3 {
697 assert!(
698 (rm0[c] - rm1[c]).abs() < 1e-6,
699 "rank0 and rank1 running_mean disagree at c={c}"
700 );
701 assert!(
702 (rm0[c] - plain_running_mean[c]).abs() < 1e-4,
703 "running_mean[{c}] sync={} plain={}",
704 rm0[c],
705 plain_running_mean[c]
706 );
707 assert!(
708 (rv0[c] - rv1[c]).abs() < 1e-6,
709 "rank0 and rank1 running_var disagree at c={c}"
710 );
711 assert!(
712 (rv0[c] - plain_running_var[c]).abs() < 1e-4,
713 "running_var[{c}] sync={} plain={}",
714 rv0[c],
715 plain_running_var[c]
716 );
717 }
718 }
719
720 #[test]
721 fn test_sync_bn_eval_mode_uses_running_stats() {
722 let input = cpu_tensor(
726 &(0..12).map(|i| i as f32).collect::<Vec<_>>(),
727 &[1, 3, 2, 2],
728 );
729 let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
730 sync.train();
731 for _ in 0..3 {
733 let _ = sync.forward(&input).unwrap();
734 }
735 sync.eval();
736 let other = cpu_tensor(&[100.0_f32; 12], &[1, 3, 2, 2]);
739 let out = sync.forward(&other).unwrap();
740 for v in out.data().unwrap() {
742 assert!(v.is_finite(), "output should be finite, got {v}");
743 }
744 }
745
746 #[test]
747 fn test_sync_bn_constructor_validates_num_features() {
748 assert!(SyncBatchNorm2d::<f32>::new(0, 1e-5, 0.1, true).is_err());
749 }
750
751 #[test]
752 fn test_sync_bn_rejects_wrong_input_shape() {
753 let sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
754 let bad = cpu_tensor(&[1.0, 2.0, 3.0], &[3]);
755 assert!(sync.forward(&bad).is_err());
756 }
757
758 #[test]
759 fn test_sync_bn_rejects_wrong_channel_count() {
760 let sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
761 let bad = cpu_tensor(&[0.0; 16], &[1, 4, 2, 2]);
762 assert!(sync.forward(&bad).is_err());
763 }
764}