1use alloc::boxed::Box;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4
5use burn_tensor::{Bool, Int, Tensor, backend::Backend};
6
7use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
8use burn_core::module::{ModuleVisitor, Param, ParamId};
9
10pub struct Collector {
45 pub tensors: Vec<TensorSnapshot>,
47 path_stack: Vec<String>,
48 container_stack: Vec<String>,
49 filter: Option<PathFilter>,
50 adapter: Option<Box<dyn ModuleAdapter>>,
51 skip_enum_variants: bool,
54}
55
56impl Default for Collector {
57 fn default() -> Self {
58 Self::new(None, None, false)
59 }
60}
61
62impl Collector {
63 pub fn new(
93 filter: Option<PathFilter>,
94 adapter: Option<Box<dyn ModuleAdapter>>,
95 skip_enum_variants: bool,
96 ) -> Self {
97 Self {
98 tensors: Vec::new(),
99 path_stack: Vec::new(),
100 container_stack: Vec::new(),
101 filter,
102 adapter,
103 skip_enum_variants,
104 }
105 }
106
107 pub fn into_tensors(self) -> Vec<TensorSnapshot> {
109 if let Some(adapter) = self.adapter {
110 self.tensors
111 .into_iter()
112 .map(|snapshot| adapter.adapt(&snapshot))
113 .collect()
114 } else {
115 self.tensors
116 }
117 }
118
119 fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool {
120 match &self.filter {
122 None => true,
123 Some(f) => f.matches_with_container_path(path, container_stack),
124 }
125 }
126}
127
128impl<B: Backend> ModuleVisitor<B> for Collector {
129 fn enter_module(&mut self, name: &str, container_type: &str) {
130 self.container_stack.push(container_type.to_string());
132
133 if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
136 self.path_stack.push(name.to_string());
137 }
138 }
139
140 fn exit_module(&mut self, _name: &str, container_type: &str) {
141 self.container_stack.pop();
142
143 if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
145 self.path_stack.pop();
146 }
147 }
148
149 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
150 if self.should_collect(&self.path_stack, &self.container_stack) {
151 self.tensors.push(TensorSnapshot::from_float(
152 ¶m.transform_for_save().val(),
153 self.path_stack.clone(),
154 self.container_stack.clone(),
155 param.id,
156 ));
157 }
158 }
159
160 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
161 if self.should_collect(&self.path_stack, &self.container_stack) {
162 self.tensors.push(TensorSnapshot::from_int(
163 ¶m.transform_for_save().val(),
164 self.path_stack.clone(),
165 self.container_stack.clone(),
166 param.id,
167 ));
168 }
169 }
170
171 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
172 if self.should_collect(&self.path_stack, &self.container_stack) {
173 self.tensors.push(TensorSnapshot::from_bool(
174 ¶m.transform_for_save().val(),
175 self.path_stack.clone(),
176 self.container_stack.clone(),
177 param.id,
178 ));
179 }
180 }
181
182 fn visit_float_with_path<const D: usize>(
183 &mut self,
184 path: &[String],
185 id: ParamId,
186 tensor: &Tensor<B, D>,
187 ) {
188 if self.should_collect(path, &self.container_stack) {
190 self.tensors.push(TensorSnapshot::from_float(
191 tensor,
192 path.to_vec(),
193 self.container_stack.clone(),
194 id,
195 ));
196 }
197 }
198
199 fn visit_int_with_path<const D: usize>(
200 &mut self,
201 path: &[String],
202 id: ParamId,
203 tensor: &Tensor<B, D, Int>,
204 ) {
205 if self.should_collect(path, &self.container_stack) {
206 self.tensors.push(TensorSnapshot::from_int(
207 tensor,
208 path.to_vec(),
209 self.container_stack.clone(),
210 id,
211 ));
212 }
213 }
214
215 fn visit_bool_with_path<const D: usize>(
216 &mut self,
217 path: &[String],
218 id: ParamId,
219 tensor: &Tensor<B, D, Bool>,
220 ) {
221 if self.should_collect(path, &self.container_stack) {
222 self.tensors.push(TensorSnapshot::from_bool(
223 tensor,
224 path.to_vec(),
225 self.container_stack.clone(),
226 id,
227 ));
228 }
229 }
230}
231
232#[cfg(all(test, feature = "std"))]
233mod tests {
234 use super::*;
235
236 use burn_core as burn;
237
238 type TestBackend = burn_ndarray::NdArray;
239 use alloc::collections::BTreeMap;
240 use alloc::string::String;
241 use burn_core::module::{Module, Param};
242 use burn_nn::LinearConfig;
243
244 #[test]
245 fn tensor_snapshot_collector() {
246 let device = Default::default();
247 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
248
249 let mut collector = Collector::new(None, None, false);
250 let id = ParamId::new();
251
252 collector.visit_float_with_path(&["model".to_string(), "weight".to_string()], id, &tensor);
254
255 assert_eq!(collector.tensors.len(), 1);
256 assert_eq!(collector.tensors[0].full_path(), "model.weight");
257
258 let view = &collector.tensors[0];
260 let data = view.to_data().unwrap();
261 assert_eq!(data.shape, vec![2, 2]);
262 }
263
264 #[test]
265 fn root_level_parameters() {
266 use burn_core::module::ModuleVisitor;
267
268 let device = Default::default();
269
270 let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
272 let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
273
274 let mut collector = Collector::new(None, None, false);
275
276 ModuleVisitor::<TestBackend>::enter_module(&mut collector, "weight", "");
279 ModuleVisitor::<TestBackend>::visit_float(&mut collector, &weight);
280 ModuleVisitor::<TestBackend>::exit_module(&mut collector, "weight", "");
281
282 ModuleVisitor::<TestBackend>::enter_module(&mut collector, "bias", "");
284 ModuleVisitor::<TestBackend>::visit_float(&mut collector, &bias);
285 ModuleVisitor::<TestBackend>::exit_module(&mut collector, "bias", "");
286
287 assert_eq!(collector.tensors.len(), 2);
289
290 assert_eq!(collector.tensors[0].full_path(), "weight");
292 assert_eq!(collector.tensors[1].full_path(), "bias");
293
294 let weight_data = collector.tensors[0]
296 .to_data()
297 .unwrap()
298 .to_vec::<f32>()
299 .unwrap();
300 let bias_data = collector.tensors[1]
301 .to_data()
302 .unwrap()
303 .to_vec::<f32>()
304 .unwrap();
305
306 assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
307 assert_eq!(bias_data, vec![5.0, 6.0]);
308 }
309
310 #[test]
311 #[cfg(target_has_atomic = "ptr")]
312 fn tensor_snapshot_collector_with_filter() {
313 let device = Default::default();
314 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
315
316 let filter = PathFilter::new().with_regex(r"^encoder\..*");
317 let mut collector = Collector::new(Some(filter), None, false);
318 let id = ParamId::new();
319
320 collector.visit_float_with_path(
322 &["encoder".to_string(), "weight".to_string()],
323 id,
324 &tensor,
325 );
326 collector.visit_float_with_path(
328 &["decoder".to_string(), "weight".to_string()],
329 id,
330 &tensor,
331 );
332
333 assert_eq!(collector.tensors.len(), 1);
334 assert_eq!(collector.tensors[0].full_path(), "encoder.weight");
335 }
336
337 #[test]
338 #[cfg(target_has_atomic = "ptr")]
339 fn tensor_snapshot_collector_with_multiple_filters() {
340 let device = Default::default();
341 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
342
343 let filter = PathFilter::new()
345 .with_regex(r"^encoder\..*") .with_regex(r".*\.bias$"); let mut collector = Collector::new(Some(filter), None, false);
348 let id = ParamId::new();
349
350 collector.visit_float_with_path(
352 &["encoder".to_string(), "weight".to_string()],
353 id,
354 &tensor,
355 ); collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor); collector.visit_float_with_path(
361 &["decoder".to_string(), "weight".to_string()],
362 id,
363 &tensor,
364 ); assert_eq!(collector.tensors.len(), 3);
367 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
368 assert!(paths.contains(&"encoder.weight".to_string()));
369 assert!(paths.contains(&"decoder.bias".to_string()));
370 assert!(paths.contains(&"encoder.bias".to_string()));
371 assert!(!paths.contains(&"decoder.weight".to_string()));
372 }
373
374 #[test]
375 fn tensor_snapshot_collector_with_predicate() {
376 let device = Default::default();
377 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
378
379 fn filter_fn(path: &str, _container_path: &str) -> bool {
381 path.starts_with("encoder.") || path == "decoder.bias"
382 }
383 let filter = PathFilter::new().with_predicate(filter_fn);
384 let mut collector = Collector::new(Some(filter), None, false);
385 let id = ParamId::new();
386
387 collector.visit_float_with_path(
389 &["encoder".to_string(), "weight".to_string()],
390 id,
391 &tensor,
392 );
393 collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor);
394 collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor);
395
396 collector.visit_float_with_path(
398 &["decoder".to_string(), "weight".to_string()],
399 id,
400 &tensor,
401 );
402 collector.visit_float_with_path(&["other".to_string(), "tensor".to_string()], id, &tensor);
403
404 assert_eq!(collector.tensors.len(), 3);
405 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
406 assert!(paths.contains(&"encoder.weight".to_string()));
407 assert!(paths.contains(&"encoder.bias".to_string()));
408 assert!(paths.contains(&"decoder.bias".to_string()));
409 assert!(!paths.contains(&"decoder.weight".to_string()));
410 assert!(!paths.contains(&"other.tensor".to_string()));
411 }
412
413 #[test]
414 fn tensor_snapshot_collector_predicate_with_complex_logic() {
415 let device = Default::default();
416 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
417
418 fn complex_filter(path: &str, _container_path: &str) -> bool {
420 let parts: Vec<&str> = path.split('.').collect();
421 if parts.len() != 3 {
422 return false;
423 }
424 (parts[1] == "layer1" || parts[1] == "layer2") && parts[2] == "weight"
426 }
427 let filter = PathFilter::new().with_predicate(complex_filter);
428 let mut collector = Collector::new(Some(filter), None, false);
429 let id = ParamId::new();
430
431 collector.visit_float_with_path(
433 &[
434 "model".to_string(),
435 "layer1".to_string(),
436 "weight".to_string(),
437 ],
438 id,
439 &tensor,
440 );
441 collector.visit_float_with_path(
442 &[
443 "model".to_string(),
444 "layer2".to_string(),
445 "weight".to_string(),
446 ],
447 id,
448 &tensor,
449 );
450
451 collector.visit_float_with_path(
453 &[
454 "model".to_string(),
455 "layer1".to_string(),
456 "bias".to_string(),
457 ],
458 id,
459 &tensor,
460 );
461 collector.visit_float_with_path(
462 &[
463 "model".to_string(),
464 "layer3".to_string(),
465 "weight".to_string(),
466 ],
467 id,
468 &tensor,
469 );
470 collector.visit_float_with_path(
471 &["encoder".to_string(), "weight".to_string()],
472 id,
473 &tensor,
474 ); assert_eq!(collector.tensors.len(), 2);
477 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
478 assert!(paths.contains(&"model.layer1.weight".to_string()));
479 assert!(paths.contains(&"model.layer2.weight".to_string()));
480 assert!(!paths.contains(&"model.layer1.bias".to_string()));
481 assert!(!paths.contains(&"model.layer3.weight".to_string()));
482 assert!(!paths.contains(&"encoder.weight".to_string()));
483 }
484
485 struct TensorPathCollector {
487 pub paths: BTreeMap<String, (ParamId, Vec<usize>)>,
488 path_stack: Vec<String>,
489 }
490
491 impl TensorPathCollector {
492 fn new() -> Self {
493 Self {
494 paths: BTreeMap::new(),
495 path_stack: Vec::new(),
496 }
497 }
498
499 fn current_path(&self) -> String {
500 self.path_stack.join(".")
501 }
502 }
503
504 impl<B: Backend> ModuleVisitor<B> for TensorPathCollector {
505 fn enter_module(&mut self, name: &str, _container_type: &str) {
506 self.path_stack.push(name.to_string());
507 }
508
509 fn exit_module(&mut self, _name: &str, _container_type: &str) {
510 self.path_stack.pop();
511 }
512
513 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
514 let path = self.current_path();
515 if !path.is_empty() {
516 self.paths.insert(
517 path,
518 (param.id, param.transform_for_save().val().shape().to_vec()),
519 );
520 }
521 }
522
523 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
524 let path = self.current_path();
525 if !path.is_empty() {
526 self.paths.insert(
527 path,
528 (param.id, param.transform_for_save().val().shape().to_vec()),
529 );
530 }
531 }
532
533 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
534 let path = self.current_path();
535 if !path.is_empty() {
536 self.paths.insert(
537 path,
538 (param.id, param.transform_for_save().val().shape().to_vec()),
539 );
540 }
541 }
542 }
543
544 #[derive(Module, Debug)]
546 struct InnerModule<B: Backend> {
547 weight: Param<Tensor<B, 2>>,
548 bias: Param<Tensor<B, 1>>,
549 }
550
551 #[derive(Module, Debug)]
552 struct OuterModule<B: Backend> {
553 layer1: InnerModule<B>,
554 layer2: InnerModule<B>,
555 }
556
557 impl<B: Backend> InnerModule<B> {
558 fn new(device: &B::Device) -> Self {
559 Self {
560 weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
561 bias: Param::from_data([5.0, 6.0], device),
562 }
563 }
564 }
565
566 impl<B: Backend> OuterModule<B> {
567 fn new(device: &B::Device) -> Self {
568 Self {
569 layer1: InnerModule::new(device),
570 layer2: InnerModule::new(device),
571 }
572 }
573 }
574
575 #[test]
576 fn nested_module_path_tracking() {
577 let device = Default::default();
578 let module = OuterModule::<TestBackend>::new(&device);
579
580 let mut collector = TensorPathCollector::new();
581 module.visit(&mut collector);
582
583 let paths = collector.paths;
584
585 assert!(paths.contains_key("layer1.weight"), "Missing layer1.weight");
588 assert!(paths.contains_key("layer1.bias"), "Missing layer1.bias");
589 assert!(paths.contains_key("layer2.weight"), "Missing layer2.weight");
590 assert!(paths.contains_key("layer2.bias"), "Missing layer2.bias");
591
592 assert_eq!(paths.get("layer1.weight").unwrap().1, vec![2, 2]);
594 assert_eq!(paths.get("layer1.bias").unwrap().1, vec![2]);
595 assert_eq!(paths.get("layer2.weight").unwrap().1, vec![2, 2]);
596 assert_eq!(paths.get("layer2.bias").unwrap().1, vec![2]);
597 }
598
599 #[test]
600 fn linear_module_paths() {
601 let device = Default::default();
602 let config = LinearConfig::new(10, 20).with_bias(true);
603 let linear = config.init::<TestBackend>(&device);
604
605 let mut collector = TensorPathCollector::new();
606 linear.visit(&mut collector);
607
608 let paths = collector.paths;
609
610 assert!(paths.contains_key("weight"));
612 assert!(paths.contains_key("bias"));
613
614 assert_eq!(paths.get("weight").unwrap().1, vec![10, 20]);
616 assert_eq!(paths.get("bias").unwrap().1, vec![20]);
617 }
618
619 #[derive(Module, Debug)]
621 struct Level4Module<B: Backend> {
622 weight: Param<Tensor<B, 2>>,
623 bias: Param<Tensor<B, 1>>,
624 }
625
626 #[derive(Module, Debug)]
627 struct Level3Module<B: Backend> {
628 layer: Level4Module<B>,
629 extra: Level4Module<B>,
630 }
631
632 #[derive(Module, Debug)]
633 struct Level2Module<B: Backend> {
634 block1: Level3Module<B>,
635 block2: Level3Module<B>,
636 }
637
638 #[derive(Module, Debug)]
639 struct Level1Module<B: Backend> {
640 encoder: Level2Module<B>,
641 decoder: Level2Module<B>,
642 }
643
644 #[derive(Module, Debug)]
645 struct DeepModel<B: Backend> {
646 backbone: Level1Module<B>,
647 head: Level4Module<B>,
648 }
649
650 impl<B: Backend> Level4Module<B> {
651 fn new(device: &B::Device) -> Self {
652 Self {
653 weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
654 bias: Param::from_data([5.0, 6.0], device),
655 }
656 }
657 }
658
659 impl<B: Backend> Level3Module<B> {
660 fn new(device: &B::Device) -> Self {
661 Self {
662 layer: Level4Module::new(device),
663 extra: Level4Module::new(device),
664 }
665 }
666 }
667
668 impl<B: Backend> Level2Module<B> {
669 fn new(device: &B::Device) -> Self {
670 Self {
671 block1: Level3Module::new(device),
672 block2: Level3Module::new(device),
673 }
674 }
675 }
676
677 impl<B: Backend> Level1Module<B> {
678 fn new(device: &B::Device) -> Self {
679 Self {
680 encoder: Level2Module::new(device),
681 decoder: Level2Module::new(device),
682 }
683 }
684 }
685
686 impl<B: Backend> DeepModel<B> {
687 fn new(device: &B::Device) -> Self {
688 Self {
689 backbone: Level1Module::new(device),
690 head: Level4Module::new(device),
691 }
692 }
693 }
694
695 #[test]
696 fn deep_module_path_tracking() {
697 let device = Default::default();
698 let model = DeepModel::<TestBackend>::new(&device);
699
700 let mut collector = Collector::new(None, None, false);
701 model.visit(&mut collector);
702
703 let views = collector.tensors;
704 let paths: Vec<String> = views.iter().map(|v| v.full_path()).collect();
705
706 assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
708 assert!(paths.contains(&"backbone.encoder.block1.layer.bias".to_string()));
709 assert!(paths.contains(&"backbone.encoder.block1.extra.weight".to_string()));
710 assert!(paths.contains(&"backbone.encoder.block1.extra.bias".to_string()));
711
712 assert!(paths.contains(&"backbone.encoder.block2.layer.weight".to_string()));
713 assert!(paths.contains(&"backbone.encoder.block2.layer.bias".to_string()));
714 assert!(paths.contains(&"backbone.encoder.block2.extra.weight".to_string()));
715 assert!(paths.contains(&"backbone.encoder.block2.extra.bias".to_string()));
716
717 assert!(paths.contains(&"backbone.decoder.block1.layer.weight".to_string()));
718 assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
719 assert!(paths.contains(&"backbone.decoder.block1.extra.weight".to_string()));
720 assert!(paths.contains(&"backbone.decoder.block1.extra.bias".to_string()));
721
722 assert!(paths.contains(&"backbone.decoder.block2.layer.weight".to_string()));
723 assert!(paths.contains(&"backbone.decoder.block2.layer.bias".to_string()));
724 assert!(paths.contains(&"backbone.decoder.block2.extra.weight".to_string()));
725 assert!(paths.contains(&"backbone.decoder.block2.extra.bias".to_string()));
726
727 assert!(paths.contains(&"head.weight".to_string()));
729 assert!(paths.contains(&"head.bias".to_string()));
730
731 assert_eq!(views.len(), 18);
733
734 let view = views
736 .iter()
737 .find(|v| v.full_path() == "backbone.encoder.block1.layer.weight")
738 .unwrap();
739 let data = view.to_data().unwrap();
740 assert_eq!(data.shape, vec![2, 2]);
741 }
742
743 #[test]
744 fn deep_module_filtered_export() {
745 let device = Default::default();
746 let model = DeepModel::<TestBackend>::new(&device);
747
748 #[cfg(target_has_atomic = "ptr")]
750 {
751 let filter = PathFilter::new().with_regex(r"^backbone\.encoder\..*");
752 let mut collector = Collector::new(Some(filter), None, false);
753 model.visit(&mut collector);
754 assert_eq!(collector.tensors.len(), 8); }
756
757 #[cfg(target_has_atomic = "ptr")]
759 {
760 let filter = PathFilter::new().with_regex(r".*\.block1\..*");
761 let mut collector = Collector::new(Some(filter), None, false);
762 model.visit(&mut collector);
763 assert_eq!(collector.tensors.len(), 8); }
765
766 #[cfg(target_has_atomic = "ptr")]
768 {
769 let filter = PathFilter::new().with_regex(r".*\.weight$");
770 let mut collector = Collector::new(Some(filter), None, false);
771 model.visit(&mut collector);
772 assert_eq!(collector.tensors.len(), 9); }
774
775 #[cfg(target_has_atomic = "ptr")]
777 {
778 let filter = PathFilter::new()
779 .with_regex(r"^backbone\.encoder\.block1\..*") .with_regex(r"^backbone\.decoder\..*\.bias$") .with_regex(r"^head\.weight$"); let mut collector = Collector::new(Some(filter), None, false);
783 model.visit(&mut collector);
784
785 assert_eq!(collector.tensors.len(), 9);
790
791 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
792 assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
793 assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
794 assert!(paths.contains(&"head.weight".to_string()));
795 assert!(!paths.contains(&"head.bias".to_string())); }
797 }
798
799 use crate::traits::ModuleSnapshot;
800 use burn_nn::Linear;
801 use hashbrown::HashMap;
802
803 #[derive(Module, Debug)]
805 struct OptionalFieldModule<B: Backend> {
806 required: Param<Tensor<B, 2>>,
807 optional: Option<Param<Tensor<B, 1>>>,
808 }
809
810 impl<B: Backend> OptionalFieldModule<B> {
811 fn new_with_optional(device: &B::Device) -> Self {
812 Self {
813 required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
814 optional: Some(Param::from_data([5.0, 6.0], device)),
815 }
816 }
817
818 fn new_without_optional(device: &B::Device) -> Self {
819 Self {
820 required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
821 optional: None,
822 }
823 }
824 }
825
826 #[test]
827 fn optional_field_module_with_value() {
828 let device = Default::default();
829 let module = OptionalFieldModule::<TestBackend>::new_with_optional(&device);
830
831 let views: HashMap<String, TensorSnapshot> = module
832 .collect(None, None, false)
833 .into_iter()
834 .map(|v| (v.full_path(), v))
835 .collect();
836
837 assert_eq!(views.len(), 2);
838 assert!(views.contains_key("required"));
839 assert!(views.contains_key("optional"));
840 }
841
842 #[test]
843 fn optional_field_module_without_value() {
844 let device = Default::default();
845 let module = OptionalFieldModule::<TestBackend>::new_without_optional(&device);
846
847 let views: HashMap<String, TensorSnapshot> = module
848 .collect(None, None, false)
849 .into_iter()
850 .map(|v| (v.full_path(), v))
851 .collect();
852
853 assert_eq!(views.len(), 1);
854 assert!(views.contains_key("required"));
855 assert!(!views.contains_key("optional"));
856 }
857
858 #[derive(Module, Debug)]
860 struct VecModule<B: Backend> {
861 layers: Vec<Linear<B>>,
862 }
863
864 impl<B: Backend> VecModule<B> {
865 fn new(device: &B::Device, num_layers: usize) -> Self {
866 Self {
867 layers: (0..num_layers)
868 .map(|_| LinearConfig::new(10, 10).init(device))
869 .collect(),
870 }
871 }
872 }
873
874 #[derive(Module, Debug)]
876 struct TupleModule<B: Backend> {
877 layers: (Linear<B>, Linear<B>, Linear<B>),
878 }
879
880 impl<B: Backend> TupleModule<B> {
881 fn new(device: &B::Device) -> Self {
882 Self {
883 layers: (
884 LinearConfig::new(10, 10).init(device),
885 LinearConfig::new(10, 10).init(device),
886 LinearConfig::new(10, 10).init(device),
887 ),
888 }
889 }
890 }
891
892 #[test]
893 fn vec_module_collect() {
894 let device = Default::default();
895 let module = VecModule::<TestBackend>::new(&device, 3);
896
897 let views: HashMap<String, TensorSnapshot> = module
898 .collect(None, None, false)
899 .into_iter()
900 .map(|v| (v.full_path(), v))
901 .collect();
902
903 assert_eq!(views.len(), 6); assert!(views.contains_key("layers.0.weight"));
908 assert!(views.contains_key("layers.0.bias"));
909 assert!(views.contains_key("layers.1.weight"));
910 assert!(views.contains_key("layers.1.bias"));
911 assert!(views.contains_key("layers.2.weight"));
912 assert!(views.contains_key("layers.2.bias"));
913 }
914
915 #[test]
916 fn tuple_module_collect() {
917 let device = Default::default();
918 let module = TupleModule::<TestBackend>::new(&device);
919
920 let snapshots = module.collect(None, None, false);
921 assert_eq!(snapshots.len(), 6);
922
923 let views: HashMap<String, TensorSnapshot> =
924 snapshots.into_iter().map(|v| (v.full_path(), v)).collect();
925
926 assert_eq!(views.len(), 6);
927
928 assert!(views.contains_key("layers.0.weight"));
929 assert!(views.contains_key("layers.0.bias"));
930 assert!(views.contains_key("layers.1.weight"));
931 assert!(views.contains_key("layers.1.bias"));
932 assert!(views.contains_key("layers.2.weight"));
933 assert!(views.contains_key("layers.2.bias"));
934 }
935
936 #[derive(Module, Debug)]
938 struct ArrayModule<B: Backend> {
939 layers: [Linear<B>; 3],
940 }
941
942 impl<B: Backend> ArrayModule<B> {
943 fn new(device: &B::Device) -> Self {
944 Self {
945 layers: [
946 LinearConfig::new(10, 10).init(device),
947 LinearConfig::new(10, 10).init(device),
948 LinearConfig::new(10, 10).init(device),
949 ],
950 }
951 }
952 }
953
954 #[test]
955 fn array_module_collect() {
956 let device = Default::default();
957 let module = ArrayModule::<TestBackend>::new(&device);
958
959 let views: HashMap<String, TensorSnapshot> = module
960 .collect(None, None, false)
961 .into_iter()
962 .map(|v| (v.full_path(), v))
963 .collect();
964
965 assert_eq!(views.len(), 6); for i in 0..3 {
970 assert!(views.contains_key(&format!("layers.{}.weight", i)));
971 assert!(views.contains_key(&format!("layers.{}.bias", i)));
972 }
973 }
974
975 #[derive(Module, Debug)]
977 enum EnumModule<B: Backend> {
978 LayerA(Linear<B>),
979 LayerB(Linear<B>),
980 LayerC(Linear<B>),
981 }
982
983 #[test]
984 fn enum_module_collect() {
985 let device = Default::default();
986
987 let module_a = EnumModule::<TestBackend>::LayerA(LinearConfig::new(10, 20).init(&device));
989 let views_a: HashMap<String, TensorSnapshot> = module_a
990 .collect(None, None, false)
991 .into_iter()
992 .map(|v| (v.full_path(), v))
993 .collect();
994
995 assert_eq!(views_a.len(), 2);
997 assert!(views_a.contains_key("LayerA.weight"));
998 assert!(views_a.contains_key("LayerA.bias"));
999
1000 let module_b = EnumModule::<TestBackend>::LayerB(LinearConfig::new(10, 20).init(&device));
1002 let views_b: HashMap<String, TensorSnapshot> = module_b
1003 .collect(None, None, false)
1004 .into_iter()
1005 .map(|v| (v.full_path(), v))
1006 .collect();
1007
1008 assert_eq!(views_b.len(), 2);
1009 assert!(views_b.contains_key("LayerB.weight"));
1010 assert!(views_b.contains_key("LayerB.bias"));
1011 }
1012
1013 #[test]
1015 fn linear_container_type() {
1016 let device = Default::default();
1017
1018 #[derive(Module, Debug)]
1019 struct ModelWithLinear<B: Backend> {
1020 linear: Linear<B>,
1021 }
1022
1023 impl<B: Backend> ModelWithLinear<B> {
1024 fn new(device: &B::Device) -> Self {
1025 Self {
1026 linear: LinearConfig::new(10, 20).init(device),
1027 }
1028 }
1029 }
1030
1031 let model = ModelWithLinear::<TestBackend>::new(&device);
1032
1033 let views: HashMap<String, TensorSnapshot> = model
1034 .collect(None, None, false)
1035 .into_iter()
1036 .map(|v| (v.full_path(), v))
1037 .collect();
1038
1039 for (path, view) in views.iter() {
1041 if path == "linear.weight" || path == "linear.bias" {
1042 assert_eq!(
1043 view.module_type(),
1044 Some("Struct:Linear".to_string()),
1045 "Tensor '{}' should have module type 'Struct:Linear'",
1046 path
1047 );
1048 }
1049 }
1050 }
1051
1052 #[test]
1053 fn complex_model_container_types() {
1054 let device = Default::default();
1055
1056 #[derive(Module, Debug)]
1057 struct ComplexModel<B: Backend> {
1058 linear_layers: [Linear<B>; 2],
1059 vec_layers: Vec<Linear<B>>,
1060 single_linear: Linear<B>,
1061 }
1062
1063 impl<B: Backend> ComplexModel<B> {
1064 fn new(device: &B::Device) -> Self {
1065 Self {
1066 linear_layers: [
1067 LinearConfig::new(100, 50).init(device),
1068 LinearConfig::new(50, 10).init(device),
1069 ],
1070 vec_layers: vec![
1071 LinearConfig::new(10, 10).init(device),
1072 LinearConfig::new(10, 10).init(device),
1073 ],
1074 single_linear: LinearConfig::new(10, 1).init(device),
1075 }
1076 }
1077 }
1078
1079 let model = ComplexModel::<TestBackend>::new(&device);
1080
1081 let views: HashMap<String, TensorSnapshot> = model
1082 .collect(None, None, false)
1083 .into_iter()
1084 .map(|v| (v.full_path(), v))
1085 .collect();
1086
1087 assert_eq!(views.len(), 10);
1089
1090 for (_path, view) in views.iter() {
1092 assert_eq!(view.module_type(), Some("Struct:Linear".to_string()));
1093 }
1094 }
1095
1096 #[test]
1097 fn collect_with_container_filter() {
1098 let device = Default::default();
1099
1100 #[derive(Module, Debug)]
1101 struct FilterTestModel<B: Backend> {
1102 layers: Vec<Linear<B>>,
1103 }
1104
1105 impl<B: Backend> FilterTestModel<B> {
1106 fn new(device: &B::Device) -> Self {
1107 Self {
1108 layers: vec![
1109 LinearConfig::new(10, 10).init(device),
1110 LinearConfig::new(10, 10).init(device),
1111 ],
1112 }
1113 }
1114 }
1115
1116 let model = FilterTestModel::<TestBackend>::new(&device);
1117
1118 let filter = PathFilter::new().with_predicate(|_path, container_path| {
1120 container_path.split('.').next_back() == Some("Struct:Linear")
1121 });
1122
1123 let linear_views: Vec<TensorSnapshot> = model.collect(Some(filter), None, false);
1124
1125 for view in linear_views.iter() {
1127 assert_eq!(
1128 view.module_type(),
1129 Some("Struct:Linear".to_string()),
1130 "All tensors should be from Linear modules"
1131 );
1132 }
1133
1134 assert_eq!(linear_views.len(), 4);
1136 }
1137}