1use std::sync::OnceLock;
20use std::sync::atomic::{AtomicBool, Ordering};
21
22use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
23
24use crate::conv::{ConvTranspose1d, ConvTranspose2d, ConvTranspose3d};
25use crate::module::Module;
26use crate::parameter::Parameter;
27
28fn channels_from_input<T: Float>(
29 input: &Tensor<T>,
30 op: &str,
31 expected_ndim: usize,
32) -> FerrotorchResult<usize> {
33 if input.ndim() != expected_ndim {
34 return Err(FerrotorchError::ShapeMismatch {
35 message: format!(
36 "{op}: expected {expected_ndim}-D input [N, C, ...], got {}-D",
37 input.ndim()
38 ),
39 });
40 }
41 Ok(input.shape()[1])
42}
43
44#[derive(Debug)]
46pub struct LazyConvTranspose1d<T: Float> {
47 out_channels: usize,
48 kernel_size: usize,
49 stride: usize,
50 padding: usize,
51 output_padding: usize,
52 bias_enabled: bool,
53 inner: OnceLock<ConvTranspose1d<T>>,
54 training: AtomicBool,
55}
56
57impl<T: Float> LazyConvTranspose1d<T> {
58 pub fn new(
59 out_channels: usize,
60 kernel_size: usize,
61 stride: usize,
62 padding: usize,
63 output_padding: usize,
64 bias: bool,
65 ) -> Self {
66 Self {
67 out_channels,
68 kernel_size,
69 stride,
70 padding,
71 output_padding,
72 bias_enabled: bias,
73 inner: OnceLock::new(),
74 training: AtomicBool::new(true),
75 }
76 }
77
78 pub fn is_initialized(&self) -> bool {
79 self.inner.get().is_some()
80 }
81
82 pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
83 if self.inner.get().is_none() {
84 let inner = ConvTranspose1d::<T>::new(
85 in_channels,
86 self.out_channels,
87 self.kernel_size,
88 self.stride,
89 self.padding,
90 self.output_padding,
91 self.bias_enabled,
92 )?;
93 let _ = self.inner.set(inner);
94 }
95 Ok(())
96 }
97}
98
99impl<T: Float> Module<T> for LazyConvTranspose1d<T> {
100 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
101 if self.inner.get().is_none() {
102 let c = channels_from_input(input, "LazyConvTranspose1d", 3)?;
103 self.materialize(c)?;
104 }
105 self.inner.get().expect("inner").forward(input)
106 }
107
108 fn parameters(&self) -> Vec<&Parameter<T>> {
109 self.inner.get().map(|m| m.parameters()).unwrap_or_default()
110 }
111
112 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
113 self.inner
114 .get_mut()
115 .map(|m| m.parameters_mut())
116 .unwrap_or_default()
117 }
118
119 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
120 self.inner
121 .get()
122 .map(|m| m.named_parameters())
123 .unwrap_or_default()
124 }
125
126 fn train(&mut self) {
127 self.training.store(true, Ordering::Relaxed);
128 if let Some(m) = self.inner.get_mut() {
129 m.train();
130 }
131 }
132
133 fn eval(&mut self) {
134 self.training.store(false, Ordering::Relaxed);
135 if let Some(m) = self.inner.get_mut() {
136 m.eval();
137 }
138 }
139
140 fn is_training(&self) -> bool {
141 self.training.load(Ordering::Relaxed)
142 }
143}
144
145#[derive(Debug)]
147pub struct LazyConvTranspose2d<T: Float> {
148 out_channels: usize,
149 kernel_size: (usize, usize),
150 stride: (usize, usize),
151 padding: (usize, usize),
152 output_padding: (usize, usize),
153 bias_enabled: bool,
154 inner: OnceLock<ConvTranspose2d<T>>,
155 training: AtomicBool,
156}
157
158impl<T: Float> LazyConvTranspose2d<T> {
159 pub fn new(
160 out_channels: usize,
161 kernel_size: (usize, usize),
162 stride: (usize, usize),
163 padding: (usize, usize),
164 output_padding: (usize, usize),
165 bias: bool,
166 ) -> Self {
167 Self {
168 out_channels,
169 kernel_size,
170 stride,
171 padding,
172 output_padding,
173 bias_enabled: bias,
174 inner: OnceLock::new(),
175 training: AtomicBool::new(true),
176 }
177 }
178
179 pub fn is_initialized(&self) -> bool {
180 self.inner.get().is_some()
181 }
182
183 pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
184 if self.inner.get().is_none() {
185 let inner = ConvTranspose2d::<T>::new(
186 in_channels,
187 self.out_channels,
188 self.kernel_size,
189 self.stride,
190 self.padding,
191 self.output_padding,
192 self.bias_enabled,
193 )?;
194 let _ = self.inner.set(inner);
195 }
196 Ok(())
197 }
198}
199
200impl<T: Float> Module<T> for LazyConvTranspose2d<T> {
201 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
202 if self.inner.get().is_none() {
203 let c = channels_from_input(input, "LazyConvTranspose2d", 4)?;
204 self.materialize(c)?;
205 }
206 self.inner.get().expect("inner").forward(input)
207 }
208
209 fn parameters(&self) -> Vec<&Parameter<T>> {
210 self.inner.get().map(|m| m.parameters()).unwrap_or_default()
211 }
212
213 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
214 self.inner
215 .get_mut()
216 .map(|m| m.parameters_mut())
217 .unwrap_or_default()
218 }
219
220 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
221 self.inner
222 .get()
223 .map(|m| m.named_parameters())
224 .unwrap_or_default()
225 }
226
227 fn train(&mut self) {
228 self.training.store(true, Ordering::Relaxed);
229 if let Some(m) = self.inner.get_mut() {
230 m.train();
231 }
232 }
233
234 fn eval(&mut self) {
235 self.training.store(false, Ordering::Relaxed);
236 if let Some(m) = self.inner.get_mut() {
237 m.eval();
238 }
239 }
240
241 fn is_training(&self) -> bool {
242 self.training.load(Ordering::Relaxed)
243 }
244}
245
246#[derive(Debug)]
248pub struct LazyConvTranspose3d<T: Float> {
249 out_channels: usize,
250 kernel_size: (usize, usize, usize),
251 stride: (usize, usize, usize),
252 padding: (usize, usize, usize),
253 output_padding: (usize, usize, usize),
254 bias_enabled: bool,
255 inner: OnceLock<ConvTranspose3d<T>>,
256 training: AtomicBool,
257}
258
259impl<T: Float> LazyConvTranspose3d<T> {
260 pub fn new(
261 out_channels: usize,
262 kernel_size: (usize, usize, usize),
263 stride: (usize, usize, usize),
264 padding: (usize, usize, usize),
265 output_padding: (usize, usize, usize),
266 bias: bool,
267 ) -> Self {
268 Self {
269 out_channels,
270 kernel_size,
271 stride,
272 padding,
273 output_padding,
274 bias_enabled: bias,
275 inner: OnceLock::new(),
276 training: AtomicBool::new(true),
277 }
278 }
279
280 pub fn is_initialized(&self) -> bool {
281 self.inner.get().is_some()
282 }
283
284 pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
285 if self.inner.get().is_none() {
286 let inner = ConvTranspose3d::<T>::new(
287 in_channels,
288 self.out_channels,
289 self.kernel_size,
290 self.stride,
291 self.padding,
292 self.output_padding,
293 self.bias_enabled,
294 )?;
295 let _ = self.inner.set(inner);
296 }
297 Ok(())
298 }
299}
300
301impl<T: Float> Module<T> for LazyConvTranspose3d<T> {
302 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
303 if self.inner.get().is_none() {
304 let c = channels_from_input(input, "LazyConvTranspose3d", 5)?;
305 self.materialize(c)?;
306 }
307 self.inner.get().expect("inner").forward(input)
308 }
309
310 fn parameters(&self) -> Vec<&Parameter<T>> {
311 self.inner.get().map(|m| m.parameters()).unwrap_or_default()
312 }
313
314 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
315 self.inner
316 .get_mut()
317 .map(|m| m.parameters_mut())
318 .unwrap_or_default()
319 }
320
321 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
322 self.inner
323 .get()
324 .map(|m| m.named_parameters())
325 .unwrap_or_default()
326 }
327
328 fn train(&mut self) {
329 self.training.store(true, Ordering::Relaxed);
330 if let Some(m) = self.inner.get_mut() {
331 m.train();
332 }
333 }
334
335 fn eval(&mut self) {
336 self.training.store(false, Ordering::Relaxed);
337 if let Some(m) = self.inner.get_mut() {
338 m.eval();
339 }
340 }
341
342 fn is_training(&self) -> bool {
343 self.training.load(Ordering::Relaxed)
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use ferrotorch_core::storage::TensorStorage;
351
352 fn cpu_tensor(data: Vec<f32>, shape: &[usize]) -> Tensor<f32> {
353 Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
354 }
355
356 #[test]
357 fn lazy_conv_transpose2d_explicit_materialize() {
358 let m: LazyConvTranspose2d<f32> =
359 LazyConvTranspose2d::new(8, (3, 3), (1, 1), (1, 1), (0, 0), true);
360 assert!(!m.is_initialized());
361 m.materialize(4).unwrap();
362 assert!(m.is_initialized());
363 assert_eq!(m.parameters().len(), 2);
365 }
366
367 #[test]
368 fn lazy_conv_transpose1d_rejects_wrong_rank() {
369 let m: LazyConvTranspose1d<f32> = LazyConvTranspose1d::new(4, 3, 1, 0, 0, true);
370 let input = cpu_tensor(vec![1.0, 2.0], &[2]);
371 let err = m.forward(&input).unwrap_err();
372 assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
373 }
374
375 #[test]
376 fn lazy_conv_transpose3d_explicit_materialize() {
377 let m: LazyConvTranspose3d<f32> =
378 LazyConvTranspose3d::new(2, (2, 2, 2), (1, 1, 1), (0, 0, 0), (0, 0, 0), false);
379 m.materialize(3).unwrap();
380 assert!(m.is_initialized());
381 }
382
383 #[test]
384 fn lazy_conv_transpose_train_eval_toggle() {
385 let mut m: LazyConvTranspose2d<f32> =
386 LazyConvTranspose2d::new(4, (3, 3), (1, 1), (0, 0), (0, 0), true);
387 assert!(m.is_training());
388 m.eval();
389 assert!(!m.is_training());
390 m.train();
391 assert!(m.is_training());
392 }
393}