#![allow(clippy::needless_range_loop)]
use crate::pytorch::PytorchReader;
use crate::pytorch::reader::{ByteOrder, FileFormat};
use burn_tensor::{BoolStore, DType, TensorData, Tolerance, shape};
use std::path::PathBuf;
fn test_data_path(filename: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("src")
.join("pytorch")
.join("tests")
.join("reader")
.join("test_data")
.join(filename)
}
#[test]
fn test_float32_tensor() {
let path = test_data_path("float32.pt");
let reader = PytorchReader::new(&path).expect("Failed to load float32.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 4);
assert!((values[0] - 1.0).abs() < 1e-6);
assert!((values[1] - 2.5).abs() < 1e-6);
assert!((values[2] - (-3.7)).abs() < 1e-6);
assert!((values[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_float64_tensor() {
let path = test_data_path("float64.pt");
let reader = PytorchReader::new(&path).expect("Failed to load float64.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F64);
assert_eq!(tensor.shape, shape![3]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f64>().unwrap();
assert_eq!(values.len(), 3);
assert!((values[0] - 1.1).abs() < 1e-10);
assert!((values[1] - 2.2).abs() < 1e-10);
assert!((values[2] - 3.3).abs() < 1e-10);
}
#[test]
fn test_int64_tensor() {
let path = test_data_path("int64.pt");
let reader = PytorchReader::new(&path).expect("Failed to load int64.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::I64);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<i64>().unwrap();
assert_eq!(values, &[100, -200, 300, 0]);
}
#[test]
fn test_int32_tensor() {
let path = test_data_path("int32.pt");
let reader = PytorchReader::new(&path).expect("Failed to load int32.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::I32);
assert_eq!(tensor.shape, shape![3]);
let data = tensor.to_data().unwrap();
let data_converted = data.convert::<i32>();
let values = data_converted.as_slice::<i32>().unwrap();
assert_eq!(values, &[10, 20, -30]);
}
#[test]
fn test_int16_tensor() {
let path = test_data_path("int16.pt");
let reader = PytorchReader::new(&path).expect("Failed to load int16.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::I16);
assert_eq!(tensor.shape, shape![3]);
let data = tensor.to_data().unwrap();
let data_converted = data.convert::<i16>();
let values = data_converted.as_slice::<i16>().unwrap();
assert_eq!(values, &[1000, -2000, 3000]);
}
#[test]
fn test_int8_tensor() {
let path = test_data_path("int8.pt");
let reader = PytorchReader::new(&path).expect("Failed to load int8.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::I8);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let data_converted = data.convert::<i8>();
let values = data_converted.as_slice::<i8>().unwrap();
assert_eq!(values, &[127, -128, 0, 50]);
}
#[test]
fn test_bool_tensor() {
let path = test_data_path("bool.pt");
let reader = PytorchReader::new(&path).expect("Failed to load bool.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::Bool(BoolStore::Native));
assert_eq!(tensor.shape, shape![5]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<bool>().unwrap();
assert_eq!(values, &[true, false, true, true, false]);
}
#[test]
fn test_uint8_tensor() {
let path = test_data_path("uint8.pt");
let reader = PytorchReader::new(&path).expect("Failed to load uint8.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::U8);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<u8>().unwrap();
assert_eq!(values, &[0, 128, 255, 42]);
}
#[test]
fn test_float16_tensor() {
use half::f16;
let path = test_data_path("float16.pt");
let reader = PytorchReader::new(&path).expect("Failed to load float16.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F16);
assert_eq!(tensor.shape, shape![3]);
let data = tensor.to_data().unwrap();
assert_eq!(data.shape, shape![3]);
let values = data.as_slice::<f16>().unwrap();
assert_eq!(values.len(), 3);
assert!((values[0].to_f32() - 1.5).abs() < 1e-2);
assert!((values[1].to_f32() - (-2.25)).abs() < 1e-2);
assert!((values[2].to_f32() - 3.125).abs() < 1e-2);
}
#[test]
fn test_bfloat16_tensor() {
use half::bf16;
let path = test_data_path("bfloat16.pt");
let reader = PytorchReader::new(&path).expect("Failed to load bfloat16.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::BF16);
assert_eq!(tensor.shape, shape![3]);
let data = tensor.to_data().unwrap();
assert_eq!(data.shape, shape![3]);
let values = data.as_slice::<bf16>().unwrap();
assert_eq!(values.len(), 3);
assert!((values[0].to_f32() - 1.5).abs() < 1e-2);
assert!((values[1].to_f32() - (-2.5)).abs() < 1e-2);
assert!((values[2].to_f32() - 3.5).abs() < 1e-2);
}
#[test]
fn test_2d_tensor() {
let path = test_data_path("tensor_2d.pt");
let reader = PytorchReader::new(&path).expect("Failed to load tensor_2d.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![3, 2]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 6);
for (i, expected) in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0].iter().enumerate() {
assert!((values[i] - expected).abs() < 1e-6);
}
}
#[test]
fn test_3d_tensor() {
let path = test_data_path("tensor_3d.pt");
let reader = PytorchReader::new(&path).expect("Failed to load tensor_3d.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![2, 3, 4]);
let data = tensor.to_data().unwrap();
assert_eq!(data.shape, shape![2, 3, 4]);
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 24);
}
#[test]
fn test_4d_tensor() {
let path = test_data_path("tensor_4d.pt");
let reader = PytorchReader::new(&path).expect("Failed to load tensor_4d.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![2, 3, 2, 2]);
let data = tensor.to_data().unwrap();
assert_eq!(data.shape, shape![2, 3, 2, 2]);
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 24);
}
#[test]
fn test_state_dict() {
let path = test_data_path("state_dict.pt");
let reader = PytorchReader::new(&path).expect("Failed to load state_dict.pt");
let keys = reader.keys();
assert_eq!(keys.len(), 4);
assert!(keys.contains(&"weight".to_string()));
assert!(keys.contains(&"bias".to_string()));
assert!(keys.contains(&"running_mean".to_string()));
assert!(keys.contains(&"running_var".to_string()));
let weight = reader.get("weight").unwrap();
assert_eq!(weight.shape, shape![3, 4]);
assert_eq!(weight.dtype, DType::F32);
let bias = reader.get("bias").unwrap();
assert_eq!(bias.shape, shape![3]);
assert_eq!(bias.dtype, DType::F32);
let running_mean = reader.get("running_mean").unwrap();
assert_eq!(running_mean.shape, shape![3]);
let mean_data = running_mean.to_data().unwrap();
let mean_values = mean_data.as_slice::<f32>().unwrap();
assert!(mean_values.iter().all(|&v| v.abs() < 1e-6));
let running_var = reader.get("running_var").unwrap();
assert_eq!(running_var.shape, shape![3]);
let var_data = running_var.to_data().unwrap();
let var_values = var_data.as_slice::<f32>().unwrap();
assert!(var_values.iter().all(|&v| (v - 1.0).abs() < 1e-6));
}
#[test]
fn test_nested_dict() {
let path = test_data_path("nested_dict.pt");
let reader = PytorchReader::new(&path).expect("Failed to load nested_dict.pt");
let keys = reader.keys();
assert_eq!(keys.len(), 4);
assert!(keys.contains(&"layer1.weight".to_string()));
assert!(keys.contains(&"layer1.bias".to_string()));
assert!(keys.contains(&"layer2.weight".to_string()));
assert!(keys.contains(&"layer2.bias".to_string()));
let layer1_weight = reader.get("layer1.weight").unwrap();
assert_eq!(layer1_weight.shape, shape![2, 3]);
assert_eq!(layer1_weight.dtype, DType::F32);
let data = layer1_weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 6);
let layer2_weight = reader.get("layer2.weight").unwrap();
assert_eq!(layer2_weight.shape, shape![4, 2]);
assert_eq!(layer2_weight.dtype, DType::F32);
let data = layer2_weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 8); }
#[test]
fn test_checkpoint() {
let path = test_data_path("checkpoint.pt");
let reader = PytorchReader::new(&path).expect("Failed to load checkpoint.pt");
let keys = reader.keys();
assert!(keys.contains(&"model_state_dict.fc1.weight".to_string()));
assert!(keys.contains(&"model_state_dict.fc1.bias".to_string()));
assert!(keys.contains(&"model_state_dict.fc2.weight".to_string()));
assert!(keys.contains(&"model_state_dict.fc2.bias".to_string()));
let fc1_weight = reader.get("model_state_dict.fc1.weight").unwrap();
assert_eq!(fc1_weight.shape, shape![10, 5]);
let data = fc1_weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 50);
let fc2_weight = reader.get("model_state_dict.fc2.weight").unwrap();
assert_eq!(fc2_weight.shape, shape![3, 10]);
let data = fc2_weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 30); }
#[test]
fn test_empty_tensor() {
let path = test_data_path("empty.pt");
let reader = PytorchReader::new(&path).expect("Failed to load empty.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.shape, shape![0]); assert_eq!(tensor.dtype, DType::F32);
}
#[test]
fn test_scalar_tensor() {
let path = test_data_path("scalar.pt");
let reader = PytorchReader::new(&path).expect("Failed to load scalar.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.shape, shape![]); assert_eq!(tensor.dtype, DType::F32);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 1);
assert!((values[0] - 42.0).abs() < 1e-6);
}
#[test]
fn test_large_shape() {
let path = test_data_path("large_shape.pt");
let reader = PytorchReader::new(&path).expect("Failed to load large_shape.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.shape, shape![100, 100]);
assert_eq!(tensor.dtype, DType::F32);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 10000);
assert!((values[0] - 1.0).abs() < 1e-6); assert!((values[5050] - 2.0).abs() < 1e-6); assert!((values[9999] - 3.0).abs() < 1e-6); }
#[test]
fn test_mixed_types() {
let path = test_data_path("mixed_types.pt");
let reader = PytorchReader::new(&path).expect("Failed to load mixed_types.pt");
let tensors = reader.tensors();
assert_eq!(tensors.len(), 4);
let float32 = reader.get("float32").unwrap();
assert_eq!(float32.dtype, DType::F32);
assert_eq!(float32.shape, shape![2]);
let data = float32.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert!((values[0] - 1.0).abs() < 1e-6);
assert!((values[1] - 2.0).abs() < 1e-6);
let int64 = reader.get("int64").unwrap();
assert_eq!(int64.dtype, DType::I64);
assert_eq!(int64.shape, shape![2]);
let data = int64.to_data().unwrap();
let values = data.as_slice::<i64>().unwrap();
assert_eq!(values, &[100, 200]);
let bool_tensor = reader.get("bool").unwrap();
assert_eq!(bool_tensor.dtype, DType::Bool(BoolStore::Native));
assert_eq!(bool_tensor.shape, shape![2]);
let data = bool_tensor.to_data().unwrap();
let values = data.as_slice::<bool>().unwrap();
assert_eq!(values, &[true, false]);
let float64 = reader.get("float64").unwrap();
assert_eq!(float64.dtype, DType::F64);
assert_eq!(float64.shape, shape![2]);
let data = float64.to_data().unwrap();
let values = data.as_slice::<f64>().unwrap();
assert!((values[0] - 1.1).abs() < 1e-10);
assert!((values[1] - 2.2).abs() < 1e-10);
}
#[test]
fn test_special_values() {
let path = test_data_path("special_values.pt");
let reader = PytorchReader::new(&path).expect("Failed to load special_values.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![5]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 5);
assert!(values[0].is_nan());
assert!(values[1].is_infinite() && values[1] > 0.0);
assert!(values[2].is_infinite() && values[2] < 0.0);
assert!((values[3] - 0.0).abs() < 1e-6);
assert!((values[4] - 1.0).abs() < 1e-6);
}
#[test]
fn test_extreme_values() {
let path = test_data_path("extreme_values.pt");
let reader = PytorchReader::new(&path).expect("Failed to load extreme_values.pt");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 4);
assert!(values[0] > 0.0 && values[0] < 1e-20);
assert!(values[1] > 1e20);
assert!(values[2] < 0.0 && values[2] > -1e-20);
assert!(values[3] < -1e20);
}
#[test]
fn test_parameter() {
let path = test_data_path("parameter.pt");
let reader = PytorchReader::new(&path).expect("Failed to load parameter.pt");
let tensors = reader.tensors();
assert_eq!(tensors.len(), 1);
let param = reader.get("param").unwrap();
assert_eq!(param.shape, shape![3, 3]);
assert_eq!(param.dtype, DType::F32);
let data = param.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 9);
}
#[test]
fn test_buffers() {
let path = test_data_path("buffers.pt");
let reader = PytorchReader::new(&path).expect("Failed to load buffers.pt");
let tensors = reader.tensors();
assert_eq!(tensors.len(), 2);
let buffer1 = reader.get("buffer1").unwrap();
assert_eq!(buffer1.dtype, DType::I32);
assert_eq!(buffer1.shape, shape![3]);
let data1 = buffer1.to_data().unwrap();
let data1_converted = data1.convert::<i32>();
let values1 = data1_converted.as_slice::<i32>().unwrap();
assert_eq!(values1, &[1, 2, 3]);
let buffer2 = reader.get("buffer2").unwrap();
assert_eq!(buffer2.dtype, DType::Bool(BoolStore::Native));
assert_eq!(buffer2.shape, shape![2]);
let data2 = buffer2.to_data().unwrap();
let values2 = data2.as_slice::<bool>().unwrap();
assert_eq!(values2, &[true, false]);
}
#[test]
fn test_complex_structure() {
let path = test_data_path("complex_structure.pt");
let reader = PytorchReader::new(&path).expect("Failed to load complex_structure.pt");
let keys = reader.keys();
assert!(keys.contains(&"state.encoder.layer_0.weight".to_string()));
assert!(keys.contains(&"state.encoder.layer_0.bias".to_string()));
assert!(keys.contains(&"state.encoder.layer_1.weight".to_string()));
assert!(keys.contains(&"state.encoder.layer_1.bias".to_string()));
assert!(keys.contains(&"state.decoder.weight".to_string()));
assert!(keys.contains(&"state.decoder.bias".to_string()));
let layer0_weight = reader.get("state.encoder.layer_0.weight").unwrap();
assert_eq!(layer0_weight.shape, shape![4, 3]);
let data = layer0_weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 12);
let decoder_weight = reader.get("state.decoder.weight").unwrap();
assert_eq!(decoder_weight.shape, shape![3, 2]);
let data = decoder_weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 6); }
#[test]
fn test_read_pytorch_tensors_convenience() {
let path = test_data_path("state_dict.pt");
let reader = PytorchReader::new(&path).expect("Failed to read file");
let keys = reader.keys();
assert_eq!(keys.len(), 4);
assert!(keys.contains(&"weight".to_string()));
assert!(keys.contains(&"bias".to_string()));
let weight = reader.get("weight").unwrap();
let weight_data = weight.to_data().unwrap();
assert_eq!(weight_data.shape, shape![3, 4]);
assert_eq!(weight_data.dtype, DType::F32);
}
#[test]
fn test_with_top_level_key() {
let path = test_data_path("checkpoint.pt");
let reader = PytorchReader::with_top_level_key(&path, "model_state_dict")
.expect("Failed to load with top-level key");
let keys = reader.keys();
assert!(keys.contains(&"fc1.weight".to_string()));
assert!(keys.contains(&"fc1.bias".to_string()));
assert!(keys.contains(&"fc2.weight".to_string()));
assert!(keys.contains(&"fc2.bias".to_string()));
assert!(!keys.contains(&"model_state_dict.fc1.weight".to_string()));
}
#[test]
fn test_legacy_format() {
let path = test_data_path("simple_legacy.pt");
let reader = PytorchReader::new(&path).expect("Failed to load legacy format");
let keys = reader.keys();
assert!(keys.contains(&"weight".to_string()), "Missing 'weight' key");
assert!(keys.contains(&"bias".to_string()), "Missing 'bias' key");
assert!(
keys.contains(&"running_mean".to_string()),
"Missing 'running_mean' key"
);
let weight = reader.get("weight").expect("weight not found");
assert_eq!(weight.shape, shape![2, 3]);
assert_eq!(weight.dtype, DType::F32);
let bias = reader.get("bias").expect("bias not found");
assert_eq!(bias.shape, shape![2]);
assert_eq!(bias.dtype, DType::F32);
let bias_data = bias.to_data().unwrap();
let expected_bias_data = TensorData::new(vec![1.0_f32, 1.0], vec![2]);
bias_data.assert_approx_eq::<f32>(&expected_bias_data, Tolerance::default());
let running_mean = reader.get("running_mean").expect("running_mean not found");
assert_eq!(running_mean.shape, shape![2]);
assert_eq!(running_mean.dtype, DType::F32);
let mean_data = running_mean.to_data().unwrap();
let expected_mean_data = TensorData::new(vec![0.0_f32, 0.0], vec![2]);
mean_data.assert_approx_eq::<f32>(&expected_mean_data, Tolerance::default());
}
#[test]
fn test_legacy_uncloned_views() {
let path = test_data_path("legacy_uncloned_views.pt");
let reader = PytorchReader::new(&path).expect("Failed to load legacy format");
let keys = reader.keys();
assert!(
keys.contains(&"tensor1".to_string()),
"Missing 'tensor1' key"
);
assert!(
keys.contains(&"tensor2".to_string()),
"Missing 'tensor2' key"
);
let tensor1 = reader.get("tensor1").expect("tensor1 not found");
assert_eq!(tensor1.shape, shape![10]);
assert_eq!(tensor1.dtype, DType::F32);
let tensor1_data = tensor1.to_data().unwrap();
let expected_tensor1_data =
TensorData::new(vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19], vec![10]);
tensor1_data.assert_approx_eq::<f32>(&expected_tensor1_data, Tolerance::default());
let tensor2 = reader.get("tensor2").expect("tensor2 not found");
assert_eq!(tensor2.shape, shape![10]);
assert_eq!(tensor2.dtype, DType::F32);
let tensor2_data = tensor2.to_data().unwrap();
let expected_tensor2_data =
TensorData::new(vec![50, 51, 52, 53, 54, 55, 56, 57, 58, 59], vec![10]);
tensor2_data.assert_approx_eq::<f32>(&expected_tensor2_data, Tolerance::default());
}
#[test]
fn test_legacy_with_offsets() {
let path = test_data_path("legacy_with_offsets.pt");
let reader = PytorchReader::new(&path).expect("Should read legacy file with offsets");
assert_eq!(reader.keys().len(), 3, "Should have 3 tensors");
let tensor1 = reader
.get("tensor1")
.expect("Legacy file should contain tensor1");
assert_eq!(tensor1.shape, shape![10]);
let data1 = tensor1.to_data().unwrap();
let expected_data1 = TensorData::new(
vec![
1.00_f32, 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07, 1.08, 1.09,
],
vec![10],
);
data1.assert_approx_eq::<f32>(&expected_data1, Tolerance::default());
let tensor2 = reader
.get("tensor2")
.expect("Legacy file should contain tensor2");
assert_eq!(tensor2.shape, shape![5]);
let data2 = tensor2.to_data().unwrap();
let expected_data2 = TensorData::new(vec![2.0_f32, 2.1, 2.2, 2.3, 2.4], vec![5]);
data2.assert_approx_eq::<f32>(&expected_data2, Tolerance::default());
let tensor3 = reader
.get("tensor3")
.expect("Legacy file should contain tensor3");
assert_eq!(tensor3.shape, shape![5]);
let data3 = tensor3.to_data().unwrap();
let expected_data3 = TensorData::new(vec![3.0_f32, 3.1, 3.2, 3.3, 3.4], vec![5]);
data3.assert_approx_eq::<f32>(&expected_data3, Tolerance::default());
}
#[test]
fn test_legacy_shared_storage() {
let path = test_data_path("legacy_shared_storage.pt");
let reader = PytorchReader::new(&path).expect("Should read legacy file with shared storage");
let keys = reader.keys();
assert!(keys.len() >= 2, "Should have at least 2 tensors");
for key in &keys {
assert!(reader.get(key).is_some(), "Should have tensor: {}", key);
let tensor = reader.get(key).unwrap();
let data = tensor.to_data().unwrap();
match tensor.dtype {
DType::F32 => {
let values = data.as_slice::<f32>().unwrap();
assert!(!values.is_empty(), "Tensor {} should have data", key);
}
DType::I64 => {
let values = data.as_slice::<i64>().unwrap();
assert!(!values.is_empty(), "Tensor {} should have data", key);
}
_ => {
assert!(!data.shape.is_empty(), "Tensor {} should have shape", key);
}
}
}
}
#[test]
fn test_metadata_zip_format() {
let path = test_data_path("float32.pt");
let reader = PytorchReader::new(&path).expect("Failed to load float32.pt");
let metadata = reader.metadata();
assert_eq!(metadata.format_type, FileFormat::Zip);
assert_eq!(metadata.byte_order, ByteOrder::LittleEndian);
assert_eq!(metadata.tensor_count, 1);
assert!(metadata.total_data_size.is_some());
assert!(metadata.is_modern_format());
assert!(!metadata.is_legacy_format());
}
#[test]
fn test_metadata_legacy_format() {
let path = test_data_path("simple_legacy.pt");
let reader = PytorchReader::new(&path).expect("Failed to load legacy file");
let metadata = reader.metadata();
assert_eq!(metadata.format_type, FileFormat::Legacy);
assert_eq!(metadata.byte_order, ByteOrder::LittleEndian);
assert_eq!(metadata.tensor_count, 3); assert!(metadata.total_data_size.is_some());
}
#[test]
fn test_legacy_metadata_detailed() {
let path = test_data_path("simple_legacy.pt");
let reader = PytorchReader::new(&path).expect("Failed to load legacy file");
let metadata = reader.metadata();
assert_eq!(
metadata.format_type,
FileFormat::Legacy,
"Should be Legacy format"
);
assert_eq!(
metadata.byte_order,
ByteOrder::LittleEndian,
"Legacy format is little-endian"
);
assert_eq!(
metadata.tensor_count, 3,
"Should have 3 tensors: weight, bias, running_mean"
);
assert!(
metadata.total_data_size.is_some(),
"Should have total data size"
);
assert!(
metadata.total_data_size.unwrap() > 0,
"Data size should be positive"
);
assert_eq!(
metadata.format_version, None,
"Legacy format doesn't have version file"
);
assert_eq!(
metadata.pytorch_version, None,
"Legacy format doesn't store PyTorch version reliably"
);
assert!(
!metadata.has_storage_alignment,
"Legacy format doesn't have storage alignment"
);
let keys = reader.keys();
assert!(
keys.contains(&"weight".to_string()),
"Should have weight tensor"
);
assert!(
keys.contains(&"bias".to_string()),
"Should have bias tensor"
);
assert!(
keys.contains(&"running_mean".to_string()),
"Should have running_mean tensor"
);
}
#[test]
fn test_small_invalid_file() {
let path = test_data_path("broken.pt");
let result = PytorchReader::new(&path);
assert!(result.is_err(), "Expected error for broken file");
if let Err(e) = result {
let err_str = format!("{}", e);
assert!(
err_str.contains("Pickle") || err_str.contains("Invalid"),
"Error should mention pickle or invalid format: {}",
err_str
);
}
}
#[test]
fn test_read_pickle_data_basic() {
use crate::pytorch::reader::PickleValue;
let path = test_data_path("checkpoint.pt");
let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read pickle data");
if let PickleValue::Dict(dict) = data {
assert!(dict.contains_key("model_state_dict"));
assert!(dict.contains_key("optimizer_state_dict"));
assert!(dict.contains_key("epoch"));
assert!(dict.contains_key("loss"));
if let Some(PickleValue::Int(epoch)) = dict.get("epoch") {
assert_eq!(*epoch, 42);
} else {
panic!("Expected epoch to be an integer");
}
if let Some(PickleValue::Float(loss)) = dict.get("loss") {
assert!(*loss > 0.0 && *loss < 1.0, "Loss should be between 0 and 1");
} else {
panic!("Expected loss to be a float");
}
} else {
panic!("Expected root to be a dictionary");
}
}
#[test]
fn test_read_pickle_data_with_key() {
use crate::pytorch::reader::PickleValue;
let path = test_data_path("checkpoint.pt");
let data = PytorchReader::read_pickle_data(&path, Some("model_state_dict"))
.expect("Failed to read pickle data with key");
if let PickleValue::Dict(dict) = data {
assert!(dict.contains_key("fc1.weight"));
assert!(dict.contains_key("fc1.bias"));
assert!(dict.contains_key("fc2.weight"));
assert!(dict.contains_key("fc2.bias"));
assert!(!dict.contains_key("optimizer_state_dict"));
assert!(!dict.contains_key("epoch"));
} else {
panic!("Expected model_state_dict to be a dictionary");
}
}
#[test]
fn test_read_pickle_data_nested_structure() {
use crate::pytorch::reader::PickleValue;
let path = test_data_path("nested_dict.pt");
let data =
PytorchReader::read_pickle_data(&path, None).expect("Failed to read nested structure");
if let PickleValue::Dict(dict) = data {
assert!(!dict.is_empty(), "Dictionary should not be empty");
for (_key, value) in dict.iter() {
assert!(
matches!(value, PickleValue::None | PickleValue::Dict(_)),
"Values should be None or nested dicts"
);
}
} else {
panic!("Expected nested_dict to be a dictionary");
}
}
#[test]
fn test_read_pickle_data_types() {
use crate::pytorch::reader::PickleValue;
let path = test_data_path("mixed_types.pt");
let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read mixed types");
if let PickleValue::Dict(dict) = data {
assert!(dict.len() >= 3, "Should have at least 3 tensor types");
for (_key, value) in dict.iter() {
assert!(
matches!(value, PickleValue::None),
"Tensors should be None in pickle data"
);
}
} else {
panic!("Expected mixed_types to be a dictionary");
}
}
#[test]
fn test_read_pickle_data_key_not_found() {
let path = test_data_path("checkpoint.pt");
let result = PytorchReader::read_pickle_data(&path, Some("nonexistent_key"));
assert!(result.is_err());
if let Err(e) = result {
let err_str = format!("{}", e);
assert!(
err_str.contains("not found"),
"Error should mention key not found: {}",
err_str
);
}
}
#[test]
fn test_read_pickle_data_simple_pickle() {
use crate::pytorch::reader::PickleValue;
let path = test_data_path("state_dict.pt");
let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read simple pickle");
if let PickleValue::Dict(dict) = data {
assert!(dict.len() >= 3);
assert!(dict.contains_key("weight"));
assert!(dict.contains_key("bias"));
for (_key, value) in dict.iter() {
assert!(matches!(value, PickleValue::None));
}
} else {
panic!("Expected state_dict to contain a dictionary");
}
}
#[test]
fn test_load_config_basic() {
let path = test_data_path("checkpoint.pt");
#[derive(Debug, serde::Deserialize, PartialEq)]
struct CheckpointConfig {
epoch: i64,
loss: f64,
}
let config: CheckpointConfig =
PytorchReader::load_config(&path, None).expect("Failed to load config");
assert_eq!(config.epoch, 42);
assert!((config.loss - 0.123).abs() < 1e-6);
}
#[test]
fn test_load_config_with_top_level_key() {
let path = test_data_path("checkpoint.pt");
#[derive(Debug, serde::Deserialize, PartialEq)]
struct DummyConfig {
field: String,
}
let result: Result<DummyConfig, _> = PytorchReader::load_config(&path, Some("epoch"));
assert!(result.is_err());
let path2 = test_data_path("nested_dict.pt");
let data = PytorchReader::read_pickle_data(&path2, None).unwrap();
if let crate::pytorch::reader::PickleValue::Dict(dict) = data {
assert!(!dict.is_empty());
} else {
panic!("Expected a dict");
}
}
#[test]
fn test_load_config_complex_types() {
let path = test_data_path("checkpoint.pt");
#[derive(Debug, serde::Deserialize, PartialEq)]
struct PartialCheckpoint {
epoch: i64,
loss: f64,
}
let config: PartialCheckpoint =
PytorchReader::load_config(&path, None).expect("Failed to load config");
assert_eq!(config.epoch, 42);
assert!((config.loss - 0.123).abs() < 1e-6);
}
#[test]
fn test_load_config_key_not_found() {
let path = test_data_path("checkpoint.pt");
#[derive(Debug, serde::Deserialize)]
struct DummyConfig {
#[allow(dead_code)]
field: String,
}
let result: Result<DummyConfig, _> = PytorchReader::load_config(&path, Some("nonexistent"));
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains("not found") || error.to_string().contains("Key"));
}
#[test]
fn test_pickle_value_conversion() {
use crate::pytorch::reader::PickleValue;
let path = test_data_path("checkpoint.pt");
let data = PytorchReader::read_pickle_data(&path, None).unwrap();
match data {
PickleValue::Dict(dict) => {
if let Some(PickleValue::Int(epoch)) = dict.get("epoch") {
assert!(*epoch >= 0);
}
if let Some(PickleValue::Float(loss)) = dict.get("loss") {
assert!(loss.is_finite());
}
if let Some(PickleValue::Dict(model_dict)) = dict.get("model_state_dict") {
assert!(!model_dict.is_empty());
}
}
_ => panic!("Unexpected root type"),
}
}
#[test]
fn test_tar_format_detection() {
let tar_path = test_data_path("tar_float32.tar");
let zip_path = test_data_path("float32.pt");
let reader = PytorchReader::new(&tar_path).expect("Failed to load TAR file");
let metadata = reader.metadata();
assert_eq!(metadata.format_type, FileFormat::Tar);
let reader = PytorchReader::new(&zip_path).expect("Failed to load ZIP file");
let metadata = reader.metadata();
assert_ne!(metadata.format_type, FileFormat::Tar);
}
#[test]
fn test_tar_float32_tensor() {
let path = test_data_path("tar_float32.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_float32.tar");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 4);
assert!((values[0] - 1.0).abs() < 1e-6);
assert!((values[1] - 2.5).abs() < 1e-6);
assert!((values[2] - (-3.7)).abs() < 1e-6);
assert!((values[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_tar_float64_tensor() {
let path = test_data_path("tar_float64.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_float64.tar");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::F64);
assert_eq!(tensor.shape, shape![3]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<f64>().unwrap();
assert_eq!(values.len(), 3);
assert!((values[0] - 1.1).abs() < 1e-10);
assert!((values[1] - 2.2).abs() < 1e-10);
assert!((values[2] - 3.3).abs() < 1e-10);
}
#[test]
fn test_tar_int64_tensor() {
let path = test_data_path("tar_int64.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_int64.tar");
let tensor = reader.get("tensor").expect("tensor key not found");
assert_eq!(tensor.dtype, DType::I64);
assert_eq!(tensor.shape, shape![4]);
let data = tensor.to_data().unwrap();
let values = data.as_slice::<i64>().unwrap();
assert_eq!(values, &[100, -200, 300, 0]);
}
#[test]
fn test_tar_multiple_tensors() {
let path = test_data_path("tar_weight_bias.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_weight_bias.tar");
let weight = reader.get("weight").expect("weight key not found");
assert_eq!(weight.dtype, DType::F32);
assert_eq!(weight.shape, shape![2, 3]);
let data = weight.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 6);
assert!((values[0] - 0.1).abs() < 1e-6);
assert!((values[1] - 0.2).abs() < 1e-6);
assert!((values[5] - 0.6).abs() < 1e-6);
let bias = reader.get("bias").expect("bias key not found");
assert_eq!(bias.dtype, DType::F32);
assert_eq!(bias.shape, shape![2]);
let data = bias.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 2);
assert!((values[0] - 0.01).abs() < 1e-6);
assert!((values[1] - 0.02).abs() < 1e-6);
}
#[test]
fn test_tar_multi_dtype() {
let path = test_data_path("tar_multi_dtype.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_multi_dtype.tar");
let float_tensor = reader
.get("float_tensor")
.expect("float_tensor key not found");
assert_eq!(float_tensor.dtype, DType::F32);
let data = float_tensor.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert!((values[0] - 1.5).abs() < 1e-6);
let double_tensor = reader
.get("double_tensor")
.expect("double_tensor key not found");
assert_eq!(double_tensor.dtype, DType::F64);
let data = double_tensor.to_data().unwrap();
let values = data.as_slice::<f64>().unwrap();
assert!((values[0] - 1.111).abs() < 1e-10);
let int_tensor = reader.get("int_tensor").expect("int_tensor key not found");
assert_eq!(int_tensor.dtype, DType::I64);
let data = int_tensor.to_data().unwrap();
let values = data.as_slice::<i64>().unwrap();
assert_eq!(values, &[10, 20, 30, 40]);
}
#[test]
fn test_tar_2d_tensor_shape() {
let path = test_data_path("tar_2d_tensor.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_2d_tensor.tar");
let matrix = reader.get("matrix").expect("matrix key not found");
assert_eq!(matrix.dtype, DType::F32);
assert_eq!(matrix.shape, shape![3, 4]);
let data = matrix.to_data().unwrap();
let values = data.as_slice::<f32>().unwrap();
assert_eq!(values.len(), 12);
for i in 0..12 {
assert!((values[i] - (i as f32 + 1.0)).abs() < 1e-6);
}
}
#[test]
fn test_tar_metadata() {
let path = test_data_path("tar_float32.tar");
let reader = PytorchReader::new(&path).expect("Failed to load tar_float32.tar");
let metadata = reader.metadata();
assert_eq!(metadata.format_type, FileFormat::Tar);
assert_eq!(metadata.byte_order, ByteOrder::LittleEndian);
assert_eq!(metadata.tensor_count, 1);
assert!(metadata.total_data_size.is_some());
}