use alloc::boxed::Box;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use burn_tensor::{Bool, Int, Tensor, backend::Backend};
use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
use burn_core::module::{ModuleVisitor, Param, ParamId};
pub struct Collector {
pub tensors: Vec<TensorSnapshot>,
path_stack: Vec<String>,
container_stack: Vec<String>,
filter: Option<PathFilter>,
adapter: Option<Box<dyn ModuleAdapter>>,
skip_enum_variants: bool,
}
impl Default for Collector {
fn default() -> Self {
Self::new(None, None, false)
}
}
impl Collector {
pub fn new(
filter: Option<PathFilter>,
adapter: Option<Box<dyn ModuleAdapter>>,
skip_enum_variants: bool,
) -> Self {
Self {
tensors: Vec::new(),
path_stack: Vec::new(),
container_stack: Vec::new(),
filter,
adapter,
skip_enum_variants,
}
}
pub fn into_tensors(self) -> Vec<TensorSnapshot> {
if let Some(adapter) = self.adapter {
self.tensors
.into_iter()
.map(|snapshot| adapter.adapt(&snapshot))
.collect()
} else {
self.tensors
}
}
fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool {
match &self.filter {
None => true,
Some(f) => f.matches_with_container_path(path, container_stack),
}
}
}
impl<B: Backend> ModuleVisitor<B> for Collector {
fn enter_module(&mut self, name: &str, container_type: &str) {
self.container_stack.push(container_type.to_string());
if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
self.path_stack.push(name.to_string());
}
}
fn exit_module(&mut self, _name: &str, container_type: &str) {
self.container_stack.pop();
if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
self.path_stack.pop();
}
}
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
if self.should_collect(&self.path_stack, &self.container_stack) {
self.tensors.push(TensorSnapshot::from_float(
¶m.transform_for_save().val(),
self.path_stack.clone(),
self.container_stack.clone(),
param.id,
));
}
}
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
if self.should_collect(&self.path_stack, &self.container_stack) {
self.tensors.push(TensorSnapshot::from_int(
¶m.transform_for_save().val(),
self.path_stack.clone(),
self.container_stack.clone(),
param.id,
));
}
}
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
if self.should_collect(&self.path_stack, &self.container_stack) {
self.tensors.push(TensorSnapshot::from_bool(
¶m.transform_for_save().val(),
self.path_stack.clone(),
self.container_stack.clone(),
param.id,
));
}
}
fn visit_float_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D>,
) {
if self.should_collect(path, &self.container_stack) {
self.tensors.push(TensorSnapshot::from_float(
tensor,
path.to_vec(),
self.container_stack.clone(),
id,
));
}
}
fn visit_int_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D, Int>,
) {
if self.should_collect(path, &self.container_stack) {
self.tensors.push(TensorSnapshot::from_int(
tensor,
path.to_vec(),
self.container_stack.clone(),
id,
));
}
}
fn visit_bool_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D, Bool>,
) {
if self.should_collect(path, &self.container_stack) {
self.tensors.push(TensorSnapshot::from_bool(
tensor,
path.to_vec(),
self.container_stack.clone(),
id,
));
}
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use burn_core as burn;
type TestBackend = burn_flex::Flex;
use alloc::collections::BTreeMap;
use alloc::string::String;
use burn_core::module::{Module, Param};
use burn_nn::LinearConfig;
use burn_tensor::shape;
#[test]
fn tensor_snapshot_collector() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let mut collector = Collector::new(None, None, false);
let id = ParamId::new();
collector.visit_float_with_path(&["model".to_string(), "weight".to_string()], id, &tensor);
assert_eq!(collector.tensors.len(), 1);
assert_eq!(collector.tensors[0].full_path(), "model.weight");
let view = &collector.tensors[0];
let data = view.to_data().unwrap();
assert_eq!(data.shape, shape![2, 2]);
}
#[test]
fn root_level_parameters() {
use burn_core::module::ModuleVisitor;
let device = Default::default();
let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
let mut collector = Collector::new(None, None, false);
ModuleVisitor::<TestBackend>::enter_module(&mut collector, "weight", "");
ModuleVisitor::<TestBackend>::visit_float(&mut collector, &weight);
ModuleVisitor::<TestBackend>::exit_module(&mut collector, "weight", "");
ModuleVisitor::<TestBackend>::enter_module(&mut collector, "bias", "");
ModuleVisitor::<TestBackend>::visit_float(&mut collector, &bias);
ModuleVisitor::<TestBackend>::exit_module(&mut collector, "bias", "");
assert_eq!(collector.tensors.len(), 2);
assert_eq!(collector.tensors[0].full_path(), "weight");
assert_eq!(collector.tensors[1].full_path(), "bias");
let weight_data = collector.tensors[0]
.to_data()
.unwrap()
.to_vec::<f32>()
.unwrap();
let bias_data = collector.tensors[1]
.to_data()
.unwrap()
.to_vec::<f32>()
.unwrap();
assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(bias_data, vec![5.0, 6.0]);
}
#[test]
#[cfg(target_has_atomic = "ptr")]
fn tensor_snapshot_collector_with_filter() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let filter = PathFilter::new().with_regex(r"^encoder\..*");
let mut collector = Collector::new(Some(filter), None, false);
let id = ParamId::new();
collector.visit_float_with_path(
&["encoder".to_string(), "weight".to_string()],
id,
&tensor,
);
collector.visit_float_with_path(
&["decoder".to_string(), "weight".to_string()],
id,
&tensor,
);
assert_eq!(collector.tensors.len(), 1);
assert_eq!(collector.tensors[0].full_path(), "encoder.weight");
}
#[test]
#[cfg(target_has_atomic = "ptr")]
fn tensor_snapshot_collector_with_multiple_filters() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let filter = PathFilter::new()
.with_regex(r"^encoder\..*") .with_regex(r".*\.bias$"); let mut collector = Collector::new(Some(filter), None, false);
let id = ParamId::new();
collector.visit_float_with_path(
&["encoder".to_string(), "weight".to_string()],
id,
&tensor,
); collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor);
collector.visit_float_with_path(
&["decoder".to_string(), "weight".to_string()],
id,
&tensor,
);
assert_eq!(collector.tensors.len(), 3);
let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
assert!(paths.contains(&"encoder.weight".to_string()));
assert!(paths.contains(&"decoder.bias".to_string()));
assert!(paths.contains(&"encoder.bias".to_string()));
assert!(!paths.contains(&"decoder.weight".to_string()));
}
#[test]
fn tensor_snapshot_collector_with_predicate() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
fn filter_fn(path: &str, _container_path: &str) -> bool {
path.starts_with("encoder.") || path == "decoder.bias"
}
let filter = PathFilter::new().with_predicate(filter_fn);
let mut collector = Collector::new(Some(filter), None, false);
let id = ParamId::new();
collector.visit_float_with_path(
&["encoder".to_string(), "weight".to_string()],
id,
&tensor,
);
collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor);
collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor);
collector.visit_float_with_path(
&["decoder".to_string(), "weight".to_string()],
id,
&tensor,
);
collector.visit_float_with_path(&["other".to_string(), "tensor".to_string()], id, &tensor);
assert_eq!(collector.tensors.len(), 3);
let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
assert!(paths.contains(&"encoder.weight".to_string()));
assert!(paths.contains(&"encoder.bias".to_string()));
assert!(paths.contains(&"decoder.bias".to_string()));
assert!(!paths.contains(&"decoder.weight".to_string()));
assert!(!paths.contains(&"other.tensor".to_string()));
}
#[test]
fn tensor_snapshot_collector_predicate_with_complex_logic() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
fn complex_filter(path: &str, _container_path: &str) -> bool {
let parts: Vec<&str> = path.split('.').collect();
if parts.len() != 3 {
return false;
}
(parts[1] == "layer1" || parts[1] == "layer2") && parts[2] == "weight"
}
let filter = PathFilter::new().with_predicate(complex_filter);
let mut collector = Collector::new(Some(filter), None, false);
let id = ParamId::new();
collector.visit_float_with_path(
&[
"model".to_string(),
"layer1".to_string(),
"weight".to_string(),
],
id,
&tensor,
);
collector.visit_float_with_path(
&[
"model".to_string(),
"layer2".to_string(),
"weight".to_string(),
],
id,
&tensor,
);
collector.visit_float_with_path(
&[
"model".to_string(),
"layer1".to_string(),
"bias".to_string(),
],
id,
&tensor,
);
collector.visit_float_with_path(
&[
"model".to_string(),
"layer3".to_string(),
"weight".to_string(),
],
id,
&tensor,
);
collector.visit_float_with_path(
&["encoder".to_string(), "weight".to_string()],
id,
&tensor,
);
assert_eq!(collector.tensors.len(), 2);
let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
assert!(paths.contains(&"model.layer1.weight".to_string()));
assert!(paths.contains(&"model.layer2.weight".to_string()));
assert!(!paths.contains(&"model.layer1.bias".to_string()));
assert!(!paths.contains(&"model.layer3.weight".to_string()));
assert!(!paths.contains(&"encoder.weight".to_string()));
}
struct TensorPathCollector {
pub paths: BTreeMap<String, (ParamId, Vec<usize>)>,
path_stack: Vec<String>,
}
impl TensorPathCollector {
fn new() -> Self {
Self {
paths: BTreeMap::new(),
path_stack: Vec::new(),
}
}
fn current_path(&self) -> String {
self.path_stack.join(".")
}
}
impl<B: Backend> ModuleVisitor<B> for TensorPathCollector {
fn enter_module(&mut self, name: &str, _container_type: &str) {
self.path_stack.push(name.to_string());
}
fn exit_module(&mut self, _name: &str, _container_type: &str) {
self.path_stack.pop();
}
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let path = self.current_path();
if !path.is_empty() {
self.paths.insert(
path,
(param.id, param.transform_for_save().val().shape().to_vec()),
);
}
}
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {
let path = self.current_path();
if !path.is_empty() {
self.paths.insert(
path,
(param.id, param.transform_for_save().val().shape().to_vec()),
);
}
}
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {
let path = self.current_path();
if !path.is_empty() {
self.paths.insert(
path,
(param.id, param.transform_for_save().val().shape().to_vec()),
);
}
}
}
#[derive(Module, Debug)]
struct InnerModule<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Param<Tensor<B, 1>>,
}
#[derive(Module, Debug)]
struct OuterModule<B: Backend> {
layer1: InnerModule<B>,
layer2: InnerModule<B>,
}
impl<B: Backend> InnerModule<B> {
fn new(device: &B::Device) -> Self {
Self {
weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
bias: Param::from_data([5.0, 6.0], device),
}
}
}
impl<B: Backend> OuterModule<B> {
fn new(device: &B::Device) -> Self {
Self {
layer1: InnerModule::new(device),
layer2: InnerModule::new(device),
}
}
}
#[test]
fn nested_module_path_tracking() {
let device = Default::default();
let module = OuterModule::<TestBackend>::new(&device);
let mut collector = TensorPathCollector::new();
module.visit(&mut collector);
let paths = collector.paths;
assert!(paths.contains_key("layer1.weight"), "Missing layer1.weight");
assert!(paths.contains_key("layer1.bias"), "Missing layer1.bias");
assert!(paths.contains_key("layer2.weight"), "Missing layer2.weight");
assert!(paths.contains_key("layer2.bias"), "Missing layer2.bias");
assert_eq!(paths.get("layer1.weight").unwrap().1, vec![2, 2]);
assert_eq!(paths.get("layer1.bias").unwrap().1, vec![2]);
assert_eq!(paths.get("layer2.weight").unwrap().1, vec![2, 2]);
assert_eq!(paths.get("layer2.bias").unwrap().1, vec![2]);
}
#[test]
fn linear_module_paths() {
let device = Default::default();
let config = LinearConfig::new(10, 20).with_bias(true);
let linear = config.init::<TestBackend>(&device);
let mut collector = TensorPathCollector::new();
linear.visit(&mut collector);
let paths = collector.paths;
assert!(paths.contains_key("weight"));
assert!(paths.contains_key("bias"));
assert_eq!(paths.get("weight").unwrap().1, vec![10, 20]);
assert_eq!(paths.get("bias").unwrap().1, vec![20]);
}
#[derive(Module, Debug)]
struct Level4Module<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Param<Tensor<B, 1>>,
}
#[derive(Module, Debug)]
struct Level3Module<B: Backend> {
layer: Level4Module<B>,
extra: Level4Module<B>,
}
#[derive(Module, Debug)]
struct Level2Module<B: Backend> {
block1: Level3Module<B>,
block2: Level3Module<B>,
}
#[derive(Module, Debug)]
struct Level1Module<B: Backend> {
encoder: Level2Module<B>,
decoder: Level2Module<B>,
}
#[derive(Module, Debug)]
struct DeepModel<B: Backend> {
backbone: Level1Module<B>,
head: Level4Module<B>,
}
impl<B: Backend> Level4Module<B> {
fn new(device: &B::Device) -> Self {
Self {
weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
bias: Param::from_data([5.0, 6.0], device),
}
}
}
impl<B: Backend> Level3Module<B> {
fn new(device: &B::Device) -> Self {
Self {
layer: Level4Module::new(device),
extra: Level4Module::new(device),
}
}
}
impl<B: Backend> Level2Module<B> {
fn new(device: &B::Device) -> Self {
Self {
block1: Level3Module::new(device),
block2: Level3Module::new(device),
}
}
}
impl<B: Backend> Level1Module<B> {
fn new(device: &B::Device) -> Self {
Self {
encoder: Level2Module::new(device),
decoder: Level2Module::new(device),
}
}
}
impl<B: Backend> DeepModel<B> {
fn new(device: &B::Device) -> Self {
Self {
backbone: Level1Module::new(device),
head: Level4Module::new(device),
}
}
}
#[test]
fn deep_module_path_tracking() {
let device = Default::default();
let model = DeepModel::<TestBackend>::new(&device);
let mut collector = Collector::new(None, None, false);
model.visit(&mut collector);
let views = collector.tensors;
let paths: Vec<String> = views.iter().map(|v| v.full_path()).collect();
assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
assert!(paths.contains(&"backbone.encoder.block1.layer.bias".to_string()));
assert!(paths.contains(&"backbone.encoder.block1.extra.weight".to_string()));
assert!(paths.contains(&"backbone.encoder.block1.extra.bias".to_string()));
assert!(paths.contains(&"backbone.encoder.block2.layer.weight".to_string()));
assert!(paths.contains(&"backbone.encoder.block2.layer.bias".to_string()));
assert!(paths.contains(&"backbone.encoder.block2.extra.weight".to_string()));
assert!(paths.contains(&"backbone.encoder.block2.extra.bias".to_string()));
assert!(paths.contains(&"backbone.decoder.block1.layer.weight".to_string()));
assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
assert!(paths.contains(&"backbone.decoder.block1.extra.weight".to_string()));
assert!(paths.contains(&"backbone.decoder.block1.extra.bias".to_string()));
assert!(paths.contains(&"backbone.decoder.block2.layer.weight".to_string()));
assert!(paths.contains(&"backbone.decoder.block2.layer.bias".to_string()));
assert!(paths.contains(&"backbone.decoder.block2.extra.weight".to_string()));
assert!(paths.contains(&"backbone.decoder.block2.extra.bias".to_string()));
assert!(paths.contains(&"head.weight".to_string()));
assert!(paths.contains(&"head.bias".to_string()));
assert_eq!(views.len(), 18);
let view = views
.iter()
.find(|v| v.full_path() == "backbone.encoder.block1.layer.weight")
.unwrap();
let data = view.to_data().unwrap();
assert_eq!(data.shape, shape![2, 2]);
}
#[test]
fn deep_module_filtered_export() {
let device = Default::default();
let model = DeepModel::<TestBackend>::new(&device);
#[cfg(target_has_atomic = "ptr")]
{
let filter = PathFilter::new().with_regex(r"^backbone\.encoder\..*");
let mut collector = Collector::new(Some(filter), None, false);
model.visit(&mut collector);
assert_eq!(collector.tensors.len(), 8); }
#[cfg(target_has_atomic = "ptr")]
{
let filter = PathFilter::new().with_regex(r".*\.block1\..*");
let mut collector = Collector::new(Some(filter), None, false);
model.visit(&mut collector);
assert_eq!(collector.tensors.len(), 8); }
#[cfg(target_has_atomic = "ptr")]
{
let filter = PathFilter::new().with_regex(r".*\.weight$");
let mut collector = Collector::new(Some(filter), None, false);
model.visit(&mut collector);
assert_eq!(collector.tensors.len(), 9); }
#[cfg(target_has_atomic = "ptr")]
{
let filter = PathFilter::new()
.with_regex(r"^backbone\.encoder\.block1\..*") .with_regex(r"^backbone\.decoder\..*\.bias$") .with_regex(r"^head\.weight$"); let mut collector = Collector::new(Some(filter), None, false);
model.visit(&mut collector);
assert_eq!(collector.tensors.len(), 9);
let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();
assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string()));
assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string()));
assert!(paths.contains(&"head.weight".to_string()));
assert!(!paths.contains(&"head.bias".to_string())); }
}
use crate::traits::ModuleSnapshot;
use burn_nn::Linear;
use hashbrown::HashMap;
#[derive(Module, Debug)]
struct OptionalFieldModule<B: Backend> {
required: Param<Tensor<B, 2>>,
optional: Option<Param<Tensor<B, 1>>>,
}
impl<B: Backend> OptionalFieldModule<B> {
fn new_with_optional(device: &B::Device) -> Self {
Self {
required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
optional: Some(Param::from_data([5.0, 6.0], device)),
}
}
fn new_without_optional(device: &B::Device) -> Self {
Self {
required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
optional: None,
}
}
}
#[test]
fn optional_field_module_with_value() {
let device = Default::default();
let module = OptionalFieldModule::<TestBackend>::new_with_optional(&device);
let views: HashMap<String, TensorSnapshot> = module
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views.len(), 2);
assert!(views.contains_key("required"));
assert!(views.contains_key("optional"));
}
#[test]
fn optional_field_module_without_value() {
let device = Default::default();
let module = OptionalFieldModule::<TestBackend>::new_without_optional(&device);
let views: HashMap<String, TensorSnapshot> = module
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views.len(), 1);
assert!(views.contains_key("required"));
assert!(!views.contains_key("optional"));
}
#[derive(Module, Debug)]
struct VecModule<B: Backend> {
layers: Vec<Linear<B>>,
}
impl<B: Backend> VecModule<B> {
fn new(device: &B::Device, num_layers: usize) -> Self {
Self {
layers: (0..num_layers)
.map(|_| LinearConfig::new(10, 10).init(device))
.collect(),
}
}
}
#[derive(Module, Debug)]
struct TupleModule<B: Backend> {
layers: (Linear<B>, Linear<B>, Linear<B>),
}
impl<B: Backend> TupleModule<B> {
fn new(device: &B::Device) -> Self {
Self {
layers: (
LinearConfig::new(10, 10).init(device),
LinearConfig::new(10, 10).init(device),
LinearConfig::new(10, 10).init(device),
),
}
}
}
#[test]
fn vec_module_collect() {
let device = Default::default();
let module = VecModule::<TestBackend>::new(&device, 3);
let views: HashMap<String, TensorSnapshot> = module
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views.len(), 6);
assert!(views.contains_key("layers.0.weight"));
assert!(views.contains_key("layers.0.bias"));
assert!(views.contains_key("layers.1.weight"));
assert!(views.contains_key("layers.1.bias"));
assert!(views.contains_key("layers.2.weight"));
assert!(views.contains_key("layers.2.bias"));
}
#[test]
fn tuple_module_collect() {
let device = Default::default();
let module = TupleModule::<TestBackend>::new(&device);
let snapshots = module.collect(None, None, false);
assert_eq!(snapshots.len(), 6);
let views: HashMap<String, TensorSnapshot> =
snapshots.into_iter().map(|v| (v.full_path(), v)).collect();
assert_eq!(views.len(), 6);
assert!(views.contains_key("layers.0.weight"));
assert!(views.contains_key("layers.0.bias"));
assert!(views.contains_key("layers.1.weight"));
assert!(views.contains_key("layers.1.bias"));
assert!(views.contains_key("layers.2.weight"));
assert!(views.contains_key("layers.2.bias"));
}
#[derive(Module, Debug)]
struct ArrayModule<B: Backend> {
layers: [Linear<B>; 3],
}
impl<B: Backend> ArrayModule<B> {
fn new(device: &B::Device) -> Self {
Self {
layers: [
LinearConfig::new(10, 10).init(device),
LinearConfig::new(10, 10).init(device),
LinearConfig::new(10, 10).init(device),
],
}
}
}
#[test]
fn array_module_collect() {
let device = Default::default();
let module = ArrayModule::<TestBackend>::new(&device);
let views: HashMap<String, TensorSnapshot> = module
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views.len(), 6);
for i in 0..3 {
assert!(views.contains_key(&format!("layers.{}.weight", i)));
assert!(views.contains_key(&format!("layers.{}.bias", i)));
}
}
#[derive(Module, Debug)]
enum EnumModule<B: Backend> {
LayerA(Linear<B>),
LayerB(Linear<B>),
LayerC(Linear<B>),
}
#[test]
fn enum_module_collect() {
let device = Default::default();
let module_a = EnumModule::<TestBackend>::LayerA(LinearConfig::new(10, 20).init(&device));
let views_a: HashMap<String, TensorSnapshot> = module_a
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views_a.len(), 2);
assert!(views_a.contains_key("LayerA.weight"));
assert!(views_a.contains_key("LayerA.bias"));
let module_b = EnumModule::<TestBackend>::LayerB(LinearConfig::new(10, 20).init(&device));
let views_b: HashMap<String, TensorSnapshot> = module_b
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views_b.len(), 2);
assert!(views_b.contains_key("LayerB.weight"));
assert!(views_b.contains_key("LayerB.bias"));
}
#[test]
fn linear_container_type() {
let device = Default::default();
#[derive(Module, Debug)]
struct ModelWithLinear<B: Backend> {
linear: Linear<B>,
}
impl<B: Backend> ModelWithLinear<B> {
fn new(device: &B::Device) -> Self {
Self {
linear: LinearConfig::new(10, 20).init(device),
}
}
}
let model = ModelWithLinear::<TestBackend>::new(&device);
let views: HashMap<String, TensorSnapshot> = model
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
for (path, view) in views.iter() {
if path == "linear.weight" || path == "linear.bias" {
assert_eq!(
view.module_type(),
Some("Struct:Linear".to_string()),
"Tensor '{}' should have module type 'Struct:Linear'",
path
);
}
}
}
#[test]
fn complex_model_container_types() {
let device = Default::default();
#[derive(Module, Debug)]
struct ComplexModel<B: Backend> {
linear_layers: [Linear<B>; 2],
vec_layers: Vec<Linear<B>>,
single_linear: Linear<B>,
}
impl<B: Backend> ComplexModel<B> {
fn new(device: &B::Device) -> Self {
Self {
linear_layers: [
LinearConfig::new(100, 50).init(device),
LinearConfig::new(50, 10).init(device),
],
vec_layers: vec![
LinearConfig::new(10, 10).init(device),
LinearConfig::new(10, 10).init(device),
],
single_linear: LinearConfig::new(10, 1).init(device),
}
}
}
let model = ComplexModel::<TestBackend>::new(&device);
let views: HashMap<String, TensorSnapshot> = model
.collect(None, None, false)
.into_iter()
.map(|v| (v.full_path(), v))
.collect();
assert_eq!(views.len(), 10);
for (_path, view) in views.iter() {
assert_eq!(view.module_type(), Some("Struct:Linear".to_string()));
}
}
#[test]
fn collect_with_container_filter() {
let device = Default::default();
#[derive(Module, Debug)]
struct FilterTestModel<B: Backend> {
layers: Vec<Linear<B>>,
}
impl<B: Backend> FilterTestModel<B> {
fn new(device: &B::Device) -> Self {
Self {
layers: vec![
LinearConfig::new(10, 10).init(device),
LinearConfig::new(10, 10).init(device),
],
}
}
}
let model = FilterTestModel::<TestBackend>::new(&device);
let filter = PathFilter::new().with_predicate(|_path, container_path| {
container_path.split('.').next_back() == Some("Struct:Linear")
});
let linear_views: Vec<TensorSnapshot> = model.collect(Some(filter), None, false);
for view in linear_views.iter() {
assert_eq!(
view.module_type(),
Some("Struct:Linear".to_string()),
"All tensors should be from Linear modules"
);
}
assert_eq!(linear_views.len(), 4);
}
}