Skip to main content

ferrotorch_nn/
container.rs

1//! Container modules: [`Sequential`], [`ModuleList`], and [`ModuleDict`].
2//!
3//! These mirror PyTorch's `nn.Sequential`, `nn.ModuleList`, and
4//! `nn.ModuleDict`. They hold sub-modules and propagate `parameters()`,
5//! `train()`/`eval()`, and `state_dict()` to all children.
6//!
7//! ## REQ status (per `.design/ferrotorch-nn/container.md`)
8//!
9//! | REQ | Status | Evidence |
10//! |---|---|---|
11//! | REQ-1 | SHIPPED | `pub struct Sequential<T: Float>` with `layers: Vec<Box<dyn Module<T>>>` mirrors `torch/nn/modules/container.py:59-333`; consumed by `pub use container::{ModuleDict, ModuleList, Sequential}` at `lib.rs:195` and downstream MLP/CNN composition code. |
12//! | REQ-2 | SHIPPED | `Sequential::new(layers)` constructor mirrors `torch/nn/modules/container.py:108-122`; consumed by every downstream model composition site. |
13//! | REQ-3 | SHIPPED | `push`, `len`, `is_empty` inherent methods mirror `nn.Sequential.append` (`container.py:256-275`); consumed by builder-pattern model construction in downstream crates. |
14//! | REQ-4 | SHIPPED | `impl Module<T> for Sequential<T>` with chained forward, flat-mapped parameter iteration, indexed `named_parameters` mirrors `container.py:117-122, 248-254`; consumed by `optimizer.parameters()` flow and `ferrotorch-nn/src/transformer.rs` composing layers through `Box<dyn Module<T>>`. |
15//! | REQ-5 | SHIPPED | `pub struct ModuleList<T: Float>` with `Vec<Box<dyn Module<T>>>` mirrors `container.py:335-502`; consumed via the `lib.rs:195` re-export; downstream MoE-style heads use it. |
16//! | REQ-6 | SHIPPED | `::new`, `::empty`, `::get`, `::get_mut`, `::push`, `::len`, `::is_empty` mirror `container.py:361-500`; consumed by downstream multi-branch model code that iterates over a list of modules. |
17//! | REQ-7 | SHIPPED | `impl Module<T> for ModuleList<T>` with `forward → InvalidArgument` (matches upstream's `_forward_unimplemented` fallback at `container.py:502`); consumed by `optimizer.parameters()` reading the container's parameter iteration. |
18//! | REQ-8 | SHIPPED | `pub struct ModuleDict<T: Float>` with insertion-ordered `Vec<(String, Box<dyn Module<T>>)>` mirrors `container.py:505-700`'s documented order guarantee; consumed via `lib.rs:195` re-export by downstream encoder/decoder architectures. |
19//! | REQ-9 | SHIPPED | `::new`, `::insert` (in-place replace), `::get`, `::get_mut`, `::keys`, `::len`, `::is_empty` mirror `container.py:548-700`; consumed by dynamic-dispatch model heads in downstream code. |
20//! | REQ-10 | SHIPPED | `impl Module<T> for ModuleDict<T>` with `forward → InvalidArgument`, key-prefixed `named_parameters`, train/eval propagation; consumed by `optimizer.parameters()` and state-dict round-trips. |
21//! | REQ-11 | SHIPPED | `impl Default for ModuleDict<T>` returning empty dict; consumed by parent modules that derive `Default` over an embedded `ModuleDict` field. |
22//! | REQ-12 | SHIPPED | `Display` impls for all three containers render `(i): <module>` lines mirroring `container.py:34-44` `_addindent`; consumed by training driver logging that prints `{model}`. |
23
24use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
25
26use crate::module::Module;
27use crate::parameter::Parameter;
28
29// ===========================================================================
30// Sequential
31// ===========================================================================
32
33/// A sequential container that chains modules in order.
34///
35/// The `forward()` method feeds the output of each layer as the input to the
36/// next, matching PyTorch's `nn.Sequential` semantics.
37///
38/// # Named parameters
39///
40/// Parameters are prefixed by layer index: `"0.weight"`, `"0.bias"`,
41/// `"1.weight"`, etc. — matching PyTorch's convention.
42///
43/// # Examples
44///
45/// ```ignore
46/// let model = Sequential::new(vec![
47///     Box::new(Linear::<f32>::new(784, 256, true)?),
48///     Box::new(ReLU::new()),
49///     Box::new(Linear::<f32>::new(256, 10, true)?),
50/// ]);
51/// let output = model.forward(&input)?;
52/// ```
53pub struct Sequential<T: Float> {
54    layers: Vec<Box<dyn Module<T>>>,
55    training: bool,
56}
57
58impl<T: Float> Sequential<T> {
59    /// Create a new sequential container from an ordered list of modules.
60    pub fn new(layers: Vec<Box<dyn Module<T>>>) -> Self {
61        Self {
62            layers,
63            training: true,
64        }
65    }
66
67    /// Append a module to the end of the sequence.
68    pub fn push(&mut self, layer: Box<dyn Module<T>>) {
69        self.layers.push(layer);
70    }
71
72    /// Number of layers.
73    #[inline]
74    pub fn len(&self) -> usize {
75        self.layers.len()
76    }
77
78    /// Whether the container is empty.
79    #[inline]
80    pub fn is_empty(&self) -> bool {
81        self.layers.is_empty()
82    }
83}
84
85impl<T: Float> Module<T> for Sequential<T> {
86    /// Forward pass: chains each layer's forward in order.
87    ///
88    /// Returns an error if there are no layers, or if any layer's forward
89    /// fails.
90    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
91        if self.layers.is_empty() {
92            return Err(FerrotorchError::InvalidArgument {
93                message: "Sequential: cannot forward through empty container".into(),
94            });
95        }
96
97        let mut output = self.layers[0].forward(input)?;
98        for layer in &self.layers[1..] {
99            output = layer.forward(&output)?;
100        }
101        Ok(output)
102    }
103
104    fn parameters(&self) -> Vec<&Parameter<T>> {
105        self.layers
106            .iter()
107            .flat_map(|layer| layer.parameters())
108            .collect()
109    }
110
111    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
112        self.layers
113            .iter_mut()
114            .flat_map(|layer| layer.parameters_mut())
115            .collect()
116    }
117
118    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
119        self.layers
120            .iter()
121            .enumerate()
122            .flat_map(|(i, layer)| {
123                layer
124                    .named_parameters()
125                    .into_iter()
126                    .map(move |(name, param)| (format!("{i}.{name}"), param))
127            })
128            .collect()
129    }
130
131    fn train(&mut self) {
132        self.training = true;
133        for layer in &mut self.layers {
134            layer.train();
135        }
136    }
137
138    fn eval(&mut self) {
139        self.training = false;
140        for layer in &mut self.layers {
141            layer.eval();
142        }
143    }
144
145    fn is_training(&self) -> bool {
146        self.training
147    }
148}
149
150impl<T: Float> std::fmt::Display for Sequential<T> {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        writeln!(f, "Sequential(")?;
153        for (i, _layer) in self.layers.iter().enumerate() {
154            writeln!(f, "  ({i}): <module>")?;
155        }
156        write!(f, ")")
157    }
158}
159
160// ===========================================================================
161// ModuleList
162// ===========================================================================
163
164/// An ordered list of modules, registered for parameter tracking.
165///
166/// Unlike [`Sequential`], `ModuleList` does **not** define a forward pass.
167/// Users iterate over the list manually and call each module's `forward()`
168/// as needed. This mirrors PyTorch's `nn.ModuleList`.
169///
170/// # Named parameters
171///
172/// Parameters are prefixed by list index: `"0.weight"`, `"1.weight"`, etc.
173///
174/// # Examples
175///
176/// ```ignore
177/// # use ferrotorch_core::FerrotorchError;
178/// fn example(input: &Tensor<f32>) -> Result<Tensor<f32>, FerrotorchError> {
179///     let list = ModuleList::<f32>::new(vec![
180///         Box::new(Linear::<f32>::new(10, 10, true)?),
181///         Box::new(Linear::<f32>::new(10, 10, true)?),
182///     ]);
183///
184///     let mut x = input.clone();
185///     for i in 0..list.len() {
186///         let module = list.get(i).ok_or_else(|| FerrotorchError::InvalidArgument {
187///             message: format!("ModuleList index {i} out of bounds"),
188///         })?;
189///         x = module.forward(&x)?;
190///     }
191///     Ok(x)
192/// }
193/// ```
194pub struct ModuleList<T: Float> {
195    modules: Vec<Box<dyn Module<T>>>,
196    training: bool,
197}
198
199impl<T: Float> ModuleList<T> {
200    /// Create a new module list from a vector of modules.
201    pub fn new(modules: Vec<Box<dyn Module<T>>>) -> Self {
202        Self {
203            modules,
204            training: true,
205        }
206    }
207
208    /// Create an empty module list.
209    pub fn empty() -> Self {
210        Self {
211            modules: Vec::new(),
212            training: true,
213        }
214    }
215
216    /// Get a reference to the module at the given index.
217    pub fn get(&self, index: usize) -> Option<&dyn Module<T>> {
218        self.modules.get(index).map(|m| m.as_ref())
219    }
220
221    /// Get a mutable reference to the module at the given index.
222    pub fn get_mut(&mut self, index: usize) -> Option<&mut dyn Module<T>> {
223        match self.modules.get_mut(index) {
224            Some(m) => Some(m.as_mut()),
225            None => None,
226        }
227    }
228
229    /// Append a module to the end of the list.
230    pub fn push(&mut self, module: Box<dyn Module<T>>) {
231        self.modules.push(module);
232    }
233
234    /// Number of modules.
235    #[inline]
236    pub fn len(&self) -> usize {
237        self.modules.len()
238    }
239
240    /// Whether the list is empty.
241    #[inline]
242    pub fn is_empty(&self) -> bool {
243        self.modules.is_empty()
244    }
245}
246
247impl<T: Float> Module<T> for ModuleList<T> {
248    /// ModuleList does not implement forward.
249    ///
250    /// Users should iterate manually and call each module's `forward()`.
251    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
252        Err(FerrotorchError::InvalidArgument {
253            message: "ModuleList does not implement forward. \
254                      Iterate over the list and call each module's forward() manually."
255                .into(),
256        })
257    }
258
259    fn parameters(&self) -> Vec<&Parameter<T>> {
260        self.modules.iter().flat_map(|m| m.parameters()).collect()
261    }
262
263    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
264        self.modules
265            .iter_mut()
266            .flat_map(|m| m.parameters_mut())
267            .collect()
268    }
269
270    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
271        self.modules
272            .iter()
273            .enumerate()
274            .flat_map(|(i, m)| {
275                m.named_parameters()
276                    .into_iter()
277                    .map(move |(name, param)| (format!("{i}.{name}"), param))
278            })
279            .collect()
280    }
281
282    fn train(&mut self) {
283        self.training = true;
284        for m in &mut self.modules {
285            m.train();
286        }
287    }
288
289    fn eval(&mut self) {
290        self.training = false;
291        for m in &mut self.modules {
292            m.eval();
293        }
294    }
295
296    fn is_training(&self) -> bool {
297        self.training
298    }
299}
300
301impl<T: Float> std::fmt::Display for ModuleList<T> {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        writeln!(f, "ModuleList(")?;
304        for (i, _m) in self.modules.iter().enumerate() {
305            writeln!(f, "  ({i}): <module>")?;
306        }
307        write!(f, ")")
308    }
309}
310
311// ===========================================================================
312// ModuleDict
313// ===========================================================================
314
315/// An ordered dictionary of named modules, registered for parameter tracking.
316///
317/// Uses a `Vec<(String, Box<dyn Module<T>>)>` internally to preserve
318/// insertion order without requiring an external dependency like `IndexMap`.
319///
320/// Like [`ModuleList`], `ModuleDict` does **not** define a forward pass.
321/// Users look up modules by key and call `forward()` manually. This mirrors
322/// PyTorch's `nn.ModuleDict`.
323///
324/// # Named parameters
325///
326/// Parameters are prefixed by their dictionary key: `"encoder.weight"`,
327/// `"decoder.weight"`, etc.
328///
329/// # Examples
330///
331/// ```ignore
332/// # use ferrotorch_core::FerrotorchError;
333/// fn example(input: &Tensor<f32>) -> Result<Tensor<f32>, FerrotorchError> {
334///     let mut dict = ModuleDict::<f32>::new();
335///     dict.insert("encoder", Box::new(Linear::<f32>::new(784, 256, true)?));
336///     dict.insert("decoder", Box::new(Linear::<f32>::new(256, 784, true)?));
337///
338///     let encoder = dict.get("encoder").ok_or_else(|| FerrotorchError::InvalidArgument {
339///         message: "missing 'encoder' module".into(),
340///     })?;
341///     let decoder = dict.get("decoder").ok_or_else(|| FerrotorchError::InvalidArgument {
342///         message: "missing 'decoder' module".into(),
343///     })?;
344///     let encoded = encoder.forward(input)?;
345///     decoder.forward(&encoded)
346/// }
347/// ```
348pub struct ModuleDict<T: Float> {
349    entries: Vec<(String, Box<dyn Module<T>>)>,
350    training: bool,
351}
352
353impl<T: Float> ModuleDict<T> {
354    /// Create an empty module dict.
355    pub fn new() -> Self {
356        Self {
357            entries: Vec::new(),
358            training: true,
359        }
360    }
361
362    /// Insert a module with the given key.
363    ///
364    /// If a module with the same key already exists, it is replaced
365    /// (preserving insertion position).
366    pub fn insert(&mut self, key: impl Into<String>, module: Box<dyn Module<T>>) {
367        let key = key.into();
368        // Replace existing entry if key already exists.
369        for entry in &mut self.entries {
370            if entry.0 == key {
371                entry.1 = module;
372                return;
373            }
374        }
375        self.entries.push((key, module));
376    }
377
378    /// Get a reference to the module with the given key.
379    pub fn get(&self, key: &str) -> Option<&dyn Module<T>> {
380        self.entries
381            .iter()
382            .find(|(k, _)| k == key)
383            .map(|(_, m)| m.as_ref())
384    }
385
386    /// Get a mutable reference to the module with the given key.
387    pub fn get_mut(&mut self, key: &str) -> Option<&mut dyn Module<T>> {
388        for (k, m) in &mut self.entries {
389            if k == key {
390                return Some(m.as_mut());
391            }
392        }
393        None
394    }
395
396    /// Return the keys in insertion order.
397    pub fn keys(&self) -> Vec<&str> {
398        self.entries.iter().map(|(k, _)| k.as_str()).collect()
399    }
400
401    /// Number of entries.
402    #[inline]
403    pub fn len(&self) -> usize {
404        self.entries.len()
405    }
406
407    /// Whether the dict is empty.
408    #[inline]
409    pub fn is_empty(&self) -> bool {
410        self.entries.is_empty()
411    }
412}
413
414impl<T: Float> Default for ModuleDict<T> {
415    fn default() -> Self {
416        Self::new()
417    }
418}
419
420impl<T: Float> Module<T> for ModuleDict<T> {
421    /// ModuleDict does not implement forward.
422    ///
423    /// Users should look up modules by key and call `forward()` manually.
424    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
425        Err(FerrotorchError::InvalidArgument {
426            message: "ModuleDict does not implement forward. \
427                      Look up modules by key and call forward() manually."
428                .into(),
429        })
430    }
431
432    fn parameters(&self) -> Vec<&Parameter<T>> {
433        self.entries
434            .iter()
435            .flat_map(|(_, m)| m.parameters())
436            .collect()
437    }
438
439    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
440        self.entries
441            .iter_mut()
442            .flat_map(|(_, m)| m.parameters_mut())
443            .collect()
444    }
445
446    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
447        self.entries
448            .iter()
449            .flat_map(|(key, m)| {
450                m.named_parameters()
451                    .into_iter()
452                    .map(move |(name, param)| (format!("{key}.{name}"), param))
453            })
454            .collect()
455    }
456
457    fn train(&mut self) {
458        self.training = true;
459        for (_, m) in &mut self.entries {
460            m.train();
461        }
462    }
463
464    fn eval(&mut self) {
465        self.training = false;
466        for (_, m) in &mut self.entries {
467            m.eval();
468        }
469    }
470
471    fn is_training(&self) -> bool {
472        self.training
473    }
474}
475
476impl<T: Float> std::fmt::Display for ModuleDict<T> {
477    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
478        writeln!(f, "ModuleDict(")?;
479        for (key, _m) in &self.entries {
480            writeln!(f, "  ({key}): <module>")?;
481        }
482        write!(f, ")")
483    }
484}
485
486// ===========================================================================
487// Tests
488// ===========================================================================
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    // -----------------------------------------------------------------------
495    // Test helper: a simple identity module with one parameter.
496    // -----------------------------------------------------------------------
497
498    struct IdentityWithParam<T: Float> {
499        weight: Parameter<T>,
500        training: bool,
501    }
502
503    impl<T: Float> IdentityWithParam<T> {
504        fn new(size: usize) -> FerrotorchResult<Self> {
505            Ok(Self {
506                weight: Parameter::zeros(&[size])?,
507                training: true,
508            })
509        }
510    }
511
512    impl<T: Float> Module<T> for IdentityWithParam<T> {
513        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
514            Ok(input.clone())
515        }
516
517        fn parameters(&self) -> Vec<&Parameter<T>> {
518            vec![&self.weight]
519        }
520
521        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
522            vec![&mut self.weight]
523        }
524
525        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
526            vec![("weight".to_string(), &self.weight)]
527        }
528
529        fn train(&mut self) {
530            self.training = true;
531        }
532
533        fn eval(&mut self) {
534            self.training = false;
535        }
536
537        fn is_training(&self) -> bool {
538            self.training
539        }
540    }
541
542    // -----------------------------------------------------------------------
543    // Sequential tests
544    // -----------------------------------------------------------------------
545
546    #[test]
547    fn test_sequential_forward_chains_layers() {
548        // 3 identity layers — output should equal input.
549        let seq = Sequential::<f32>::new(vec![
550            Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
551            Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
552            Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
553        ]);
554
555        let input = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
556        let output = seq.forward(&input).unwrap();
557        assert_eq!(output.shape(), &[2, 4]);
558    }
559
560    #[test]
561    fn test_sequential_empty_forward_errors() {
562        let seq = Sequential::<f32>::new(vec![]);
563        let input = ferrotorch_core::zeros::<f32>(&[1, 4]).unwrap();
564        assert!(seq.forward(&input).is_err());
565    }
566
567    #[test]
568    fn test_sequential_parameter_count() {
569        let seq = Sequential::<f32>::new(vec![
570            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
571            Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
572            Box::new(IdentityWithParam::<f32>::new(7).unwrap()),
573        ]);
574
575        let params = seq.parameters();
576        assert_eq!(params.len(), 3);
577
578        let total: usize = params.iter().map(|p| p.numel()).sum();
579        assert_eq!(total, 3 + 5 + 7);
580    }
581
582    #[test]
583    fn test_sequential_named_parameters_keys() {
584        let seq = Sequential::<f32>::new(vec![
585            Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
586            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
587            Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
588        ]);
589
590        let named = seq.named_parameters();
591        let keys: Vec<&str> = named.iter().map(|(k, _)| k.as_str()).collect();
592        assert_eq!(keys, &["0.weight", "1.weight", "2.weight"]);
593    }
594
595    #[test]
596    fn test_sequential_train_eval_propagation() {
597        let mut seq = Sequential::<f32>::new(vec![
598            Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
599            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
600        ]);
601
602        assert!(seq.is_training());
603
604        seq.eval();
605        assert!(!seq.is_training());
606        // Sub-modules should also be in eval mode.
607        for layer in &seq.layers {
608            assert!(!layer.is_training());
609        }
610
611        seq.train();
612        assert!(seq.is_training());
613        for layer in &seq.layers {
614            assert!(layer.is_training());
615        }
616    }
617
618    #[test]
619    fn test_sequential_state_dict_roundtrip() {
620        let seq = Sequential::<f32>::new(vec![
621            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
622            Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
623        ]);
624
625        let sd = seq.state_dict();
626        assert!(sd.contains_key("0.weight"));
627        assert!(sd.contains_key("1.weight"));
628        assert_eq!(sd["0.weight"].shape(), &[3]);
629        assert_eq!(sd["1.weight"].shape(), &[5]);
630
631        // Load into a new Sequential with the same architecture.
632        let mut seq2 = Sequential::<f32>::new(vec![
633            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
634            Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
635        ]);
636        seq2.load_state_dict(&sd, true).unwrap();
637
638        let sd2 = seq2.state_dict();
639        assert_eq!(
640            sd["0.weight"].data().unwrap(),
641            sd2["0.weight"].data().unwrap()
642        );
643        assert_eq!(
644            sd["1.weight"].data().unwrap(),
645            sd2["1.weight"].data().unwrap()
646        );
647    }
648
649    #[test]
650    fn test_sequential_push() {
651        let mut seq = Sequential::<f32>::new(vec![]);
652        assert!(seq.is_empty());
653        assert_eq!(seq.len(), 0);
654
655        seq.push(Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
656        assert_eq!(seq.len(), 1);
657        assert!(!seq.is_empty());
658    }
659
660    // -----------------------------------------------------------------------
661    // ModuleList tests
662    // -----------------------------------------------------------------------
663
664    #[test]
665    fn test_module_list_forward_errors() {
666        let list =
667            ModuleList::<f32>::new(vec![Box::new(IdentityWithParam::<f32>::new(4).unwrap())]);
668        let input = ferrotorch_core::zeros::<f32>(&[1, 4]).unwrap();
669        assert!(list.forward(&input).is_err());
670    }
671
672    #[test]
673    fn test_module_list_get() {
674        let list = ModuleList::<f32>::new(vec![
675            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
676            Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
677        ]);
678
679        assert!(list.get(0).is_some());
680        assert!(list.get(1).is_some());
681        assert!(list.get(2).is_none());
682    }
683
684    #[test]
685    fn test_module_list_get_mut() {
686        let mut list =
687            ModuleList::<f32>::new(vec![Box::new(IdentityWithParam::<f32>::new(3).unwrap())]);
688
689        let m = list.get_mut(0).unwrap();
690        m.eval();
691        assert!(!list.get(0).unwrap().is_training());
692    }
693
694    #[test]
695    fn test_module_list_push() {
696        let mut list = ModuleList::<f32>::empty();
697        assert_eq!(list.len(), 0);
698        assert!(list.is_empty());
699
700        list.push(Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
701        assert_eq!(list.len(), 1);
702        assert!(!list.is_empty());
703    }
704
705    #[test]
706    fn test_module_list_parameters() {
707        let list = ModuleList::<f32>::new(vec![
708            Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
709            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
710        ]);
711
712        assert_eq!(list.parameters().len(), 2);
713
714        let named = list.named_parameters();
715        let keys: Vec<&str> = named.iter().map(|(k, _)| k.as_str()).collect();
716        assert_eq!(keys, &["0.weight", "1.weight"]);
717    }
718
719    #[test]
720    fn test_module_list_train_eval() {
721        let mut list = ModuleList::<f32>::new(vec![
722            Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
723            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
724        ]);
725
726        list.eval();
727        assert!(!list.is_training());
728        assert!(!list.get(0).unwrap().is_training());
729        assert!(!list.get(1).unwrap().is_training());
730
731        list.train();
732        assert!(list.is_training());
733        assert!(list.get(0).unwrap().is_training());
734        assert!(list.get(1).unwrap().is_training());
735    }
736
737    // -----------------------------------------------------------------------
738    // ModuleDict tests
739    // -----------------------------------------------------------------------
740
741    #[test]
742    fn test_module_dict_forward_errors() {
743        let mut dict = ModuleDict::<f32>::new();
744        dict.insert("enc", Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
745        let input = ferrotorch_core::zeros::<f32>(&[1, 4]).unwrap();
746        assert!(dict.forward(&input).is_err());
747    }
748
749    #[test]
750    fn test_module_dict_insert_get() {
751        let mut dict = ModuleDict::<f32>::new();
752        dict.insert(
753            "encoder",
754            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
755        );
756        dict.insert(
757            "decoder",
758            Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
759        );
760
761        assert!(dict.get("encoder").is_some());
762        assert!(dict.get("decoder").is_some());
763        assert!(dict.get("missing").is_none());
764        assert_eq!(dict.len(), 2);
765    }
766
767    #[test]
768    fn test_module_dict_insert_replaces() {
769        let mut dict = ModuleDict::<f32>::new();
770        dict.insert("layer", Box::new(IdentityWithParam::<f32>::new(3).unwrap()));
771        dict.insert("layer", Box::new(IdentityWithParam::<f32>::new(7).unwrap()));
772
773        // Should still have only 1 entry, with the new parameter size.
774        assert_eq!(dict.len(), 1);
775        let named = dict.named_parameters();
776        assert_eq!(named.len(), 1);
777        assert_eq!(named[0].1.shape(), &[7]);
778    }
779
780    #[test]
781    fn test_module_dict_keys_insertion_order() {
782        let mut dict = ModuleDict::<f32>::new();
783        dict.insert(
784            "c_layer",
785            Box::new(IdentityWithParam::<f32>::new(1).unwrap()),
786        );
787        dict.insert(
788            "a_layer",
789            Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
790        );
791        dict.insert(
792            "b_layer",
793            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
794        );
795
796        assert_eq!(dict.keys(), &["c_layer", "a_layer", "b_layer"]);
797    }
798
799    #[test]
800    fn test_module_dict_get_mut() {
801        let mut dict = ModuleDict::<f32>::new();
802        dict.insert("layer", Box::new(IdentityWithParam::<f32>::new(3).unwrap()));
803
804        let m = dict.get_mut("layer").unwrap();
805        m.eval();
806        assert!(!dict.get("layer").unwrap().is_training());
807    }
808
809    #[test]
810    fn test_module_dict_named_parameters_prefixed_by_key() {
811        let mut dict = ModuleDict::<f32>::new();
812        dict.insert(
813            "encoder",
814            Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
815        );
816        dict.insert(
817            "decoder",
818            Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
819        );
820
821        let named = dict.named_parameters();
822        let keys: Vec<&str> = named.iter().map(|(k, _)| k.as_str()).collect();
823        assert_eq!(keys, &["encoder.weight", "decoder.weight"]);
824    }
825
826    #[test]
827    fn test_module_dict_train_eval() {
828        let mut dict = ModuleDict::<f32>::new();
829        dict.insert("a", Box::new(IdentityWithParam::<f32>::new(2).unwrap()));
830        dict.insert("b", Box::new(IdentityWithParam::<f32>::new(3).unwrap()));
831
832        dict.eval();
833        assert!(!dict.is_training());
834        assert!(!dict.get("a").unwrap().is_training());
835        assert!(!dict.get("b").unwrap().is_training());
836
837        dict.train();
838        assert!(dict.is_training());
839        assert!(dict.get("a").unwrap().is_training());
840        assert!(dict.get("b").unwrap().is_training());
841    }
842
843    #[test]
844    fn test_module_dict_default() {
845        let dict = ModuleDict::<f32>::default();
846        assert!(dict.is_empty());
847        assert_eq!(dict.len(), 0);
848    }
849
850    #[test]
851    fn test_module_dict_state_dict_roundtrip() {
852        let mut dict = ModuleDict::<f32>::new();
853        dict.insert("enc", Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
854        dict.insert("dec", Box::new(IdentityWithParam::<f32>::new(6).unwrap()));
855
856        let sd = dict.state_dict();
857        assert!(sd.contains_key("enc.weight"));
858        assert!(sd.contains_key("dec.weight"));
859
860        let mut dict2 = ModuleDict::<f32>::new();
861        dict2.insert("enc", Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
862        dict2.insert("dec", Box::new(IdentityWithParam::<f32>::new(6).unwrap()));
863        dict2.load_state_dict(&sd, true).unwrap();
864
865        let sd2 = dict2.state_dict();
866        assert_eq!(
867            sd["enc.weight"].data().unwrap(),
868            sd2["enc.weight"].data().unwrap()
869        );
870        assert_eq!(
871            sd["dec.weight"].data().unwrap(),
872            sd2["dec.weight"].data().unwrap()
873        );
874    }
875
876    // -----------------------------------------------------------------------
877    // Send + Sync
878    // -----------------------------------------------------------------------
879
880    #[test]
881    fn test_containers_are_send_sync() {
882        fn assert_send_sync<T: Send + Sync>() {}
883        assert_send_sync::<Sequential<f32>>();
884        assert_send_sync::<ModuleList<f32>>();
885        assert_send_sync::<ModuleDict<f32>>();
886    }
887}