1use std::sync::OnceLock;
47use std::sync::atomic::{AtomicBool, Ordering};
48
49use ferrotorch_core::grad_fns::linalg::linear_fused;
50use ferrotorch_core::grad_fns::shape::reshape;
51use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
52
53use crate::init::{NonLinearity, kaiming_uniform, zeros as init_zeros};
54use crate::module::Module;
55use crate::parameter::Parameter;
56
57#[derive(Debug)]
69pub struct LazyLinear<T: Float> {
70 out_features: usize,
71 bias_enabled: bool,
72 weight: OnceLock<Parameter<T>>,
73 bias: OnceLock<Parameter<T>>,
74 training: AtomicBool,
75}
76
77impl<T: Float> LazyLinear<T> {
78 pub fn new(out_features: usize, bias: bool) -> FerrotorchResult<Self> {
85 if out_features == 0 {
86 return Err(FerrotorchError::InvalidArgument {
87 message: "LazyLinear: out_features must be > 0".into(),
88 });
89 }
90 Ok(Self {
91 out_features,
92 bias_enabled: bias,
93 weight: OnceLock::new(),
94 bias: OnceLock::new(),
95 training: AtomicBool::new(true),
96 })
97 }
98
99 pub fn is_initialized(&self) -> bool {
102 self.weight.get().is_some()
103 }
104
105 pub fn out_features(&self) -> usize {
107 self.out_features
108 }
109
110 pub fn in_features(&self) -> Option<usize> {
113 self.weight.get().map(|w| w.tensor().shape()[1])
114 }
115
116 pub fn materialize(&self, in_features: usize) -> FerrotorchResult<()> {
125 if in_features == 0 {
126 return Err(FerrotorchError::InvalidArgument {
127 message: "LazyLinear: in_features must be > 0".into(),
128 });
129 }
130 if self.weight.get().is_none() {
131 let mut w = Parameter::zeros(&[self.out_features, in_features])?;
132 kaiming_uniform(&mut w, NonLinearity::ReLU)?;
133 let _ = self.weight.set(w);
136 }
137 if self.bias_enabled && self.bias.get().is_none() {
138 let mut b = Parameter::zeros(&[self.out_features])?;
139 init_zeros(&mut b)?;
140 let _ = self.bias.set(b);
141 }
142 Ok(())
143 }
144}
145
146impl<T: Float> Module<T> for LazyLinear<T> {
147 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
148 if input.ndim() == 0 {
149 return Err(FerrotorchError::ShapeMismatch {
150 message: "LazyLinear: scalar input not supported".into(),
151 });
152 }
153
154 if self.weight.get().is_none() {
156 let last_dim = input.shape()[input.ndim() - 1];
157 self.materialize(last_dim)?;
158 }
159
160 let weight = self
161 .weight
162 .get()
163 .expect("weight should be initialized after materialize()");
164 let in_features = weight.tensor().shape()[1];
165
166 let last_dim = input.shape()[input.ndim() - 1];
167 if last_dim != in_features {
168 return Err(FerrotorchError::ShapeMismatch {
169 message: format!(
170 "LazyLinear: input has {} features but layer was initialized with {}",
171 last_dim, in_features
172 ),
173 });
174 }
175
176 let input_shape = input.shape().to_vec();
179 let batch_shape = &input_shape[..input_shape.len() - 1];
180 let n: usize = batch_shape.iter().product::<usize>().max(1);
181 let needs_reshape = input.ndim() != 2;
182 let input_2d = if needs_reshape {
183 reshape(input, &[n as isize, in_features as isize])?
184 } else {
185 input.clone()
186 };
187
188 let output_2d = linear_fused(
189 &input_2d,
190 weight.tensor(),
191 self.bias.get().map(|b| b.tensor()),
192 )?;
193
194 if needs_reshape {
195 let mut out_shape: Vec<isize> = batch_shape.iter().map(|&d| d as isize).collect();
196 out_shape.push(self.out_features as isize);
197 reshape(&output_2d, &out_shape)
198 } else {
199 Ok(output_2d)
200 }
201 }
202
203 fn parameters(&self) -> Vec<&Parameter<T>> {
204 let mut params = Vec::new();
205 if let Some(w) = self.weight.get() {
206 params.push(w);
207 }
208 if let Some(b) = self.bias.get() {
209 params.push(b);
210 }
211 params
212 }
213
214 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
215 let mut params = Vec::new();
216 if let Some(w) = self.weight.get_mut() {
217 params.push(w);
218 }
219 if let Some(b) = self.bias.get_mut() {
220 params.push(b);
221 }
222 params
223 }
224
225 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
226 let mut params = Vec::new();
227 if let Some(w) = self.weight.get() {
228 params.push(("weight".to_string(), w));
229 }
230 if let Some(b) = self.bias.get() {
231 params.push(("bias".to_string(), b));
232 }
233 params
234 }
235
236 fn train(&mut self) {
237 self.training.store(true, Ordering::Relaxed);
238 }
239
240 fn eval(&mut self) {
241 self.training.store(false, Ordering::Relaxed);
242 }
243
244 fn is_training(&self) -> bool {
245 self.training.load(Ordering::Relaxed)
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use ferrotorch_core::Tensor;
253 use ferrotorch_core::storage::TensorStorage;
254
255 fn cpu_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
256 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
257 }
258
259 #[test]
260 fn test_lazy_linear_uninitialized_until_first_forward() {
261 let lazy: LazyLinear<f32> = LazyLinear::new(8, true).unwrap();
262 assert!(!lazy.is_initialized());
263 assert_eq!(lazy.in_features(), None);
264 assert_eq!(lazy.parameters().len(), 0);
266 }
267
268 #[test]
269 fn test_lazy_linear_materializes_on_first_forward() {
270 let lazy: LazyLinear<f32> = LazyLinear::new(4, true).unwrap();
271 let input = cpu_tensor(
272 &[
273 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
274 ],
275 &[2, 6],
276 );
277 let out = lazy.forward(&input).unwrap();
278 assert_eq!(out.shape(), &[2, 4]);
279 assert!(lazy.is_initialized());
280 assert_eq!(lazy.in_features(), Some(6));
281 assert_eq!(lazy.parameters().len(), 2); }
283
284 #[test]
285 fn test_lazy_linear_no_bias_has_one_param() {
286 let lazy: LazyLinear<f32> = LazyLinear::new(3, false).unwrap();
287 let input = cpu_tensor(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
288 let _ = lazy.forward(&input).unwrap();
289 assert_eq!(lazy.parameters().len(), 1);
290 assert!(lazy.bias.get().is_none());
291 }
292
293 #[test]
294 fn test_lazy_linear_subsequent_forward_uses_initialized_weights() {
295 let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
296 let input1 = cpu_tensor(&[1.0, 2.0, 3.0], &[1, 3]);
297 let _ = lazy.forward(&input1).unwrap();
298
299 let input2 = cpu_tensor(&[4.0, 5.0, 6.0], &[1, 3]);
301 let out2 = lazy.forward(&input2).unwrap();
302 assert_eq!(out2.shape(), &[1, 2]);
303 }
304
305 #[test]
306 fn test_lazy_linear_rejects_mismatched_in_features() {
307 let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
308 let input1 = cpu_tensor(&[1.0, 2.0, 3.0], &[1, 3]);
309 let _ = lazy.forward(&input1).unwrap();
310 let input_bad = cpu_tensor(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
312 let result = lazy.forward(&input_bad);
313 assert!(result.is_err());
314 }
315
316 #[test]
317 fn test_lazy_linear_explicit_materialize_initializes_eagerly() {
318 let lazy: LazyLinear<f32> = LazyLinear::new(8, true).unwrap();
319 assert!(!lazy.is_initialized());
320 lazy.materialize(16).unwrap();
321 assert!(lazy.is_initialized());
322 assert_eq!(lazy.in_features(), Some(16));
323 assert_eq!(lazy.parameters().len(), 2);
325 }
326
327 #[test]
328 fn test_lazy_linear_materialize_idempotent() {
329 let lazy: LazyLinear<f32> = LazyLinear::new(4, false).unwrap();
330 lazy.materialize(8).unwrap();
331 lazy.materialize(8).unwrap();
333 lazy.materialize(16).unwrap();
336 assert_eq!(lazy.in_features(), Some(8));
337 }
338
339 #[test]
340 fn test_lazy_linear_zero_out_features_errors() {
341 let result = LazyLinear::<f32>::new(0, true);
342 assert!(result.is_err());
343 }
344
345 #[test]
346 fn test_lazy_linear_higher_rank_input() {
347 let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
349 let data: Vec<f32> = (0..24).map(|i| i as f32 / 10.0).collect();
350 let input = cpu_tensor(&data, &[2, 4, 3]);
351 let out = lazy.forward(&input).unwrap();
352 assert_eq!(out.shape(), &[2, 4, 2]);
353 assert_eq!(lazy.in_features(), Some(3));
354 }
355
356 #[test]
357 fn test_lazy_linear_named_parameters_after_init() {
358 let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
359 let input = cpu_tensor(&[1.0, 2.0, 3.0], &[1, 3]);
360 let _ = lazy.forward(&input).unwrap();
361 let names: Vec<String> = lazy
362 .named_parameters()
363 .iter()
364 .map(|(n, _)| n.clone())
365 .collect();
366 assert!(names.contains(&"weight".to_string()));
367 assert!(names.contains(&"bias".to_string()));
368 }
369
370 #[test]
371 fn test_lazy_linear_train_eval_toggle() {
372 let mut lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
373 assert!(lazy.is_training());
374 lazy.eval();
375 assert!(!lazy.is_training());
376 lazy.train();
377 assert!(lazy.is_training());
378 }
379}