1use std::sync::OnceLock;
42use std::sync::atomic::{AtomicBool, Ordering};
43
44use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
45
46use crate::conv::{Conv1d, Conv2d, Conv3d};
47use crate::module::Module;
48use crate::parameter::Parameter;
49
50#[derive(Debug)]
57pub struct LazyConv1d<T: Float> {
58 out_channels: usize,
59 kernel_size: usize,
60 stride: usize,
61 padding: usize,
62 bias_enabled: bool,
63 inner: OnceLock<Conv1d<T>>,
64 training: AtomicBool,
65}
66
67impl<T: Float> LazyConv1d<T> {
68 pub fn new(
71 out_channels: usize,
72 kernel_size: usize,
73 stride: usize,
74 padding: usize,
75 bias: bool,
76 ) -> FerrotorchResult<Self> {
77 if out_channels == 0 {
78 return Err(FerrotorchError::InvalidArgument {
79 message: "LazyConv1d: out_channels must be > 0".into(),
80 });
81 }
82 if kernel_size == 0 {
83 return Err(FerrotorchError::InvalidArgument {
84 message: "LazyConv1d: kernel_size must be > 0".into(),
85 });
86 }
87 if stride == 0 {
88 return Err(FerrotorchError::InvalidArgument {
89 message: "LazyConv1d: stride must be > 0".into(),
90 });
91 }
92 Ok(Self {
93 out_channels,
94 kernel_size,
95 stride,
96 padding,
97 bias_enabled: bias,
98 inner: OnceLock::new(),
99 training: AtomicBool::new(true),
100 })
101 }
102
103 pub fn is_initialized(&self) -> bool {
106 self.inner.get().is_some()
107 }
108
109 pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
111 if in_channels == 0 {
112 return Err(FerrotorchError::InvalidArgument {
113 message: "LazyConv1d: in_channels must be > 0".into(),
114 });
115 }
116 if self.inner.get().is_none() {
117 let conv = Conv1d::new(
118 in_channels,
119 self.out_channels,
120 self.kernel_size,
121 self.stride,
122 self.padding,
123 self.bias_enabled,
124 )?;
125 let _ = self.inner.set(conv);
126 }
127 Ok(())
128 }
129}
130
131impl<T: Float> Module<T> for LazyConv1d<T> {
132 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
133 if input.ndim() != 3 {
134 return Err(FerrotorchError::InvalidArgument {
135 message: format!(
136 "LazyConv1d expects 3-D input [B, C, L], got {:?}",
137 input.shape()
138 ),
139 });
140 }
141 if self.inner.get().is_none() {
142 let in_channels = input.shape()[1];
143 self.materialize(in_channels)?;
144 }
145 let conv = self.inner.get().expect("initialized after materialize()");
146 conv.forward(input)
147 }
148
149 fn parameters(&self) -> Vec<&Parameter<T>> {
150 self.inner.get().map(|c| c.parameters()).unwrap_or_default()
151 }
152
153 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
154 self.inner
155 .get_mut()
156 .map(|c| c.parameters_mut())
157 .unwrap_or_default()
158 }
159
160 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
161 self.inner
162 .get()
163 .map(|c| c.named_parameters())
164 .unwrap_or_default()
165 }
166
167 fn train(&mut self) {
168 self.training.store(true, Ordering::Relaxed);
169 if let Some(c) = self.inner.get_mut() {
170 c.train();
171 }
172 }
173
174 fn eval(&mut self) {
175 self.training.store(false, Ordering::Relaxed);
176 if let Some(c) = self.inner.get_mut() {
177 c.eval();
178 }
179 }
180
181 fn is_training(&self) -> bool {
182 self.training.load(Ordering::Relaxed)
183 }
184}
185
186#[derive(Debug)]
193pub struct LazyConv2d<T: Float> {
194 out_channels: usize,
195 kernel_size: (usize, usize),
196 stride: (usize, usize),
197 padding: (usize, usize),
198 bias_enabled: bool,
199 inner: OnceLock<Conv2d<T>>,
200 training: AtomicBool,
201}
202
203impl<T: Float> LazyConv2d<T> {
204 pub fn new(
207 out_channels: usize,
208 kernel_size: (usize, usize),
209 stride: (usize, usize),
210 padding: (usize, usize),
211 bias: bool,
212 ) -> FerrotorchResult<Self> {
213 if out_channels == 0 {
214 return Err(FerrotorchError::InvalidArgument {
215 message: "LazyConv2d: out_channels must be > 0".into(),
216 });
217 }
218 if kernel_size.0 == 0 || kernel_size.1 == 0 {
219 return Err(FerrotorchError::InvalidArgument {
220 message: "LazyConv2d: kernel_size must be > 0 in both dimensions".into(),
221 });
222 }
223 if stride.0 == 0 || stride.1 == 0 {
224 return Err(FerrotorchError::InvalidArgument {
225 message: "LazyConv2d: stride must be > 0 in both dimensions".into(),
226 });
227 }
228 Ok(Self {
229 out_channels,
230 kernel_size,
231 stride,
232 padding,
233 bias_enabled: bias,
234 inner: OnceLock::new(),
235 training: AtomicBool::new(true),
236 })
237 }
238
239 pub fn is_initialized(&self) -> bool {
240 self.inner.get().is_some()
241 }
242
243 pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
244 if in_channels == 0 {
245 return Err(FerrotorchError::InvalidArgument {
246 message: "LazyConv2d: in_channels must be > 0".into(),
247 });
248 }
249 if self.inner.get().is_none() {
250 let conv = Conv2d::new(
251 in_channels,
252 self.out_channels,
253 self.kernel_size,
254 self.stride,
255 self.padding,
256 self.bias_enabled,
257 )?;
258 let _ = self.inner.set(conv);
259 }
260 Ok(())
261 }
262}
263
264impl<T: Float> Module<T> for LazyConv2d<T> {
265 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
266 if input.ndim() != 4 {
267 return Err(FerrotorchError::InvalidArgument {
268 message: format!(
269 "LazyConv2d expects 4-D input [B, C, H, W], got {:?}",
270 input.shape()
271 ),
272 });
273 }
274 if self.inner.get().is_none() {
275 let in_channels = input.shape()[1];
276 self.materialize(in_channels)?;
277 }
278 let conv = self.inner.get().expect("initialized after materialize()");
279 conv.forward(input)
280 }
281
282 fn parameters(&self) -> Vec<&Parameter<T>> {
283 self.inner.get().map(|c| c.parameters()).unwrap_or_default()
284 }
285
286 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
287 self.inner
288 .get_mut()
289 .map(|c| c.parameters_mut())
290 .unwrap_or_default()
291 }
292
293 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
294 self.inner
295 .get()
296 .map(|c| c.named_parameters())
297 .unwrap_or_default()
298 }
299
300 fn train(&mut self) {
301 self.training.store(true, Ordering::Relaxed);
302 if let Some(c) = self.inner.get_mut() {
303 c.train();
304 }
305 }
306
307 fn eval(&mut self) {
308 self.training.store(false, Ordering::Relaxed);
309 if let Some(c) = self.inner.get_mut() {
310 c.eval();
311 }
312 }
313
314 fn is_training(&self) -> bool {
315 self.training.load(Ordering::Relaxed)
316 }
317}
318
319#[derive(Debug)]
326pub struct LazyConv3d<T: Float> {
327 out_channels: usize,
328 kernel_size: (usize, usize, usize),
329 stride: (usize, usize, usize),
330 padding: (usize, usize, usize),
331 bias_enabled: bool,
332 inner: OnceLock<Conv3d<T>>,
333 training: AtomicBool,
334}
335
336impl<T: Float> LazyConv3d<T> {
337 pub fn new(
340 out_channels: usize,
341 kernel_size: (usize, usize, usize),
342 stride: (usize, usize, usize),
343 padding: (usize, usize, usize),
344 bias: bool,
345 ) -> FerrotorchResult<Self> {
346 if out_channels == 0 {
347 return Err(FerrotorchError::InvalidArgument {
348 message: "LazyConv3d: out_channels must be > 0".into(),
349 });
350 }
351 if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
352 return Err(FerrotorchError::InvalidArgument {
353 message: "LazyConv3d: kernel_size must be > 0 in all dimensions".into(),
354 });
355 }
356 if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
357 return Err(FerrotorchError::InvalidArgument {
358 message: "LazyConv3d: stride must be > 0 in all dimensions".into(),
359 });
360 }
361 Ok(Self {
362 out_channels,
363 kernel_size,
364 stride,
365 padding,
366 bias_enabled: bias,
367 inner: OnceLock::new(),
368 training: AtomicBool::new(true),
369 })
370 }
371
372 pub fn is_initialized(&self) -> bool {
373 self.inner.get().is_some()
374 }
375
376 pub fn materialize(&self, in_channels: usize) -> FerrotorchResult<()> {
377 if in_channels == 0 {
378 return Err(FerrotorchError::InvalidArgument {
379 message: "LazyConv3d: in_channels must be > 0".into(),
380 });
381 }
382 if self.inner.get().is_none() {
383 let conv = Conv3d::new(
384 in_channels,
385 self.out_channels,
386 self.kernel_size,
387 self.stride,
388 self.padding,
389 self.bias_enabled,
390 )?;
391 let _ = self.inner.set(conv);
392 }
393 Ok(())
394 }
395}
396
397impl<T: Float> Module<T> for LazyConv3d<T> {
398 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
399 if input.ndim() != 5 {
400 return Err(FerrotorchError::InvalidArgument {
401 message: format!(
402 "LazyConv3d expects 5-D input [B, C, D, H, W], got {:?}",
403 input.shape()
404 ),
405 });
406 }
407 if self.inner.get().is_none() {
408 let in_channels = input.shape()[1];
409 self.materialize(in_channels)?;
410 }
411 let conv = self.inner.get().expect("initialized after materialize()");
412 conv.forward(input)
413 }
414
415 fn parameters(&self) -> Vec<&Parameter<T>> {
416 self.inner.get().map(|c| c.parameters()).unwrap_or_default()
417 }
418
419 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
420 self.inner
421 .get_mut()
422 .map(|c| c.parameters_mut())
423 .unwrap_or_default()
424 }
425
426 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
427 self.inner
428 .get()
429 .map(|c| c.named_parameters())
430 .unwrap_or_default()
431 }
432
433 fn train(&mut self) {
434 self.training.store(true, Ordering::Relaxed);
435 if let Some(c) = self.inner.get_mut() {
436 c.train();
437 }
438 }
439
440 fn eval(&mut self) {
441 self.training.store(false, Ordering::Relaxed);
442 if let Some(c) = self.inner.get_mut() {
443 c.eval();
444 }
445 }
446
447 fn is_training(&self) -> bool {
448 self.training.load(Ordering::Relaxed)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use ferrotorch_core::Tensor;
456 use ferrotorch_core::storage::TensorStorage;
457
458 fn cpu_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
459 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
460 }
461
462 #[test]
467 fn test_lazy_conv1d_uninitialized_until_first_forward() {
468 let lazy: LazyConv1d<f32> = LazyConv1d::new(8, 3, 1, 0, true).unwrap();
469 assert!(!lazy.is_initialized());
470 assert_eq!(lazy.parameters().len(), 0);
471 }
472
473 #[test]
474 fn test_lazy_conv1d_materializes_on_first_forward() {
475 let lazy: LazyConv1d<f32> = LazyConv1d::new(4, 3, 1, 1, true).unwrap();
476 let input = cpu_tensor(&(0..10).map(|i| i as f32).collect::<Vec<_>>(), &[1, 2, 5]);
478 let out = lazy.forward(&input).unwrap();
479 assert_eq!(out.shape()[0], 1);
481 assert_eq!(out.shape()[1], 4);
482 assert!(lazy.is_initialized());
483 assert_eq!(lazy.parameters().len(), 2);
485 }
486
487 #[test]
488 fn test_lazy_conv1d_rejects_wrong_input_ndim() {
489 let lazy: LazyConv1d<f32> = LazyConv1d::new(2, 3, 1, 0, true).unwrap();
490 let bad = cpu_tensor(&[1.0, 2.0, 3.0], &[3]);
491 assert!(lazy.forward(&bad).is_err());
492 }
493
494 #[test]
495 fn test_lazy_conv1d_explicit_materialize() {
496 let lazy: LazyConv1d<f32> = LazyConv1d::new(8, 3, 1, 0, true).unwrap();
497 lazy.materialize(16).unwrap();
498 assert!(lazy.is_initialized());
499 assert_eq!(lazy.parameters().len(), 2);
500 }
501
502 #[test]
503 fn test_lazy_conv1d_zero_out_channels_errors() {
504 assert!(LazyConv1d::<f32>::new(0, 3, 1, 0, true).is_err());
505 }
506
507 #[test]
512 fn test_lazy_conv2d_uninitialized_until_first_forward() {
513 let lazy: LazyConv2d<f32> = LazyConv2d::new(16, (3, 3), (1, 1), (1, 1), true).unwrap();
514 assert!(!lazy.is_initialized());
515 assert_eq!(lazy.parameters().len(), 0);
516 }
517
518 #[test]
519 fn test_lazy_conv2d_materializes_on_first_forward() {
520 let lazy: LazyConv2d<f32> = LazyConv2d::new(4, (3, 3), (1, 1), (1, 1), true).unwrap();
521 let data: Vec<f32> = (0..48).map(|i| i as f32 / 10.0).collect();
523 let input = cpu_tensor(&data, &[1, 3, 4, 4]);
524 let out = lazy.forward(&input).unwrap();
525 assert_eq!(out.shape()[0], 1);
526 assert_eq!(out.shape()[1], 4);
527 assert_eq!(out.shape()[2], 4); assert_eq!(out.shape()[3], 4);
529 assert!(lazy.is_initialized());
530 assert_eq!(lazy.parameters().len(), 2);
531 }
532
533 #[test]
534 fn test_lazy_conv2d_no_bias() {
535 let lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), false).unwrap();
536 let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
537 let input = cpu_tensor(&data, &[1, 3, 4, 4]);
538 let _ = lazy.forward(&input).unwrap();
539 assert_eq!(lazy.parameters().len(), 1);
540 }
541
542 #[test]
543 fn test_lazy_conv2d_subsequent_forward_reuses_inner() {
544 let lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), true).unwrap();
545 let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
546 let input1 = cpu_tensor(&data, &[1, 3, 4, 4]);
547 let out1 = lazy.forward(&input1).unwrap();
548
549 let first_weight_ptr = lazy.parameters()[0].tensor().data().unwrap().as_ptr();
552
553 let input2 = cpu_tensor(&data, &[1, 3, 4, 4]);
554 let out2 = lazy.forward(&input2).unwrap();
555 let second_weight_ptr = lazy.parameters()[0].tensor().data().unwrap().as_ptr();
556 assert_eq!(first_weight_ptr, second_weight_ptr);
557 assert_eq!(out1.shape(), out2.shape());
558 }
559
560 #[test]
561 fn test_lazy_conv2d_rejects_wrong_ndim() {
562 let lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), true).unwrap();
563 let bad = cpu_tensor(&[1.0; 9], &[3, 3]);
564 assert!(lazy.forward(&bad).is_err());
565 }
566
567 #[test]
568 fn test_lazy_conv2d_train_eval_propagates_to_inner() {
569 let mut lazy: LazyConv2d<f32> = LazyConv2d::new(2, (3, 3), (1, 1), (1, 1), true).unwrap();
570 let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
571 let input = cpu_tensor(&data, &[1, 3, 4, 4]);
572 let _ = lazy.forward(&input).unwrap();
573 lazy.eval();
574 assert!(!lazy.is_training());
575 lazy.train();
576 assert!(lazy.is_training());
577 }
578
579 #[test]
584 fn test_lazy_conv3d_uninitialized_until_first_forward() {
585 let lazy: LazyConv3d<f32> =
586 LazyConv3d::new(4, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
587 assert!(!lazy.is_initialized());
588 }
589
590 #[test]
591 fn test_lazy_conv3d_materializes_on_first_forward() {
592 let lazy: LazyConv3d<f32> =
593 LazyConv3d::new(2, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
594 let data: Vec<f32> = (0..128).map(|i| i as f32 / 10.0).collect();
596 let input = cpu_tensor(&data, &[1, 2, 4, 4, 4]);
597 let out = lazy.forward(&input).unwrap();
598 assert_eq!(out.shape()[0], 1);
599 assert_eq!(out.shape()[1], 2);
600 assert!(lazy.is_initialized());
601 }
602
603 #[test]
604 fn test_lazy_conv3d_rejects_wrong_ndim() {
605 let lazy: LazyConv3d<f32> =
606 LazyConv3d::new(2, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
607 let bad = cpu_tensor(&[0.0; 48], &[1, 3, 4, 4]);
608 assert!(lazy.forward(&bad).is_err());
609 }
610
611 #[test]
612 fn test_lazy_conv3d_zero_kernel_errors() {
613 assert!(LazyConv3d::<f32>::new(2, (3, 0, 3), (1, 1, 1), (1, 1, 1), true).is_err());
614 }
615}