1use std::sync::OnceLock;
22use std::sync::atomic::{AtomicBool, Ordering};
23
24use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
25
26use crate::module::Module;
27use crate::norm::{
28 BatchNorm1d, BatchNorm2d, BatchNorm3d, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d,
29};
30use crate::parameter::Parameter;
31
32fn channels_from_input<T: Float>(
34 input: &Tensor<T>,
35 op: &str,
36 expected_ndim: usize,
37) -> FerrotorchResult<usize> {
38 if input.ndim() != expected_ndim {
39 return Err(FerrotorchError::ShapeMismatch {
40 message: format!(
41 "{op}: expected {expected_ndim}-D input, got {}-D",
42 input.ndim()
43 ),
44 });
45 }
46 Ok(input.shape()[1])
47}
48
49macro_rules! lazy_batchnorm {
50 ($name:ident, $inner:ident, $expected_ndim:expr, $kind:literal) => {
51 #[doc = concat!("Lazy variant of [`", stringify!($inner), "`] — `num_features` is")]
52 #[doc = "discovered from the input's channel dim on the first forward call."]
53 #[derive(Debug)]
54 pub struct $name<T: Float> {
55 eps: f64,
56 momentum: f64,
57 affine: bool,
58 inner: OnceLock<$inner<T>>,
59 training: AtomicBool,
60 }
61
62 impl<T: Float> $name<T> {
63 pub fn new(eps: f64, momentum: f64, affine: bool) -> Self {
64 Self {
65 eps,
66 momentum,
67 affine,
68 inner: OnceLock::new(),
69 training: AtomicBool::new(true),
70 }
71 }
72
73 pub fn is_initialized(&self) -> bool {
74 self.inner.get().is_some()
75 }
76
77 pub fn num_features(&self) -> Option<usize> {
78 self.inner.get().map(|m| {
79 m.parameters()
80 .first()
81 .map(|p| p.tensor().shape()[0])
82 .unwrap_or(0)
83 })
84 }
85
86 pub fn materialize(&self, num_features: usize) -> FerrotorchResult<()> {
87 if self.inner.get().is_none() {
88 let inner =
89 $inner::<T>::new(num_features, self.eps, self.momentum, self.affine)?;
90 let _ = self.inner.set(inner);
91 }
92 Ok(())
93 }
94
95 pub fn running_mean(&self) -> Option<Vec<f64>> {
105 self.inner.get().map(|m| m.running_mean())
106 }
107
108 pub fn running_var(&self) -> Option<Vec<f64>> {
112 self.inner.get().map(|m| m.running_var())
113 }
114 }
115
116 impl<T: Float> Module<T> for $name<T> {
117 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
118 if self.inner.get().is_none() {
119 let c = channels_from_input(input, $kind, $expected_ndim)?;
120 self.materialize(c)?;
121 }
122 let inner = self.inner.get().ok_or_else(|| FerrotorchError::Internal {
123 message: "LazyBatchNorm: inner not initialized after materialize() — invariant violated".into(),
124 })?;
125 inner.forward(input)
126 }
127
128 fn parameters(&self) -> Vec<&Parameter<T>> {
129 self.inner.get().map(|m| m.parameters()).unwrap_or_default()
130 }
131
132 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
133 self.inner
134 .get_mut()
135 .map(|m| m.parameters_mut())
136 .unwrap_or_default()
137 }
138
139 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
140 self.inner
141 .get()
142 .map(|m| m.named_parameters())
143 .unwrap_or_default()
144 }
145
146 fn train(&mut self) {
147 self.training.store(true, Ordering::Relaxed);
148 if let Some(m) = self.inner.get_mut() {
149 m.train();
150 }
151 }
152
153 fn eval(&mut self) {
154 self.training.store(false, Ordering::Relaxed);
155 if let Some(m) = self.inner.get_mut() {
156 m.eval();
157 }
158 }
159
160 fn is_training(&self) -> bool {
161 self.training.load(Ordering::Relaxed)
162 }
163 }
164 };
165}
166
167lazy_batchnorm!(LazyBatchNorm1d, BatchNorm1d, 2, "LazyBatchNorm1d"); lazy_batchnorm!(LazyBatchNorm2d, BatchNorm2d, 4, "LazyBatchNorm2d");
169lazy_batchnorm!(LazyBatchNorm3d, BatchNorm3d, 5, "LazyBatchNorm3d");
170
171macro_rules! lazy_instancenorm {
173 ($name:ident, $inner:ident, $expected_ndim:expr, $kind:literal) => {
174 #[doc = concat!("Lazy variant of [`", stringify!($inner), "`].")]
175 #[derive(Debug)]
176 pub struct $name<T: Float> {
177 eps: f64,
178 affine: bool,
179 inner: OnceLock<$inner<T>>,
180 training: AtomicBool,
181 }
182
183 impl<T: Float> $name<T> {
184 pub fn new(eps: f64, affine: bool) -> Self {
185 Self {
186 eps,
187 affine,
188 inner: OnceLock::new(),
189 training: AtomicBool::new(true),
190 }
191 }
192
193 pub fn is_initialized(&self) -> bool {
194 self.inner.get().is_some()
195 }
196
197 pub fn materialize(&self, num_features: usize) -> FerrotorchResult<()> {
198 if self.inner.get().is_none() {
199 let inner = $inner::<T>::new(num_features, self.eps, self.affine)?;
200 let _ = self.inner.set(inner);
201 }
202 Ok(())
203 }
204 }
205
206 impl<T: Float> Module<T> for $name<T> {
207 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
208 if self.inner.get().is_none() {
209 let c = channels_from_input(input, $kind, $expected_ndim)?;
210 self.materialize(c)?;
211 }
212 self.inner
213 .get()
214 .ok_or_else(|| FerrotorchError::Internal {
215 message: "LazyInstanceNorm: inner not initialized after materialize() — invariant violated".into(),
216 })?
217 .forward(input)
218 }
219
220 fn parameters(&self) -> Vec<&Parameter<T>> {
221 self.inner.get().map(|m| m.parameters()).unwrap_or_default()
222 }
223
224 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
225 self.inner
226 .get_mut()
227 .map(|m| m.parameters_mut())
228 .unwrap_or_default()
229 }
230
231 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
232 self.inner
233 .get()
234 .map(|m| m.named_parameters())
235 .unwrap_or_default()
236 }
237
238 fn train(&mut self) {
239 self.training.store(true, Ordering::Relaxed);
240 if let Some(m) = self.inner.get_mut() {
241 m.train();
242 }
243 }
244
245 fn eval(&mut self) {
246 self.training.store(false, Ordering::Relaxed);
247 if let Some(m) = self.inner.get_mut() {
248 m.eval();
249 }
250 }
251
252 fn is_training(&self) -> bool {
253 self.training.load(Ordering::Relaxed)
254 }
255 }
256 };
257}
258
259lazy_instancenorm!(LazyInstanceNorm1d, InstanceNorm1d, 3, "LazyInstanceNorm1d");
260lazy_instancenorm!(LazyInstanceNorm2d, InstanceNorm2d, 4, "LazyInstanceNorm2d");
261lazy_instancenorm!(LazyInstanceNorm3d, InstanceNorm3d, 5, "LazyInstanceNorm3d");
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use ferrotorch_core::storage::TensorStorage;
267
268 fn cpu_tensor(data: Vec<f32>, shape: &[usize]) -> Tensor<f32> {
269 Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
270 }
271
272 #[test]
273 fn lazy_batchnorm2d_materializes_on_first_forward() {
274 let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
275 assert!(!bn.is_initialized());
276 let data: Vec<f32> = (0..72).map(|i| i as f32).collect();
278 let input = cpu_tensor(data, &[2, 4, 3, 3]);
279 let _out = bn.forward(&input).unwrap();
280 assert!(bn.is_initialized());
281 assert_eq!(bn.num_features(), Some(4));
282 }
283
284 #[test]
285 fn lazy_batchnorm2d_rejects_wrong_rank() {
286 let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
287 let input = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
288 let err = bn.forward(&input).unwrap_err();
289 assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
290 }
291
292 #[test]
293 fn lazy_batchnorm2d_explicit_materialize() {
294 let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
295 bn.materialize(8).unwrap();
296 assert!(bn.is_initialized());
297 assert_eq!(bn.num_features(), Some(8));
298 }
299
300 #[test]
301 fn lazy_instancenorm2d_materializes() {
302 let inn: LazyInstanceNorm2d<f32> = LazyInstanceNorm2d::new(1e-5, true);
303 assert!(!inn.is_initialized());
304 let data: Vec<f32> = (0..36).map(|i| i as f32).collect();
305 let input = cpu_tensor(data, &[1, 4, 3, 3]);
306 let _out = inn.forward(&input).unwrap();
307 assert!(inn.is_initialized());
308 }
309
310 #[test]
311 fn lazy_batchnorm3d_materializes_on_5d_input() {
312 let bn: LazyBatchNorm3d<f32> = LazyBatchNorm3d::new(1e-5, 0.1, true);
313 let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
314 let input = cpu_tensor(data, &[1, 2, 2, 2, 2]);
316 let _ = bn.forward(&input).unwrap();
317 assert!(bn.is_initialized());
318 }
319
320 #[test]
321 fn lazy_instancenorm3d_explicit_materialize() {
322 let inn: LazyInstanceNorm3d<f32> = LazyInstanceNorm3d::new(1e-5, true);
323 inn.materialize(4).unwrap();
324 assert!(inn.is_initialized());
325 }
326
327 #[test]
328 fn lazy_batchnorm_accessors_some_after_materialize() {
329 let bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
331 assert!(bn.running_mean().is_none());
332 assert!(bn.running_var().is_none());
333
334 let data: Vec<f32> = (0..72).map(|i| i as f32).collect();
336 let input = cpu_tensor(data, &[2, 4, 3, 3]);
337 let _out = bn.forward(&input).unwrap();
338
339 let rm = bn.running_mean().expect("running_mean Some after forward");
341 let rv = bn.running_var().expect("running_var Some after forward");
342 assert_eq!(rm.len(), 4, "running_mean length must equal num_features");
343 assert_eq!(rv.len(), 4, "running_var length must equal num_features");
344 assert!(
347 rm.iter().any(|&v| v != 0.0),
348 "running_mean must update on training forward pass; got {rm:?}"
349 );
350 assert!(
355 rv.iter().all(|&v| v > 0.0),
356 "running_var must remain positive; got {rv:?}"
357 );
358 }
359
360 #[test]
361 fn lazy_batchnorm1d_and_3d_accessors_match_inner() {
362 let bn1: LazyBatchNorm1d<f32> = LazyBatchNorm1d::new(1e-5, 0.1, true);
365 bn1.materialize(3).unwrap();
366 let rm1 = bn1.running_mean().expect("Some after materialize");
367 let rv1 = bn1.running_var().expect("Some after materialize");
368 assert_eq!(rm1, vec![0.0, 0.0, 0.0]);
369 assert_eq!(rv1, vec![1.0, 1.0, 1.0]);
370
371 let bn3: LazyBatchNorm3d<f32> = LazyBatchNorm3d::new(1e-5, 0.1, true);
372 bn3.materialize(2).unwrap();
373 assert_eq!(bn3.running_mean().unwrap(), vec![0.0, 0.0]);
374 assert_eq!(bn3.running_var().unwrap(), vec![1.0, 1.0]);
375 }
376
377 #[test]
378 fn lazy_norm_train_eval_toggle() {
379 let mut bn: LazyBatchNorm2d<f32> = LazyBatchNorm2d::new(1e-5, 0.1, true);
380 assert!(bn.is_training());
381 bn.eval();
382 assert!(!bn.is_training());
383 bn.train();
384 assert!(bn.is_training());
385 }
386}