Skip to main content

ferrotorch_distributed/
sync_batch_norm.rs

1//! Synchronized Batch Normalization (SyncBatchNorm).
2//!
3//! Like [`ferrotorch_nn::BatchNorm2d`] but synchronizes per-channel mean and
4//! variance across all ranks of a distributed process group, so the
5//! normalization sees the *global* batch statistics instead of just the
6//! per-rank local mini-batch.
7//!
8//! # When to use
9//!
10//! SyncBatchNorm is the right normalization choice when:
11//!
12//! - The per-rank batch is small enough that local statistics are noisy
13//!   (e.g. detection/segmentation training where per-GPU batch size is
14//!   1–4 images).
15//! - You want bit-identical normalization to the no-DDP single-GPU case.
16//!
17//! For large per-rank batches (≥ 32 images on each rank) the synchronization
18//! overhead usually outweighs the statistical benefit and a plain
19//! `BatchNorm2d` is preferable.
20//!
21//! # Synchronization
22//!
23//! Forward pass:
24//!   1. Each rank computes its local per-channel `sum` and `sum_sq`.
25//!   2. Both vectors are concatenated and `allreduce`d (sum) across ranks.
26//!   3. The global mean and variance are computed by dividing by the
27//!      *global* element count `N_global = local_count * world_size`.
28//!   4. Normalization uses the global statistics.
29//!
30//! Backward pass:
31//!   1. Each rank computes its local per-channel `sum_dl_dx_hat` and
32//!      `sum_dl_dx_hat_x_hat`.
33//!   2. Both vectors are `allreduce`d (sum) across ranks.
34//!   3. `grad_input` uses the global means in the standard BatchNorm
35//!      VJP formula, ensuring the gradient is consistent with the
36//!      synchronized forward.
37//!   4. `grad_weight` and `grad_bias` are accumulated locally per rank;
38//!      DDP's gradient hook will sum them across ranks at the next
39//!      synchronization point. (We do NOT pre-sum gamma/beta gradients
40//!      here, matching PyTorch SyncBatchNorm semantics.)
41//!
42//! Mirrors `torch.nn.SyncBatchNorm`. CL-392.
43//!
44//! ## REQ status (per `.design/ferrotorch-distributed/sync_batch_norm.md`)
45//!
46//! Full evidence rows (impl + non-test production consumer + upstream
47//! cites) live in the design doc; this synopsis is a one-line summary per
48//! REQ.
49//!
50//! | REQ | Status | Evidence |
51//! |---|---|---|
52//! | REQ-1 (`SyncBatchNorm2d<T>` struct) | SHIPPED | `pub struct SyncBatchNorm2d` (with `#[non_exhaustive]`) in `sync_batch_norm.rs` mirrors `class SyncBatchNorm(_BatchNorm)` in `torch/nn/modules/batchnorm.py`; consumer `pub use sync_batch_norm::SyncBatchNorm2d` in `lib.rs` |
53//! | REQ-2 (`new` constructor) | SHIPPED | `pub fn new` in `sync_batch_norm.rs` mirrors `SyncBatchNorm(num_features, eps, momentum, affine, ...)` in `torch/nn/modules/batchnorm.py`; consumer `lib.rs` re-export — `new` is the constructor |
54//! | REQ-3 (`with_backend` builder) | SHIPPED | `pub fn with_backend` in `sync_batch_norm.rs`; consumer `lib.rs` re-export of `SyncBatchNorm2d` |
55//! | REQ-4 (running-stat accessors) | SHIPPED | `pub fn running_mean` / `running_var` / `num_batches_tracked` in `sync_batch_norm.rs`; consumer via `lib.rs` re-export — the user-facing read path for checkpointing |
56//! | REQ-5 (`impl Module<T>`) | SHIPPED | `impl Module<T> for SyncBatchNorm2d<T>` in `sync_batch_norm.rs`; consumer via `lib.rs` re-export — the trait impl makes `SyncBatchNorm2d` drop-in wherever `impl Module<T>` is accepted |
57//! | REQ-6 (forward training with packed allreduce) | SHIPPED | training arm of `forward` in `sync_batch_norm.rs` mirrors `batch_norm_stats` in `torch/nn/modules/_functions.py`; consumer of `crate::collective::allreduce`; surfaced via `lib.rs` re-export |
58//! | REQ-7 (forward eval using running stats) | SHIPPED | eval arm of `forward` in `sync_batch_norm.rs`; consumer via `lib.rs` re-export of `SyncBatchNorm2d` |
59//! | REQ-8 (normalize + autograd hookup) | SHIPPED | normalize / affine / `Tensor::from_operation` arms of `forward` in `sync_batch_norm.rs`; consumer registers `SyncBatchNorm2dBackward` on the autograd graph; surfaced via `lib.rs` re-export |
60//! | REQ-9 (`SyncBatchNorm2dBackward` `GradFn`) | SHIPPED | `struct SyncBatchNorm2dBackward<T>` + `impl GradFn<T>` in `sync_batch_norm.rs` mirror `class SyncBatchNorm(Function)` in `torch/nn/modules/_functions.py`; consumer `forward` (same file) registers it; `ferrotorch_core` autograd engine is the runtime consumer |
61//! | REQ-10 (shape / channel / CUDA validation) | SHIPPED | top-of-`forward` and top-of-`SyncBatchNorm2dBackward::backward` validation arms in `sync_batch_norm.rs`; consumer via `lib.rs` re-export of `SyncBatchNorm2d` |
62
63use std::sync::{Arc, Mutex};
64
65use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
66use ferrotorch_core::storage::TensorStorage;
67use ferrotorch_core::tensor::{GradFn, Tensor};
68use ferrotorch_core::{Float, is_grad_enabled};
69use ferrotorch_nn::{Module, Parameter};
70
71use crate::backend::Backend;
72use crate::collective::{ReduceOp, allreduce};
73
74/// 2-D synchronized batch normalization.
75///
76/// Same API surface as `BatchNorm2d` plus an `Arc<dyn Backend>` for
77/// cross-rank communication. When `world_size == 1` (or no backend is
78/// provided), behaves exactly like a plain BatchNorm2d.
79///
80/// Marked `#[non_exhaustive]` so future configuration knobs (e.g.,
81/// running-stats sync mode, momentum schedule, fused fwd/bwd flags) can
82/// be added without breaking external struct-literal construction.
83/// Construct via [`SyncBatchNorm2d::new`] / [`SyncBatchNorm2d::with_backend`].
84#[non_exhaustive]
85pub struct SyncBatchNorm2d<T: Float> {
86    pub num_features: usize,
87    pub eps: f64,
88    pub momentum: f64,
89    pub affine: bool,
90    pub weight: Option<Parameter<T>>,
91    pub bias: Option<Parameter<T>>,
92    running_mean: Mutex<Vec<f64>>,
93    running_var: Mutex<Vec<f64>>,
94    num_batches_tracked: Mutex<usize>,
95    training: Mutex<bool>,
96    /// Optional process group backend used to synchronize statistics.
97    /// When `None`, behaves like a non-distributed BatchNorm2d.
98    backend: Option<Arc<dyn Backend>>,
99}
100
101impl<T: Float> std::fmt::Debug for SyncBatchNorm2d<T> {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("SyncBatchNorm2d")
104            .field("num_features", &self.num_features)
105            .field("eps", &self.eps)
106            .field("momentum", &self.momentum)
107            .field("affine", &self.affine)
108            .field(
109                "world_size",
110                &self.backend.as_ref().map(|b| b.world_size()).unwrap_or(1),
111            )
112            .field("training", &self.training)
113            .finish()
114    }
115}
116
117impl<T: Float> SyncBatchNorm2d<T> {
118    /// Create a new SyncBatchNorm2d with no backend (acts as a plain
119    /// BatchNorm2d). Use [`with_backend`](Self::with_backend) to attach
120    /// the process group.
121    pub fn new(
122        num_features: usize,
123        eps: f64,
124        momentum: f64,
125        affine: bool,
126    ) -> FerrotorchResult<Self> {
127        if num_features == 0 {
128            return Err(FerrotorchError::InvalidArgument {
129                message: "SyncBatchNorm2d: num_features must be positive".into(),
130            });
131        }
132        let weight = if affine {
133            Some(Parameter::ones(&[num_features])?)
134        } else {
135            None
136        };
137        let bias = if affine {
138            Some(Parameter::zeros(&[num_features])?)
139        } else {
140            None
141        };
142        Ok(Self {
143            num_features,
144            eps,
145            momentum,
146            affine,
147            weight,
148            bias,
149            running_mean: Mutex::new(vec![0.0; num_features]),
150            running_var: Mutex::new(vec![1.0; num_features]),
151            num_batches_tracked: Mutex::new(0),
152            training: Mutex::new(true),
153            backend: None,
154        })
155    }
156
157    /// Attach a backend so the layer synchronizes its statistics across
158    /// the distributed process group.
159    pub fn with_backend(mut self, backend: Arc<dyn Backend>) -> Self {
160        self.backend = Some(backend);
161        self
162    }
163
164    /// Snapshot of the current per-channel running mean.
165    pub fn running_mean(&self) -> Vec<f64> {
166        self.running_mean.lock().unwrap().clone()
167    }
168
169    /// Snapshot of the current per-channel running variance.
170    pub fn running_var(&self) -> Vec<f64> {
171        self.running_var.lock().unwrap().clone()
172    }
173
174    /// Number of training batches processed so far.
175    pub fn num_batches_tracked(&self) -> usize {
176        *self.num_batches_tracked.lock().unwrap()
177    }
178}
179
180impl<T: Float> Module<T> for SyncBatchNorm2d<T> {
181    // De-interleaving sum/sum_sq from a packed reduce buffer isn't expressible
182    // as a single slice memcpy.
183    #[allow(clippy::manual_memcpy)]
184    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
185        let shape = input.shape().to_vec();
186        if shape.len() != 4 {
187            return Err(FerrotorchError::ShapeMismatch {
188                message: format!(
189                    "SyncBatchNorm2d: expected 4D input [B, C, H, W], got {:?}",
190                    shape
191                ),
192            });
193        }
194        let batch = shape[0];
195        let channels = shape[1];
196        let height = shape[2];
197        let width = shape[3];
198        let spatial = height * width;
199
200        if channels != self.num_features {
201            return Err(FerrotorchError::ShapeMismatch {
202                message: format!(
203                    "SyncBatchNorm2d: expected {} channels, got {}",
204                    self.num_features, channels
205                ),
206            });
207        }
208
209        if input.is_cuda() {
210            return Err(FerrotorchError::NotImplementedOnCuda {
211                op: "SyncBatchNorm2d::forward",
212            });
213        }
214
215        let input_data = input.data()?;
216        let eps_t = T::from(self.eps).unwrap();
217        let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
218        let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
219        let is_training = *self.training.lock().unwrap();
220
221        let mut chan_mean = vec![<T as num_traits::Zero>::zero(); channels];
222        let mut chan_var = vec![<T as num_traits::Zero>::zero(); channels];
223
224        if is_training {
225            // Local per-channel sum and sum of squares.
226            let local_count = batch * spatial;
227            let mut sum = vec![<T as num_traits::Zero>::zero(); channels];
228            let mut sum_sq = vec![<T as num_traits::Zero>::zero(); channels];
229
230            for c in 0..channels {
231                for b in 0..batch {
232                    let base = b * channels * spatial + c * spatial;
233                    for s in 0..spatial {
234                        let v = input_data[base + s];
235                        sum[c] += v;
236                        sum_sq[c] += v * v;
237                    }
238                }
239            }
240
241            // Synchronize across ranks if a backend is attached.
242            let global_count = if let Some(ref backend) = self.backend {
243                let world_size = backend.world_size();
244                if world_size > 1 {
245                    // Pack sum and sum_sq into a single tensor for one
246                    // allreduce instead of two.
247                    let mut packed: Vec<T> = Vec::with_capacity(2 * channels);
248                    packed.extend_from_slice(&sum);
249                    packed.extend_from_slice(&sum_sq);
250                    let packed_t = Tensor::from_storage(
251                        TensorStorage::cpu(packed),
252                        vec![2 * channels],
253                        false,
254                    )?;
255                    let reduced = allreduce(&packed_t, backend.as_ref(), ReduceOp::Sum)?;
256                    let reduced_data = reduced.data()?;
257                    for c in 0..channels {
258                        sum[c] = reduced_data[c];
259                        sum_sq[c] = reduced_data[channels + c];
260                    }
261                    local_count * world_size
262                } else {
263                    local_count
264                }
265            } else {
266                local_count
267            };
268
269            let global_count_t = T::from(global_count).unwrap();
270            for c in 0..channels {
271                let m = sum[c] / global_count_t;
272                chan_mean[c] = m;
273                // E[X^2] - E[X]^2 (biased variance, matches PyTorch)
274                chan_var[c] = sum_sq[c] / global_count_t - m * m;
275            }
276
277            // Update running statistics with the synchronized batch stats.
278            {
279                let mut rm = self.running_mean.lock().unwrap();
280                let mut rv = self.running_var.lock().unwrap();
281                let mut nbt = self.num_batches_tracked.lock().unwrap();
282                *nbt += 1;
283                let mom = self.momentum;
284                let bessel = if global_count > 1 {
285                    global_count as f64 / (global_count as f64 - 1.0)
286                } else {
287                    1.0
288                };
289                for c in 0..channels {
290                    let bm = chan_mean[c].to_f64().unwrap();
291                    let bv = chan_var[c].to_f64().unwrap();
292                    rm[c] = (1.0 - mom) * rm[c] + mom * bm;
293                    rv[c] = (1.0 - mom) * rv[c] + mom * bv * bessel;
294                }
295            }
296        } else {
297            // Eval mode: use the running statistics regardless of backend.
298            let rm = self.running_mean.lock().unwrap();
299            let rv = self.running_var.lock().unwrap();
300            for c in 0..channels {
301                chan_mean[c] = T::from(rm[c]).unwrap();
302                chan_var[c] = T::from(rv[c]).unwrap();
303            }
304        }
305
306        // Normalize and optionally scale/shift.
307        let mut output = vec![<T as num_traits::Zero>::zero(); input.numel()];
308        let mut x_hat_data = if is_grad_enabled() && input.requires_grad() {
309            Vec::with_capacity(input.numel())
310        } else {
311            Vec::new()
312        };
313        let need_x_hat = is_grad_enabled() && input.requires_grad();
314
315        let mut inv_std = vec![<T as num_traits::Zero>::zero(); channels];
316        for c in 0..channels {
317            inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
318        }
319
320        for b in 0..batch {
321            for c in 0..channels {
322                let base = b * channels * spatial + c * spatial;
323                for s in 0..spatial {
324                    let idx = base + s;
325                    let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
326                    if need_x_hat {
327                        x_hat_data.push(normed);
328                    }
329                    if self.affine {
330                        let w = weight_data.as_ref().unwrap();
331                        let bi = bias_data.as_ref().unwrap();
332                        output[idx] = normed * w[c] + bi[c];
333                    } else {
334                        output[idx] = normed;
335                    }
336                }
337            }
338        }
339
340        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.clone(), false)?;
341
342        if need_x_hat {
343            let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
344            let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
345            let local_count = batch * spatial;
346            let global_count = self
347                .backend
348                .as_ref()
349                .map(|b| local_count * b.world_size())
350                .unwrap_or(local_count);
351            let grad_fn = Arc::new(SyncBatchNorm2dBackward {
352                input: input.clone(),
353                x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.clone(), false)?,
354                weight: weight_tensor,
355                bias: bias_tensor,
356                chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
357                eps: self.eps,
358                affine: self.affine,
359                global_count,
360                backend: self.backend.clone(),
361            });
362            Tensor::from_operation(
363                TensorStorage::cpu(result.data()?.to_vec()),
364                result.shape().to_vec(),
365                grad_fn,
366            )
367        } else {
368            Ok(result)
369        }
370    }
371
372    fn parameters(&self) -> Vec<&Parameter<T>> {
373        match (&self.weight, &self.bias) {
374            (Some(w), Some(b)) => vec![w, b],
375            _ => vec![],
376        }
377    }
378
379    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
380        match (&mut self.weight, &mut self.bias) {
381            (Some(w), Some(b)) => vec![w, b],
382            _ => vec![],
383        }
384    }
385
386    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
387        match (&self.weight, &self.bias) {
388            (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
389            _ => vec![],
390        }
391    }
392
393    fn train(&mut self) {
394        *self.training.lock().unwrap() = true;
395    }
396
397    fn eval(&mut self) {
398        *self.training.lock().unwrap() = false;
399    }
400
401    fn is_training(&self) -> bool {
402        *self.training.lock().unwrap()
403    }
404}
405
406/// Backward node for [`SyncBatchNorm2d`]. Synchronizes the two intermediate
407/// sums (`sum_dl_dx_hat` and `sum_dl_dx_hat_x_hat`) across ranks via
408/// allreduce so that `grad_input` is consistent with the synchronized
409/// forward. `grad_weight` and `grad_bias` are kept local — DDP will reduce
410/// them at the parameter sync step.
411struct SyncBatchNorm2dBackward<T: Float> {
412    input: Tensor<T>,
413    x_hat: Tensor<T>,
414    weight: Option<Tensor<T>>,
415    bias: Option<Tensor<T>>,
416    chan_var: Vec<f64>,
417    eps: f64,
418    affine: bool,
419    global_count: usize,
420    backend: Option<Arc<dyn Backend>>,
421}
422
423impl<T: Float> std::fmt::Debug for SyncBatchNorm2dBackward<T> {
424    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        f.debug_struct("SyncBatchNorm2dBackward")
426            .field("global_count", &self.global_count)
427            .finish()
428    }
429}
430
431impl<T: Float> GradFn<T> for SyncBatchNorm2dBackward<T> {
432    #[allow(clippy::manual_memcpy)]
433    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
434        let shape = self.input.shape();
435        let batch = shape[0];
436        let channels = shape[1];
437        let height = shape[2];
438        let width = shape[3];
439        let spatial = height * width;
440
441        if self.input.is_cuda() {
442            return Err(FerrotorchError::NotImplementedOnCuda {
443                op: "SyncBatchNorm2dBackward",
444            });
445        }
446
447        let go_data = grad_output.data()?;
448        let x_hat_data = self.x_hat.data()?;
449        let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
450
451        let mut grad_input = vec![<T as num_traits::Zero>::zero(); self.input.numel()];
452        let mut grad_weight = vec![<T as num_traits::Zero>::zero(); channels];
453        let mut grad_bias = vec![<T as num_traits::Zero>::zero(); channels];
454
455        // First pass: per-channel local sums.
456        let mut local_dl_dx_hat_sum = vec![<T as num_traits::Zero>::zero(); channels];
457        let mut local_dl_dx_hat_x_hat_sum = vec![<T as num_traits::Zero>::zero(); channels];
458
459        for c in 0..channels {
460            for b in 0..batch {
461                let base = b * channels * spatial + c * spatial;
462                for s in 0..spatial {
463                    let idx = base + s;
464                    let x_h = x_hat_data[idx];
465                    let go = go_data[idx];
466                    let dl_dx_hat = if self.affine {
467                        go * weight_data.as_ref().unwrap()[c]
468                    } else {
469                        go
470                    };
471                    local_dl_dx_hat_sum[c] += dl_dx_hat;
472                    local_dl_dx_hat_x_hat_sum[c] += dl_dx_hat * x_h;
473                    if self.affine {
474                        grad_weight[c] += go * x_h;
475                        grad_bias[c] += go;
476                    }
477                }
478            }
479        }
480
481        // Synchronize the two sum vectors across ranks. Pack into a single
482        // tensor of shape [2*C] for one allreduce.
483        let mut global_dl_dx_hat_sum = local_dl_dx_hat_sum.clone();
484        let mut global_dl_dx_hat_x_hat_sum = local_dl_dx_hat_x_hat_sum.clone();
485
486        if let Some(ref backend) = self.backend {
487            if backend.world_size() > 1 {
488                let mut packed: Vec<T> = Vec::with_capacity(2 * channels);
489                packed.extend_from_slice(&local_dl_dx_hat_sum);
490                packed.extend_from_slice(&local_dl_dx_hat_x_hat_sum);
491                let packed_t =
492                    Tensor::from_storage(TensorStorage::cpu(packed), vec![2 * channels], false)?;
493                let reduced = allreduce(&packed_t, backend.as_ref(), ReduceOp::Sum)?;
494                let reduced_data = reduced.data()?;
495                for c in 0..channels {
496                    global_dl_dx_hat_sum[c] = reduced_data[c];
497                    global_dl_dx_hat_x_hat_sum[c] = reduced_data[channels + c];
498                }
499            }
500        }
501
502        let global_count_t = T::from(self.global_count).unwrap();
503
504        // Second pass: compute grad_input using the synchronized means.
505        for c in 0..channels {
506            let var_f64 = self.chan_var[c];
507            let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
508
509            let dl_dx_hat_mean = global_dl_dx_hat_sum[c] / global_count_t;
510            let dl_dx_hat_x_hat_mean = global_dl_dx_hat_x_hat_sum[c] / global_count_t;
511
512            for b in 0..batch {
513                let base = b * channels * spatial + c * spatial;
514                for s in 0..spatial {
515                    let idx = base + s;
516                    let x_h = x_hat_data[idx];
517                    let go = go_data[idx];
518                    let dl_dx_hat = if self.affine {
519                        go * weight_data.as_ref().unwrap()[c]
520                    } else {
521                        go
522                    };
523                    grad_input[idx] =
524                        inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
525                }
526            }
527        }
528
529        let grad_input_tensor = Tensor::from_storage(
530            TensorStorage::cpu(grad_input),
531            self.input.shape().to_vec(),
532            false,
533        )?;
534        let grad_weight_out = if self.affine {
535            self.weight.as_ref().and_then(|w| {
536                if w.requires_grad() {
537                    Some(
538                        Tensor::from_storage(
539                            TensorStorage::cpu(grad_weight),
540                            vec![channels],
541                            false,
542                        )
543                        .unwrap(),
544                    )
545                } else {
546                    None
547                }
548            })
549        } else {
550            None
551        };
552        let grad_bias_out = if self.affine {
553            self.bias.as_ref().and_then(|b| {
554                if b.requires_grad() {
555                    Some(
556                        Tensor::from_storage(TensorStorage::cpu(grad_bias), vec![channels], false)
557                            .unwrap(),
558                    )
559                } else {
560                    None
561                }
562            })
563        } else {
564            None
565        };
566
567        Ok(vec![
568            Some(grad_input_tensor),
569            grad_weight_out,
570            grad_bias_out,
571        ])
572    }
573
574    fn inputs(&self) -> Vec<&Tensor<T>> {
575        let mut v: Vec<&Tensor<T>> = vec![&self.input];
576        if let Some(ref w) = self.weight {
577            v.push(w);
578        }
579        if let Some(ref b) = self.bias {
580            v.push(b);
581        }
582        v
583    }
584
585    fn name(&self) -> &'static str {
586        "SyncBatchNorm2dBackward"
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use crate::backend::SimulatedBackend;
594    use ferrotorch_core::Tensor;
595    use ferrotorch_nn::BatchNorm2d;
596    use std::thread;
597
598    fn cpu_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
599        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
600    }
601
602    #[test]
603    fn test_sync_bn_world_size_1_matches_batch_norm() {
604        // With no backend (or world_size=1), SyncBatchNorm2d should produce
605        // the exact same output as a plain BatchNorm2d on the same input.
606        let input_data: Vec<f32> = (0..24).map(|i| i as f32 / 10.0).collect();
607        let input = cpu_tensor(&input_data, &[2, 3, 2, 2]);
608
609        let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
610        let mut plain = BatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
611        sync.train();
612        plain.train();
613
614        let out_sync = sync.forward(&input).unwrap();
615        let out_plain = plain.forward(&input).unwrap();
616
617        let s = out_sync.data().unwrap();
618        let p = out_plain.data().unwrap();
619        for (i, (a, b)) in s.iter().zip(p.iter()).enumerate() {
620            assert!((a - b).abs() < 1e-5, "out[{i}]: sync={a}, plain={b}");
621        }
622    }
623
624    #[test]
625    fn test_sync_bn_two_ranks_match_full_batch() {
626        // Set up a 4-element batch of [4, 3, 2, 2] and split it across
627        // two simulated ranks. SyncBatchNorm2d should compute the same
628        // mean/var as a single-rank BatchNorm2d on the full batch.
629        let full_data: Vec<f32> = (0..48).map(|i| (i as f32 - 24.0) / 10.0).collect();
630        let full = cpu_tensor(&full_data, &[4, 3, 2, 2]);
631
632        let mut plain = BatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
633        plain.train();
634        let plain_out = plain.forward(&full).unwrap();
635        let plain_data = plain_out.data().unwrap().to_vec();
636        let plain_running_mean = plain.running_mean();
637        let plain_running_var = plain.running_var();
638
639        // Per-rank slices: rank 0 sees first 2 batch elements (indices
640        // 0..24), rank 1 sees the last 2 (indices 24..48).
641        let r0 = full_data[0..24].to_vec();
642        let r1 = full_data[24..48].to_vec();
643        let r0_t = cpu_tensor(&r0, &[2, 3, 2, 2]);
644        let r1_t = cpu_tensor(&r1, &[2, 3, 2, 2]);
645
646        // Build a 2-rank simulated backend and run forward on each rank
647        // in its own thread (allreduce blocks pairs of ranks).
648        let group = SimulatedBackend::create_group(2).unwrap();
649        let mut iter = group.into_iter();
650        let b0 = Arc::new(iter.next().unwrap());
651        let b1 = Arc::new(iter.next().unwrap());
652
653        let r0_clone = r0_t.clone();
654        let r1_clone = r1_t.clone();
655        let b0_clone: Arc<dyn Backend> = b0.clone();
656        let b1_clone: Arc<dyn Backend> = b1.clone();
657
658        let h0 = thread::spawn(move || {
659            let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true)
660                .unwrap()
661                .with_backend(b0_clone);
662            sync.train();
663            let out = sync.forward(&r0_clone).unwrap();
664            (
665                out.data().unwrap().to_vec(),
666                sync.running_mean(),
667                sync.running_var(),
668            )
669        });
670        let h1 = thread::spawn(move || {
671            let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true)
672                .unwrap()
673                .with_backend(b1_clone);
674            sync.train();
675            let out = sync.forward(&r1_clone).unwrap();
676            (
677                out.data().unwrap().to_vec(),
678                sync.running_mean(),
679                sync.running_var(),
680            )
681        });
682
683        let (out0, rm0, rv0) = h0.join().unwrap();
684        let (out1, rm1, rv1) = h1.join().unwrap();
685
686        // Concatenate the per-rank outputs in batch order — they should
687        // match the single-rank full-batch output element-for-element.
688        let mut concat = out0.clone();
689        concat.extend_from_slice(&out1);
690        for (i, (a, b)) in concat.iter().zip(plain_data.iter()).enumerate() {
691            assert!((a - b).abs() < 1e-4, "out[{i}]: sync={a}, plain={b}");
692        }
693
694        // Both ranks should have identical running statistics, and they
695        // should match the single-rank full-batch running statistics.
696        for c in 0..3 {
697            assert!(
698                (rm0[c] - rm1[c]).abs() < 1e-6,
699                "rank0 and rank1 running_mean disagree at c={c}"
700            );
701            assert!(
702                (rm0[c] - plain_running_mean[c]).abs() < 1e-4,
703                "running_mean[{c}] sync={} plain={}",
704                rm0[c],
705                plain_running_mean[c]
706            );
707            assert!(
708                (rv0[c] - rv1[c]).abs() < 1e-6,
709                "rank0 and rank1 running_var disagree at c={c}"
710            );
711            assert!(
712                (rv0[c] - plain_running_var[c]).abs() < 1e-4,
713                "running_var[{c}] sync={} plain={}",
714                rv0[c],
715                plain_running_var[c]
716            );
717        }
718    }
719
720    #[test]
721    fn test_sync_bn_eval_mode_uses_running_stats() {
722        // After warming up running stats in train mode, eval mode should
723        // produce deterministic output independent of the input batch
724        // distribution.
725        let input = cpu_tensor(
726            &(0..12).map(|i| i as f32).collect::<Vec<_>>(),
727            &[1, 3, 2, 2],
728        );
729        let mut sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
730        sync.train();
731        // Warm up.
732        for _ in 0..3 {
733            let _ = sync.forward(&input).unwrap();
734        }
735        sync.eval();
736        // Now run with a different input — output should still be normalized
737        // using the running stats from training.
738        let other = cpu_tensor(&[100.0_f32; 12], &[1, 3, 2, 2]);
739        let out = sync.forward(&other).unwrap();
740        // Just verify forward completes and produces finite output.
741        for v in out.data().unwrap() {
742            assert!(v.is_finite(), "output should be finite, got {v}");
743        }
744    }
745
746    #[test]
747    fn test_sync_bn_constructor_validates_num_features() {
748        assert!(SyncBatchNorm2d::<f32>::new(0, 1e-5, 0.1, true).is_err());
749    }
750
751    #[test]
752    fn test_sync_bn_rejects_wrong_input_shape() {
753        let sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
754        let bad = cpu_tensor(&[1.0, 2.0, 3.0], &[3]);
755        assert!(sync.forward(&bad).is_err());
756    }
757
758    #[test]
759    fn test_sync_bn_rejects_wrong_channel_count() {
760        let sync = SyncBatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
761        let bad = cpu_tensor(&[0.0; 16], &[1, 4, 2, 2]);
762        assert!(sync.forward(&bad).is_err());
763    }
764}