use burn_core as burn;
use std::path::PathBuf;
use crate::ModuleStore;
use crate::pytorch::PytorchStore;
use burn_core::module::Module;
use burn_nn::conv::{Conv2d, Conv2dConfig};
use burn_nn::{Linear, LinearConfig};
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;
fn pytorch_test_path(subdir: &str, filename: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("pytorch-tests")
.join("tests")
.join(subdir)
.join(filename)
}
fn test_data_path(filename: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("src")
.join("pytorch")
.join("tests")
.join("reader")
.join("test_data")
.join(filename)
}
fn store_test_data_path(filename: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("src")
.join("pytorch")
.join("tests")
.join("store")
.join("test_data")
.join(filename)
}
#[cfg(test)]
mod basic_tests {
use super::*;
#[test]
fn test_store_creation() {
let store = PytorchStore::from_file("model.pth");
assert!(store.validate);
assert!(!store.allow_partial);
assert!(store.top_level_key.is_none());
assert!(store.map_indices_contiguous);
}
#[test]
fn test_store_map_indices_contiguous_default() {
let store = PytorchStore::from_file("model.pth");
assert!(
store.map_indices_contiguous,
"map_indices_contiguous should be enabled by default"
);
}
#[test]
fn test_store_map_indices_contiguous_disabled() {
let store = PytorchStore::from_file("model.pth").map_indices_contiguous(false);
assert!(
!store.map_indices_contiguous,
"map_indices_contiguous should be disabled after explicit call"
);
}
#[test]
fn test_store_with_top_level_key() {
let store = PytorchStore::from_file("model.pth").with_top_level_key("state_dict");
assert_eq!(store.top_level_key, Some("state_dict".to_string()));
}
#[test]
fn test_store_configuration() {
let store = PytorchStore::from_file("model.pth")
.validate(false)
.allow_partial(true)
.with_regex(r"^encoder\.")
.with_full_path("decoder.weight");
assert!(!store.validate);
assert!(store.allow_partial);
assert!(!store.filter.is_empty());
}
#[test]
fn test_store_with_remapping() {
let store = PytorchStore::from_file("model.pth").with_key_remapping(r"^old\.", "new.");
assert!(!store.remapper.is_empty());
}
#[test]
fn test_store_save_not_supported() {
let store = PytorchStore::from_file("test.pth");
assert!(store.validate);
}
}
#[cfg(test)]
mod linear_model_tests {
use super::*;
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
pub struct SimpleLinearModel<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> SimpleLinearModel<B> {
pub fn new(device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(2, 3).init(device),
fc2: LinearConfig::new(3, 4).init(device),
}
}
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.fc1.forward(x);
self.fc2.forward(x)
}
}
#[test]
fn test_load_linear_model() {
let device = Default::default();
let path = pytorch_test_path("linear", "linear.pt");
let mut model = SimpleLinearModel::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file(path).allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model);
assert!(
result.is_ok(),
"Failed to load linear model: {:?}",
result.err()
);
let result = result.unwrap();
assert!(!result.applied.is_empty(), "No tensors were applied");
let input = Tensor::<TestBackend, 2>::ones([1, 2], &device);
let output = model.forward(input);
assert_eq!(&*output.shape(), [1, 4]);
}
#[test]
fn test_load_linear_with_bias() {
let device = Default::default();
let path = pytorch_test_path("linear", "linear_with_bias.pt");
#[derive(Module, Debug)]
struct LinearWithBias<B: Backend> {
fc1: Linear<B>,
}
let mut model = LinearWithBias {
fc1: LinearConfig::new(2, 3).init(&device),
};
let mut store = PytorchStore::from_file(path).allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model);
assert!(result.is_ok(), "Failed to load model with bias");
let result = result.unwrap();
let bias_loaded = result.applied.iter().any(|s| s.contains("bias"));
assert!(bias_loaded, "Bias parameters not loaded");
}
#[test]
fn test_filter_layers() {
let device = Default::default();
let path = pytorch_test_path("linear", "linear.pt");
let mut model = SimpleLinearModel::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file(path)
.with_regex(r"^fc1\.")
.allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model).unwrap();
for tensor in &result.applied {
assert!(tensor.contains("fc1"));
assert!(!tensor.contains("fc2"));
}
}
#[test]
fn test_remap_layer_names() {
let device = Default::default();
let path = pytorch_test_path("linear", "linear.pt");
#[derive(Module, Debug)]
struct RemappedModel<B: Backend> {
linear1: Linear<B>,
linear2: Linear<B>,
}
let mut model = RemappedModel {
linear1: LinearConfig::new(2, 3).init(&device),
linear2: LinearConfig::new(3, 4).init(&device),
};
let mut store = PytorchStore::from_file(path)
.with_key_remapping(r"^fc1\.", "linear1.")
.with_key_remapping(r"^fc2\.", "linear2.")
.allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model);
assert!(result.is_ok(), "Failed to load with remapped names");
let result = result.unwrap();
let has_linear1 = result.applied.iter().any(|s| s.contains("linear1"));
assert!(has_linear1, "Remapped names not applied");
}
}
#[cfg(test)]
mod conv_model_tests {
use super::*;
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
struct SimpleConvModel<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> SimpleConvModel<B> {
pub fn new(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([3, 16], [3, 3]).init(device),
conv2: Conv2dConfig::new([16, 32], [3, 3]).init(device),
}
}
}
#[test]
fn test_load_conv2d_model() {
let device = Default::default();
let path = pytorch_test_path("conv2d", "conv2d.pt");
if !path.exists() {
println!("Skipping conv2d test - file not found: {:?}", path);
return;
}
let mut model = SimpleConvModel::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file(path).allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model);
if let Ok(result) = result {
assert!(!result.applied.is_empty(), "No conv tensors applied");
let has_conv_weights = result.applied.iter().any(|s| s.contains("weight"));
assert!(has_conv_weights, "Conv weights not loaded");
}
}
#[test]
fn test_load_conv1d_model() {
let path = pytorch_test_path("conv1d", "conv1d.pt");
if !path.exists() {
println!("Skipping conv1d test - file not found: {:?}", path);
return;
}
let store = PytorchStore::from_file(path).allow_partial(true);
assert!(store.allow_partial);
}
}
#[cfg(test)]
mod complex_model_tests {
use super::*;
type TestBackend = burn_flex::Flex;
#[test]
fn test_load_with_top_level_key() {
let path = test_data_path("checkpoint.pt");
let store = PytorchStore::from_file(path)
.with_top_level_key("model_state_dict")
.allow_partial(true);
assert_eq!(store.top_level_key, Some("model_state_dict".to_string()));
}
#[test]
fn test_load_nested_structure() {
let path = test_data_path("complex_structure.pt");
let store = PytorchStore::from_file(path).allow_partial(true);
assert!(store.allow_partial);
}
#[test]
fn test_legacy_format() {
let path = test_data_path("simple_legacy.pt");
if !path.exists() {
println!("Skipping legacy format test - file not found: {:?}", path);
return;
}
let store = PytorchStore::from_file(path).allow_partial(true);
assert!(store.allow_partial);
}
#[test]
fn test_key_remap_chained() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping key remap test - file not found: {:?}", path);
return;
}
let device = Default::default();
#[derive(Module, Debug)]
struct RemappedChainModel<B: Backend> {
convolution1: Linear<B>, linear2: Linear<B>, }
let mut model = RemappedChainModel {
convolution1: LinearConfig::new(2, 3).init(&device),
linear2: LinearConfig::new(3, 4).init(&device),
};
let mut store = PytorchStore::from_file(path)
.with_key_remapping(r"^fc1\.", "convolution1.")
.with_key_remapping(r"^fc2\.", "linear2.")
.allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model);
if let Ok(result) = result {
assert!(
!result.applied.is_empty(),
"No tensors were applied after remapping"
);
}
}
}
#[cfg(test)]
mod adapter_tests {
use super::*;
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
pub struct SimpleLinearModel<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> SimpleLinearModel<B> {
pub fn new(device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(2, 3).init(device),
fc2: LinearConfig::new(3, 4).init(device),
}
}
}
#[test]
fn test_pytorch_adapter_always_applied() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping adapter test - file not found: {:?}", path);
return;
}
let device = Default::default();
let mut model = SimpleLinearModel::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file(path).allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model);
assert!(
result.is_ok(),
"Failed to load with internal PyTorchToBurnAdapter: {:?}",
result.err()
);
assert!(!result.unwrap().applied.is_empty());
}
#[test]
fn test_pytorch_adapter_with_filtering() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping filtering test - file not found: {:?}", path);
return;
}
let device = Default::default();
let mut model = SimpleLinearModel::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file(path)
.with_predicate(|path, _| !path.contains("bias"))
.allow_partial(true);
let result = store.apply_to::<TestBackend, _>(&mut model).unwrap();
for applied_path in &result.applied {
assert!(
!applied_path.contains("bias"),
"Bias tensor was not filtered: {}",
applied_path
);
}
}
}
#[cfg(test)]
mod error_handling_tests {
use super::*;
use burn_flex::Flex;
#[derive(Module, Debug)]
pub struct SimpleLinearModel<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> SimpleLinearModel<B> {
pub fn new(device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(2, 3).init(device),
fc2: LinearConfig::new(3, 4).init(device),
}
}
}
#[test]
fn test_missing_file() {
let device = Default::default();
let mut model = SimpleLinearModel::<Flex>::new(&device);
let mut store = PytorchStore::from_file("nonexistent.pth");
let result = store.apply_to::<Flex, _>(&mut model);
assert!(result.is_err());
match result {
Err(crate::pytorch::PytorchStoreError::Reader(_)) => {}
_ => panic!("Expected reader error for missing file"),
}
}
#[test]
fn test_invalid_top_level_key() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!(
"Skipping invalid top level key test - file not found: {:?}",
path
);
return;
}
let device = Default::default();
let mut model = SimpleLinearModel::<Flex>::new(&device);
let mut store = PytorchStore::from_file(path).with_top_level_key("nonexistent_key");
let result = store.apply_to::<Flex, _>(&mut model);
assert!(result.is_err(), "Should fail with invalid top level key");
}
#[test]
fn test_strict_validation() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!(
"Skipping strict validation test - file not found: {:?}",
path
);
return;
}
let device = Default::default();
let mut model = SimpleLinearModel::<Flex>::new(&device);
let mut store = PytorchStore::from_file(path)
.with_regex(r"^this_will_never_match$")
.validate(true)
.allow_partial(false);
let result = store.apply_to::<Flex, _>(&mut model);
assert!(
result.is_err(),
"Should fail when no tensors match with allow_partial=false"
);
}
}
#[cfg(test)]
mod enum_variant_tests {
use super::*;
use crate::ModuleSnapshot;
use burn_flex::Flex;
#[derive(Module, Debug)]
pub enum ConvBlock<B: Backend> {
BaseConv(Linear<B>),
DwsConv(Linear<B>),
}
#[derive(Module, Debug)]
pub struct ModelWithEnum<B: Backend> {
feature: ConvBlock<B>,
classifier: Linear<B>,
}
impl<B: Backend> ModelWithEnum<B> {
pub fn new(device: &B::Device) -> Self {
Self {
feature: ConvBlock::BaseConv(LinearConfig::new(3, 64).init(device)),
classifier: LinearConfig::new(64, 10).init(device),
}
}
}
#[test]
fn test_enum_variant_path_mismatch() {
let device = Default::default();
let mut model = ModelWithEnum::<Flex>::new(&device);
let pytorch_file = store_test_data_path("model_without_enum_variants.pt");
let mut store = PytorchStore::from_file(pytorch_file)
.skip_enum_variants(false) .allow_partial(true) .validate(false);
let result = store.apply_to::<Flex, _>(&mut model);
match result {
Ok(apply_result) => {
assert!(
!apply_result.missing.is_empty(),
"Should have missing tensors due to enum variant path mismatch"
);
let enum_missing: Vec<_> = apply_result
.missing
.iter()
.filter(|(_, container_stack)| container_stack.contains("Enum:"))
.collect();
assert!(
!enum_missing.is_empty(),
"Missing tensors should be detected as having enum containers"
);
let has_base_conv_path = apply_result
.missing
.iter()
.any(|(path, _)| path.contains("BaseConv"));
assert!(
has_base_conv_path,
"Should have missing paths with 'BaseConv' enum variant. Missing: {:?}",
apply_result
.missing
.iter()
.map(|(p, _)| p)
.collect::<Vec<_>>()
);
println!("\n{}", apply_result);
let display_output = format!("{}", apply_result);
assert!(
display_output.contains("enum variant"),
"Display output should mention enum variants"
);
}
Err(e) => panic!(
"Load should succeed with allow_partial=true, got error: {}",
e
),
}
}
#[test]
fn test_enum_variant_detection_in_container_stack() {
let device = Default::default();
let model = ModelWithEnum::<Flex>::new(&device);
let snapshots = model.collect(None, None, false);
let enum_snapshot = snapshots
.iter()
.find(|s| s.full_path().contains("feature"))
.expect("Should have feature snapshots");
if let Some(container_stack) = &enum_snapshot.container_stack {
let container_str = container_stack.join(".");
assert!(
container_str.contains("Enum:ConvBlock"),
"Container stack should contain Enum:ConvBlock marker. Got: {}",
container_str
);
} else {
panic!("Snapshot should have container_stack");
}
}
#[test]
fn test_skip_enum_variants_feature() {
let device = Default::default();
let mut model = ModelWithEnum::<Flex>::new(&device);
let pytorch_file = store_test_data_path("model_without_enum_variants.pt");
let mut store = PytorchStore::from_file(pytorch_file)
.skip_enum_variants(true) .allow_partial(true)
.validate(false);
let result = store.apply_to::<Flex, _>(&mut model);
match result {
Ok(apply_result) => {
println!("\n{}", apply_result);
let feature_applied = apply_result
.applied
.iter()
.filter(|path| path.contains("feature"))
.count();
assert!(
feature_applied > 0,
"Should have applied feature tensors with skip_enum_variants=true. Applied: {:?}",
apply_result.applied
);
let feature_missing = apply_result
.missing
.iter()
.filter(|(path, _)| path.contains("feature"))
.count();
assert_eq!(
feature_missing, 0,
"Feature tensors should not be missing with skip_enum_variants=true. Missing: {:?}",
apply_result.missing
);
}
Err(e) => panic!(
"Load with skip_enum_variants should succeed, got error: {}",
e
),
}
}
}
#[cfg(test)]
mod direct_access_tests {
use super::*;
#[test]
fn test_get_all_snapshots() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path);
let snapshots = store.get_all_snapshots().unwrap();
assert!(!snapshots.is_empty(), "Should have snapshots");
assert!(
snapshots.contains_key("fc1.weight"),
"Should contain fc1.weight"
);
assert!(
snapshots.contains_key("fc1.bias"),
"Should contain fc1.bias"
);
}
#[test]
fn test_get_snapshot_existing() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path);
let snapshot = store.get_snapshot("fc1.weight").unwrap();
assert!(snapshot.is_some(), "Should find fc1.weight");
let snapshot = snapshot.unwrap();
assert_eq!(snapshot.shape.len(), 2, "Weight should be 2D tensor");
let data = snapshot.to_data().unwrap();
assert!(!data.bytes.is_empty(), "Data should not be empty");
}
#[test]
fn test_get_snapshot_not_found() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path);
let snapshot = store.get_snapshot("nonexistent.weight").unwrap();
assert!(snapshot.is_none(), "Should not find nonexistent tensor");
}
#[test]
fn test_keys() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path);
let keys = store.keys().unwrap();
assert!(!keys.is_empty(), "Should have keys");
assert!(
keys.contains(&"fc1.weight".to_string()),
"Keys should contain fc1.weight"
);
assert!(
keys.contains(&"fc1.bias".to_string()),
"Keys should contain fc1.bias"
);
}
#[test]
fn test_keys_fast_path() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(&path);
let keys = store.keys().unwrap();
assert!(!keys.is_empty(), "Should have keys via fast path");
let snapshots = store.get_all_snapshots().unwrap();
assert!(!snapshots.is_empty(), "Should have snapshots");
let keys2 = store.keys().unwrap();
assert_eq!(keys.len(), keys2.len(), "Keys count should match");
}
#[test]
fn test_caching_behavior() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path);
let snapshots1 = store.get_all_snapshots().unwrap();
let count1 = snapshots1.len();
let snapshots2 = store.get_all_snapshots().unwrap();
let count2 = snapshots2.len();
assert_eq!(count1, count2, "Cached results should match");
}
#[test]
fn test_get_all_snapshots_with_remapping() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path).with_key_remapping(r"^fc1\.", "linear1.");
let snapshots = store.get_all_snapshots().unwrap();
assert!(
snapshots.contains_key("linear1.weight"),
"Should contain remapped key linear1.weight. Keys: {:?}",
snapshots.keys().collect::<Vec<_>>()
);
assert!(
snapshots.contains_key("linear1.bias"),
"Should contain remapped key linear1.bias"
);
assert!(
!snapshots.contains_key("fc1.weight"),
"Should not contain original key fc1.weight"
);
}
#[test]
fn test_get_snapshot_with_remapped_name() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path).with_key_remapping(r"^fc1\.", "linear1.");
let snapshot = store.get_snapshot("linear1.weight").unwrap();
assert!(snapshot.is_some(), "Should find tensor by remapped name");
let snapshot_orig = store.get_snapshot("fc1.weight").unwrap();
assert!(
snapshot_orig.is_none(),
"Should not find tensor by original name after remapping"
);
}
#[test]
fn test_get_all_snapshots_ignores_filter() {
let path = pytorch_test_path("linear", "linear.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store = PytorchStore::from_file(path).with_regex(r"^fc1\.");
let snapshots = store.get_all_snapshots().unwrap();
assert!(
snapshots.contains_key("fc1.weight"),
"Should contain fc1.weight"
);
assert!(
snapshots.contains_key("fc2.weight"),
"Should contain fc2.weight (filter not applied to get_all_snapshots)"
);
}
}
#[cfg(test)]
mod map_indices_contiguous_tests {
use super::*;
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
struct SequentialConvModel<B: Backend> {
fc: Vec<Conv2d<B>>,
}
impl<B: Backend> SequentialConvModel<B> {
pub fn new(device: &B::Device, num_layers: usize) -> Self {
Self {
fc: (0..num_layers)
.map(|_| {
Conv2dConfig::new([2, 2], [3, 3])
.with_bias(true)
.init(device)
})
.collect(),
}
}
}
#[test]
fn test_load_non_contiguous_indexes_with_mapping() {
let path = pytorch_test_path("non_contiguous_indexes", "non_contiguous_indexes.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let device = Default::default();
let mut model = SequentialConvModel::<TestBackend>::new(&device, 5);
let mut store = PytorchStore::from_file(&path)
.map_indices_contiguous(true)
.allow_partial(true)
.validate(false);
let result = store.apply_to::<TestBackend, _>(&mut model);
match result {
Ok(apply_result) => {
println!("Applied tensors: {:?}", apply_result.applied);
println!("Missing tensors: {:?}", apply_result.missing);
println!("Unused tensors: {:?}", apply_result.unused);
assert!(
!apply_result.applied.is_empty(),
"Should have applied tensors"
);
for i in 0..5 {
let has_weight = apply_result
.applied
.iter()
.any(|p| p.contains(&format!("fc.{}.weight", i)));
let has_bias = apply_result
.applied
.iter()
.any(|p| p.contains(&format!("fc.{}.bias", i)));
assert!(
has_weight,
"Should have applied fc.{}.weight, applied: {:?}",
i, apply_result.applied
);
assert!(
has_bias,
"Should have applied fc.{}.bias, applied: {:?}",
i, apply_result.applied
);
}
let missing_fc: Vec<_> = apply_result
.missing
.iter()
.filter(|(p, _)| p.starts_with("fc."))
.collect();
assert!(
missing_fc.is_empty(),
"Should have no missing fc tensors with index mapping. Missing: {:?}",
missing_fc
);
}
Err(e) => panic!("Failed to load with index mapping: {}", e),
}
}
#[test]
fn test_load_non_contiguous_indexes_without_mapping() {
let path = pytorch_test_path("non_contiguous_indexes", "non_contiguous_indexes.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let device = Default::default();
let mut model = SequentialConvModel::<TestBackend>::new(&device, 5);
let mut store = PytorchStore::from_file(&path)
.map_indices_contiguous(false) .allow_partial(true)
.validate(false);
let result = store.apply_to::<TestBackend, _>(&mut model);
match result {
Ok(apply_result) => {
println!(
"Without index mapping - Applied tensors: {:?}",
apply_result.applied
);
println!(
"Without index mapping - Missing tensors: {:?}",
apply_result.missing
);
let missing_fc: Vec<_> = apply_result
.missing
.iter()
.filter(|(p, _)| p.starts_with("fc."))
.collect();
assert!(
!missing_fc.is_empty(),
"Should have missing fc tensors without index mapping (indices 1, 3 don't exist in file)"
);
let has_fc1_missing = apply_result
.missing
.iter()
.any(|(p, _)| p.starts_with("fc.1."));
let has_fc3_missing = apply_result
.missing
.iter()
.any(|(p, _)| p.starts_with("fc.3."));
assert!(
has_fc1_missing || has_fc3_missing,
"Should have fc.1 or fc.3 missing. Missing: {:?}",
apply_result.missing
);
}
Err(e) => panic!("Unexpected error: {}", e),
}
}
#[test]
fn test_mapping_applied_to_keys() {
let path = pytorch_test_path("non_contiguous_indexes", "non_contiguous_indexes.pt");
if !path.exists() {
println!("Skipping test - file not found: {:?}", path);
return;
}
let mut store_mapped = PytorchStore::from_file(&path).map_indices_contiguous(true);
let keys_mapped = store_mapped.keys().unwrap();
println!("Keys with index mapping: {:?}", keys_mapped);
assert!(
keys_mapped.iter().any(|k| k.starts_with("fc.1.")),
"With index mapping, should have fc.1 (from fc.2)"
);
assert!(
keys_mapped.iter().any(|k| k.starts_with("fc.2.")),
"With index mapping, should have fc.2 (from fc.4)"
);
let mut store_no_mapping = PytorchStore::from_file(&path).map_indices_contiguous(false);
let keys_no_mapping = store_no_mapping.keys().unwrap();
println!("Keys without index mapping: {:?}", keys_no_mapping);
assert!(
keys_no_mapping.iter().any(|k| k.starts_with("fc.2.")),
"Without index mapping, should have original fc.2"
);
assert!(
keys_no_mapping.iter().any(|k| k.starts_with("fc.4.")),
"Without index mapping, should have original fc.4"
);
assert!(
!keys_no_mapping.iter().any(|k| k.starts_with("fc.1.")),
"Without index mapping, should NOT have fc.1 (not in original file)"
);
}
}