Skip to main content

ferrotorch_nn/
lazy_linear.rs

1//! Lazy variants of [`Linear`](super::Linear) and convolution layers.
2//!
3//! Lazy modules defer parameter allocation until the first forward call,
4//! at which point the input tensor's shape is inspected to determine the
5//! missing dimensions (`in_features` for `LazyLinear`, `in_channels` for
6//! `LazyConv*d`). Mirrors `torch.nn.LazyLinear` and friends.
7//!
8//! # Use cases
9//!
10//! Lazy modules are useful when:
11//! - The input feature size is hard to compute by hand (e.g. after a
12//!   variable-length sequence of pooling/strided convs feeds into a
13//!   classifier head — `LazyLinear` lets you write the model definition
14//!   without manually working out the flattened dimension).
15//! - You want to load a state_dict that already has the correct weight
16//!   shape and let the module pick that up at load time.
17//!
18//! # Thread safety
19//!
20//! Initialization uses [`std::sync::OnceLock`] so the first forward call
21//! across any number of threads is safely materialized exactly once. Any
22//! subsequent forward call sees the initialized parameters via a regular
23//! reference (no lock acquisition on the hot path).
24//!
25//! # Limitations
26//!
27//! Once initialized, lazy modules are functionally identical to their
28//! eager counterparts. They cannot be re-initialized with a different
29//! input feature size — you would need to construct a fresh lazy module
30//! for that. CL-445.
31//!
32//! ## REQ status (per `.design/ferrotorch-nn/lazy_linear.md`)
33//!
34//! | REQ | Status | Evidence |
35//! |---|---|---|
36//! | REQ-1 | SHIPPED | impl: `pub struct LazyLinear<T: Float>` with `OnceLock<Parameter<T>>` fields here, mirroring upstream `LazyLinear` with `UninitializedParameter` at `torch/nn/modules/linear.py:18` and the `LazyModuleMixin` protocol at `torch/nn/modules/lazy.py`; non-test consumer: `pub use lazy_linear::LazyLinear` in `lib.rs`. |
37//! | REQ-2 | SHIPPED | impl: the `LazyLinear::new(out_features, bias)` constructor here rejecting `out_features == 0`; non-test consumer: dynamic-shape model construction in `ferrotorch-train`'s learner setup. |
38//! | REQ-3 | SHIPPED | impl: the `is_initialized` / `in_features` / `out_features` accessors here; non-test consumer: dispatch logic in dynamic-shape model setup queries `is_initialized` to decide whether to call the materialize path eagerly. |
39//! | REQ-4 | SHIPPED | impl: the `LazyLinear::materialize` body here (idempotent first-wins allocator); non-test consumer: dynamic-shape model setup calls `materialize(known_in_features)` to populate the param list before constructing the optimizer. |
40//! | REQ-5 | SHIPPED | impl: `<LazyLinear as Module>::forward` body here (materialize-on-first + `linear_fused` dispatch); non-test consumer: any model containing a `LazyLinear` runs this on every forward pass. |
41//! | REQ-6 | SHIPPED | impl: flatten-then-reshape branch inside `<LazyLinear as Module>::forward` here, mirroring `Linear::forward`; non-test consumer: 3-D / 4-D inputs flow through the same path in production transformer / vision usage. |
42//! | REQ-7 | SHIPPED | impl: `Module::parameters` / `parameters_mut` / `named_parameters` building `Vec` from `OnceLock` contents here; non-test consumer: `ferrotorch_optim::Optimizer` walks `model.parameters_mut()` AFTER the first forward (or after explicit materialize), at which point the lazy params surface. |
43//! | REQ-8 | SHIPPED | impl: `kaiming_uniform(&mut w, NonLinearity::ReLU)` and `init_zeros(&mut b)` inside the materialize body here; non-test consumer: every `LazyLinear` instance goes through this code path on first init. |
44//! | REQ-9 | SHIPPED | `OnceLock::set` is documented race-safe by the standard library; the hot post-init path uses lock-free `OnceLock::get`. Verified by `Send + Sync` requirements on `OnceLock<Parameter<T>>` (held by composition of `Send + Sync` field types); non-test consumer: any multi-threaded training scaffolding requiring `Send + Sync`. |
45
46use std::sync::OnceLock;
47use std::sync::atomic::{AtomicBool, Ordering};
48
49use ferrotorch_core::grad_fns::linalg::linear_fused;
50use ferrotorch_core::grad_fns::shape::reshape;
51use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
52
53use crate::init::{NonLinearity, kaiming_uniform, zeros as init_zeros};
54use crate::module::Module;
55use crate::parameter::Parameter;
56
57/// A linear layer that defers `in_features` discovery to the first
58/// forward call.
59///
60/// On the first call to [`forward`](Self::forward), the input's last
61/// dimension is taken as `in_features`, the weight (shape
62/// `[out_features, in_features]`) and optional bias (shape
63/// `[out_features]`) are allocated and initialized identically to
64/// [`Linear`](super::Linear), and stored. Subsequent forward calls
65/// behave exactly like a standard `Linear`.
66///
67/// Mirrors `torch.nn.LazyLinear`.
68#[derive(Debug)]
69pub struct LazyLinear<T: Float> {
70    out_features: usize,
71    bias_enabled: bool,
72    weight: OnceLock<Parameter<T>>,
73    bias: OnceLock<Parameter<T>>,
74    training: AtomicBool,
75}
76
77impl<T: Float> LazyLinear<T> {
78    /// Build a new `LazyLinear` with the given `out_features` and bias flag.
79    /// `in_features` will be discovered from the first forward input.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if `out_features == 0`.
84    pub fn new(out_features: usize, bias: bool) -> FerrotorchResult<Self> {
85        if out_features == 0 {
86            return Err(FerrotorchError::InvalidArgument {
87                message: "LazyLinear: out_features must be > 0".into(),
88            });
89        }
90        Ok(Self {
91            out_features,
92            bias_enabled: bias,
93            weight: OnceLock::new(),
94            bias: OnceLock::new(),
95            training: AtomicBool::new(true),
96        })
97    }
98
99    /// Returns `true` once the parameters have been materialized
100    /// (i.e. after the first successful forward call).
101    pub fn is_initialized(&self) -> bool {
102        self.weight.get().is_some()
103    }
104
105    /// Number of output features. Always known at construction time.
106    pub fn out_features(&self) -> usize {
107        self.out_features
108    }
109
110    /// Number of input features. `None` until the first forward call
111    /// has materialized the weight.
112    pub fn in_features(&self) -> Option<usize> {
113        self.weight.get().map(|w| w.tensor().shape()[1])
114    }
115
116    /// Eagerly materialize the parameters with the given `in_features`.
117    /// Useful when you want the parameters present before any forward
118    /// call (e.g. so they show up in `parameters()` for the optimizer).
119    ///
120    /// Calling this after the parameters are already initialized is
121    /// a no-op (returns Ok). Calling this with a different in_features
122    /// than was previously materialized is also a no-op — the existing
123    /// parameters are kept; the contract is "first one wins".
124    pub fn materialize(&self, in_features: usize) -> FerrotorchResult<()> {
125        if in_features == 0 {
126            return Err(FerrotorchError::InvalidArgument {
127                message: "LazyLinear: in_features must be > 0".into(),
128            });
129        }
130        if self.weight.get().is_none() {
131            let mut w = Parameter::zeros(&[self.out_features, in_features])?;
132            kaiming_uniform(&mut w, NonLinearity::ReLU)?;
133            // set() returns Err if another thread won the race; that's
134            // fine, the other initialization wins and ours is dropped.
135            let _ = self.weight.set(w);
136        }
137        if self.bias_enabled && self.bias.get().is_none() {
138            let mut b = Parameter::zeros(&[self.out_features])?;
139            init_zeros(&mut b)?;
140            let _ = self.bias.set(b);
141        }
142        Ok(())
143    }
144}
145
146impl<T: Float> Module<T> for LazyLinear<T> {
147    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
148        if input.ndim() == 0 {
149            return Err(FerrotorchError::ShapeMismatch {
150                message: "LazyLinear: scalar input not supported".into(),
151            });
152        }
153
154        // Materialize on first call. Subsequent calls hit a fast path.
155        if self.weight.get().is_none() {
156            let last_dim = input.shape()[input.ndim() - 1];
157            self.materialize(last_dim)?;
158        }
159
160        let weight = self
161            .weight
162            .get()
163            .expect("weight should be initialized after materialize()");
164        let in_features = weight.tensor().shape()[1];
165
166        let last_dim = input.shape()[input.ndim() - 1];
167        if last_dim != in_features {
168            return Err(FerrotorchError::ShapeMismatch {
169                message: format!(
170                    "LazyLinear: input has {} features but layer was initialized with {}",
171                    last_dim, in_features
172                ),
173            });
174        }
175
176        // Same logic as Linear::forward — flatten leading dims, fused
177        // linear, reshape back.
178        let input_shape = input.shape().to_vec();
179        let batch_shape = &input_shape[..input_shape.len() - 1];
180        let n: usize = batch_shape.iter().product::<usize>().max(1);
181        let needs_reshape = input.ndim() != 2;
182        let input_2d = if needs_reshape {
183            reshape(input, &[n as isize, in_features as isize])?
184        } else {
185            input.clone()
186        };
187
188        let output_2d = linear_fused(
189            &input_2d,
190            weight.tensor(),
191            self.bias.get().map(|b| b.tensor()),
192        )?;
193
194        if needs_reshape {
195            let mut out_shape: Vec<isize> = batch_shape.iter().map(|&d| d as isize).collect();
196            out_shape.push(self.out_features as isize);
197            reshape(&output_2d, &out_shape)
198        } else {
199            Ok(output_2d)
200        }
201    }
202
203    fn parameters(&self) -> Vec<&Parameter<T>> {
204        let mut params = Vec::new();
205        if let Some(w) = self.weight.get() {
206            params.push(w);
207        }
208        if let Some(b) = self.bias.get() {
209            params.push(b);
210        }
211        params
212    }
213
214    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
215        let mut params = Vec::new();
216        if let Some(w) = self.weight.get_mut() {
217            params.push(w);
218        }
219        if let Some(b) = self.bias.get_mut() {
220            params.push(b);
221        }
222        params
223    }
224
225    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
226        let mut params = Vec::new();
227        if let Some(w) = self.weight.get() {
228            params.push(("weight".to_string(), w));
229        }
230        if let Some(b) = self.bias.get() {
231            params.push(("bias".to_string(), b));
232        }
233        params
234    }
235
236    fn train(&mut self) {
237        self.training.store(true, Ordering::Relaxed);
238    }
239
240    fn eval(&mut self) {
241        self.training.store(false, Ordering::Relaxed);
242    }
243
244    fn is_training(&self) -> bool {
245        self.training.load(Ordering::Relaxed)
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use ferrotorch_core::Tensor;
253    use ferrotorch_core::storage::TensorStorage;
254
255    fn cpu_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
256        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
257    }
258
259    #[test]
260    fn test_lazy_linear_uninitialized_until_first_forward() {
261        let lazy: LazyLinear<f32> = LazyLinear::new(8, true).unwrap();
262        assert!(!lazy.is_initialized());
263        assert_eq!(lazy.in_features(), None);
264        // Empty parameters list pre-init.
265        assert_eq!(lazy.parameters().len(), 0);
266    }
267
268    #[test]
269    fn test_lazy_linear_materializes_on_first_forward() {
270        let lazy: LazyLinear<f32> = LazyLinear::new(4, true).unwrap();
271        let input = cpu_tensor(
272            &[
273                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
274            ],
275            &[2, 6],
276        );
277        let out = lazy.forward(&input).unwrap();
278        assert_eq!(out.shape(), &[2, 4]);
279        assert!(lazy.is_initialized());
280        assert_eq!(lazy.in_features(), Some(6));
281        assert_eq!(lazy.parameters().len(), 2); // weight + bias
282    }
283
284    #[test]
285    fn test_lazy_linear_no_bias_has_one_param() {
286        let lazy: LazyLinear<f32> = LazyLinear::new(3, false).unwrap();
287        let input = cpu_tensor(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
288        let _ = lazy.forward(&input).unwrap();
289        assert_eq!(lazy.parameters().len(), 1);
290        assert!(lazy.bias.get().is_none());
291    }
292
293    #[test]
294    fn test_lazy_linear_subsequent_forward_uses_initialized_weights() {
295        let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
296        let input1 = cpu_tensor(&[1.0, 2.0, 3.0], &[1, 3]);
297        let _ = lazy.forward(&input1).unwrap();
298
299        // Second forward with the same in_features should succeed.
300        let input2 = cpu_tensor(&[4.0, 5.0, 6.0], &[1, 3]);
301        let out2 = lazy.forward(&input2).unwrap();
302        assert_eq!(out2.shape(), &[1, 2]);
303    }
304
305    #[test]
306    fn test_lazy_linear_rejects_mismatched_in_features() {
307        let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
308        let input1 = cpu_tensor(&[1.0, 2.0, 3.0], &[1, 3]);
309        let _ = lazy.forward(&input1).unwrap();
310        // Now in_features is locked to 3.
311        let input_bad = cpu_tensor(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
312        let result = lazy.forward(&input_bad);
313        assert!(result.is_err());
314    }
315
316    #[test]
317    fn test_lazy_linear_explicit_materialize_initializes_eagerly() {
318        let lazy: LazyLinear<f32> = LazyLinear::new(8, true).unwrap();
319        assert!(!lazy.is_initialized());
320        lazy.materialize(16).unwrap();
321        assert!(lazy.is_initialized());
322        assert_eq!(lazy.in_features(), Some(16));
323        // Parameters are now visible to optimizers without a forward call.
324        assert_eq!(lazy.parameters().len(), 2);
325    }
326
327    #[test]
328    fn test_lazy_linear_materialize_idempotent() {
329        let lazy: LazyLinear<f32> = LazyLinear::new(4, false).unwrap();
330        lazy.materialize(8).unwrap();
331        // Second call with same in_features is a no-op.
332        lazy.materialize(8).unwrap();
333        // Even with different in_features, the first one wins -- does
334        // not panic, does not re-initialize.
335        lazy.materialize(16).unwrap();
336        assert_eq!(lazy.in_features(), Some(8));
337    }
338
339    #[test]
340    fn test_lazy_linear_zero_out_features_errors() {
341        let result = LazyLinear::<f32>::new(0, true);
342        assert!(result.is_err());
343    }
344
345    #[test]
346    fn test_lazy_linear_higher_rank_input() {
347        // 3-D input [batch, seq, features] should be handled like Linear.
348        let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
349        let data: Vec<f32> = (0..24).map(|i| i as f32 / 10.0).collect();
350        let input = cpu_tensor(&data, &[2, 4, 3]);
351        let out = lazy.forward(&input).unwrap();
352        assert_eq!(out.shape(), &[2, 4, 2]);
353        assert_eq!(lazy.in_features(), Some(3));
354    }
355
356    #[test]
357    fn test_lazy_linear_named_parameters_after_init() {
358        let lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
359        let input = cpu_tensor(&[1.0, 2.0, 3.0], &[1, 3]);
360        let _ = lazy.forward(&input).unwrap();
361        let names: Vec<String> = lazy
362            .named_parameters()
363            .iter()
364            .map(|(n, _)| n.clone())
365            .collect();
366        assert!(names.contains(&"weight".to_string()));
367        assert!(names.contains(&"bias".to_string()));
368    }
369
370    #[test]
371    fn test_lazy_linear_train_eval_toggle() {
372        let mut lazy: LazyLinear<f32> = LazyLinear::new(2, true).unwrap();
373        assert!(lazy.is_training());
374        lazy.eval();
375        assert!(!lazy.is_training());
376        lazy.train();
377        assert!(lazy.is_training());
378    }
379}