1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
25
26use crate::module::Module;
27use crate::parameter::Parameter;
28
29pub struct Sequential<T: Float> {
54 layers: Vec<Box<dyn Module<T>>>,
55 training: bool,
56}
57
58impl<T: Float> Sequential<T> {
59 pub fn new(layers: Vec<Box<dyn Module<T>>>) -> Self {
61 Self {
62 layers,
63 training: true,
64 }
65 }
66
67 pub fn push(&mut self, layer: Box<dyn Module<T>>) {
69 self.layers.push(layer);
70 }
71
72 #[inline]
74 pub fn len(&self) -> usize {
75 self.layers.len()
76 }
77
78 #[inline]
80 pub fn is_empty(&self) -> bool {
81 self.layers.is_empty()
82 }
83}
84
85impl<T: Float> Module<T> for Sequential<T> {
86 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
91 if self.layers.is_empty() {
92 return Err(FerrotorchError::InvalidArgument {
93 message: "Sequential: cannot forward through empty container".into(),
94 });
95 }
96
97 let mut output = self.layers[0].forward(input)?;
98 for layer in &self.layers[1..] {
99 output = layer.forward(&output)?;
100 }
101 Ok(output)
102 }
103
104 fn parameters(&self) -> Vec<&Parameter<T>> {
105 self.layers
106 .iter()
107 .flat_map(|layer| layer.parameters())
108 .collect()
109 }
110
111 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
112 self.layers
113 .iter_mut()
114 .flat_map(|layer| layer.parameters_mut())
115 .collect()
116 }
117
118 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
119 self.layers
120 .iter()
121 .enumerate()
122 .flat_map(|(i, layer)| {
123 layer
124 .named_parameters()
125 .into_iter()
126 .map(move |(name, param)| (format!("{i}.{name}"), param))
127 })
128 .collect()
129 }
130
131 fn train(&mut self) {
132 self.training = true;
133 for layer in &mut self.layers {
134 layer.train();
135 }
136 }
137
138 fn eval(&mut self) {
139 self.training = false;
140 for layer in &mut self.layers {
141 layer.eval();
142 }
143 }
144
145 fn is_training(&self) -> bool {
146 self.training
147 }
148}
149
150impl<T: Float> std::fmt::Display for Sequential<T> {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 writeln!(f, "Sequential(")?;
153 for (i, _layer) in self.layers.iter().enumerate() {
154 writeln!(f, " ({i}): <module>")?;
155 }
156 write!(f, ")")
157 }
158}
159
160pub struct ModuleList<T: Float> {
195 modules: Vec<Box<dyn Module<T>>>,
196 training: bool,
197}
198
199impl<T: Float> ModuleList<T> {
200 pub fn new(modules: Vec<Box<dyn Module<T>>>) -> Self {
202 Self {
203 modules,
204 training: true,
205 }
206 }
207
208 pub fn empty() -> Self {
210 Self {
211 modules: Vec::new(),
212 training: true,
213 }
214 }
215
216 pub fn get(&self, index: usize) -> Option<&dyn Module<T>> {
218 self.modules.get(index).map(|m| m.as_ref())
219 }
220
221 pub fn get_mut(&mut self, index: usize) -> Option<&mut dyn Module<T>> {
223 match self.modules.get_mut(index) {
224 Some(m) => Some(m.as_mut()),
225 None => None,
226 }
227 }
228
229 pub fn push(&mut self, module: Box<dyn Module<T>>) {
231 self.modules.push(module);
232 }
233
234 #[inline]
236 pub fn len(&self) -> usize {
237 self.modules.len()
238 }
239
240 #[inline]
242 pub fn is_empty(&self) -> bool {
243 self.modules.is_empty()
244 }
245}
246
247impl<T: Float> Module<T> for ModuleList<T> {
248 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
252 Err(FerrotorchError::InvalidArgument {
253 message: "ModuleList does not implement forward. \
254 Iterate over the list and call each module's forward() manually."
255 .into(),
256 })
257 }
258
259 fn parameters(&self) -> Vec<&Parameter<T>> {
260 self.modules.iter().flat_map(|m| m.parameters()).collect()
261 }
262
263 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
264 self.modules
265 .iter_mut()
266 .flat_map(|m| m.parameters_mut())
267 .collect()
268 }
269
270 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
271 self.modules
272 .iter()
273 .enumerate()
274 .flat_map(|(i, m)| {
275 m.named_parameters()
276 .into_iter()
277 .map(move |(name, param)| (format!("{i}.{name}"), param))
278 })
279 .collect()
280 }
281
282 fn train(&mut self) {
283 self.training = true;
284 for m in &mut self.modules {
285 m.train();
286 }
287 }
288
289 fn eval(&mut self) {
290 self.training = false;
291 for m in &mut self.modules {
292 m.eval();
293 }
294 }
295
296 fn is_training(&self) -> bool {
297 self.training
298 }
299}
300
301impl<T: Float> std::fmt::Display for ModuleList<T> {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 writeln!(f, "ModuleList(")?;
304 for (i, _m) in self.modules.iter().enumerate() {
305 writeln!(f, " ({i}): <module>")?;
306 }
307 write!(f, ")")
308 }
309}
310
311pub struct ModuleDict<T: Float> {
349 entries: Vec<(String, Box<dyn Module<T>>)>,
350 training: bool,
351}
352
353impl<T: Float> ModuleDict<T> {
354 pub fn new() -> Self {
356 Self {
357 entries: Vec::new(),
358 training: true,
359 }
360 }
361
362 pub fn insert(&mut self, key: impl Into<String>, module: Box<dyn Module<T>>) {
367 let key = key.into();
368 for entry in &mut self.entries {
370 if entry.0 == key {
371 entry.1 = module;
372 return;
373 }
374 }
375 self.entries.push((key, module));
376 }
377
378 pub fn get(&self, key: &str) -> Option<&dyn Module<T>> {
380 self.entries
381 .iter()
382 .find(|(k, _)| k == key)
383 .map(|(_, m)| m.as_ref())
384 }
385
386 pub fn get_mut(&mut self, key: &str) -> Option<&mut dyn Module<T>> {
388 for (k, m) in &mut self.entries {
389 if k == key {
390 return Some(m.as_mut());
391 }
392 }
393 None
394 }
395
396 pub fn keys(&self) -> Vec<&str> {
398 self.entries.iter().map(|(k, _)| k.as_str()).collect()
399 }
400
401 #[inline]
403 pub fn len(&self) -> usize {
404 self.entries.len()
405 }
406
407 #[inline]
409 pub fn is_empty(&self) -> bool {
410 self.entries.is_empty()
411 }
412}
413
414impl<T: Float> Default for ModuleDict<T> {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420impl<T: Float> Module<T> for ModuleDict<T> {
421 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
425 Err(FerrotorchError::InvalidArgument {
426 message: "ModuleDict does not implement forward. \
427 Look up modules by key and call forward() manually."
428 .into(),
429 })
430 }
431
432 fn parameters(&self) -> Vec<&Parameter<T>> {
433 self.entries
434 .iter()
435 .flat_map(|(_, m)| m.parameters())
436 .collect()
437 }
438
439 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
440 self.entries
441 .iter_mut()
442 .flat_map(|(_, m)| m.parameters_mut())
443 .collect()
444 }
445
446 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
447 self.entries
448 .iter()
449 .flat_map(|(key, m)| {
450 m.named_parameters()
451 .into_iter()
452 .map(move |(name, param)| (format!("{key}.{name}"), param))
453 })
454 .collect()
455 }
456
457 fn train(&mut self) {
458 self.training = true;
459 for (_, m) in &mut self.entries {
460 m.train();
461 }
462 }
463
464 fn eval(&mut self) {
465 self.training = false;
466 for (_, m) in &mut self.entries {
467 m.eval();
468 }
469 }
470
471 fn is_training(&self) -> bool {
472 self.training
473 }
474}
475
476impl<T: Float> std::fmt::Display for ModuleDict<T> {
477 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
478 writeln!(f, "ModuleDict(")?;
479 for (key, _m) in &self.entries {
480 writeln!(f, " ({key}): <module>")?;
481 }
482 write!(f, ")")
483 }
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493
494 struct IdentityWithParam<T: Float> {
499 weight: Parameter<T>,
500 training: bool,
501 }
502
503 impl<T: Float> IdentityWithParam<T> {
504 fn new(size: usize) -> FerrotorchResult<Self> {
505 Ok(Self {
506 weight: Parameter::zeros(&[size])?,
507 training: true,
508 })
509 }
510 }
511
512 impl<T: Float> Module<T> for IdentityWithParam<T> {
513 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
514 Ok(input.clone())
515 }
516
517 fn parameters(&self) -> Vec<&Parameter<T>> {
518 vec![&self.weight]
519 }
520
521 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
522 vec![&mut self.weight]
523 }
524
525 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
526 vec![("weight".to_string(), &self.weight)]
527 }
528
529 fn train(&mut self) {
530 self.training = true;
531 }
532
533 fn eval(&mut self) {
534 self.training = false;
535 }
536
537 fn is_training(&self) -> bool {
538 self.training
539 }
540 }
541
542 #[test]
547 fn test_sequential_forward_chains_layers() {
548 let seq = Sequential::<f32>::new(vec![
550 Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
551 Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
552 Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
553 ]);
554
555 let input = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
556 let output = seq.forward(&input).unwrap();
557 assert_eq!(output.shape(), &[2, 4]);
558 }
559
560 #[test]
561 fn test_sequential_empty_forward_errors() {
562 let seq = Sequential::<f32>::new(vec![]);
563 let input = ferrotorch_core::zeros::<f32>(&[1, 4]).unwrap();
564 assert!(seq.forward(&input).is_err());
565 }
566
567 #[test]
568 fn test_sequential_parameter_count() {
569 let seq = Sequential::<f32>::new(vec![
570 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
571 Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
572 Box::new(IdentityWithParam::<f32>::new(7).unwrap()),
573 ]);
574
575 let params = seq.parameters();
576 assert_eq!(params.len(), 3);
577
578 let total: usize = params.iter().map(|p| p.numel()).sum();
579 assert_eq!(total, 3 + 5 + 7);
580 }
581
582 #[test]
583 fn test_sequential_named_parameters_keys() {
584 let seq = Sequential::<f32>::new(vec![
585 Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
586 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
587 Box::new(IdentityWithParam::<f32>::new(4).unwrap()),
588 ]);
589
590 let named = seq.named_parameters();
591 let keys: Vec<&str> = named.iter().map(|(k, _)| k.as_str()).collect();
592 assert_eq!(keys, &["0.weight", "1.weight", "2.weight"]);
593 }
594
595 #[test]
596 fn test_sequential_train_eval_propagation() {
597 let mut seq = Sequential::<f32>::new(vec![
598 Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
599 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
600 ]);
601
602 assert!(seq.is_training());
603
604 seq.eval();
605 assert!(!seq.is_training());
606 for layer in &seq.layers {
608 assert!(!layer.is_training());
609 }
610
611 seq.train();
612 assert!(seq.is_training());
613 for layer in &seq.layers {
614 assert!(layer.is_training());
615 }
616 }
617
618 #[test]
619 fn test_sequential_state_dict_roundtrip() {
620 let seq = Sequential::<f32>::new(vec![
621 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
622 Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
623 ]);
624
625 let sd = seq.state_dict();
626 assert!(sd.contains_key("0.weight"));
627 assert!(sd.contains_key("1.weight"));
628 assert_eq!(sd["0.weight"].shape(), &[3]);
629 assert_eq!(sd["1.weight"].shape(), &[5]);
630
631 let mut seq2 = Sequential::<f32>::new(vec![
633 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
634 Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
635 ]);
636 seq2.load_state_dict(&sd, true).unwrap();
637
638 let sd2 = seq2.state_dict();
639 assert_eq!(
640 sd["0.weight"].data().unwrap(),
641 sd2["0.weight"].data().unwrap()
642 );
643 assert_eq!(
644 sd["1.weight"].data().unwrap(),
645 sd2["1.weight"].data().unwrap()
646 );
647 }
648
649 #[test]
650 fn test_sequential_push() {
651 let mut seq = Sequential::<f32>::new(vec![]);
652 assert!(seq.is_empty());
653 assert_eq!(seq.len(), 0);
654
655 seq.push(Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
656 assert_eq!(seq.len(), 1);
657 assert!(!seq.is_empty());
658 }
659
660 #[test]
665 fn test_module_list_forward_errors() {
666 let list =
667 ModuleList::<f32>::new(vec![Box::new(IdentityWithParam::<f32>::new(4).unwrap())]);
668 let input = ferrotorch_core::zeros::<f32>(&[1, 4]).unwrap();
669 assert!(list.forward(&input).is_err());
670 }
671
672 #[test]
673 fn test_module_list_get() {
674 let list = ModuleList::<f32>::new(vec![
675 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
676 Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
677 ]);
678
679 assert!(list.get(0).is_some());
680 assert!(list.get(1).is_some());
681 assert!(list.get(2).is_none());
682 }
683
684 #[test]
685 fn test_module_list_get_mut() {
686 let mut list =
687 ModuleList::<f32>::new(vec![Box::new(IdentityWithParam::<f32>::new(3).unwrap())]);
688
689 let m = list.get_mut(0).unwrap();
690 m.eval();
691 assert!(!list.get(0).unwrap().is_training());
692 }
693
694 #[test]
695 fn test_module_list_push() {
696 let mut list = ModuleList::<f32>::empty();
697 assert_eq!(list.len(), 0);
698 assert!(list.is_empty());
699
700 list.push(Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
701 assert_eq!(list.len(), 1);
702 assert!(!list.is_empty());
703 }
704
705 #[test]
706 fn test_module_list_parameters() {
707 let list = ModuleList::<f32>::new(vec![
708 Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
709 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
710 ]);
711
712 assert_eq!(list.parameters().len(), 2);
713
714 let named = list.named_parameters();
715 let keys: Vec<&str> = named.iter().map(|(k, _)| k.as_str()).collect();
716 assert_eq!(keys, &["0.weight", "1.weight"]);
717 }
718
719 #[test]
720 fn test_module_list_train_eval() {
721 let mut list = ModuleList::<f32>::new(vec![
722 Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
723 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
724 ]);
725
726 list.eval();
727 assert!(!list.is_training());
728 assert!(!list.get(0).unwrap().is_training());
729 assert!(!list.get(1).unwrap().is_training());
730
731 list.train();
732 assert!(list.is_training());
733 assert!(list.get(0).unwrap().is_training());
734 assert!(list.get(1).unwrap().is_training());
735 }
736
737 #[test]
742 fn test_module_dict_forward_errors() {
743 let mut dict = ModuleDict::<f32>::new();
744 dict.insert("enc", Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
745 let input = ferrotorch_core::zeros::<f32>(&[1, 4]).unwrap();
746 assert!(dict.forward(&input).is_err());
747 }
748
749 #[test]
750 fn test_module_dict_insert_get() {
751 let mut dict = ModuleDict::<f32>::new();
752 dict.insert(
753 "encoder",
754 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
755 );
756 dict.insert(
757 "decoder",
758 Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
759 );
760
761 assert!(dict.get("encoder").is_some());
762 assert!(dict.get("decoder").is_some());
763 assert!(dict.get("missing").is_none());
764 assert_eq!(dict.len(), 2);
765 }
766
767 #[test]
768 fn test_module_dict_insert_replaces() {
769 let mut dict = ModuleDict::<f32>::new();
770 dict.insert("layer", Box::new(IdentityWithParam::<f32>::new(3).unwrap()));
771 dict.insert("layer", Box::new(IdentityWithParam::<f32>::new(7).unwrap()));
772
773 assert_eq!(dict.len(), 1);
775 let named = dict.named_parameters();
776 assert_eq!(named.len(), 1);
777 assert_eq!(named[0].1.shape(), &[7]);
778 }
779
780 #[test]
781 fn test_module_dict_keys_insertion_order() {
782 let mut dict = ModuleDict::<f32>::new();
783 dict.insert(
784 "c_layer",
785 Box::new(IdentityWithParam::<f32>::new(1).unwrap()),
786 );
787 dict.insert(
788 "a_layer",
789 Box::new(IdentityWithParam::<f32>::new(2).unwrap()),
790 );
791 dict.insert(
792 "b_layer",
793 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
794 );
795
796 assert_eq!(dict.keys(), &["c_layer", "a_layer", "b_layer"]);
797 }
798
799 #[test]
800 fn test_module_dict_get_mut() {
801 let mut dict = ModuleDict::<f32>::new();
802 dict.insert("layer", Box::new(IdentityWithParam::<f32>::new(3).unwrap()));
803
804 let m = dict.get_mut("layer").unwrap();
805 m.eval();
806 assert!(!dict.get("layer").unwrap().is_training());
807 }
808
809 #[test]
810 fn test_module_dict_named_parameters_prefixed_by_key() {
811 let mut dict = ModuleDict::<f32>::new();
812 dict.insert(
813 "encoder",
814 Box::new(IdentityWithParam::<f32>::new(3).unwrap()),
815 );
816 dict.insert(
817 "decoder",
818 Box::new(IdentityWithParam::<f32>::new(5).unwrap()),
819 );
820
821 let named = dict.named_parameters();
822 let keys: Vec<&str> = named.iter().map(|(k, _)| k.as_str()).collect();
823 assert_eq!(keys, &["encoder.weight", "decoder.weight"]);
824 }
825
826 #[test]
827 fn test_module_dict_train_eval() {
828 let mut dict = ModuleDict::<f32>::new();
829 dict.insert("a", Box::new(IdentityWithParam::<f32>::new(2).unwrap()));
830 dict.insert("b", Box::new(IdentityWithParam::<f32>::new(3).unwrap()));
831
832 dict.eval();
833 assert!(!dict.is_training());
834 assert!(!dict.get("a").unwrap().is_training());
835 assert!(!dict.get("b").unwrap().is_training());
836
837 dict.train();
838 assert!(dict.is_training());
839 assert!(dict.get("a").unwrap().is_training());
840 assert!(dict.get("b").unwrap().is_training());
841 }
842
843 #[test]
844 fn test_module_dict_default() {
845 let dict = ModuleDict::<f32>::default();
846 assert!(dict.is_empty());
847 assert_eq!(dict.len(), 0);
848 }
849
850 #[test]
851 fn test_module_dict_state_dict_roundtrip() {
852 let mut dict = ModuleDict::<f32>::new();
853 dict.insert("enc", Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
854 dict.insert("dec", Box::new(IdentityWithParam::<f32>::new(6).unwrap()));
855
856 let sd = dict.state_dict();
857 assert!(sd.contains_key("enc.weight"));
858 assert!(sd.contains_key("dec.weight"));
859
860 let mut dict2 = ModuleDict::<f32>::new();
861 dict2.insert("enc", Box::new(IdentityWithParam::<f32>::new(4).unwrap()));
862 dict2.insert("dec", Box::new(IdentityWithParam::<f32>::new(6).unwrap()));
863 dict2.load_state_dict(&sd, true).unwrap();
864
865 let sd2 = dict2.state_dict();
866 assert_eq!(
867 sd["enc.weight"].data().unwrap(),
868 sd2["enc.weight"].data().unwrap()
869 );
870 assert_eq!(
871 sd["dec.weight"].data().unwrap(),
872 sd2["dec.weight"].data().unwrap()
873 );
874 }
875
876 #[test]
881 fn test_containers_are_send_sync() {
882 fn assert_send_sync<T: Send + Sync>() {}
883 assert_send_sync::<Sequential<f32>>();
884 assert_send_sync::<ModuleList<f32>>();
885 assert_send_sync::<ModuleDict<f32>>();
886 }
887}