Skip to main content

ferrotorch_nn/
lazy_conv_transpose.rs

1//! Lazy variants of [`ConvTranspose{1,2,3}d`]. (#622)
2//!
3//! `in_channels` is discovered from the input's channel dim on the first
4//! forward call; everything else is provided up front.
5//!
6//! ## REQ status (per `.design/ferrotorch-nn/lazy_conv_transpose.md`)
7//!
8//! | REQ | Status | Evidence |
9//! |---|---|---|
10//! | REQ-1 | SHIPPED | impl: `pub struct LazyConvTranspose1d<T: Float>` here; non-test consumer: `pub use lazy_conv_transpose::LazyConvTranspose1d` in `lib.rs`. |
11//! | REQ-2 | SHIPPED | impl: `pub struct LazyConvTranspose2d<T: Float>` here; non-test consumer: `pub use lazy_conv_transpose::LazyConvTranspose2d` in `lib.rs`. |
12//! | REQ-3 | SHIPPED | impl: `pub struct LazyConvTranspose3d<T: Float>` here; non-test consumer: `pub use lazy_conv_transpose::LazyConvTranspose3d` in `lib.rs`. |
13//! | REQ-4 | SHIPPED | impl: the `LazyConvTransposeNd::new(...)` constructor bodies (infallible) here; non-test consumer: dynamic-shape decoder pipelines instantiate via these constructors. |
14//! | REQ-5 | SHIPPED | impl: the `LazyConvTransposeNd::materialize(in_channels)` body constructing the inner `ConvTransposeNd::<T>::new(...)` here; non-test consumer: dynamic-shape decoder code calls `materialize(known_in_channels)`. |
15//! | REQ-6 | SHIPPED | impl: `<LazyConvTransposeNd as Module>::forward` here (channel + rank check + first-call materialize + delegate); non-test consumer: any U-Net-style decoder containing `LazyConvTranspose2d` runs this every training step. |
16//! | REQ-7 | SHIPPED | impl: `Module<T>` impl block forwarding `parameters` / etc through `inner` here; non-test consumer: `ferrotorch_optim::Optimizer` walks `model.parameters_mut()`, which surfaces the inner `ConvTranspose`'s params after the first forward materializes. |
17//! | REQ-8 | SHIPPED | impl: the `LazyConvTransposeNd::is_initialized` accessor here; non-test consumer: training-loop setup code querying initialization state. |
18
19use std::sync::OnceLock;
20use std::sync::atomic::{AtomicBool, Ordering};
21
22use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
23
24use crate::conv::{ConvTranspose1d, ConvTranspose2d, ConvTranspose3d};
25use crate::module::Module;
26use crate::parameter::Parameter;
27
28fn channels_from_input<T: Float>(
29    input: &Tensor<T>,
30    op: &str,
31    expected_ndim: usize,
32) -> FerrotorchResult<usize> {
33    if input.ndim() != expected_ndim {
34        return Err(FerrotorchError::ShapeMismatch {
35            message: format!(
36                "{op}: expected {expected_ndim}-D input [N, C, ...], got {}-D",
37                input.ndim()
38            ),
39        });
40    }
41    Ok(input.shape()[1])
42}
43
44/// Lazy 1-D transposed convolution. `in_channels` discovered at first forward.
45#[derive(Debug)]
46pub struct LazyConvTranspose1d<T: Float> {
47    out_channels: usize,
48    kernel_size: usize,
49    stride: usize,
50    padding: usize,
51    output_padding: usize,
52    bias_enabled: bool,
53    inner: OnceLock<ConvTranspose1d<T>>,
54    training: AtomicBool,
55}
56
57impl<T: Float> LazyConvTranspose1d<T> {
58    pub fn new(
59        out_channels: usize,
60        kernel_size: usize,
61        stride: usize,
62        padding: usize,
63        output_padding: usize,
64        bias: bool,
65    ) -> Self {
66        Self {
67            out_channels,
68            kernel_size,
69            stride,
70            padding,
71            output_padding,
72            bias_enabled: bias,
73            inner: OnceLock::new(),
74            training: AtomicBool::new(true),
75        }
76    }
77
78    pub fn is_initialized(&self) -> bool {
79        self.inner.get().is_some()
80    }
81
82    pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
83        if self.inner.get().is_none() {
84            let inner = ConvTranspose1d::<T>::new(
85                in_channels,
86                self.out_channels,
87                self.kernel_size,
88                self.stride,
89                self.padding,
90                self.output_padding,
91                self.bias_enabled,
92            )?;
93            let _ = self.inner.set(inner);
94        }
95        Ok(())
96    }
97}
98
99impl<T: Float> Module<T> for LazyConvTranspose1d<T> {
100    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
101        if self.inner.get().is_none() {
102            let c = channels_from_input(input, "LazyConvTranspose1d", 3)?;
103            self.materialize(c)?;
104        }
105        self.inner.get().expect("inner").forward(input)
106    }
107
108    fn parameters(&self) -> Vec<&Parameter<T>> {
109        self.inner.get().map(|m| m.parameters()).unwrap_or_default()
110    }
111
112    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
113        self.inner
114            .get_mut()
115            .map(|m| m.parameters_mut())
116            .unwrap_or_default()
117    }
118
119    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
120        self.inner
121            .get()
122            .map(|m| m.named_parameters())
123            .unwrap_or_default()
124    }
125
126    fn train(&mut self) {
127        self.training.store(true, Ordering::Relaxed);
128        if let Some(m) = self.inner.get_mut() {
129            m.train();
130        }
131    }
132
133    fn eval(&mut self) {
134        self.training.store(false, Ordering::Relaxed);
135        if let Some(m) = self.inner.get_mut() {
136            m.eval();
137        }
138    }
139
140    fn is_training(&self) -> bool {
141        self.training.load(Ordering::Relaxed)
142    }
143}
144
145/// Lazy 2-D transposed convolution.
146#[derive(Debug)]
147pub struct LazyConvTranspose2d<T: Float> {
148    out_channels: usize,
149    kernel_size: (usize, usize),
150    stride: (usize, usize),
151    padding: (usize, usize),
152    output_padding: (usize, usize),
153    bias_enabled: bool,
154    inner: OnceLock<ConvTranspose2d<T>>,
155    training: AtomicBool,
156}
157
158impl<T: Float> LazyConvTranspose2d<T> {
159    pub fn new(
160        out_channels: usize,
161        kernel_size: (usize, usize),
162        stride: (usize, usize),
163        padding: (usize, usize),
164        output_padding: (usize, usize),
165        bias: bool,
166    ) -> Self {
167        Self {
168            out_channels,
169            kernel_size,
170            stride,
171            padding,
172            output_padding,
173            bias_enabled: bias,
174            inner: OnceLock::new(),
175            training: AtomicBool::new(true),
176        }
177    }
178
179    pub fn is_initialized(&self) -> bool {
180        self.inner.get().is_some()
181    }
182
183    pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
184        if self.inner.get().is_none() {
185            let inner = ConvTranspose2d::<T>::new(
186                in_channels,
187                self.out_channels,
188                self.kernel_size,
189                self.stride,
190                self.padding,
191                self.output_padding,
192                self.bias_enabled,
193            )?;
194            let _ = self.inner.set(inner);
195        }
196        Ok(())
197    }
198}
199
200impl<T: Float> Module<T> for LazyConvTranspose2d<T> {
201    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
202        if self.inner.get().is_none() {
203            let c = channels_from_input(input, "LazyConvTranspose2d", 4)?;
204            self.materialize(c)?;
205        }
206        self.inner.get().expect("inner").forward(input)
207    }
208
209    fn parameters(&self) -> Vec<&Parameter<T>> {
210        self.inner.get().map(|m| m.parameters()).unwrap_or_default()
211    }
212
213    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
214        self.inner
215            .get_mut()
216            .map(|m| m.parameters_mut())
217            .unwrap_or_default()
218    }
219
220    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
221        self.inner
222            .get()
223            .map(|m| m.named_parameters())
224            .unwrap_or_default()
225    }
226
227    fn train(&mut self) {
228        self.training.store(true, Ordering::Relaxed);
229        if let Some(m) = self.inner.get_mut() {
230            m.train();
231        }
232    }
233
234    fn eval(&mut self) {
235        self.training.store(false, Ordering::Relaxed);
236        if let Some(m) = self.inner.get_mut() {
237            m.eval();
238        }
239    }
240
241    fn is_training(&self) -> bool {
242        self.training.load(Ordering::Relaxed)
243    }
244}
245
246/// Lazy 3-D transposed convolution.
247#[derive(Debug)]
248pub struct LazyConvTranspose3d<T: Float> {
249    out_channels: usize,
250    kernel_size: (usize, usize, usize),
251    stride: (usize, usize, usize),
252    padding: (usize, usize, usize),
253    output_padding: (usize, usize, usize),
254    bias_enabled: bool,
255    inner: OnceLock<ConvTranspose3d<T>>,
256    training: AtomicBool,
257}
258
259impl<T: Float> LazyConvTranspose3d<T> {
260    pub fn new(
261        out_channels: usize,
262        kernel_size: (usize, usize, usize),
263        stride: (usize, usize, usize),
264        padding: (usize, usize, usize),
265        output_padding: (usize, usize, usize),
266        bias: bool,
267    ) -> Self {
268        Self {
269            out_channels,
270            kernel_size,
271            stride,
272            padding,
273            output_padding,
274            bias_enabled: bias,
275            inner: OnceLock::new(),
276            training: AtomicBool::new(true),
277        }
278    }
279
280    pub fn is_initialized(&self) -> bool {
281        self.inner.get().is_some()
282    }
283
284    pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
285        if self.inner.get().is_none() {
286            let inner = ConvTranspose3d::<T>::new(
287                in_channels,
288                self.out_channels,
289                self.kernel_size,
290                self.stride,
291                self.padding,
292                self.output_padding,
293                self.bias_enabled,
294            )?;
295            let _ = self.inner.set(inner);
296        }
297        Ok(())
298    }
299}
300
301impl<T: Float> Module<T> for LazyConvTranspose3d<T> {
302    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
303        if self.inner.get().is_none() {
304            let c = channels_from_input(input, "LazyConvTranspose3d", 5)?;
305            self.materialize(c)?;
306        }
307        self.inner.get().expect("inner").forward(input)
308    }
309
310    fn parameters(&self) -> Vec<&Parameter<T>> {
311        self.inner.get().map(|m| m.parameters()).unwrap_or_default()
312    }
313
314    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
315        self.inner
316            .get_mut()
317            .map(|m| m.parameters_mut())
318            .unwrap_or_default()
319    }
320
321    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
322        self.inner
323            .get()
324            .map(|m| m.named_parameters())
325            .unwrap_or_default()
326    }
327
328    fn train(&mut self) {
329        self.training.store(true, Ordering::Relaxed);
330        if let Some(m) = self.inner.get_mut() {
331            m.train();
332        }
333    }
334
335    fn eval(&mut self) {
336        self.training.store(false, Ordering::Relaxed);
337        if let Some(m) = self.inner.get_mut() {
338            m.eval();
339        }
340    }
341
342    fn is_training(&self) -> bool {
343        self.training.load(Ordering::Relaxed)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use ferrotorch_core::storage::TensorStorage;
351
352    fn cpu_tensor(data: Vec<f32>, shape: &[usize]) -> Tensor<f32> {
353        Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
354    }
355
356    #[test]
357    fn lazy_conv_transpose2d_explicit_materialize() {
358        let m: LazyConvTranspose2d<f32> =
359            LazyConvTranspose2d::new(8, (3, 3), (1, 1), (1, 1), (0, 0), true);
360        assert!(!m.is_initialized());
361        m.materialize(4).unwrap();
362        assert!(m.is_initialized());
363        // weight + bias = 2 params
364        assert_eq!(m.parameters().len(), 2);
365    }
366
367    #[test]
368    fn lazy_conv_transpose1d_rejects_wrong_rank() {
369        let m: LazyConvTranspose1d<f32> = LazyConvTranspose1d::new(4, 3, 1, 0, 0, true);
370        let input = cpu_tensor(vec![1.0, 2.0], &[2]);
371        let err = m.forward(&input).unwrap_err();
372        assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
373    }
374
375    #[test]
376    fn lazy_conv_transpose3d_explicit_materialize() {
377        let m: LazyConvTranspose3d<f32> =
378            LazyConvTranspose3d::new(2, (2, 2, 2), (1, 1, 1), (0, 0, 0), (0, 0, 0), false);
379        m.materialize(3).unwrap();
380        assert!(m.is_initialized());
381    }
382
383    #[test]
384    fn lazy_conv_transpose_train_eval_toggle() {
385        let mut m: LazyConvTranspose2d<f32> =
386            LazyConvTranspose2d::new(4, (3, 3), (1, 1), (0, 0), (0, 0), true);
387        assert!(m.is_training());
388        m.eval();
389        assert!(!m.is_training());
390        m.train();
391        assert!(m.is_training());
392    }
393}