burn_store/
collector.rs

1use alloc::boxed::Box;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4
5use burn_tensor::{Bool, Int, Tensor, backend::Backend};
6
7use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
8use burn_core::module::{ModuleVisitor, Param, ParamId};
9
10/// Collects tensor views from modules without copying data.
11///
12/// This collector traverses a module hierarchy and creates lightweight views
13/// of tensors that can be materialized to `TensorData` on demand.
14///
15/// # Examples
16///
17/// ## Collect all tensors
18/// ```rust,no_run
19/// # use burn_store::Collector;
20/// let collector = Collector::new(None, None, false);
21/// // Use with module.visit(&mut collector);
22/// let all_tensors = collector.tensors;
23/// ```
24///
25/// ## Filter with single pattern
26/// ```rust,no_run
27/// # use burn_store::{Collector, PathFilter};
28/// let filter = PathFilter::new().with_regex(r"^encoder\..*");
29/// let collector = Collector::new(Some(filter), None, false);
30/// // Use with module.visit(&mut collector);
31/// // Only collects tensors starting with "encoder."
32/// ```
33///
34/// ## Filter with multiple patterns (OR union)
35/// ```rust,no_run
36/// # use burn_store::{Collector, PathFilter};
37/// let filter = PathFilter::new()
38///     .with_regex(r"^encoder\..*")  // Match all encoder tensors
39///     .with_regex(r".*\.bias$");    // OR match any bias tensors
40/// let collector = Collector::new(Some(filter), None, false);
41/// // Use with module.visit(&mut collector);
42/// // Collects tensors matching ANY of the patterns
43/// ```
44pub struct Collector {
45    /// Collection of tensor views
46    pub tensors: Vec<TensorSnapshot>,
47    path_stack: Vec<String>,
48    container_stack: Vec<String>,
49    filter: Option<PathFilter>,
50    adapter: Option<Box<dyn ModuleAdapter>>,
51    /// Skip enum variant names when building paths
52    /// When true, enum variant names are not included in tensor paths
53    skip_enum_variants: bool,
54}
55
56impl Default for Collector {
57    fn default() -> Self {
58        Self::new(None, None, false)
59    }
60}
61
62impl Collector {
63    /// Create a new tensor view collector with an optional filter and adapter.
64    ///
65    /// # Arguments
66    ///
67    /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.
68    ///   When `None`, all tensors are collected.
69    /// * `adapter` - Optional adapter to transform tensors based on container types.
70    ///   Applied to all collected tensors before returning.
71    /// * `skip_enum_variants` - Skip enum variant names when building paths.
72    ///   When true, paths will not include enum variant names (e.g., "feature.weight"
73    ///   instead of "feature.BaseConv.weight"). Useful when exporting to formats
74    ///   like PyTorch that don't use enum variants.
75    ///
76    /// # Examples
77    ///
78    /// ```rust,no_run
79    /// # use burn_store::{Collector, PathFilter};
80    /// // Collect all tensors without adapter
81    /// let collector = Collector::new(None, None, false);
82    ///
83    /// // Use PathFilter builder
84    /// let filter = PathFilter::new()
85    ///     .with_regex(r"^encoder\..*")
86    ///     .with_full_path("decoder.weight");
87    /// let collector = Collector::new(Some(filter), None, false);
88    ///
89    /// // Skip enum variants for PyTorch export
90    /// let collector = Collector::new(None, None, true);
91    /// ```
92    pub fn new(
93        filter: Option<PathFilter>,
94        adapter: Option<Box<dyn ModuleAdapter>>,
95        skip_enum_variants: bool,
96    ) -> Self {
97        Self {
98            tensors: Vec::new(),
99            path_stack: Vec::new(),
100            container_stack: Vec::new(),
101            filter,
102            adapter,
103            skip_enum_variants,
104        }
105    }
106
107    /// Apply the adapter to collected tensors and return the result.
108    pub fn into_tensors(self) -> Vec<TensorSnapshot> {
109        if let Some(adapter) = self.adapter {
110            self.tensors
111                .into_iter()
112                .map(|snapshot| adapter.adapt(&snapshot))
113                .collect()
114        } else {
115            self.tensors
116        }
117    }
118
119    fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool {
120        // If filter is present, use it; otherwise collect all
121        match &self.filter {
122            None => true,
123            Some(f) => f.matches_with_container_path(path, container_stack),
124        }
125    }
126}
127
128impl<B: Backend> ModuleVisitor<B> for Collector {
129    fn enter_module(&mut self, name: &str, container_type: &str) {
130        // Always track the container type for proper filtering and module type detection
131        self.container_stack.push(container_type.to_string());
132
133        // Only add to path if it's not an enum variant (when skip_enum_variants is enabled)
134        // This ensures paths are built without enum variant names from the start
135        if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
136            self.path_stack.push(name.to_string());
137        }
138    }
139
140    fn exit_module(&mut self, _name: &str, container_type: &str) {
141        self.container_stack.pop();
142
143        // Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled)
144        if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
145            self.path_stack.pop();
146        }
147    }
148
149    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
150        if self.should_collect(&self.path_stack, &self.container_stack) {
151            self.tensors.push(TensorSnapshot::from_float(
152                &param.transform_for_save().val(),
153                self.path_stack.clone(),
154                self.container_stack.clone(),
155                param.id,
156            ));
157        }
158    }
159
160    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
161        if self.should_collect(&self.path_stack, &self.container_stack) {
162            self.tensors.push(TensorSnapshot::from_int(
163                &param.transform_for_save().val(),
164                self.path_stack.clone(),
165                self.container_stack.clone(),
166                param.id,
167            ));
168        }
169    }
170
171    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
172        if self.should_collect(&self.path_stack, &self.container_stack) {
173            self.tensors.push(TensorSnapshot::from_bool(
174                &param.transform_for_save().val(),
175                self.path_stack.clone(),
176                self.container_stack.clone(),
177                param.id,
178            ));
179        }
180    }
181
182    fn visit_float_with_path<const D: usize>(
183        &mut self,
184        path: &[String],
185        id: ParamId,
186        tensor: &Tensor<B, D>,
187    ) {
188        // For path-based visits, we use the current container stack for filtering
189        if self.should_collect(path, &self.container_stack) {
190            self.tensors.push(TensorSnapshot::from_float(
191                tensor,
192                path.to_vec(),
193                self.container_stack.clone(),
194                id,
195            ));
196        }
197    }
198
199    fn visit_int_with_path<const D: usize>(
200        &mut self,
201        path: &[String],
202        id: ParamId,
203        tensor: &Tensor<B, D, Int>,
204    ) {
205        if self.should_collect(path, &self.container_stack) {
206            self.tensors.push(TensorSnapshot::from_int(
207                tensor,
208                path.to_vec(),
209                self.container_stack.clone(),
210                id,
211            ));
212        }
213    }
214
215    fn visit_bool_with_path<const D: usize>(
216        &mut self,
217        path: &[String],
218        id: ParamId,
219        tensor: &Tensor<B, D, Bool>,
220    ) {
221        if self.should_collect(path, &self.container_stack) {
222            self.tensors.push(TensorSnapshot::from_bool(
223                tensor,
224                path.to_vec(),
225                self.container_stack.clone(),
226                id,
227            ));
228        }
229    }
230}
231
232#[cfg(all(test, feature = "std"))]
233mod tests {
234    use super::*;
235
236    use burn_core as burn;
237
238    type TestBackend = burn_ndarray::NdArray;
239    use alloc::collections::BTreeMap;
240    use alloc::string::String;
241    use burn_core::module::{Module, Param};
242    use burn_nn::LinearConfig;
243
244    #[test]
245    fn tensor_snapshot_collector() {
246        let device = Default::default();
247        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
248
249        let mut collector = Collector::new(None, None, false);
250        let id = ParamId::new();
251
252        // Collect a tensor
253        collector.visit_float_with_path(&["model".to_string(), "weight".to_string()], id, &tensor);
254
255        assert_eq!(collector.tensors.len(), 1);
256        assert_eq!(collector.tensors[0].full_path(), "model.weight");
257
258        // Verify the tensor can be converted to data
259        let view = &collector.tensors[0];
260        let data = view.to_data().unwrap();
261        assert_eq!(data.shape, vec![2, 2]);
262    }
263
264    #[test]
265    fn root_level_parameters() {
266        use burn_core::module::ModuleVisitor;
267
268        let device = Default::default();
269
270        // Create root-level parameters (single-element path, not nested in modules)
271        let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
272        let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
273
274        let mut collector = Collector::new(None, None, false);
275
276        // Simulate module traversal for root-level parameters
277        // Enter "weight" path (as if we're visiting a field named "weight")
278        ModuleVisitor::<TestBackend>::enter_module(&mut collector, "weight", "");
279        ModuleVisitor::<TestBackend>::visit_float(&mut collector, &weight);
280        ModuleVisitor::<TestBackend>::exit_module(&mut collector, "weight", "");
281
282        // Enter "bias" path (as if we're visiting a field named "bias")
283        ModuleVisitor::<TestBackend>::enter_module(&mut collector, "bias", "");
284        ModuleVisitor::<TestBackend>::visit_float(&mut collector, &bias);
285        ModuleVisitor::<TestBackend>::exit_module(&mut collector, "bias", "");
286
287        // Verify both parameters were collected
288        assert_eq!(collector.tensors.len(), 2);
289
290        // Verify paths are correct (single-element paths)
291        assert_eq!(collector.tensors[0].full_path(), "weight");
292        assert_eq!(collector.tensors[1].full_path(), "bias");
293
294        // Verify data is correct
295        let weight_data = collector.tensors[0]
296            .to_data()
297            .unwrap()
298            .to_vec::<f32>()
299            .unwrap();
300        let bias_data = collector.tensors[1]
301            .to_data()
302            .unwrap()
303            .to_vec::<f32>()
304            .unwrap();
305
306        assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
307        assert_eq!(bias_data, vec![5.0, 6.0]);
308    }
309
310    #[test]
311    #[cfg(target_has_atomic = "ptr")]
312    fn tensor_snapshot_collector_with_filter() {
313        let device = Default::default();
314        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
315
316        let filter = PathFilter::new().with_regex(r"^encoder\..*");
317        let mut collector = Collector::new(Some(filter), None, false);
318        let id = ParamId::new();
319
320        // This should be collected
321        collector.visit_float_with_path(
322            &["encoder".to_string(), "weight".to_string()],
323            id,
324            &tensor,
325        );
326        // This should NOT be collected
327        collector.visit_float_with_path(
328            &["decoder".to_string(), "weight".to_string()],
329            id,
330            &tensor,
331        );
332
333        assert_eq!(collector.tensors.len(), 1);
334        assert_eq!(collector.tensors[0].full_path(), "encoder.weight");
335    }
336
337    #[test]
338    #[cfg(target_has_atomic = "ptr")]
339    fn tensor_snapshot_collector_with_multiple_filters() {
340        let device = Default::default();
341        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
342
343        // Multiple patterns - collect if matches ANY (OR union)
344        let filter = PathFilter::new()
345            .with_regex(r"^encoder\..*") // Match encoder.*
346            .with_regex(r".*\.bias$"); // Match *.bias
347        let mut collector = Collector::new(Some(filter), None, false);
348        let id = ParamId::new();
349
350        // These should be collected
351        collector.visit_float_with_path(
352            &["encoder".to_string(), "weight".to_string()],
353            id,
354            &tensor,
355        ); // matches first pattern
356        collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); // matches second pattern
357        collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor); // matches both patterns
358
359        // This should NOT be collected
360        collector.visit_float_with_path(
361            &["decoder".to_string(), "weight".to_string()],
362            id,
363            &tensor,
364        ); // matches neither
365
366        assert_eq!(collector.tensors.len(), 3);
367        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
368        assert!(paths.contains(&"encoder.weight".to_string()));
369        assert!(paths.contains(&"decoder.bias".to_string()));
370        assert!(paths.contains(&"encoder.bias".to_string()));
371        assert!(!paths.contains(&"decoder.weight".to_string()));
372    }
373
374    #[test]
375    fn tensor_snapshot_collector_with_predicate() {
376        let device = Default::default();
377        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
378
379        // Use predicate function for filtering
380        fn filter_fn(path: &str, _container_path: &str) -> bool {
381            path.starts_with("encoder.") || path == "decoder.bias"
382        }
383        let filter = PathFilter::new().with_predicate(filter_fn);
384        let mut collector = Collector::new(Some(filter), None, false);
385        let id = ParamId::new();
386
387        // These should be collected
388        collector.visit_float_with_path(
389            &["encoder".to_string(), "weight".to_string()],
390            id,
391            &tensor,
392        );
393        collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor);
394        collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor);
395
396        // This should NOT be collected
397        collector.visit_float_with_path(
398            &["decoder".to_string(), "weight".to_string()],
399            id,
400            &tensor,
401        );
402        collector.visit_float_with_path(&["other".to_string(), "tensor".to_string()], id, &tensor);
403
404        assert_eq!(collector.tensors.len(), 3);
405        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
406        assert!(paths.contains(&"encoder.weight".to_string()));
407        assert!(paths.contains(&"encoder.bias".to_string()));
408        assert!(paths.contains(&"decoder.bias".to_string()));
409        assert!(!paths.contains(&"decoder.weight".to_string()));
410        assert!(!paths.contains(&"other.tensor".to_string()));
411    }
412
413    #[test]
414    fn tensor_snapshot_collector_predicate_with_complex_logic() {
415        let device = Default::default();
416        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
417
418        // Complex predicate with multiple conditions
419        fn complex_filter(path: &str, _container_path: &str) -> bool {
420            let parts: Vec<&str> = path.split('.').collect();
421            if parts.len() != 3 {
422                return false;
423            }
424            // Only collect if it's layer1 or layer2, and it's a weight tensor
425            (parts[1] == "layer1" || parts[1] == "layer2") && parts[2] == "weight"
426        }
427        let filter = PathFilter::new().with_predicate(complex_filter);
428        let mut collector = Collector::new(Some(filter), None, false);
429        let id = ParamId::new();
430
431        // These should be collected
432        collector.visit_float_with_path(
433            &[
434                "model".to_string(),
435                "layer1".to_string(),
436                "weight".to_string(),
437            ],
438            id,
439            &tensor,
440        );
441        collector.visit_float_with_path(
442            &[
443                "model".to_string(),
444                "layer2".to_string(),
445                "weight".to_string(),
446            ],
447            id,
448            &tensor,
449        );
450
451        // These should NOT be collected
452        collector.visit_float_with_path(
453            &[
454                "model".to_string(),
455                "layer1".to_string(),
456                "bias".to_string(),
457            ],
458            id,
459            &tensor,
460        );
461        collector.visit_float_with_path(
462            &[
463                "model".to_string(),
464                "layer3".to_string(),
465                "weight".to_string(),
466            ],
467            id,
468            &tensor,
469        );
470        collector.visit_float_with_path(
471            &["encoder".to_string(), "weight".to_string()],
472            id,
473            &tensor,
474        ); // wrong structure
475
476        assert_eq!(collector.tensors.len(), 2);
477        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
478        assert!(paths.contains(&"model.layer1.weight".to_string()));
479        assert!(paths.contains(&"model.layer2.weight".to_string()));
480        assert!(!paths.contains(&"model.layer1.bias".to_string()));
481        assert!(!paths.contains(&"model.layer3.weight".to_string()));
482        assert!(!paths.contains(&"encoder.weight".to_string()));
483    }
484
485    // Test visitor that collects tensor paths
486    struct TensorPathCollector {
487        pub paths: BTreeMap<String, (ParamId, Vec<usize>)>,
488        path_stack: Vec<String>,
489    }
490
491    impl TensorPathCollector {
492        fn new() -> Self {
493            Self {
494                paths: BTreeMap::new(),
495                path_stack: Vec::new(),
496            }
497        }
498
499        fn current_path(&self) -> String {
500            self.path_stack.join(".")
501        }
502    }
503
504    impl<B: Backend> ModuleVisitor<B> for TensorPathCollector {
505        fn enter_module(&mut self, name: &str, _container_type: &str) {
506            self.path_stack.push(name.to_string());
507        }
508
509        fn exit_module(&mut self, _name: &str, _container_type: &str) {
510            self.path_stack.pop();
511        }
512
513        fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
514            let path = self.current_path();
515            if !path.is_empty() {
516                self.paths.insert(
517                    path,
518                    (param.id, param.transform_for_save().val().shape().to_vec()),
519                );
520            }
521        }
522
523        fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
524            let path = self.current_path();
525            if !path.is_empty() {
526                self.paths.insert(
527                    path,
528                    (param.id, param.transform_for_save().val().shape().to_vec()),
529                );
530            }
531        }
532
533        fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
534            let path = self.current_path();
535            if !path.is_empty() {
536                self.paths.insert(
537                    path,
538                    (param.id, param.transform_for_save().val().shape().to_vec()),
539                );
540            }
541        }
542    }
543
544    // Simple nested module for testing
545    #[derive(Module, Debug)]
546    struct InnerModule<B: Backend> {
547        weight: Param<Tensor<B, 2>>,
548        bias: Param<Tensor<B, 1>>,
549    }
550
551    #[derive(Module, Debug)]
552    struct OuterModule<B: Backend> {
553        layer1: InnerModule<B>,
554        layer2: InnerModule<B>,
555    }
556
557    impl<B: Backend> InnerModule<B> {
558        fn new(device: &B::Device) -> Self {
559            Self {
560                weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
561                bias: Param::from_data([5.0, 6.0], device),
562            }
563        }
564    }
565
566    impl<B: Backend> OuterModule<B> {
567        fn new(device: &B::Device) -> Self {
568            Self {
569                layer1: InnerModule::new(device),
570                layer2: InnerModule::new(device),
571            }
572        }
573    }
574
575    #[test]
576    fn nested_module_path_tracking() {
577        let device = Default::default();
578        let module = OuterModule::<TestBackend>::new(&device);
579
580        let mut collector = TensorPathCollector::new();
581        module.visit(&mut collector);
582
583        let paths = collector.paths;
584
585        // Verify we have the expected paths
586        // Note: Param<Tensor> fields are themselves modules, so we get an extra level
587        assert!(paths.contains_key("layer1.weight"), "Missing layer1.weight");
588        assert!(paths.contains_key("layer1.bias"), "Missing layer1.bias");
589        assert!(paths.contains_key("layer2.weight"), "Missing layer2.weight");
590        assert!(paths.contains_key("layer2.bias"), "Missing layer2.bias");
591
592        // Verify the shapes are correct
593        assert_eq!(paths.get("layer1.weight").unwrap().1, vec![2, 2]);
594        assert_eq!(paths.get("layer1.bias").unwrap().1, vec![2]);
595        assert_eq!(paths.get("layer2.weight").unwrap().1, vec![2, 2]);
596        assert_eq!(paths.get("layer2.bias").unwrap().1, vec![2]);
597    }
598
599    #[test]
600    fn linear_module_paths() {
601        let device = Default::default();
602        let config = LinearConfig::new(10, 20).with_bias(true);
603        let linear = config.init::<TestBackend>(&device);
604
605        let mut collector = TensorPathCollector::new();
606        linear.visit(&mut collector);
607
608        let paths = collector.paths;
609
610        // Linear module has weight and optional bias
611        assert!(paths.contains_key("weight"));
612        assert!(paths.contains_key("bias"));
613
614        // Check dimensions
615        assert_eq!(paths.get("weight").unwrap().1, vec![10, 20]);
616        assert_eq!(paths.get("bias").unwrap().1, vec![20]);
617    }
618
619    // Deep nesting test structures (4+ levels)
620    #[derive(Module, Debug)]
621    struct Level4Module<B: Backend> {
622        weight: Param<Tensor<B, 2>>,
623        bias: Param<Tensor<B, 1>>,
624    }
625
626    #[derive(Module, Debug)]
627    struct Level3Module<B: Backend> {
628        layer: Level4Module<B>,
629        extra: Level4Module<B>,
630    }
631
632    #[derive(Module, Debug)]
633    struct Level2Module<B: Backend> {
634        block1: Level3Module<B>,
635        block2: Level3Module<B>,
636    }
637
638    #[derive(Module, Debug)]
639    struct Level1Module<B: Backend> {
640        encoder: Level2Module<B>,
641        decoder: Level2Module<B>,
642    }
643
644    #[derive(Module, Debug)]
645    struct DeepModel<B: Backend> {
646        backbone: Level1Module<B>,
647        head: Level4Module<B>,
648    }
649
650    impl<B: Backend> Level4Module<B> {
651        fn new(device: &B::Device) -> Self {
652            Self {
653                weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
654                bias: Param::from_data([5.0, 6.0], device),
655            }
656        }
657    }
658
659    impl<B: Backend> Level3Module<B> {
660        fn new(device: &B::Device) -> Self {
661            Self {
662                layer: Level4Module::new(device),
663                extra: Level4Module::new(device),
664            }
665        }
666    }
667
668    impl<B: Backend> Level2Module<B> {
669        fn new(device: &B::Device) -> Self {
670            Self {
671                block1: Level3Module::new(device),
672                block2: Level3Module::new(device),
673            }
674        }
675    }
676
677    impl<B: Backend> Level1Module<B> {
678        fn new(device: &B::Device) -> Self {
679            Self {
680                encoder: Level2Module::new(device),
681                decoder: Level2Module::new(device),
682            }
683        }
684    }
685
686    impl<B: Backend> DeepModel<B> {
687        fn new(device: &B::Device) -> Self {
688            Self {
689                backbone: Level1Module::new(device),
690                head: Level4Module::new(device),
691            }
692        }
693    }
694
695    #[test]
696    fn deep_module_path_tracking() {
697        let device = Default::default();
698        let model = DeepModel::<TestBackend>::new(&device);
699
700        let mut collector = Collector::new(None, None, false);
701        model.visit(&mut collector);
702
703        let views = collector.tensors;
704        let paths: Vec<String> = views.iter().map(|v| v.full_path()).collect();
705
706        // Test 5-level deep paths
707        assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
708        assert!(paths.contains(&"backbone.encoder.block1.layer.bias".to_string()));
709        assert!(paths.contains(&"backbone.encoder.block1.extra.weight".to_string()));
710        assert!(paths.contains(&"backbone.encoder.block1.extra.bias".to_string()));
711
712        assert!(paths.contains(&"backbone.encoder.block2.layer.weight".to_string()));
713        assert!(paths.contains(&"backbone.encoder.block2.layer.bias".to_string()));
714        assert!(paths.contains(&"backbone.encoder.block2.extra.weight".to_string()));
715        assert!(paths.contains(&"backbone.encoder.block2.extra.bias".to_string()));
716
717        assert!(paths.contains(&"backbone.decoder.block1.layer.weight".to_string()));
718        assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
719        assert!(paths.contains(&"backbone.decoder.block1.extra.weight".to_string()));
720        assert!(paths.contains(&"backbone.decoder.block1.extra.bias".to_string()));
721
722        assert!(paths.contains(&"backbone.decoder.block2.layer.weight".to_string()));
723        assert!(paths.contains(&"backbone.decoder.block2.layer.bias".to_string()));
724        assert!(paths.contains(&"backbone.decoder.block2.extra.weight".to_string()));
725        assert!(paths.contains(&"backbone.decoder.block2.extra.bias".to_string()));
726
727        // Test 2-level paths
728        assert!(paths.contains(&"head.weight".to_string()));
729        assert!(paths.contains(&"head.bias".to_string()));
730
731        // Total should be 18 tensors (16 from backbone + 2 from head)
732        assert_eq!(views.len(), 18);
733
734        // Verify data can be materialized
735        let view = views
736            .iter()
737            .find(|v| v.full_path() == "backbone.encoder.block1.layer.weight")
738            .unwrap();
739        let data = view.to_data().unwrap();
740        assert_eq!(data.shape, vec![2, 2]);
741    }
742
743    #[test]
744    fn deep_module_filtered_export() {
745        let device = Default::default();
746        let model = DeepModel::<TestBackend>::new(&device);
747
748        // Test filtering at different depths
749        #[cfg(target_has_atomic = "ptr")]
750        {
751            let filter = PathFilter::new().with_regex(r"^backbone\.encoder\..*");
752            let mut collector = Collector::new(Some(filter), None, false);
753            model.visit(&mut collector);
754            assert_eq!(collector.tensors.len(), 8); // Only encoder tensors
755        }
756
757        // Test filtering specific blocks
758        #[cfg(target_has_atomic = "ptr")]
759        {
760            let filter = PathFilter::new().with_regex(r".*\.block1\..*");
761            let mut collector = Collector::new(Some(filter), None, false);
762            model.visit(&mut collector);
763            assert_eq!(collector.tensors.len(), 8); // block1 in both encoder and decoder
764        }
765
766        // Test filtering by tensor type at any depth
767        #[cfg(target_has_atomic = "ptr")]
768        {
769            let filter = PathFilter::new().with_regex(r".*\.weight$");
770            let mut collector = Collector::new(Some(filter), None, false);
771            model.visit(&mut collector);
772            assert_eq!(collector.tensors.len(), 9); // All weight tensors
773        }
774
775        // Test complex multi-pattern filtering
776        #[cfg(target_has_atomic = "ptr")]
777        {
778            let filter = PathFilter::new()
779                .with_regex(r"^backbone\.encoder\.block1\..*") // All encoder.block1 tensors
780                .with_regex(r"^backbone\.decoder\..*\.bias$") // All decoder biases
781                .with_regex(r"^head\.weight$"); // Head weight only
782            let mut collector = Collector::new(Some(filter), None, false);
783            model.visit(&mut collector);
784
785            // Should have:
786            // - 4 from encoder.block1 (2 weights + 2 biases)
787            // - 4 decoder biases
788            // - 1 head weight
789            assert_eq!(collector.tensors.len(), 9);
790
791            let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
792            assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
793            assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
794            assert!(paths.contains(&"head.weight".to_string()));
795            assert!(!paths.contains(&"head.bias".to_string())); // Not included
796        }
797    }
798
799    use crate::traits::ModuleSnapshot;
800    use burn_nn::Linear;
801    use hashbrown::HashMap;
802
803    // Test module with Option fields
804    #[derive(Module, Debug)]
805    struct OptionalFieldModule<B: Backend> {
806        required: Param<Tensor<B, 2>>,
807        optional: Option<Param<Tensor<B, 1>>>,
808    }
809
810    impl<B: Backend> OptionalFieldModule<B> {
811        fn new_with_optional(device: &B::Device) -> Self {
812            Self {
813                required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
814                optional: Some(Param::from_data([5.0, 6.0], device)),
815            }
816        }
817
818        fn new_without_optional(device: &B::Device) -> Self {
819            Self {
820                required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
821                optional: None,
822            }
823        }
824    }
825
826    #[test]
827    fn optional_field_module_with_value() {
828        let device = Default::default();
829        let module = OptionalFieldModule::<TestBackend>::new_with_optional(&device);
830
831        let views: HashMap<String, TensorSnapshot> = module
832            .collect(None, None, false)
833            .into_iter()
834            .map(|v| (v.full_path(), v))
835            .collect();
836
837        assert_eq!(views.len(), 2);
838        assert!(views.contains_key("required"));
839        assert!(views.contains_key("optional"));
840    }
841
842    #[test]
843    fn optional_field_module_without_value() {
844        let device = Default::default();
845        let module = OptionalFieldModule::<TestBackend>::new_without_optional(&device);
846
847        let views: HashMap<String, TensorSnapshot> = module
848            .collect(None, None, false)
849            .into_iter()
850            .map(|v| (v.full_path(), v))
851            .collect();
852
853        assert_eq!(views.len(), 1);
854        assert!(views.contains_key("required"));
855        assert!(!views.contains_key("optional"));
856    }
857
858    // Test Vec of modules
859    #[derive(Module, Debug)]
860    struct VecModule<B: Backend> {
861        layers: Vec<Linear<B>>,
862    }
863
864    impl<B: Backend> VecModule<B> {
865        fn new(device: &B::Device, num_layers: usize) -> Self {
866            Self {
867                layers: (0..num_layers)
868                    .map(|_| LinearConfig::new(10, 10).init(device))
869                    .collect(),
870            }
871        }
872    }
873
874    // Test tuple of modules
875    #[derive(Module, Debug)]
876    struct TupleModule<B: Backend> {
877        layers: (Linear<B>, Linear<B>, Linear<B>),
878    }
879
880    impl<B: Backend> TupleModule<B> {
881        fn new(device: &B::Device) -> Self {
882            Self {
883                layers: (
884                    LinearConfig::new(10, 10).init(device),
885                    LinearConfig::new(10, 10).init(device),
886                    LinearConfig::new(10, 10).init(device),
887                ),
888            }
889        }
890    }
891
892    #[test]
893    fn vec_module_collect() {
894        let device = Default::default();
895        let module = VecModule::<TestBackend>::new(&device, 3);
896
897        let views: HashMap<String, TensorSnapshot> = module
898            .collect(None, None, false)
899            .into_iter()
900            .map(|v| (v.full_path(), v))
901            .collect();
902
903        // With the fix, all Vec items should now be properly indexed and visited
904        assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors
905
906        // Check that all indexed paths exist
907        assert!(views.contains_key("layers.0.weight"));
908        assert!(views.contains_key("layers.0.bias"));
909        assert!(views.contains_key("layers.1.weight"));
910        assert!(views.contains_key("layers.1.bias"));
911        assert!(views.contains_key("layers.2.weight"));
912        assert!(views.contains_key("layers.2.bias"));
913    }
914
915    #[test]
916    fn tuple_module_collect() {
917        let device = Default::default();
918        let module = TupleModule::<TestBackend>::new(&device);
919
920        let snapshots = module.collect(None, None, false);
921        assert_eq!(snapshots.len(), 6);
922
923        let views: HashMap<String, TensorSnapshot> =
924            snapshots.into_iter().map(|v| (v.full_path(), v)).collect();
925
926        assert_eq!(views.len(), 6);
927
928        assert!(views.contains_key("layers.0.weight"));
929        assert!(views.contains_key("layers.0.bias"));
930        assert!(views.contains_key("layers.1.weight"));
931        assert!(views.contains_key("layers.1.bias"));
932        assert!(views.contains_key("layers.2.weight"));
933        assert!(views.contains_key("layers.2.bias"));
934    }
935
936    // Test array of modules
937    #[derive(Module, Debug)]
938    struct ArrayModule<B: Backend> {
939        layers: [Linear<B>; 3],
940    }
941
942    impl<B: Backend> ArrayModule<B> {
943        fn new(device: &B::Device) -> Self {
944            Self {
945                layers: [
946                    LinearConfig::new(10, 10).init(device),
947                    LinearConfig::new(10, 10).init(device),
948                    LinearConfig::new(10, 10).init(device),
949                ],
950            }
951        }
952    }
953
954    #[test]
955    fn array_module_collect() {
956        let device = Default::default();
957        let module = ArrayModule::<TestBackend>::new(&device);
958
959        let views: HashMap<String, TensorSnapshot> = module
960            .collect(None, None, false)
961            .into_iter()
962            .map(|v| (v.full_path(), v))
963            .collect();
964
965        // All array items should be properly indexed
966        assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors
967
968        // Check indexed paths
969        for i in 0..3 {
970            assert!(views.contains_key(&format!("layers.{}.weight", i)));
971            assert!(views.contains_key(&format!("layers.{}.bias", i)));
972        }
973    }
974
975    // Test enum modules
976    #[derive(Module, Debug)]
977    enum EnumModule<B: Backend> {
978        LayerA(Linear<B>),
979        LayerB(Linear<B>),
980        LayerC(Linear<B>),
981    }
982
983    #[test]
984    fn enum_module_collect() {
985        let device = Default::default();
986
987        // Test variant A
988        let module_a = EnumModule::<TestBackend>::LayerA(LinearConfig::new(10, 20).init(&device));
989        let views_a: HashMap<String, TensorSnapshot> = module_a
990            .collect(None, None, false)
991            .into_iter()
992            .map(|v| (v.full_path(), v))
993            .collect();
994
995        // Should have the variant name in the path
996        assert_eq!(views_a.len(), 2);
997        assert!(views_a.contains_key("LayerA.weight"));
998        assert!(views_a.contains_key("LayerA.bias"));
999
1000        // Test variant B
1001        let module_b = EnumModule::<TestBackend>::LayerB(LinearConfig::new(10, 20).init(&device));
1002        let views_b: HashMap<String, TensorSnapshot> = module_b
1003            .collect(None, None, false)
1004            .into_iter()
1005            .map(|v| (v.full_path(), v))
1006            .collect();
1007
1008        assert_eq!(views_b.len(), 2);
1009        assert!(views_b.contains_key("LayerB.weight"));
1010        assert!(views_b.contains_key("LayerB.bias"));
1011    }
1012
1013    // Container type tracking tests
1014    #[test]
1015    fn linear_container_type() {
1016        let device = Default::default();
1017
1018        #[derive(Module, Debug)]
1019        struct ModelWithLinear<B: Backend> {
1020            linear: Linear<B>,
1021        }
1022
1023        impl<B: Backend> ModelWithLinear<B> {
1024            fn new(device: &B::Device) -> Self {
1025                Self {
1026                    linear: LinearConfig::new(10, 20).init(device),
1027                }
1028            }
1029        }
1030
1031        let model = ModelWithLinear::<TestBackend>::new(&device);
1032
1033        let views: HashMap<String, TensorSnapshot> = model
1034            .collect(None, None, false)
1035            .into_iter()
1036            .map(|v| (v.full_path(), v))
1037            .collect();
1038
1039        // Check that tensors inside Linear layers have "Struct:Linear" as their module type
1040        for (path, view) in views.iter() {
1041            if path == "linear.weight" || path == "linear.bias" {
1042                assert_eq!(
1043                    view.module_type(),
1044                    Some("Struct:Linear".to_string()),
1045                    "Tensor '{}' should have module type 'Struct:Linear'",
1046                    path
1047                );
1048            }
1049        }
1050    }
1051
1052    #[test]
1053    fn complex_model_container_types() {
1054        let device = Default::default();
1055
1056        #[derive(Module, Debug)]
1057        struct ComplexModel<B: Backend> {
1058            linear_layers: [Linear<B>; 2],
1059            vec_layers: Vec<Linear<B>>,
1060            single_linear: Linear<B>,
1061        }
1062
1063        impl<B: Backend> ComplexModel<B> {
1064            fn new(device: &B::Device) -> Self {
1065                Self {
1066                    linear_layers: [
1067                        LinearConfig::new(100, 50).init(device),
1068                        LinearConfig::new(50, 10).init(device),
1069                    ],
1070                    vec_layers: vec![
1071                        LinearConfig::new(10, 10).init(device),
1072                        LinearConfig::new(10, 10).init(device),
1073                    ],
1074                    single_linear: LinearConfig::new(10, 1).init(device),
1075                }
1076            }
1077        }
1078
1079        let model = ComplexModel::<TestBackend>::new(&device);
1080
1081        let views: HashMap<String, TensorSnapshot> = model
1082            .collect(None, None, false)
1083            .into_iter()
1084            .map(|v| (v.full_path(), v))
1085            .collect();
1086
1087        // Should have 10 tensors total
1088        assert_eq!(views.len(), 10);
1089
1090        // Verify different module types
1091        for (_path, view) in views.iter() {
1092            assert_eq!(view.module_type(), Some("Struct:Linear".to_string()));
1093        }
1094    }
1095
1096    #[test]
1097    fn collect_with_container_filter() {
1098        let device = Default::default();
1099
1100        #[derive(Module, Debug)]
1101        struct FilterTestModel<B: Backend> {
1102            layers: Vec<Linear<B>>,
1103        }
1104
1105        impl<B: Backend> FilterTestModel<B> {
1106            fn new(device: &B::Device) -> Self {
1107                Self {
1108                    layers: vec![
1109                        LinearConfig::new(10, 10).init(device),
1110                        LinearConfig::new(10, 10).init(device),
1111                    ],
1112                }
1113            }
1114        }
1115
1116        let model = FilterTestModel::<TestBackend>::new(&device);
1117
1118        // Filter to only collect tensors from Linear modules
1119        let filter = PathFilter::new().with_predicate(|_path, container_path| {
1120            container_path.split('.').next_back() == Some("Struct:Linear")
1121        });
1122
1123        let linear_views: Vec<TensorSnapshot> = model.collect(Some(filter), None, false);
1124
1125        // All collected tensors should be from Linear modules
1126        for view in linear_views.iter() {
1127            assert_eq!(
1128                view.module_type(),
1129                Some("Struct:Linear".to_string()),
1130                "All tensors should be from Linear modules"
1131            );
1132        }
1133
1134        // Should have collected all Linear tensors
1135        assert_eq!(linear_views.len(), 4);
1136    }
1137}