Skip to main content

ferrotorch_nn/
lazy_conv.rs

1//! Lazy variants of [`Conv1d`](super::Conv1d), [`Conv2d`](super::Conv2d),
2//! and [`Conv3d`](super::Conv3d).
3//!
4//! Like [`LazyLinear`](super::LazyLinear), these modules defer parameter
5//! allocation until the first forward call, at which point the input
6//! tensor's channel dimension (`dim 1` for the standard `[B, C, ...]`
7//! layout) is taken as `in_channels`, and the underlying `Conv{1,2,3}d`
8//! is constructed and stored.
9//!
10//! Mirrors `torch.nn.LazyConv1d`, `torch.nn.LazyConv2d`, and
11//! `torch.nn.LazyConv3d`.
12//!
13//! # Thread safety
14//!
15//! Materialization uses [`std::sync::OnceLock`] so the first forward call
16//! across any number of threads initializes the parameters exactly once.
17//!
18//! # Design
19//!
20//! Each lazy conv wraps a `OnceLock<ConvNd<T>>`. On first forward, the
21//! input's `dim 1` is inspected and a `ConvNd::new(...)` is constructed
22//! with the user's kernel_size / stride / padding and the newly-discovered
23//! in_channels. Subsequent forward calls delegate directly to the inner
24//! module. All parameter accessors (`parameters()`, `named_parameters()`,
25//! etc.) forward through the `OnceLock::get()` path and return an empty
26//! list before materialization. CL-445.
27//!
28//! ## REQ status (per `.design/ferrotorch-nn/lazy_conv.md`)
29//!
30//! | REQ | Status | Evidence |
31//! |---|---|---|
32//! | REQ-1 | SHIPPED | impl: `pub struct LazyConv1d<T: Float>` here with the deferred-init field layout; non-test consumer: `pub use lazy_conv::LazyConv1d` in `lib.rs`. |
33//! | REQ-2 | SHIPPED | impl: `pub struct LazyConv2d<T: Float>` here; non-test consumer: `pub use lazy_conv::LazyConv2d` in `lib.rs`. |
34//! | REQ-3 | SHIPPED | impl: `pub struct LazyConv3d<T: Float>` here; non-test consumer: `pub use lazy_conv::LazyConv3d` in `lib.rs`. |
35//! | REQ-4 | SHIPPED | impl: the `LazyConvNd::new` constructor bodies here validating `out_channels > 0`, kernel/stride > 0; non-test consumer: dynamic-shape pipeline construction in downstream vision code. |
36//! | REQ-5 | SHIPPED | impl: the `LazyConvNd::materialize(in_channels)` body constructing the inner `ConvNd::new(...)` here; non-test consumer: dynamic-shape pipelines call `materialize(known_in_channels)` to populate parameters before constructing the optimizer. |
37//! | REQ-6 | SHIPPED | impl: `<LazyConvNd as Module>::forward` body here (ndim check + first-call materialize + delegate to inner); non-test consumer: any model containing a `LazyConv2d` runs this on every training forward. |
38//! | REQ-7 | SHIPPED | impl: `Module<T>` impl forwarding `parameters` / `parameters_mut` / `named_parameters` through `inner.get()` here; non-test consumer: `ferrotorch_optim::Optimizer` walks `model.parameters_mut()` and sees the inner Conv's params after the first forward materializes. |
39//! | REQ-8 | SHIPPED | impl: the `LazyConvNd::is_initialized` accessor here; non-test consumer: training-loop setup code that queries `is_initialized` to decide whether to call the materialize path explicitly. |
40
41use std::sync::OnceLock;
42use std::sync::atomic::{AtomicBool, Ordering};
43
44use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
45
46use crate::conv::{Conv1d, Conv2d, Conv3d};
47use crate::module::Module;
48use crate::parameter::Parameter;
49
50// ===========================================================================
51// LazyConv1d
52// ===========================================================================
53
54/// 1-D convolution layer that defers `in_channels` discovery to the first
55/// forward call. Mirrors `torch.nn.LazyConv1d`.
56#[derive(Debug)]
57pub struct LazyConv1d<T: Float> {
58    out_channels: usize,
59    kernel_size: usize,
60    stride: usize,
61    padding: usize,
62    bias_enabled: bool,
63    inner: OnceLock<Conv1d<T>>,
64    training: AtomicBool,
65}
66
67impl<T: Float> LazyConv1d<T> {
68    /// Build a new `LazyConv1d`. `in_channels` will be discovered from
69    /// the first forward input (dim 1 of the `[B, C_in, L]` tensor).
70    pub fn new(
71        out_channels: usize,
72        kernel_size: usize,
73        stride: usize,
74        padding: usize,
75        bias: bool,
76    ) -> FerrotorchResult<Self> {
77        if out_channels == 0 {
78            return Err(FerrotorchError::InvalidArgument {
79                message: "LazyConv1d: out_channels must be > 0".into(),
80            });
81        }
82        if kernel_size == 0 {
83            return Err(FerrotorchError::InvalidArgument {
84                message: "LazyConv1d: kernel_size must be > 0".into(),
85            });
86        }
87        if stride == 0 {
88            return Err(FerrotorchError::InvalidArgument {
89                message: "LazyConv1d: stride must be > 0".into(),
90            });
91        }
92        Ok(Self {
93            out_channels,
94            kernel_size,
95            stride,
96            padding,
97            bias_enabled: bias,
98            inner: OnceLock::new(),
99            training: AtomicBool::new(true),
100        })
101    }
102
103    /// Returns `true` once `in_channels` has been discovered and the
104    /// inner [`Conv1d`] has been constructed.
105    pub fn is_initialized(&self) -> bool {
106        self.inner.get().is_some()
107    }
108
109    /// Eagerly materialize the inner Conv1d with the given `in_channels`.
110    pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
111        if in_channels == 0 {
112            return Err(FerrotorchError::InvalidArgument {
113                message: "LazyConv1d: in_channels must be > 0".into(),
114            });
115        }
116        if self.inner.get().is_none() {
117            let conv = Conv1d::new(
118                in_channels,
119                self.out_channels,
120                self.kernel_size,
121                self.stride,
122                self.padding,
123                self.bias_enabled,
124            )?;
125            let _ = self.inner.set(conv);
126        }
127        Ok(())
128    }
129}
130
131impl<T: Float> Module<T> for LazyConv1d<T> {
132    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
133        if input.ndim() != 3 {
134            return Err(FerrotorchError::InvalidArgument {
135                message: format!(
136                    "LazyConv1d expects 3-D input [B, C, L], got {:?}",
137                    input.shape()
138                ),
139            });
140        }
141        if self.inner.get().is_none() {
142            let in_channels = input.shape()[1];
143            self.materialize(in_channels)?;
144        }
145        let conv = self.inner.get().expect("initialized after materialize()");
146        conv.forward(input)
147    }
148
149    fn parameters(&self) -> Vec<&Parameter<T>> {
150        self.inner.get().map(|c| c.parameters()).unwrap_or_default()
151    }
152
153    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
154        self.inner
155            .get_mut()
156            .map(|c| c.parameters_mut())
157            .unwrap_or_default()
158    }
159
160    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
161        self.inner
162            .get()
163            .map(|c| c.named_parameters())
164            .unwrap_or_default()
165    }
166
167    fn train(&mut self) {
168        self.training.store(true, Ordering::Relaxed);
169        if let Some(c) = self.inner.get_mut() {
170            c.train();
171        }
172    }
173
174    fn eval(&mut self) {
175        self.training.store(false, Ordering::Relaxed);
176        if let Some(c) = self.inner.get_mut() {
177            c.eval();
178        }
179    }
180
181    fn is_training(&self) -> bool {
182        self.training.load(Ordering::Relaxed)
183    }
184}
185
186// ===========================================================================
187// LazyConv2d
188// ===========================================================================
189
190/// 2-D convolution layer that defers `in_channels` discovery to the first
191/// forward call. Mirrors `torch.nn.LazyConv2d`.
192#[derive(Debug)]
193pub struct LazyConv2d<T: Float> {
194    out_channels: usize,
195    kernel_size: (usize, usize),
196    stride: (usize, usize),
197    padding: (usize, usize),
198    bias_enabled: bool,
199    inner: OnceLock<Conv2d<T>>,
200    training: AtomicBool,
201}
202
203impl<T: Float> LazyConv2d<T> {
204    /// Build a new `LazyConv2d`. `in_channels` will be discovered from
205    /// the first forward input (dim 1 of the `[B, C_in, H, W]` tensor).
206    pub fn new(
207        out_channels: usize,
208        kernel_size: (usize, usize),
209        stride: (usize, usize),
210        padding: (usize, usize),
211        bias: bool,
212    ) -> FerrotorchResult<Self> {
213        if out_channels == 0 {
214            return Err(FerrotorchError::InvalidArgument {
215                message: "LazyConv2d: out_channels must be > 0".into(),
216            });
217        }
218        if kernel_size.0 == 0 || kernel_size.1 == 0 {
219            return Err(FerrotorchError::InvalidArgument {
220                message: "LazyConv2d: kernel_size must be > 0 in both dimensions".into(),
221            });
222        }
223        if stride.0 == 0 || stride.1 == 0 {
224            return Err(FerrotorchError::InvalidArgument {
225                message: "LazyConv2d: stride must be > 0 in both dimensions".into(),
226            });
227        }
228        Ok(Self {
229            out_channels,
230            kernel_size,
231            stride,
232            padding,
233            bias_enabled: bias,
234            inner: OnceLock::new(),
235            training: AtomicBool::new(true),
236        })
237    }
238
239    pub fn is_initialized(&self) -> bool {
240        self.inner.get().is_some()
241    }
242
243    pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
244        if in_channels == 0 {
245            return Err(FerrotorchError::InvalidArgument {
246                message: "LazyConv2d: in_channels must be > 0".into(),
247            });
248        }
249        if self.inner.get().is_none() {
250            let conv = Conv2d::new(
251                in_channels,
252                self.out_channels,
253                self.kernel_size,
254                self.stride,
255                self.padding,
256                self.bias_enabled,
257            )?;
258            let _ = self.inner.set(conv);
259        }
260        Ok(())
261    }
262}
263
264impl<T: Float> Module<T> for LazyConv2d<T> {
265    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
266        if input.ndim() != 4 {
267            return Err(FerrotorchError::InvalidArgument {
268                message: format!(
269                    "LazyConv2d expects 4-D input [B, C, H, W], got {:?}",
270                    input.shape()
271                ),
272            });
273        }
274        if self.inner.get().is_none() {
275            let in_channels = input.shape()[1];
276            self.materialize(in_channels)?;
277        }
278        let conv = self.inner.get().expect("initialized after materialize()");
279        conv.forward(input)
280    }
281
282    fn parameters(&self) -> Vec<&Parameter<T>> {
283        self.inner.get().map(|c| c.parameters()).unwrap_or_default()
284    }
285
286    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
287        self.inner
288            .get_mut()
289            .map(|c| c.parameters_mut())
290            .unwrap_or_default()
291    }
292
293    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
294        self.inner
295            .get()
296            .map(|c| c.named_parameters())
297            .unwrap_or_default()
298    }
299
300    fn train(&mut self) {
301        self.training.store(true, Ordering::Relaxed);
302        if let Some(c) = self.inner.get_mut() {
303            c.train();
304        }
305    }
306
307    fn eval(&mut self) {
308        self.training.store(false, Ordering::Relaxed);
309        if let Some(c) = self.inner.get_mut() {
310            c.eval();
311        }
312    }
313
314    fn is_training(&self) -> bool {
315        self.training.load(Ordering::Relaxed)
316    }
317}
318
319// ===========================================================================
320// LazyConv3d
321// ===========================================================================
322
323/// 3-D convolution layer that defers `in_channels` discovery to the first
324/// forward call. Mirrors `torch.nn.LazyConv3d`.
325#[derive(Debug)]
326pub struct LazyConv3d<T: Float> {
327    out_channels: usize,
328    kernel_size: (usize, usize, usize),
329    stride: (usize, usize, usize),
330    padding: (usize, usize, usize),
331    bias_enabled: bool,
332    inner: OnceLock<Conv3d<T>>,
333    training: AtomicBool,
334}
335
336impl<T: Float> LazyConv3d<T> {
337    /// Build a new `LazyConv3d`. `in_channels` will be discovered from
338    /// the first forward input (dim 1 of the `[B, C_in, D, H, W]` tensor).
339    pub fn new(
340        out_channels: usize,
341        kernel_size: (usize, usize, usize),
342        stride: (usize, usize, usize),
343        padding: (usize, usize, usize),
344        bias: bool,
345    ) -> FerrotorchResult<Self> {
346        if out_channels == 0 {
347            return Err(FerrotorchError::InvalidArgument {
348                message: "LazyConv3d: out_channels must be > 0".into(),
349            });
350        }
351        if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
352            return Err(FerrotorchError::InvalidArgument {
353                message: "LazyConv3d: kernel_size must be > 0 in all dimensions".into(),
354            });
355        }
356        if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
357            return Err(FerrotorchError::InvalidArgument {
358                message: "LazyConv3d: stride must be > 0 in all dimensions".into(),
359            });
360        }
361        Ok(Self {
362            out_channels,
363            kernel_size,
364            stride,
365            padding,
366            bias_enabled: bias,
367            inner: OnceLock::new(),
368            training: AtomicBool::new(true),
369        })
370    }
371
372    pub fn is_initialized(&self) -> bool {
373        self.inner.get().is_some()
374    }
375
376    pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
377        if in_channels == 0 {
378            return Err(FerrotorchError::InvalidArgument {
379                message: "LazyConv3d: in_channels must be > 0".into(),
380            });
381        }
382        if self.inner.get().is_none() {
383            let conv = Conv3d::new(
384                in_channels,
385                self.out_channels,
386                self.kernel_size,
387                self.stride,
388                self.padding,
389                self.bias_enabled,
390            )?;
391            let _ = self.inner.set(conv);
392        }
393        Ok(())
394    }
395}
396
397impl<T: Float> Module<T> for LazyConv3d<T> {
398    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
399        if input.ndim() != 5 {
400            return Err(FerrotorchError::InvalidArgument {
401                message: format!(
402                    "LazyConv3d expects 5-D input [B, C, D, H, W], got {:?}",
403                    input.shape()
404                ),
405            });
406        }
407        if self.inner.get().is_none() {
408            let in_channels = input.shape()[1];
409            self.materialize(in_channels)?;
410        }
411        let conv = self.inner.get().expect("initialized after materialize()");
412        conv.forward(input)
413    }
414
415    fn parameters(&self) -> Vec<&Parameter<T>> {
416        self.inner.get().map(|c| c.parameters()).unwrap_or_default()
417    }
418
419    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
420        self.inner
421            .get_mut()
422            .map(|c| c.parameters_mut())
423            .unwrap_or_default()
424    }
425
426    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
427        self.inner
428            .get()
429            .map(|c| c.named_parameters())
430            .unwrap_or_default()
431    }
432
433    fn train(&mut self) {
434        self.training.store(true, Ordering::Relaxed);
435        if let Some(c) = self.inner.get_mut() {
436            c.train();
437        }
438    }
439
440    fn eval(&mut self) {
441        self.training.store(false, Ordering::Relaxed);
442        if let Some(c) = self.inner.get_mut() {
443            c.eval();
444        }
445    }
446
447    fn is_training(&self) -> bool {
448        self.training.load(Ordering::Relaxed)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use ferrotorch_core::Tensor;
456    use ferrotorch_core::storage::TensorStorage;
457
458    fn cpu_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
459        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
460    }
461
462    // -----------------------------------------------------------------------
463    // LazyConv1d
464    // -----------------------------------------------------------------------
465
466    #[test]
467    fn test_lazy_conv1d_uninitialized_until_first_forward() {
468        let lazy: LazyConv1d<f32> = LazyConv1d::new(8, 3, 1, 0, true).unwrap();
469        assert!(!lazy.is_initialized());
470        assert_eq!(lazy.parameters().len(), 0);
471    }
472
473    #[test]
474    fn test_lazy_conv1d_materializes_on_first_forward() {
475        let lazy: LazyConv1d<f32> = LazyConv1d::new(4, 3, 1, 1, true).unwrap();
476        // Input shape: [batch=1, C_in=2, L=5]
477        let input = cpu_tensor(&(0..10).map(|i| i as f32).collect::<Vec<_>>(), &[1, 2, 5]);
478        let out = lazy.forward(&input).unwrap();
479        // [1, 4, 5] with padding=1, stride=1, kernel=3
480        assert_eq!(out.shape()[0], 1);
481        assert_eq!(out.shape()[1], 4);
482        assert!(lazy.is_initialized());
483        // weight + bias
484        assert_eq!(lazy.parameters().len(), 2);
485    }
486
487    #[test]
488    fn test_lazy_conv1d_rejects_wrong_input_ndim() {
489        let lazy: LazyConv1d<f32> = LazyConv1d::new(2, 3, 1, 0, true).unwrap();
490        let bad = cpu_tensor(&[1.0, 2.0, 3.0], &[3]);
491        assert!(lazy.forward(&bad).is_err());
492    }
493
494    #[test]
495    fn test_lazy_conv1d_explicit_materialize() {
496        let lazy: LazyConv1d<f32> = LazyConv1d::new(8, 3, 1, 0, true).unwrap();
497        lazy.materialize(16).unwrap();
498        assert!(lazy.is_initialized());
499        assert_eq!(lazy.parameters().len(), 2);
500    }
501
502    #[test]
503    fn test_lazy_conv1d_zero_out_channels_errors() {
504        assert!(LazyConv1d::<f32>::new(0, 3, 1, 0, true).is_err());
505    }
506
507    // -----------------------------------------------------------------------
508    // LazyConv2d
509    // -----------------------------------------------------------------------
510
511    #[test]
512    fn test_lazy_conv2d_uninitialized_until_first_forward() {
513        let lazy: LazyConv2d<f32> = LazyConv2d::new(16, (3, 3), (1, 1), (1, 1), true).unwrap();
514        assert!(!lazy.is_initialized());
515        assert_eq!(lazy.parameters().len(), 0);
516    }
517
518    #[test]
519    fn test_lazy_conv2d_materializes_on_first_forward() {
520        let lazy: LazyConv2d<f32> = LazyConv2d::new(4, (3, 3), (1, 1), (1, 1), true).unwrap();
521        // Input: [batch=1, C_in=3, H=4, W=4]
522        let data: Vec<f32> = (0..48).map(|i| i as f32 / 10.0).collect();
523        let input = cpu_tensor(&data, &[1, 3, 4, 4]);
524        let out = lazy.forward(&input).unwrap();
525        assert_eq!(out.shape()[0], 1);
526        assert_eq!(out.shape()[1], 4);
527        assert_eq!(out.shape()[2], 4); // padding keeps H
528        assert_eq!(out.shape()[3], 4);
529        assert!(lazy.is_initialized());
530        assert_eq!(lazy.parameters().len(), 2);
531    }
532
533    #[test]
534    fn test_lazy_conv2d_no_bias() {
535        let lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), false).unwrap();
536        let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
537        let input = cpu_tensor(&data, &[1, 3, 4, 4]);
538        let _ = lazy.forward(&input).unwrap();
539        assert_eq!(lazy.parameters().len(), 1);
540    }
541
542    #[test]
543    fn test_lazy_conv2d_subsequent_forward_reuses_inner() {
544        let lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), true).unwrap();
545        let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
546        let input1 = cpu_tensor(&data, &[1, 3, 4, 4]);
547        let out1 = lazy.forward(&input1).unwrap();
548
549        // Snapshot weight pointer to verify the inner module is not
550        // re-initialized on the second call.
551        let first_weight_ptr = lazy.parameters()[0].tensor().data().unwrap().as_ptr();
552
553        let input2 = cpu_tensor(&data, &[1, 3, 4, 4]);
554        let out2 = lazy.forward(&input2).unwrap();
555        let second_weight_ptr = lazy.parameters()[0].tensor().data().unwrap().as_ptr();
556        assert_eq!(first_weight_ptr, second_weight_ptr);
557        assert_eq!(out1.shape(), out2.shape());
558    }
559
560    #[test]
561    fn test_lazy_conv2d_rejects_wrong_ndim() {
562        let lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), true).unwrap();
563        let bad = cpu_tensor(&[1.0; 9], &[3, 3]);
564        assert!(lazy.forward(&bad).is_err());
565    }
566
567    #[test]
568    fn test_lazy_conv2d_train_eval_propagates_to_inner() {
569        let mut lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), true).unwrap();
570        let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
571        let input = cpu_tensor(&data, &[1, 3, 4, 4]);
572        let _ = lazy.forward(&input).unwrap();
573        lazy.eval();
574        assert!(!lazy.is_training());
575        lazy.train();
576        assert!(lazy.is_training());
577    }
578
579    // -----------------------------------------------------------------------
580    // LazyConv3d
581    // -----------------------------------------------------------------------
582
583    #[test]
584    fn test_lazy_conv3d_uninitialized_until_first_forward() {
585        let lazy: LazyConv3d<f32> =
586            LazyConv3d::new(4, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
587        assert!(!lazy.is_initialized());
588    }
589
590    #[test]
591    fn test_lazy_conv3d_materializes_on_first_forward() {
592        let lazy: LazyConv3d<f32> =
593            LazyConv3d::new(2, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
594        // Input: [batch=1, C_in=2, D=4, H=4, W=4]
595        let data: Vec<f32> = (0..128).map(|i| i as f32 / 10.0).collect();
596        let input = cpu_tensor(&data, &[1, 2, 4, 4, 4]);
597        let out = lazy.forward(&input).unwrap();
598        assert_eq!(out.shape()[0], 1);
599        assert_eq!(out.shape()[1], 2);
600        assert!(lazy.is_initialized());
601    }
602
603    #[test]
604    fn test_lazy_conv3d_rejects_wrong_ndim() {
605        let lazy: LazyConv3d<f32> =
606            LazyConv3d::new(2, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
607        let bad = cpu_tensor(&[0.0; 48], &[1, 3, 4, 4]);
608        assert!(lazy.forward(&bad).is_err());
609    }
610
611    #[test]
612    fn test_lazy_conv3d_zero_kernel_errors() {
613        assert!(LazyConv3d::<f32>::new(2, (3, 0, 3), (1, 1, 1), (1, 1, 1), true).is_err());
614    }
615}