Skip to main content

ferrotorch_nn/
lazy_norm.rs

1//! Lazy normalization modules. (#622)
2//!
3//! `LazyBatchNorm{1,2,3}d` and `LazyInstanceNorm{1,2,3}d` defer
4//! `num_features` discovery to the first forward call, then materialize
5//! a regular `BatchNorm*d` / `InstanceNorm*d` and forward to it.
6//!
7//! ## REQ status (per `.design/ferrotorch-nn/lazy_norm.md`)
8//!
9//! | REQ | Status | Evidence |
10//! |---|---|---|
11//! | REQ-1 | SHIPPED | impl: `lazy_batchnorm!(LazyBatchNorm1d, BatchNorm1d, ...)`, `lazy_batchnorm!(LazyBatchNorm2d, ...)`, `lazy_batchnorm!(LazyBatchNorm3d, ...)` here; non-test consumer: `pub use lazy_norm::{LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d}` in `lib.rs`. |
12//! | REQ-2 | SHIPPED | impl: `lazy_instancenorm!(LazyInstanceNorm1d, InstanceNorm1d, ...)` and analogous invocations here; non-test consumer: `pub use lazy_norm::{LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d}` in `lib.rs`. |
13//! | REQ-3 | SHIPPED | impl: macro-generated `LazyNormNd::new(...)` constructor bodies here; non-test consumer: dynamic-shape vision pipeline construction in downstream code. |
14//! | REQ-4 | SHIPPED | impl: macro-generated `LazyNormNd::materialize(num_features)` here; non-test consumer: dynamic-shape pipelines call `materialize(known_C)` to populate parameters before constructing the optimizer. |
15//! | REQ-5 | SHIPPED | impl: macro-generated `<LazyNormNd as Module>::forward` body here; non-test consumer: any model containing a `LazyBatchNorm2d` runs this every training forward. |
16//! | REQ-6 | SHIPPED | impl: the `num_features` accessor (returning `Option<usize>`) inside the `lazy_batchnorm!` macro body here; non-test consumer: optimizer-state introspection code calling `num_features()` to size buffers. |
17//! | REQ-7 | SHIPPED | impl: the `running_mean` / `running_var` accessors in the `lazy_batchnorm!` macro body here (#1072); non-test consumer: checkpoint serialization code (e.g. `safetensors` exporter) reads `running_mean` / `running_var` snapshots to persist BN state. |
18//! | REQ-8 | SHIPPED | impl: macro-generated `Module<T>` impl block forwarding `parameters` etc through `inner` here; non-test consumer: `ferrotorch_optim::Optimizer` walks `model.parameters_mut()`, surfacing the inner BN/IN params after materialize. |
19//! | REQ-9 | SHIPPED | impl: macro-generated `is_initialized` accessor here; non-test consumer: training-loop setup code querying initialization state. |
20
21use std::sync::OnceLock;
22use std::sync::atomic::{AtomicBool, Ordering};
23
24use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
25
26use crate::module::Module;
27use crate::norm::{
28    BatchNorm1d, BatchNorm2d, BatchNorm3d, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d,
29};
30use crate::parameter::Parameter;
31
32/// Generic helper: extract the channel dim (dim 1) from input shape `[N, C, ...]`.
33fn channels_from_input<T: Float>(
34    input: &Tensor<T>,
35    op: &str,
36    expected_ndim: usize,
37) -> FerrotorchResult<usize> {
38    if input.ndim() != expected_ndim {
39        return Err(FerrotorchError::ShapeMismatch {
40            message: format!(
41                "{op}: expected {expected_ndim}-D input, got {}-D",
42                input.ndim()
43            ),
44        });
45    }
46    Ok(input.shape()[1])
47}
48
49macro_rules! lazy_batchnorm {
50    ($name:ident, $inner:ident, $expected_ndim:expr, $kind:literal) => {
51        #[doc = concat!("Lazy variant of [`", stringify!($inner), "`] — `num_features` is")]
52        #[doc = "discovered from the input's channel dim on the first forward call."]
53        #[derive(Debug)]
54        pub struct $name<T: Float> {
55            eps: f64,
56            momentum: f64,
57            affine: bool,
58            inner: OnceLock<$inner<T>>,
59            training: AtomicBool,
60        }
61
62        impl<T: Float> $name<T> {
63            pub fn new(eps: f64, momentum: f64, affine: bool) -> Self {
64                Self {
65                    eps,
66                    momentum,
67                    affine,
68                    inner: OnceLock::new(),
69                    training: AtomicBool::new(true),
70                }
71            }
72
73            pub fn is_initialized(&self) -> bool {
74                self.inner.get().is_some()
75            }
76
77            pub fn num_features(&self) -> Option<usize> {
78                self.inner.get().map(|m| {
79                    m.parameters()
80                        .first()
81                        .map(|p| p.tensor().shape()[0])
82                        .unwrap_or(0)
83                })
84            }
85
86            pub fn materialize(&self, num_features: usize) -> FerrotorchResult<()> {
87                if self.inner.get().is_none() {
88                    let inner =
89                        $inner::<T>::new(num_features, self.eps, self.momentum, self.affine)?;
90                    let _ = self.inner.set(inner);
91                }
92                Ok(())
93            }
94
95            /// Snapshot of the inner BN's running mean, or `None` if the
96            /// layer has not been materialized yet. Mirrors PyTorch's
97            /// `running_mean` attribute access on a lazy BN — `None`
98            /// before the first forward pass populates `num_features`,
99            /// `Some(vec)` afterwards. See [`norm`](crate::norm)'s
100            /// `BatchNorm*d::running_mean` for the underlying snapshot
101            /// semantics. (#1072)
102            ///
103            /// [`norm`]: crate::norm
104            pub fn running_mean(&self) -> Option<Vec<f64>> {
105                self.inner.get().map(|m| m.running_mean())
106            }
107
108            /// Snapshot of the inner BN's running variance, or `None`
109            /// if the layer has not been materialized yet. Counterpart
110            /// to [`running_mean`](Self::running_mean). (#1072)
111            pub fn running_var(&self) -> Option<Vec<f64>> {
112                self.inner.get().map(|m| m.running_var())
113            }
114        }
115
116        impl<T: Float> Module<T> for $name<T> {
117            fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
118                if self.inner.get().is_none() {
119                    let c = channels_from_input(input, $kind, $expected_ndim)?;
120                    self.materialize(c)?;
121                }
122                let inner = self.inner.get().ok_or_else(|| FerrotorchError::Internal {
123                    message: "LazyBatchNorm: inner not initialized after materialize() — invariant violated".into(),
124                })?;
125                inner.forward(input)
126            }
127
128            fn parameters(&self) -> Vec<&Parameter<T>> {
129                self.inner.get().map(|m| m.parameters()).unwrap_or_default()
130            }
131
132            fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
133                self.inner
134                    .get_mut()
135                    .map(|m| m.parameters_mut())
136                    .unwrap_or_default()
137            }
138
139            fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
140                self.inner
141                    .get()
142                    .map(|m| m.named_parameters())
143                    .unwrap_or_default()
144            }
145
146            fn train(&mut self) {
147                self.training.store(true, Ordering::Relaxed);
148                if let Some(m) = self.inner.get_mut() {
149                    m.train();
150                }
151            }
152
153            fn eval(&mut self) {
154                self.training.store(false, Ordering::Relaxed);
155                if let Some(m) = self.inner.get_mut() {
156                    m.eval();
157                }
158            }
159
160            fn is_training(&self) -> bool {
161                self.training.load(Ordering::Relaxed)
162            }
163        }
164    };
165}
166
167lazy_batchnorm!(LazyBatchNorm1d, BatchNorm1d, 2, "LazyBatchNorm1d"); // BatchNorm1d also accepts 3D
168lazy_batchnorm!(LazyBatchNorm2d, BatchNorm2d, 4, "LazyBatchNorm2d");
169lazy_batchnorm!(LazyBatchNorm3d, BatchNorm3d, 5, "LazyBatchNorm3d");
170
171// InstanceNorm has a 3-arg ctor (no momentum); use a separate macro.
172macro_rules! lazy_instancenorm {
173    ($name:ident, $inner:ident, $expected_ndim:expr, $kind:literal) => {
174        #[doc = concat!("Lazy variant of [`", stringify!($inner), "`].")]
175        #[derive(Debug)]
176        pub struct $name<T: Float> {
177            eps: f64,
178            affine: bool,
179            inner: OnceLock<$inner<T>>,
180            training: AtomicBool,
181        }
182
183        impl<T: Float> $name<T> {
184            pub fn new(eps: f64, affine: bool) -> Self {
185                Self {
186                    eps,
187                    affine,
188                    inner: OnceLock::new(),
189                    training: AtomicBool::new(true),
190                }
191            }
192
193            pub fn is_initialized(&self) -> bool {
194                self.inner.get().is_some()
195            }
196
197            pub fn materialize(&self, num_features: usize) -> FerrotorchResult<()> {
198                if self.inner.get().is_none() {
199                    let inner = $inner::<T>::new(num_features, self.eps, self.affine)?;
200                    let _ = self.inner.set(inner);
201                }
202                Ok(())
203            }
204        }
205
206        impl<T: Float> Module<T> for $name<T> {
207            fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
208                if self.inner.get().is_none() {
209                    let c = channels_from_input(input, $kind, $expected_ndim)?;
210                    self.materialize(c)?;
211                }
212                self.inner
213                    .get()
214                    .ok_or_else(|| FerrotorchError::Internal {
215                        message: "LazyInstanceNorm: inner not initialized after materialize() — invariant violated".into(),
216                    })?
217                    .forward(input)
218            }
219
220            fn parameters(&self) -> Vec<&Parameter<T>> {
221                self.inner.get().map(|m| m.parameters()).unwrap_or_default()
222            }
223
224            fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
225                self.inner
226                    .get_mut()
227                    .map(|m| m.parameters_mut())
228                    .unwrap_or_default()
229            }
230
231            fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
232                self.inner
233                    .get()
234                    .map(|m| m.named_parameters())
235                    .unwrap_or_default()
236            }
237
238            fn train(&mut self) {
239                self.training.store(true, Ordering::Relaxed);
240                if let Some(m) = self.inner.get_mut() {
241                    m.train();
242                }
243            }
244
245            fn eval(&mut self) {
246                self.training.store(false, Ordering::Relaxed);
247                if let Some(m) = self.inner.get_mut() {
248                    m.eval();
249                }
250            }
251
252            fn is_training(&self) -> bool {
253                self.training.load(Ordering::Relaxed)
254            }
255        }
256    };
257}
258
259lazy_instancenorm!(LazyInstanceNorm1d, InstanceNorm1d, 3, "LazyInstanceNorm1d");
260lazy_instancenorm!(LazyInstanceNorm2d, InstanceNorm2d, 4, "LazyInstanceNorm2d");
261lazy_instancenorm!(LazyInstanceNorm3d, InstanceNorm3d, 5, "LazyInstanceNorm3d");
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use ferrotorch_core::storage::TensorStorage;
267
268    fn cpu_tensor(data: Vec<f32>, shape: &[usize]) -> Tensor<f32> {
269        Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
270    }
271
272    #[test]
273    fn lazy_batchnorm2d_materializes_on_first_forward() {
274        let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
275        assert!(!bn.is_initialized());
276        // Input: [N=2, C=4, H=3, W=3] = 72 elements.
277        let data: Vec<f32> = (0..72).map(|i| i as f32).collect();
278        let input = cpu_tensor(data, &[2, 4, 3, 3]);
279        let _out = bn.forward(&input).unwrap();
280        assert!(bn.is_initialized());
281        assert_eq!(bn.num_features(), Some(4));
282    }
283
284    #[test]
285    fn lazy_batchnorm2d_rejects_wrong_rank() {
286        let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
287        let input = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
288        let err = bn.forward(&input).unwrap_err();
289        assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
290    }
291
292    #[test]
293    fn lazy_batchnorm2d_explicit_materialize() {
294        let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
295        bn.materialize(8).unwrap();
296        assert!(bn.is_initialized());
297        assert_eq!(bn.num_features(), Some(8));
298    }
299
300    #[test]
301    fn lazy_instancenorm2d_materializes() {
302        let inn: LazyInstanceNorm2d<f32> = LazyInstanceNorm2d::new(1e-5, true);
303        assert!(!inn.is_initialized());
304        let data: Vec<f32> = (0..36).map(|i| i as f32).collect();
305        let input = cpu_tensor(data, &[1, 4, 3, 3]);
306        let _out = inn.forward(&input).unwrap();
307        assert!(inn.is_initialized());
308    }
309
310    #[test]
311    fn lazy_batchnorm3d_materializes_on_5d_input() {
312        let bn: LazyBatchNorm3d<f32> = LazyBatchNorm3d::new(1e-5, 0.1, true);
313        let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
314        // [N=1, C=2, D=2, H=2, W=2] = 16 elements
315        let input = cpu_tensor(data, &[1, 2, 2, 2, 2]);
316        let _ = bn.forward(&input).unwrap();
317        assert!(bn.is_initialized());
318    }
319
320    #[test]
321    fn lazy_instancenorm3d_explicit_materialize() {
322        let inn: LazyInstanceNorm3d<f32> = LazyInstanceNorm3d::new(1e-5, true);
323        inn.materialize(4).unwrap();
324        assert!(inn.is_initialized());
325    }
326
327    #[test]
328    fn lazy_batchnorm_accessors_some_after_materialize() {
329        // #1072: pre-materialize, accessors return None.
330        let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
331        assert!(bn.running_mean().is_none());
332        assert!(bn.running_var().is_none());
333
334        // Materialize via a forward pass (channels=4).
335        let data: Vec<f32> = (0..72).map(|i| i as f32).collect();
336        let input = cpu_tensor(data, &[2, 4, 3, 3]);
337        let _out = bn.forward(&input).unwrap();
338
339        // Post-materialize: accessors must return Some(vec) of length C.
340        let rm = bn.running_mean().expect("running_mean Some after forward");
341        let rv = bn.running_var().expect("running_var Some after forward");
342        assert_eq!(rm.len(), 4, "running_mean length must equal num_features");
343        assert_eq!(rv.len(), 4, "running_var length must equal num_features");
344        // After one training-mode forward over non-zero input, mean drifts
345        // away from zero — discriminates against a stub returning zeros.
346        assert!(
347            rm.iter().any(|&v| v != 0.0),
348            "running_mean must update on training forward pass; got {rm:?}"
349        );
350        // BN running_var initial is 1.0; after a training forward with
351        // momentum 0.1 the per-channel variance estimate moves but stays
352        // positive. Discriminates against a stub returning all zeros or
353        // returning the initial 1.0 vector unchanged.
354        assert!(
355            rv.iter().all(|&v| v > 0.0),
356            "running_var must remain positive; got {rv:?}"
357        );
358    }
359
360    #[test]
361    fn lazy_batchnorm1d_and_3d_accessors_match_inner() {
362        // #1072: cross-check that LazyBN{1,3}d's accessors return values
363        // identical to the inner BatchNorm{1,3}d's accessors.
364        let bn1: LazyBatchNorm1d<f32> = LazyBatchNorm1d::new(1e-5, 0.1, true);
365        bn1.materialize(3).unwrap();
366        let rm1 = bn1.running_mean().expect("Some after materialize");
367        let rv1 = bn1.running_var().expect("Some after materialize");
368        assert_eq!(rm1, vec![0.0, 0.0, 0.0]);
369        assert_eq!(rv1, vec![1.0, 1.0, 1.0]);
370
371        let bn3: LazyBatchNorm3d<f32> = LazyBatchNorm3d::new(1e-5, 0.1, true);
372        bn3.materialize(2).unwrap();
373        assert_eq!(bn3.running_mean().unwrap(), vec![0.0, 0.0]);
374        assert_eq!(bn3.running_var().unwrap(), vec![1.0, 1.0]);
375    }
376
377    #[test]
378    fn lazy_norm_train_eval_toggle() {
379        let mut bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
380        assert!(bn.is_training());
381        bn.eval();
382        assert!(!bn.is_training());
383        bn.train();
384        assert!(bn.is_training());
385    }
386}