burn_store/
collector.rs

1use alloc::boxed::Box;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4
5use burn_tensor::{Bool, Int, Tensor, backend::Backend};
6
7use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
8use burn_core::module::{ModuleVisitor, Param, ParamId};
9
10/// Collects tensor views from modules without copying data.
11///
12/// This collector traverses a module hierarchy and creates lightweight views
13/// of tensors that can be materialized to `TensorData` on demand.
14///
15/// # Examples
16///
17/// ## Collect all tensors
18/// ```rust,no_run
19/// # use burn_store::Collector;
20/// let collector = Collector::new(None, None);
21/// // Use with module.visit(&mut collector);
22/// let all_tensors = collector.tensors;
23/// ```
24///
25/// ## Filter with single pattern
26/// ```rust,no_run
27/// # use burn_store::{Collector, PathFilter};
28/// let filter = PathFilter::new().with_regex(r"^encoder\..*");
29/// let collector = Collector::new(Some(filter), None);
30/// // Use with module.visit(&mut collector);
31/// // Only collects tensors starting with "encoder."
32/// ```
33///
34/// ## Filter with multiple patterns (OR union)
35/// ```rust,no_run
36/// # use burn_store::{Collector, PathFilter};
37/// let filter = PathFilter::new()
38///     .with_regex(r"^encoder\..*")  // Match all encoder tensors
39///     .with_regex(r".*\.bias$");    // OR match any bias tensors
40/// let collector = Collector::new(Some(filter), None);
41/// // Use with module.visit(&mut collector);
42/// // Collects tensors matching ANY of the patterns
43/// ```
44pub struct Collector {
45    /// Collection of tensor views
46    pub tensors: Vec<TensorSnapshot>,
47    path_stack: Vec<String>,
48    container_stack: Vec<String>,
49    filter: Option<PathFilter>,
50    adapter: Option<Box<dyn ModuleAdapter>>,
51}
52
53impl Default for Collector {
54    fn default() -> Self {
55        Self::new(None, None)
56    }
57}
58
59impl Collector {
60    /// Create a new tensor view collector with an optional filter and adapter.
61    ///
62    /// # Arguments
63    ///
64    /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.
65    ///   When `None`, all tensors are collected.
66    /// * `adapter` - Optional adapter to transform tensors based on container types.
67    ///   Applied to all collected tensors before returning.
68    ///
69    /// # Examples
70    ///
71    /// ```rust,no_run
72    /// # use burn_store::{Collector, PathFilter};
73    /// // Collect all tensors without adapter
74    /// let collector = Collector::new(None, None);
75    ///
76    /// // Use PathFilter builder
77    /// let filter = PathFilter::new()
78    ///     .with_regex(r"^encoder\..*")
79    ///     .with_full_path("decoder.weight");
80    /// let collector = Collector::new(Some(filter), None);
81    /// ```
82    pub fn new(filter: Option<PathFilter>, adapter: Option<Box<dyn ModuleAdapter>>) -> Self {
83        Self {
84            tensors: Vec::new(),
85            path_stack: Vec::new(),
86            container_stack: Vec::new(),
87            filter,
88            adapter,
89        }
90    }
91
92    /// Apply the adapter to collected tensors and return the result.
93    pub fn into_tensors(self) -> Vec<TensorSnapshot> {
94        if let Some(adapter) = self.adapter {
95            self.tensors
96                .into_iter()
97                .map(|snapshot| adapter.adapt(&snapshot))
98                .collect()
99        } else {
100            self.tensors
101        }
102    }
103
104    fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool {
105        // If filter is present, use it; otherwise collect all
106        match &self.filter {
107            None => true,
108            Some(f) => f.matches_with_container_path(path, container_stack),
109        }
110    }
111}
112
113impl<B: Backend> ModuleVisitor<B> for Collector {
114    fn enter_module(&mut self, name: &str, container_type: &str) {
115        self.path_stack.push(name.to_string());
116        self.container_stack.push(container_type.to_string());
117    }
118
119    fn exit_module(&mut self, _name: &str, _container_type: &str) {
120        self.path_stack.pop();
121        self.container_stack.pop();
122    }
123
124    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
125        if self.should_collect(&self.path_stack, &self.container_stack) {
126            self.tensors.push(TensorSnapshot::from_float(
127                &param.transform_for_save().val(),
128                self.path_stack.clone(),
129                self.container_stack.clone(),
130                param.id,
131            ));
132        }
133    }
134
135    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
136        if self.should_collect(&self.path_stack, &self.container_stack) {
137            self.tensors.push(TensorSnapshot::from_int(
138                &param.transform_for_save().val(),
139                self.path_stack.clone(),
140                self.container_stack.clone(),
141                param.id,
142            ));
143        }
144    }
145
146    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
147        if self.should_collect(&self.path_stack, &self.container_stack) {
148            self.tensors.push(TensorSnapshot::from_bool(
149                &param.transform_for_save().val(),
150                self.path_stack.clone(),
151                self.container_stack.clone(),
152                param.id,
153            ));
154        }
155    }
156
157    fn visit_float_with_path<const D: usize>(
158        &mut self,
159        path: &[String],
160        id: ParamId,
161        tensor: &Tensor<B, D>,
162    ) {
163        // For path-based visits, we use the current container stack for filtering
164        if self.should_collect(path, &self.container_stack) {
165            self.tensors.push(TensorSnapshot::from_float(
166                tensor,
167                path.to_vec(),
168                self.container_stack.clone(),
169                id,
170            ));
171        }
172    }
173
174    fn visit_int_with_path<const D: usize>(
175        &mut self,
176        path: &[String],
177        id: ParamId,
178        tensor: &Tensor<B, D, Int>,
179    ) {
180        if self.should_collect(path, &self.container_stack) {
181            self.tensors.push(TensorSnapshot::from_int(
182                tensor,
183                path.to_vec(),
184                self.container_stack.clone(),
185                id,
186            ));
187        }
188    }
189
190    fn visit_bool_with_path<const D: usize>(
191        &mut self,
192        path: &[String],
193        id: ParamId,
194        tensor: &Tensor<B, D, Bool>,
195    ) {
196        if self.should_collect(path, &self.container_stack) {
197            self.tensors.push(TensorSnapshot::from_bool(
198                tensor,
199                path.to_vec(),
200                self.container_stack.clone(),
201                id,
202            ));
203        }
204    }
205}
206
207#[cfg(all(test, feature = "std"))]
208mod tests {
209    use super::*;
210
211    use burn_core as burn;
212
213    type TestBackend = burn_ndarray::NdArray;
214    use alloc::collections::BTreeMap;
215    use alloc::string::String;
216    use burn_core::module::{Module, Param};
217    use burn_nn::LinearConfig;
218
219    #[test]
220    fn tensor_snapshot_collector() {
221        let device = Default::default();
222        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
223
224        let mut collector = Collector::new(None, None);
225        let id = ParamId::new();
226
227        // Collect a tensor
228        collector.visit_float_with_path(&["model".to_string(), "weight".to_string()], id, &tensor);
229
230        assert_eq!(collector.tensors.len(), 1);
231        assert_eq!(collector.tensors[0].full_path(), "model.weight");
232
233        // Verify the tensor can be converted to data
234        let view = &collector.tensors[0];
235        let data = view.to_data().unwrap();
236        assert_eq!(data.shape, vec![2, 2]);
237    }
238
239    #[test]
240    fn root_level_parameters() {
241        use burn_core::module::ModuleVisitor;
242
243        let device = Default::default();
244
245        // Create root-level parameters (single-element path, not nested in modules)
246        let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
247        let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
248
249        let mut collector = Collector::new(None, None);
250
251        // Simulate module traversal for root-level parameters
252        // Enter "weight" path (as if we're visiting a field named "weight")
253        ModuleVisitor::<TestBackend>::enter_module(&mut collector, "weight", "");
254        ModuleVisitor::<TestBackend>::visit_float(&mut collector, &weight);
255        ModuleVisitor::<TestBackend>::exit_module(&mut collector, "weight", "");
256
257        // Enter "bias" path (as if we're visiting a field named "bias")
258        ModuleVisitor::<TestBackend>::enter_module(&mut collector, "bias", "");
259        ModuleVisitor::<TestBackend>::visit_float(&mut collector, &bias);
260        ModuleVisitor::<TestBackend>::exit_module(&mut collector, "bias", "");
261
262        // Verify both parameters were collected
263        assert_eq!(collector.tensors.len(), 2);
264
265        // Verify paths are correct (single-element paths)
266        assert_eq!(collector.tensors[0].full_path(), "weight");
267        assert_eq!(collector.tensors[1].full_path(), "bias");
268
269        // Verify data is correct
270        let weight_data = collector.tensors[0]
271            .to_data()
272            .unwrap()
273            .to_vec::<f32>()
274            .unwrap();
275        let bias_data = collector.tensors[1]
276            .to_data()
277            .unwrap()
278            .to_vec::<f32>()
279            .unwrap();
280
281        assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
282        assert_eq!(bias_data, vec![5.0, 6.0]);
283    }
284
285    #[test]
286    #[cfg(target_has_atomic = "ptr")]
287    fn tensor_snapshot_collector_with_filter() {
288        let device = Default::default();
289        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
290
291        let filter = PathFilter::new().with_regex(r"^encoder\..*");
292        let mut collector = Collector::new(Some(filter), None);
293        let id = ParamId::new();
294
295        // This should be collected
296        collector.visit_float_with_path(
297            &["encoder".to_string(), "weight".to_string()],
298            id,
299            &tensor,
300        );
301        // This should NOT be collected
302        collector.visit_float_with_path(
303            &["decoder".to_string(), "weight".to_string()],
304            id,
305            &tensor,
306        );
307
308        assert_eq!(collector.tensors.len(), 1);
309        assert_eq!(collector.tensors[0].full_path(), "encoder.weight");
310    }
311
312    #[test]
313    #[cfg(target_has_atomic = "ptr")]
314    fn tensor_snapshot_collector_with_multiple_filters() {
315        let device = Default::default();
316        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
317
318        // Multiple patterns - collect if matches ANY (OR union)
319        let filter = PathFilter::new()
320            .with_regex(r"^encoder\..*") // Match encoder.*
321            .with_regex(r".*\.bias$"); // Match *.bias
322        let mut collector = Collector::new(Some(filter), None);
323        let id = ParamId::new();
324
325        // These should be collected
326        collector.visit_float_with_path(
327            &["encoder".to_string(), "weight".to_string()],
328            id,
329            &tensor,
330        ); // matches first pattern
331        collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); // matches second pattern
332        collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor); // matches both patterns
333
334        // This should NOT be collected
335        collector.visit_float_with_path(
336            &["decoder".to_string(), "weight".to_string()],
337            id,
338            &tensor,
339        ); // matches neither
340
341        assert_eq!(collector.tensors.len(), 3);
342        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
343        assert!(paths.contains(&"encoder.weight".to_string()));
344        assert!(paths.contains(&"decoder.bias".to_string()));
345        assert!(paths.contains(&"encoder.bias".to_string()));
346        assert!(!paths.contains(&"decoder.weight".to_string()));
347    }
348
349    #[test]
350    fn tensor_snapshot_collector_with_predicate() {
351        let device = Default::default();
352        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
353
354        // Use predicate function for filtering
355        fn filter_fn(path: &str, _container_path: &str) -> bool {
356            path.starts_with("encoder.") || path == "decoder.bias"
357        }
358        let filter = PathFilter::new().with_predicate(filter_fn);
359        let mut collector = Collector::new(Some(filter), None);
360        let id = ParamId::new();
361
362        // These should be collected
363        collector.visit_float_with_path(
364            &["encoder".to_string(), "weight".to_string()],
365            id,
366            &tensor,
367        );
368        collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor);
369        collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor);
370
371        // This should NOT be collected
372        collector.visit_float_with_path(
373            &["decoder".to_string(), "weight".to_string()],
374            id,
375            &tensor,
376        );
377        collector.visit_float_with_path(&["other".to_string(), "tensor".to_string()], id, &tensor);
378
379        assert_eq!(collector.tensors.len(), 3);
380        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
381        assert!(paths.contains(&"encoder.weight".to_string()));
382        assert!(paths.contains(&"encoder.bias".to_string()));
383        assert!(paths.contains(&"decoder.bias".to_string()));
384        assert!(!paths.contains(&"decoder.weight".to_string()));
385        assert!(!paths.contains(&"other.tensor".to_string()));
386    }
387
388    #[test]
389    fn tensor_snapshot_collector_predicate_with_complex_logic() {
390        let device = Default::default();
391        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
392
393        // Complex predicate with multiple conditions
394        fn complex_filter(path: &str, _container_path: &str) -> bool {
395            let parts: Vec<&str> = path.split('.').collect();
396            if parts.len() != 3 {
397                return false;
398            }
399            // Only collect if it's layer1 or layer2, and it's a weight tensor
400            (parts[1] == "layer1" || parts[1] == "layer2") && parts[2] == "weight"
401        }
402        let filter = PathFilter::new().with_predicate(complex_filter);
403        let mut collector = Collector::new(Some(filter), None);
404        let id = ParamId::new();
405
406        // These should be collected
407        collector.visit_float_with_path(
408            &[
409                "model".to_string(),
410                "layer1".to_string(),
411                "weight".to_string(),
412            ],
413            id,
414            &tensor,
415        );
416        collector.visit_float_with_path(
417            &[
418                "model".to_string(),
419                "layer2".to_string(),
420                "weight".to_string(),
421            ],
422            id,
423            &tensor,
424        );
425
426        // These should NOT be collected
427        collector.visit_float_with_path(
428            &[
429                "model".to_string(),
430                "layer1".to_string(),
431                "bias".to_string(),
432            ],
433            id,
434            &tensor,
435        );
436        collector.visit_float_with_path(
437            &[
438                "model".to_string(),
439                "layer3".to_string(),
440                "weight".to_string(),
441            ],
442            id,
443            &tensor,
444        );
445        collector.visit_float_with_path(
446            &["encoder".to_string(), "weight".to_string()],
447            id,
448            &tensor,
449        ); // wrong structure
450
451        assert_eq!(collector.tensors.len(), 2);
452        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
453        assert!(paths.contains(&"model.layer1.weight".to_string()));
454        assert!(paths.contains(&"model.layer2.weight".to_string()));
455        assert!(!paths.contains(&"model.layer1.bias".to_string()));
456        assert!(!paths.contains(&"model.layer3.weight".to_string()));
457        assert!(!paths.contains(&"encoder.weight".to_string()));
458    }
459
460    // Test visitor that collects tensor paths
461    struct TensorPathCollector {
462        pub paths: BTreeMap<String, (ParamId, Vec<usize>)>,
463        path_stack: Vec<String>,
464    }
465
466    impl TensorPathCollector {
467        fn new() -> Self {
468            Self {
469                paths: BTreeMap::new(),
470                path_stack: Vec::new(),
471            }
472        }
473
474        fn current_path(&self) -> String {
475            self.path_stack.join(".")
476        }
477    }
478
479    impl<B: Backend> ModuleVisitor<B> for TensorPathCollector {
480        fn enter_module(&mut self, name: &str, _container_type: &str) {
481            self.path_stack.push(name.to_string());
482        }
483
484        fn exit_module(&mut self, _name: &str, _container_type: &str) {
485            self.path_stack.pop();
486        }
487
488        fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
489            let path = self.current_path();
490            if !path.is_empty() {
491                self.paths.insert(
492                    path,
493                    (param.id, param.transform_for_save().val().shape().to_vec()),
494                );
495            }
496        }
497
498        fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
499            let path = self.current_path();
500            if !path.is_empty() {
501                self.paths.insert(
502                    path,
503                    (param.id, param.transform_for_save().val().shape().to_vec()),
504                );
505            }
506        }
507
508        fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
509            let path = self.current_path();
510            if !path.is_empty() {
511                self.paths.insert(
512                    path,
513                    (param.id, param.transform_for_save().val().shape().to_vec()),
514                );
515            }
516        }
517    }
518
519    // Simple nested module for testing
520    #[derive(Module, Debug)]
521    struct InnerModule<B: Backend> {
522        weight: Param<Tensor<B, 2>>,
523        bias: Param<Tensor<B, 1>>,
524    }
525
526    #[derive(Module, Debug)]
527    struct OuterModule<B: Backend> {
528        layer1: InnerModule<B>,
529        layer2: InnerModule<B>,
530    }
531
532    impl<B: Backend> InnerModule<B> {
533        fn new(device: &B::Device) -> Self {
534            Self {
535                weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
536                bias: Param::from_data([5.0, 6.0], device),
537            }
538        }
539    }
540
541    impl<B: Backend> OuterModule<B> {
542        fn new(device: &B::Device) -> Self {
543            Self {
544                layer1: InnerModule::new(device),
545                layer2: InnerModule::new(device),
546            }
547        }
548    }
549
550    #[test]
551    fn nested_module_path_tracking() {
552        let device = Default::default();
553        let module = OuterModule::<TestBackend>::new(&device);
554
555        let mut collector = TensorPathCollector::new();
556        module.visit(&mut collector);
557
558        let paths = collector.paths;
559
560        // Verify we have the expected paths
561        // Note: Param<Tensor> fields are themselves modules, so we get an extra level
562        assert!(paths.contains_key("layer1.weight"), "Missing layer1.weight");
563        assert!(paths.contains_key("layer1.bias"), "Missing layer1.bias");
564        assert!(paths.contains_key("layer2.weight"), "Missing layer2.weight");
565        assert!(paths.contains_key("layer2.bias"), "Missing layer2.bias");
566
567        // Verify the shapes are correct
568        assert_eq!(paths.get("layer1.weight").unwrap().1, vec![2, 2]);
569        assert_eq!(paths.get("layer1.bias").unwrap().1, vec![2]);
570        assert_eq!(paths.get("layer2.weight").unwrap().1, vec![2, 2]);
571        assert_eq!(paths.get("layer2.bias").unwrap().1, vec![2]);
572    }
573
574    #[test]
575    fn linear_module_paths() {
576        let device = Default::default();
577        let config = LinearConfig::new(10, 20).with_bias(true);
578        let linear = config.init::<TestBackend>(&device);
579
580        let mut collector = TensorPathCollector::new();
581        linear.visit(&mut collector);
582
583        let paths = collector.paths;
584
585        // Linear module has weight and optional bias
586        assert!(paths.contains_key("weight"));
587        assert!(paths.contains_key("bias"));
588
589        // Check dimensions
590        assert_eq!(paths.get("weight").unwrap().1, vec![10, 20]);
591        assert_eq!(paths.get("bias").unwrap().1, vec![20]);
592    }
593
594    // Deep nesting test structures (4+ levels)
595    #[derive(Module, Debug)]
596    struct Level4Module<B: Backend> {
597        weight: Param<Tensor<B, 2>>,
598        bias: Param<Tensor<B, 1>>,
599    }
600
601    #[derive(Module, Debug)]
602    struct Level3Module<B: Backend> {
603        layer: Level4Module<B>,
604        extra: Level4Module<B>,
605    }
606
607    #[derive(Module, Debug)]
608    struct Level2Module<B: Backend> {
609        block1: Level3Module<B>,
610        block2: Level3Module<B>,
611    }
612
613    #[derive(Module, Debug)]
614    struct Level1Module<B: Backend> {
615        encoder: Level2Module<B>,
616        decoder: Level2Module<B>,
617    }
618
619    #[derive(Module, Debug)]
620    struct DeepModel<B: Backend> {
621        backbone: Level1Module<B>,
622        head: Level4Module<B>,
623    }
624
625    impl<B: Backend> Level4Module<B> {
626        fn new(device: &B::Device) -> Self {
627            Self {
628                weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
629                bias: Param::from_data([5.0, 6.0], device),
630            }
631        }
632    }
633
634    impl<B: Backend> Level3Module<B> {
635        fn new(device: &B::Device) -> Self {
636            Self {
637                layer: Level4Module::new(device),
638                extra: Level4Module::new(device),
639            }
640        }
641    }
642
643    impl<B: Backend> Level2Module<B> {
644        fn new(device: &B::Device) -> Self {
645            Self {
646                block1: Level3Module::new(device),
647                block2: Level3Module::new(device),
648            }
649        }
650    }
651
652    impl<B: Backend> Level1Module<B> {
653        fn new(device: &B::Device) -> Self {
654            Self {
655                encoder: Level2Module::new(device),
656                decoder: Level2Module::new(device),
657            }
658        }
659    }
660
661    impl<B: Backend> DeepModel<B> {
662        fn new(device: &B::Device) -> Self {
663            Self {
664                backbone: Level1Module::new(device),
665                head: Level4Module::new(device),
666            }
667        }
668    }
669
670    #[test]
671    fn deep_module_path_tracking() {
672        let device = Default::default();
673        let model = DeepModel::<TestBackend>::new(&device);
674
675        let mut collector = Collector::new(None, None);
676        model.visit(&mut collector);
677
678        let views = collector.tensors;
679        let paths: Vec<String> = views.iter().map(|v| v.full_path()).collect();
680
681        // Test 5-level deep paths
682        assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
683        assert!(paths.contains(&"backbone.encoder.block1.layer.bias".to_string()));
684        assert!(paths.contains(&"backbone.encoder.block1.extra.weight".to_string()));
685        assert!(paths.contains(&"backbone.encoder.block1.extra.bias".to_string()));
686
687        assert!(paths.contains(&"backbone.encoder.block2.layer.weight".to_string()));
688        assert!(paths.contains(&"backbone.encoder.block2.layer.bias".to_string()));
689        assert!(paths.contains(&"backbone.encoder.block2.extra.weight".to_string()));
690        assert!(paths.contains(&"backbone.encoder.block2.extra.bias".to_string()));
691
692        assert!(paths.contains(&"backbone.decoder.block1.layer.weight".to_string()));
693        assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
694        assert!(paths.contains(&"backbone.decoder.block1.extra.weight".to_string()));
695        assert!(paths.contains(&"backbone.decoder.block1.extra.bias".to_string()));
696
697        assert!(paths.contains(&"backbone.decoder.block2.layer.weight".to_string()));
698        assert!(paths.contains(&"backbone.decoder.block2.layer.bias".to_string()));
699        assert!(paths.contains(&"backbone.decoder.block2.extra.weight".to_string()));
700        assert!(paths.contains(&"backbone.decoder.block2.extra.bias".to_string()));
701
702        // Test 2-level paths
703        assert!(paths.contains(&"head.weight".to_string()));
704        assert!(paths.contains(&"head.bias".to_string()));
705
706        // Total should be 18 tensors (16 from backbone + 2 from head)
707        assert_eq!(views.len(), 18);
708
709        // Verify data can be materialized
710        let view = views
711            .iter()
712            .find(|v| v.full_path() == "backbone.encoder.block1.layer.weight")
713            .unwrap();
714        let data = view.to_data().unwrap();
715        assert_eq!(data.shape, vec![2, 2]);
716    }
717
718    #[test]
719    fn deep_module_filtered_export() {
720        let device = Default::default();
721        let model = DeepModel::<TestBackend>::new(&device);
722
723        // Test filtering at different depths
724        #[cfg(target_has_atomic = "ptr")]
725        {
726            let filter = PathFilter::new().with_regex(r"^backbone\.encoder\..*");
727            let mut collector = Collector::new(Some(filter), None);
728            model.visit(&mut collector);
729            assert_eq!(collector.tensors.len(), 8); // Only encoder tensors
730        }
731
732        // Test filtering specific blocks
733        #[cfg(target_has_atomic = "ptr")]
734        {
735            let filter = PathFilter::new().with_regex(r".*\.block1\..*");
736            let mut collector = Collector::new(Some(filter), None);
737            model.visit(&mut collector);
738            assert_eq!(collector.tensors.len(), 8); // block1 in both encoder and decoder
739        }
740
741        // Test filtering by tensor type at any depth
742        #[cfg(target_has_atomic = "ptr")]
743        {
744            let filter = PathFilter::new().with_regex(r".*\.weight$");
745            let mut collector = Collector::new(Some(filter), None);
746            model.visit(&mut collector);
747            assert_eq!(collector.tensors.len(), 9); // All weight tensors
748        }
749
750        // Test complex multi-pattern filtering
751        #[cfg(target_has_atomic = "ptr")]
752        {
753            let filter = PathFilter::new()
754                .with_regex(r"^backbone\.encoder\.block1\..*") // All encoder.block1 tensors
755                .with_regex(r"^backbone\.decoder\..*\.bias$") // All decoder biases
756                .with_regex(r"^head\.weight$"); // Head weight only
757            let mut collector = Collector::new(Some(filter), None);
758            model.visit(&mut collector);
759
760            // Should have:
761            // - 4 from encoder.block1 (2 weights + 2 biases)
762            // - 4 decoder biases
763            // - 1 head weight
764            assert_eq!(collector.tensors.len(), 9);
765
766            let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
767            assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
768            assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
769            assert!(paths.contains(&"head.weight".to_string()));
770            assert!(!paths.contains(&"head.bias".to_string())); // Not included
771        }
772    }
773
774    use crate::traits::ModuleSnapshot;
775    use burn_nn::Linear;
776    use hashbrown::HashMap;
777
778    // Test module with Option fields
779    #[derive(Module, Debug)]
780    struct OptionalFieldModule<B: Backend> {
781        required: Param<Tensor<B, 2>>,
782        optional: Option<Param<Tensor<B, 1>>>,
783    }
784
785    impl<B: Backend> OptionalFieldModule<B> {
786        fn new_with_optional(device: &B::Device) -> Self {
787            Self {
788                required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
789                optional: Some(Param::from_data([5.0, 6.0], device)),
790            }
791        }
792
793        fn new_without_optional(device: &B::Device) -> Self {
794            Self {
795                required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
796                optional: None,
797            }
798        }
799    }
800
801    #[test]
802    fn optional_field_module_with_value() {
803        let device = Default::default();
804        let module = OptionalFieldModule::<TestBackend>::new_with_optional(&device);
805
806        let views: HashMap<String, TensorSnapshot> = module
807            .collect(None, None)
808            .into_iter()
809            .map(|v| (v.full_path(), v))
810            .collect();
811
812        assert_eq!(views.len(), 2);
813        assert!(views.contains_key("required"));
814        assert!(views.contains_key("optional"));
815    }
816
817    #[test]
818    fn optional_field_module_without_value() {
819        let device = Default::default();
820        let module = OptionalFieldModule::<TestBackend>::new_without_optional(&device);
821
822        let views: HashMap<String, TensorSnapshot> = module
823            .collect(None, None)
824            .into_iter()
825            .map(|v| (v.full_path(), v))
826            .collect();
827
828        assert_eq!(views.len(), 1);
829        assert!(views.contains_key("required"));
830        assert!(!views.contains_key("optional"));
831    }
832
833    // Test Vec of modules
834    #[derive(Module, Debug)]
835    struct VecModule<B: Backend> {
836        layers: Vec<Linear<B>>,
837    }
838
839    impl<B: Backend> VecModule<B> {
840        fn new(device: &B::Device, num_layers: usize) -> Self {
841            Self {
842                layers: (0..num_layers)
843                    .map(|_| LinearConfig::new(10, 10).init(device))
844                    .collect(),
845            }
846        }
847    }
848
849    #[test]
850    fn vec_module_collect() {
851        let device = Default::default();
852        let module = VecModule::<TestBackend>::new(&device, 3);
853
854        let views: HashMap<String, TensorSnapshot> = module
855            .collect(None, None)
856            .into_iter()
857            .map(|v| (v.full_path(), v))
858            .collect();
859
860        // With the fix, all Vec items should now be properly indexed and visited
861        assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors
862
863        // Check that all indexed paths exist
864        assert!(views.contains_key("layers.0.weight"));
865        assert!(views.contains_key("layers.0.bias"));
866        assert!(views.contains_key("layers.1.weight"));
867        assert!(views.contains_key("layers.1.bias"));
868        assert!(views.contains_key("layers.2.weight"));
869        assert!(views.contains_key("layers.2.bias"));
870    }
871
872    // Test array of modules
873    #[derive(Module, Debug)]
874    struct ArrayModule<B: Backend> {
875        layers: [Linear<B>; 3],
876    }
877
878    impl<B: Backend> ArrayModule<B> {
879        fn new(device: &B::Device) -> Self {
880            Self {
881                layers: [
882                    LinearConfig::new(10, 10).init(device),
883                    LinearConfig::new(10, 10).init(device),
884                    LinearConfig::new(10, 10).init(device),
885                ],
886            }
887        }
888    }
889
890    #[test]
891    fn array_module_collect() {
892        let device = Default::default();
893        let module = ArrayModule::<TestBackend>::new(&device);
894
895        let views: HashMap<String, TensorSnapshot> = module
896            .collect(None, None)
897            .into_iter()
898            .map(|v| (v.full_path(), v))
899            .collect();
900
901        // All array items should be properly indexed
902        assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors
903
904        // Check indexed paths
905        for i in 0..3 {
906            assert!(views.contains_key(&format!("layers.{}.weight", i)));
907            assert!(views.contains_key(&format!("layers.{}.bias", i)));
908        }
909    }
910
911    // Test enum modules
912    #[derive(Module, Debug)]
913    enum EnumModule<B: Backend> {
914        LayerA(Linear<B>),
915        LayerB(Linear<B>),
916        LayerC(Linear<B>),
917    }
918
919    #[test]
920    fn enum_module_collect() {
921        let device = Default::default();
922
923        // Test variant A
924        let module_a = EnumModule::<TestBackend>::LayerA(LinearConfig::new(10, 20).init(&device));
925        let views_a: HashMap<String, TensorSnapshot> = module_a
926            .collect(None, None)
927            .into_iter()
928            .map(|v| (v.full_path(), v))
929            .collect();
930
931        // Should have the variant name in the path
932        assert_eq!(views_a.len(), 2);
933        assert!(views_a.contains_key("LayerA.weight"));
934        assert!(views_a.contains_key("LayerA.bias"));
935
936        // Test variant B
937        let module_b = EnumModule::<TestBackend>::LayerB(LinearConfig::new(10, 20).init(&device));
938        let views_b: HashMap<String, TensorSnapshot> = module_b
939            .collect(None, None)
940            .into_iter()
941            .map(|v| (v.full_path(), v))
942            .collect();
943
944        assert_eq!(views_b.len(), 2);
945        assert!(views_b.contains_key("LayerB.weight"));
946        assert!(views_b.contains_key("LayerB.bias"));
947    }
948
949    // Container type tracking tests
950    #[test]
951    fn linear_container_type() {
952        let device = Default::default();
953
954        #[derive(Module, Debug)]
955        struct ModelWithLinear<B: Backend> {
956            linear: Linear<B>,
957        }
958
959        impl<B: Backend> ModelWithLinear<B> {
960            fn new(device: &B::Device) -> Self {
961                Self {
962                    linear: LinearConfig::new(10, 20).init(device),
963                }
964            }
965        }
966
967        let model = ModelWithLinear::<TestBackend>::new(&device);
968
969        let views: HashMap<String, TensorSnapshot> = model
970            .collect(None, None)
971            .into_iter()
972            .map(|v| (v.full_path(), v))
973            .collect();
974
975        // Check that tensors inside Linear layers have "Linear" as their container type
976        for (path, view) in views.iter() {
977            if path == "linear.weight" || path == "linear.bias" {
978                assert_eq!(
979                    view.container_type(),
980                    "Linear",
981                    "Tensor '{}' should have container type 'Linear'",
982                    path
983                );
984            }
985        }
986    }
987
988    #[test]
989    fn complex_model_container_types() {
990        let device = Default::default();
991
992        #[derive(Module, Debug)]
993        struct ComplexModel<B: Backend> {
994            linear_layers: [Linear<B>; 2],
995            vec_layers: Vec<Linear<B>>,
996            single_linear: Linear<B>,
997        }
998
999        impl<B: Backend> ComplexModel<B> {
1000            fn new(device: &B::Device) -> Self {
1001                Self {
1002                    linear_layers: [
1003                        LinearConfig::new(100, 50).init(device),
1004                        LinearConfig::new(50, 10).init(device),
1005                    ],
1006                    vec_layers: vec![
1007                        LinearConfig::new(10, 10).init(device),
1008                        LinearConfig::new(10, 10).init(device),
1009                    ],
1010                    single_linear: LinearConfig::new(10, 1).init(device),
1011                }
1012            }
1013        }
1014
1015        let model = ComplexModel::<TestBackend>::new(&device);
1016
1017        let views: HashMap<String, TensorSnapshot> = model
1018            .collect(None, None)
1019            .into_iter()
1020            .map(|v| (v.full_path(), v))
1021            .collect();
1022
1023        // Should have 10 tensors total
1024        assert_eq!(views.len(), 10);
1025
1026        // Verify different container types
1027        for (_path, view) in views.iter() {
1028            assert_eq!(view.container_type(), "Linear");
1029        }
1030    }
1031
1032    #[test]
1033    fn collect_with_container_filter() {
1034        let device = Default::default();
1035
1036        #[derive(Module, Debug)]
1037        struct FilterTestModel<B: Backend> {
1038            layers: Vec<Linear<B>>,
1039        }
1040
1041        impl<B: Backend> FilterTestModel<B> {
1042            fn new(device: &B::Device) -> Self {
1043                Self {
1044                    layers: vec![
1045                        LinearConfig::new(10, 10).init(device),
1046                        LinearConfig::new(10, 10).init(device),
1047                    ],
1048                }
1049            }
1050        }
1051
1052        let model = FilterTestModel::<TestBackend>::new(&device);
1053
1054        // Filter to only collect tensors from Linear modules
1055        let filter = PathFilter::new().with_predicate(|_path, container_path| {
1056            container_path.split('.').next_back() == Some("Linear")
1057        });
1058
1059        let linear_views: Vec<TensorSnapshot> = model.collect(Some(filter), None);
1060
1061        // All collected tensors should be from Linear modules
1062        for view in linear_views.iter() {
1063            assert_eq!(
1064                view.container_type(),
1065                "Linear",
1066                "All tensors should be from Linear modules"
1067            );
1068        }
1069
1070        // Should have collected all Linear tensors
1071        assert_eq!(linear_views.len(), 4);
1072    }
1073}