1use alloc::boxed::Box;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4
5use burn_tensor::{Bool, Int, Tensor, backend::Backend};
6
7use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
8use burn_core::module::{ModuleVisitor, Param, ParamId};
9
10pub struct Collector {
45 pub tensors: Vec<TensorSnapshot>,
47 path_stack: Vec<String>,
48 container_stack: Vec<String>,
49 filter: Option<PathFilter>,
50 adapter: Option<Box<dyn ModuleAdapter>>,
51}
52
53impl Default for Collector {
54 fn default() -> Self {
55 Self::new(None, None)
56 }
57}
58
59impl Collector {
60 pub fn new(filter: Option<PathFilter>, adapter: Option<Box<dyn ModuleAdapter>>) -> Self {
83 Self {
84 tensors: Vec::new(),
85 path_stack: Vec::new(),
86 container_stack: Vec::new(),
87 filter,
88 adapter,
89 }
90 }
91
92 pub fn into_tensors(self) -> Vec<TensorSnapshot> {
94 if let Some(adapter) = self.adapter {
95 self.tensors
96 .into_iter()
97 .map(|snapshot| adapter.adapt(&snapshot))
98 .collect()
99 } else {
100 self.tensors
101 }
102 }
103
104 fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool {
105 match &self.filter {
107 None => true,
108 Some(f) => f.matches_with_container_path(path, container_stack),
109 }
110 }
111}
112
113impl<B: Backend> ModuleVisitor<B> for Collector {
114 fn enter_module(&mut self, name: &str, container_type: &str) {
115 self.path_stack.push(name.to_string());
116 self.container_stack.push(container_type.to_string());
117 }
118
119 fn exit_module(&mut self, _name: &str, _container_type: &str) {
120 self.path_stack.pop();
121 self.container_stack.pop();
122 }
123
124 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
125 if self.should_collect(&self.path_stack, &self.container_stack) {
126 self.tensors.push(TensorSnapshot::from_float(
127 ¶m.transform_for_save().val(),
128 self.path_stack.clone(),
129 self.container_stack.clone(),
130 param.id,
131 ));
132 }
133 }
134
135 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
136 if self.should_collect(&self.path_stack, &self.container_stack) {
137 self.tensors.push(TensorSnapshot::from_int(
138 ¶m.transform_for_save().val(),
139 self.path_stack.clone(),
140 self.container_stack.clone(),
141 param.id,
142 ));
143 }
144 }
145
146 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
147 if self.should_collect(&self.path_stack, &self.container_stack) {
148 self.tensors.push(TensorSnapshot::from_bool(
149 ¶m.transform_for_save().val(),
150 self.path_stack.clone(),
151 self.container_stack.clone(),
152 param.id,
153 ));
154 }
155 }
156
157 fn visit_float_with_path<const D: usize>(
158 &mut self,
159 path: &[String],
160 id: ParamId,
161 tensor: &Tensor<B, D>,
162 ) {
163 if self.should_collect(path, &self.container_stack) {
165 self.tensors.push(TensorSnapshot::from_float(
166 tensor,
167 path.to_vec(),
168 self.container_stack.clone(),
169 id,
170 ));
171 }
172 }
173
174 fn visit_int_with_path<const D: usize>(
175 &mut self,
176 path: &[String],
177 id: ParamId,
178 tensor: &Tensor<B, D, Int>,
179 ) {
180 if self.should_collect(path, &self.container_stack) {
181 self.tensors.push(TensorSnapshot::from_int(
182 tensor,
183 path.to_vec(),
184 self.container_stack.clone(),
185 id,
186 ));
187 }
188 }
189
190 fn visit_bool_with_path<const D: usize>(
191 &mut self,
192 path: &[String],
193 id: ParamId,
194 tensor: &Tensor<B, D, Bool>,
195 ) {
196 if self.should_collect(path, &self.container_stack) {
197 self.tensors.push(TensorSnapshot::from_bool(
198 tensor,
199 path.to_vec(),
200 self.container_stack.clone(),
201 id,
202 ));
203 }
204 }
205}
206
207#[cfg(all(test, feature = "std"))]
208mod tests {
209 use super::*;
210
211 use burn_core as burn;
212
213 type TestBackend = burn_ndarray::NdArray;
214 use alloc::collections::BTreeMap;
215 use alloc::string::String;
216 use burn_core::module::{Module, Param};
217 use burn_nn::LinearConfig;
218
219 #[test]
220 fn tensor_snapshot_collector() {
221 let device = Default::default();
222 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
223
224 let mut collector = Collector::new(None, None);
225 let id = ParamId::new();
226
227 collector.visit_float_with_path(&["model".to_string(), "weight".to_string()], id, &tensor);
229
230 assert_eq!(collector.tensors.len(), 1);
231 assert_eq!(collector.tensors[0].full_path(), "model.weight");
232
233 let view = &collector.tensors[0];
235 let data = view.to_data().unwrap();
236 assert_eq!(data.shape, vec![2, 2]);
237 }
238
239 #[test]
240 fn root_level_parameters() {
241 use burn_core::module::ModuleVisitor;
242
243 let device = Default::default();
244
245 let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
247 let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
248
249 let mut collector = Collector::new(None, None);
250
251 ModuleVisitor::<TestBackend>::enter_module(&mut collector, "weight", "");
254 ModuleVisitor::<TestBackend>::visit_float(&mut collector, &weight);
255 ModuleVisitor::<TestBackend>::exit_module(&mut collector, "weight", "");
256
257 ModuleVisitor::<TestBackend>::enter_module(&mut collector, "bias", "");
259 ModuleVisitor::<TestBackend>::visit_float(&mut collector, &bias);
260 ModuleVisitor::<TestBackend>::exit_module(&mut collector, "bias", "");
261
262 assert_eq!(collector.tensors.len(), 2);
264
265 assert_eq!(collector.tensors[0].full_path(), "weight");
267 assert_eq!(collector.tensors[1].full_path(), "bias");
268
269 let weight_data = collector.tensors[0]
271 .to_data()
272 .unwrap()
273 .to_vec::<f32>()
274 .unwrap();
275 let bias_data = collector.tensors[1]
276 .to_data()
277 .unwrap()
278 .to_vec::<f32>()
279 .unwrap();
280
281 assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
282 assert_eq!(bias_data, vec![5.0, 6.0]);
283 }
284
285 #[test]
286 #[cfg(target_has_atomic = "ptr")]
287 fn tensor_snapshot_collector_with_filter() {
288 let device = Default::default();
289 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
290
291 let filter = PathFilter::new().with_regex(r"^encoder\..*");
292 let mut collector = Collector::new(Some(filter), None);
293 let id = ParamId::new();
294
295 collector.visit_float_with_path(
297 &["encoder".to_string(), "weight".to_string()],
298 id,
299 &tensor,
300 );
301 collector.visit_float_with_path(
303 &["decoder".to_string(), "weight".to_string()],
304 id,
305 &tensor,
306 );
307
308 assert_eq!(collector.tensors.len(), 1);
309 assert_eq!(collector.tensors[0].full_path(), "encoder.weight");
310 }
311
312 #[test]
313 #[cfg(target_has_atomic = "ptr")]
314 fn tensor_snapshot_collector_with_multiple_filters() {
315 let device = Default::default();
316 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
317
318 let filter = PathFilter::new()
320 .with_regex(r"^encoder\..*") .with_regex(r".*\.bias$"); let mut collector = Collector::new(Some(filter), None);
323 let id = ParamId::new();
324
325 collector.visit_float_with_path(
327 &["encoder".to_string(), "weight".to_string()],
328 id,
329 &tensor,
330 ); collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor); collector.visit_float_with_path(
336 &["decoder".to_string(), "weight".to_string()],
337 id,
338 &tensor,
339 ); assert_eq!(collector.tensors.len(), 3);
342 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
343 assert!(paths.contains(&"encoder.weight".to_string()));
344 assert!(paths.contains(&"decoder.bias".to_string()));
345 assert!(paths.contains(&"encoder.bias".to_string()));
346 assert!(!paths.contains(&"decoder.weight".to_string()));
347 }
348
349 #[test]
350 fn tensor_snapshot_collector_with_predicate() {
351 let device = Default::default();
352 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
353
354 fn filter_fn(path: &str, _container_path: &str) -> bool {
356 path.starts_with("encoder.") || path == "decoder.bias"
357 }
358 let filter = PathFilter::new().with_predicate(filter_fn);
359 let mut collector = Collector::new(Some(filter), None);
360 let id = ParamId::new();
361
362 collector.visit_float_with_path(
364 &["encoder".to_string(), "weight".to_string()],
365 id,
366 &tensor,
367 );
368 collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor);
369 collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor);
370
371 collector.visit_float_with_path(
373 &["decoder".to_string(), "weight".to_string()],
374 id,
375 &tensor,
376 );
377 collector.visit_float_with_path(&["other".to_string(), "tensor".to_string()], id, &tensor);
378
379 assert_eq!(collector.tensors.len(), 3);
380 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
381 assert!(paths.contains(&"encoder.weight".to_string()));
382 assert!(paths.contains(&"encoder.bias".to_string()));
383 assert!(paths.contains(&"decoder.bias".to_string()));
384 assert!(!paths.contains(&"decoder.weight".to_string()));
385 assert!(!paths.contains(&"other.tensor".to_string()));
386 }
387
388 #[test]
389 fn tensor_snapshot_collector_predicate_with_complex_logic() {
390 let device = Default::default();
391 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
392
393 fn complex_filter(path: &str, _container_path: &str) -> bool {
395 let parts: Vec<&str> = path.split('.').collect();
396 if parts.len() != 3 {
397 return false;
398 }
399 (parts[1] == "layer1" || parts[1] == "layer2") && parts[2] == "weight"
401 }
402 let filter = PathFilter::new().with_predicate(complex_filter);
403 let mut collector = Collector::new(Some(filter), None);
404 let id = ParamId::new();
405
406 collector.visit_float_with_path(
408 &[
409 "model".to_string(),
410 "layer1".to_string(),
411 "weight".to_string(),
412 ],
413 id,
414 &tensor,
415 );
416 collector.visit_float_with_path(
417 &[
418 "model".to_string(),
419 "layer2".to_string(),
420 "weight".to_string(),
421 ],
422 id,
423 &tensor,
424 );
425
426 collector.visit_float_with_path(
428 &[
429 "model".to_string(),
430 "layer1".to_string(),
431 "bias".to_string(),
432 ],
433 id,
434 &tensor,
435 );
436 collector.visit_float_with_path(
437 &[
438 "model".to_string(),
439 "layer3".to_string(),
440 "weight".to_string(),
441 ],
442 id,
443 &tensor,
444 );
445 collector.visit_float_with_path(
446 &["encoder".to_string(), "weight".to_string()],
447 id,
448 &tensor,
449 ); assert_eq!(collector.tensors.len(), 2);
452 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
453 assert!(paths.contains(&"model.layer1.weight".to_string()));
454 assert!(paths.contains(&"model.layer2.weight".to_string()));
455 assert!(!paths.contains(&"model.layer1.bias".to_string()));
456 assert!(!paths.contains(&"model.layer3.weight".to_string()));
457 assert!(!paths.contains(&"encoder.weight".to_string()));
458 }
459
460 struct TensorPathCollector {
462 pub paths: BTreeMap<String, (ParamId, Vec<usize>)>,
463 path_stack: Vec<String>,
464 }
465
466 impl TensorPathCollector {
467 fn new() -> Self {
468 Self {
469 paths: BTreeMap::new(),
470 path_stack: Vec::new(),
471 }
472 }
473
474 fn current_path(&self) -> String {
475 self.path_stack.join(".")
476 }
477 }
478
479 impl<B: Backend> ModuleVisitor<B> for TensorPathCollector {
480 fn enter_module(&mut self, name: &str, _container_type: &str) {
481 self.path_stack.push(name.to_string());
482 }
483
484 fn exit_module(&mut self, _name: &str, _container_type: &str) {
485 self.path_stack.pop();
486 }
487
488 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
489 let path = self.current_path();
490 if !path.is_empty() {
491 self.paths.insert(
492 path,
493 (param.id, param.transform_for_save().val().shape().to_vec()),
494 );
495 }
496 }
497
498 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
499 let path = self.current_path();
500 if !path.is_empty() {
501 self.paths.insert(
502 path,
503 (param.id, param.transform_for_save().val().shape().to_vec()),
504 );
505 }
506 }
507
508 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
509 let path = self.current_path();
510 if !path.is_empty() {
511 self.paths.insert(
512 path,
513 (param.id, param.transform_for_save().val().shape().to_vec()),
514 );
515 }
516 }
517 }
518
519 #[derive(Module, Debug)]
521 struct InnerModule<B: Backend> {
522 weight: Param<Tensor<B, 2>>,
523 bias: Param<Tensor<B, 1>>,
524 }
525
526 #[derive(Module, Debug)]
527 struct OuterModule<B: Backend> {
528 layer1: InnerModule<B>,
529 layer2: InnerModule<B>,
530 }
531
532 impl<B: Backend> InnerModule<B> {
533 fn new(device: &B::Device) -> Self {
534 Self {
535 weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
536 bias: Param::from_data([5.0, 6.0], device),
537 }
538 }
539 }
540
541 impl<B: Backend> OuterModule<B> {
542 fn new(device: &B::Device) -> Self {
543 Self {
544 layer1: InnerModule::new(device),
545 layer2: InnerModule::new(device),
546 }
547 }
548 }
549
550 #[test]
551 fn nested_module_path_tracking() {
552 let device = Default::default();
553 let module = OuterModule::<TestBackend>::new(&device);
554
555 let mut collector = TensorPathCollector::new();
556 module.visit(&mut collector);
557
558 let paths = collector.paths;
559
560 assert!(paths.contains_key("layer1.weight"), "Missing layer1.weight");
563 assert!(paths.contains_key("layer1.bias"), "Missing layer1.bias");
564 assert!(paths.contains_key("layer2.weight"), "Missing layer2.weight");
565 assert!(paths.contains_key("layer2.bias"), "Missing layer2.bias");
566
567 assert_eq!(paths.get("layer1.weight").unwrap().1, vec![2, 2]);
569 assert_eq!(paths.get("layer1.bias").unwrap().1, vec![2]);
570 assert_eq!(paths.get("layer2.weight").unwrap().1, vec![2, 2]);
571 assert_eq!(paths.get("layer2.bias").unwrap().1, vec![2]);
572 }
573
574 #[test]
575 fn linear_module_paths() {
576 let device = Default::default();
577 let config = LinearConfig::new(10, 20).with_bias(true);
578 let linear = config.init::<TestBackend>(&device);
579
580 let mut collector = TensorPathCollector::new();
581 linear.visit(&mut collector);
582
583 let paths = collector.paths;
584
585 assert!(paths.contains_key("weight"));
587 assert!(paths.contains_key("bias"));
588
589 assert_eq!(paths.get("weight").unwrap().1, vec![10, 20]);
591 assert_eq!(paths.get("bias").unwrap().1, vec![20]);
592 }
593
594 #[derive(Module, Debug)]
596 struct Level4Module<B: Backend> {
597 weight: Param<Tensor<B, 2>>,
598 bias: Param<Tensor<B, 1>>,
599 }
600
601 #[derive(Module, Debug)]
602 struct Level3Module<B: Backend> {
603 layer: Level4Module<B>,
604 extra: Level4Module<B>,
605 }
606
607 #[derive(Module, Debug)]
608 struct Level2Module<B: Backend> {
609 block1: Level3Module<B>,
610 block2: Level3Module<B>,
611 }
612
613 #[derive(Module, Debug)]
614 struct Level1Module<B: Backend> {
615 encoder: Level2Module<B>,
616 decoder: Level2Module<B>,
617 }
618
619 #[derive(Module, Debug)]
620 struct DeepModel<B: Backend> {
621 backbone: Level1Module<B>,
622 head: Level4Module<B>,
623 }
624
625 impl<B: Backend> Level4Module<B> {
626 fn new(device: &B::Device) -> Self {
627 Self {
628 weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
629 bias: Param::from_data([5.0, 6.0], device),
630 }
631 }
632 }
633
634 impl<B: Backend> Level3Module<B> {
635 fn new(device: &B::Device) -> Self {
636 Self {
637 layer: Level4Module::new(device),
638 extra: Level4Module::new(device),
639 }
640 }
641 }
642
643 impl<B: Backend> Level2Module<B> {
644 fn new(device: &B::Device) -> Self {
645 Self {
646 block1: Level3Module::new(device),
647 block2: Level3Module::new(device),
648 }
649 }
650 }
651
652 impl<B: Backend> Level1Module<B> {
653 fn new(device: &B::Device) -> Self {
654 Self {
655 encoder: Level2Module::new(device),
656 decoder: Level2Module::new(device),
657 }
658 }
659 }
660
661 impl<B: Backend> DeepModel<B> {
662 fn new(device: &B::Device) -> Self {
663 Self {
664 backbone: Level1Module::new(device),
665 head: Level4Module::new(device),
666 }
667 }
668 }
669
670 #[test]
671 fn deep_module_path_tracking() {
672 let device = Default::default();
673 let model = DeepModel::<TestBackend>::new(&device);
674
675 let mut collector = Collector::new(None, None);
676 model.visit(&mut collector);
677
678 let views = collector.tensors;
679 let paths: Vec<String> = views.iter().map(|v| v.full_path()).collect();
680
681 assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
683 assert!(paths.contains(&"backbone.encoder.block1.layer.bias".to_string()));
684 assert!(paths.contains(&"backbone.encoder.block1.extra.weight".to_string()));
685 assert!(paths.contains(&"backbone.encoder.block1.extra.bias".to_string()));
686
687 assert!(paths.contains(&"backbone.encoder.block2.layer.weight".to_string()));
688 assert!(paths.contains(&"backbone.encoder.block2.layer.bias".to_string()));
689 assert!(paths.contains(&"backbone.encoder.block2.extra.weight".to_string()));
690 assert!(paths.contains(&"backbone.encoder.block2.extra.bias".to_string()));
691
692 assert!(paths.contains(&"backbone.decoder.block1.layer.weight".to_string()));
693 assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
694 assert!(paths.contains(&"backbone.decoder.block1.extra.weight".to_string()));
695 assert!(paths.contains(&"backbone.decoder.block1.extra.bias".to_string()));
696
697 assert!(paths.contains(&"backbone.decoder.block2.layer.weight".to_string()));
698 assert!(paths.contains(&"backbone.decoder.block2.layer.bias".to_string()));
699 assert!(paths.contains(&"backbone.decoder.block2.extra.weight".to_string()));
700 assert!(paths.contains(&"backbone.decoder.block2.extra.bias".to_string()));
701
702 assert!(paths.contains(&"head.weight".to_string()));
704 assert!(paths.contains(&"head.bias".to_string()));
705
706 assert_eq!(views.len(), 18);
708
709 let view = views
711 .iter()
712 .find(|v| v.full_path() == "backbone.encoder.block1.layer.weight")
713 .unwrap();
714 let data = view.to_data().unwrap();
715 assert_eq!(data.shape, vec![2, 2]);
716 }
717
718 #[test]
719 fn deep_module_filtered_export() {
720 let device = Default::default();
721 let model = DeepModel::<TestBackend>::new(&device);
722
723 #[cfg(target_has_atomic = "ptr")]
725 {
726 let filter = PathFilter::new().with_regex(r"^backbone\.encoder\..*");
727 let mut collector = Collector::new(Some(filter), None);
728 model.visit(&mut collector);
729 assert_eq!(collector.tensors.len(), 8); }
731
732 #[cfg(target_has_atomic = "ptr")]
734 {
735 let filter = PathFilter::new().with_regex(r".*\.block1\..*");
736 let mut collector = Collector::new(Some(filter), None);
737 model.visit(&mut collector);
738 assert_eq!(collector.tensors.len(), 8); }
740
741 #[cfg(target_has_atomic = "ptr")]
743 {
744 let filter = PathFilter::new().with_regex(r".*\.weight$");
745 let mut collector = Collector::new(Some(filter), None);
746 model.visit(&mut collector);
747 assert_eq!(collector.tensors.len(), 9); }
749
750 #[cfg(target_has_atomic = "ptr")]
752 {
753 let filter = PathFilter::new()
754 .with_regex(r"^backbone\.encoder\.block1\..*") .with_regex(r"^backbone\.decoder\..*\.bias$") .with_regex(r"^head\.weight$"); let mut collector = Collector::new(Some(filter), None);
758 model.visit(&mut collector);
759
760 assert_eq!(collector.tensors.len(), 9);
765
766 let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
767 assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
768 assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
769 assert!(paths.contains(&"head.weight".to_string()));
770 assert!(!paths.contains(&"head.bias".to_string())); }
772 }
773
774 use crate::traits::ModuleSnapshot;
775 use burn_nn::Linear;
776 use hashbrown::HashMap;
777
778 #[derive(Module, Debug)]
780 struct OptionalFieldModule<B: Backend> {
781 required: Param<Tensor<B, 2>>,
782 optional: Option<Param<Tensor<B, 1>>>,
783 }
784
785 impl<B: Backend> OptionalFieldModule<B> {
786 fn new_with_optional(device: &B::Device) -> Self {
787 Self {
788 required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
789 optional: Some(Param::from_data([5.0, 6.0], device)),
790 }
791 }
792
793 fn new_without_optional(device: &B::Device) -> Self {
794 Self {
795 required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
796 optional: None,
797 }
798 }
799 }
800
801 #[test]
802 fn optional_field_module_with_value() {
803 let device = Default::default();
804 let module = OptionalFieldModule::<TestBackend>::new_with_optional(&device);
805
806 let views: HashMap<String, TensorSnapshot> = module
807 .collect(None, None)
808 .into_iter()
809 .map(|v| (v.full_path(), v))
810 .collect();
811
812 assert_eq!(views.len(), 2);
813 assert!(views.contains_key("required"));
814 assert!(views.contains_key("optional"));
815 }
816
817 #[test]
818 fn optional_field_module_without_value() {
819 let device = Default::default();
820 let module = OptionalFieldModule::<TestBackend>::new_without_optional(&device);
821
822 let views: HashMap<String, TensorSnapshot> = module
823 .collect(None, None)
824 .into_iter()
825 .map(|v| (v.full_path(), v))
826 .collect();
827
828 assert_eq!(views.len(), 1);
829 assert!(views.contains_key("required"));
830 assert!(!views.contains_key("optional"));
831 }
832
833 #[derive(Module, Debug)]
835 struct VecModule<B: Backend> {
836 layers: Vec<Linear<B>>,
837 }
838
839 impl<B: Backend> VecModule<B> {
840 fn new(device: &B::Device, num_layers: usize) -> Self {
841 Self {
842 layers: (0..num_layers)
843 .map(|_| LinearConfig::new(10, 10).init(device))
844 .collect(),
845 }
846 }
847 }
848
849 #[test]
850 fn vec_module_collect() {
851 let device = Default::default();
852 let module = VecModule::<TestBackend>::new(&device, 3);
853
854 let views: HashMap<String, TensorSnapshot> = module
855 .collect(None, None)
856 .into_iter()
857 .map(|v| (v.full_path(), v))
858 .collect();
859
860 assert_eq!(views.len(), 6); assert!(views.contains_key("layers.0.weight"));
865 assert!(views.contains_key("layers.0.bias"));
866 assert!(views.contains_key("layers.1.weight"));
867 assert!(views.contains_key("layers.1.bias"));
868 assert!(views.contains_key("layers.2.weight"));
869 assert!(views.contains_key("layers.2.bias"));
870 }
871
872 #[derive(Module, Debug)]
874 struct ArrayModule<B: Backend> {
875 layers: [Linear<B>; 3],
876 }
877
878 impl<B: Backend> ArrayModule<B> {
879 fn new(device: &B::Device) -> Self {
880 Self {
881 layers: [
882 LinearConfig::new(10, 10).init(device),
883 LinearConfig::new(10, 10).init(device),
884 LinearConfig::new(10, 10).init(device),
885 ],
886 }
887 }
888 }
889
890 #[test]
891 fn array_module_collect() {
892 let device = Default::default();
893 let module = ArrayModule::<TestBackend>::new(&device);
894
895 let views: HashMap<String, TensorSnapshot> = module
896 .collect(None, None)
897 .into_iter()
898 .map(|v| (v.full_path(), v))
899 .collect();
900
901 assert_eq!(views.len(), 6); for i in 0..3 {
906 assert!(views.contains_key(&format!("layers.{}.weight", i)));
907 assert!(views.contains_key(&format!("layers.{}.bias", i)));
908 }
909 }
910
911 #[derive(Module, Debug)]
913 enum EnumModule<B: Backend> {
914 LayerA(Linear<B>),
915 LayerB(Linear<B>),
916 LayerC(Linear<B>),
917 }
918
919 #[test]
920 fn enum_module_collect() {
921 let device = Default::default();
922
923 let module_a = EnumModule::<TestBackend>::LayerA(LinearConfig::new(10, 20).init(&device));
925 let views_a: HashMap<String, TensorSnapshot> = module_a
926 .collect(None, None)
927 .into_iter()
928 .map(|v| (v.full_path(), v))
929 .collect();
930
931 assert_eq!(views_a.len(), 2);
933 assert!(views_a.contains_key("LayerA.weight"));
934 assert!(views_a.contains_key("LayerA.bias"));
935
936 let module_b = EnumModule::<TestBackend>::LayerB(LinearConfig::new(10, 20).init(&device));
938 let views_b: HashMap<String, TensorSnapshot> = module_b
939 .collect(None, None)
940 .into_iter()
941 .map(|v| (v.full_path(), v))
942 .collect();
943
944 assert_eq!(views_b.len(), 2);
945 assert!(views_b.contains_key("LayerB.weight"));
946 assert!(views_b.contains_key("LayerB.bias"));
947 }
948
949 #[test]
951 fn linear_container_type() {
952 let device = Default::default();
953
954 #[derive(Module, Debug)]
955 struct ModelWithLinear<B: Backend> {
956 linear: Linear<B>,
957 }
958
959 impl<B: Backend> ModelWithLinear<B> {
960 fn new(device: &B::Device) -> Self {
961 Self {
962 linear: LinearConfig::new(10, 20).init(device),
963 }
964 }
965 }
966
967 let model = ModelWithLinear::<TestBackend>::new(&device);
968
969 let views: HashMap<String, TensorSnapshot> = model
970 .collect(None, None)
971 .into_iter()
972 .map(|v| (v.full_path(), v))
973 .collect();
974
975 for (path, view) in views.iter() {
977 if path == "linear.weight" || path == "linear.bias" {
978 assert_eq!(
979 view.container_type(),
980 "Linear",
981 "Tensor '{}' should have container type 'Linear'",
982 path
983 );
984 }
985 }
986 }
987
988 #[test]
989 fn complex_model_container_types() {
990 let device = Default::default();
991
992 #[derive(Module, Debug)]
993 struct ComplexModel<B: Backend> {
994 linear_layers: [Linear<B>; 2],
995 vec_layers: Vec<Linear<B>>,
996 single_linear: Linear<B>,
997 }
998
999 impl<B: Backend> ComplexModel<B> {
1000 fn new(device: &B::Device) -> Self {
1001 Self {
1002 linear_layers: [
1003 LinearConfig::new(100, 50).init(device),
1004 LinearConfig::new(50, 10).init(device),
1005 ],
1006 vec_layers: vec![
1007 LinearConfig::new(10, 10).init(device),
1008 LinearConfig::new(10, 10).init(device),
1009 ],
1010 single_linear: LinearConfig::new(10, 1).init(device),
1011 }
1012 }
1013 }
1014
1015 let model = ComplexModel::<TestBackend>::new(&device);
1016
1017 let views: HashMap<String, TensorSnapshot> = model
1018 .collect(None, None)
1019 .into_iter()
1020 .map(|v| (v.full_path(), v))
1021 .collect();
1022
1023 assert_eq!(views.len(), 10);
1025
1026 for (_path, view) in views.iter() {
1028 assert_eq!(view.container_type(), "Linear");
1029 }
1030 }
1031
1032 #[test]
1033 fn collect_with_container_filter() {
1034 let device = Default::default();
1035
1036 #[derive(Module, Debug)]
1037 struct FilterTestModel<B: Backend> {
1038 layers: Vec<Linear<B>>,
1039 }
1040
1041 impl<B: Backend> FilterTestModel<B> {
1042 fn new(device: &B::Device) -> Self {
1043 Self {
1044 layers: vec![
1045 LinearConfig::new(10, 10).init(device),
1046 LinearConfig::new(10, 10).init(device),
1047 ],
1048 }
1049 }
1050 }
1051
1052 let model = FilterTestModel::<TestBackend>::new(&device);
1053
1054 let filter = PathFilter::new().with_predicate(|_path, container_path| {
1056 container_path.split('.').next_back() == Some("Linear")
1057 });
1058
1059 let linear_views: Vec<TensorSnapshot> = model.collect(Some(filter), None);
1060
1061 for view in linear_views.iter() {
1063 assert_eq!(
1064 view.container_type(),
1065 "Linear",
1066 "All tensors should be from Linear modules"
1067 );
1068 }
1069
1070 assert_eq!(linear_views.len(), 4);
1072 }
1073}