import torch
import numpy as np
import os
from pathlib import Path
test_dir = Path(__file__).parent / "test_data"
test_dir.mkdir(exist_ok=True)
def save_test_file(filename, data, description):
filepath = test_dir / filename
torch.save(data, filepath)
print(f"✓ {filename}: {description}")
return filepath
print("\n=== Generating Basic Tensor Tests ===")
float32_tensor = torch.tensor([1.0, 2.5, -3.7, 0.0], dtype=torch.float32)
save_test_file("float32.pt", {"tensor": float32_tensor}, "Float32 tensor [1.0, 2.5, -3.7, 0.0]")
float64_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float64)
save_test_file("float64.pt", {"tensor": float64_tensor}, "Float64 tensor [1.1, 2.2, 3.3]")
int64_tensor = torch.tensor([100, -200, 300, 0], dtype=torch.int64)
save_test_file("int64.pt", {"tensor": int64_tensor}, "Int64 tensor [100, -200, 300, 0]")
int32_tensor = torch.tensor([10, 20, -30], dtype=torch.int32)
save_test_file("int32.pt", {"tensor": int32_tensor}, "Int32 tensor [10, 20, -30]")
int16_tensor = torch.tensor([1000, -2000, 3000], dtype=torch.int16)
save_test_file("int16.pt", {"tensor": int16_tensor}, "Int16 tensor [1000, -2000, 3000]")
int8_tensor = torch.tensor([127, -128, 0, 50], dtype=torch.int8)
save_test_file("int8.pt", {"tensor": int8_tensor}, "Int8 tensor [127, -128, 0, 50]")
bool_tensor = torch.tensor([True, False, True, True, False], dtype=torch.bool)
save_test_file("bool.pt", {"tensor": bool_tensor}, "Bool tensor [True, False, True, True, False]")
float16_tensor = torch.tensor([1.5, -2.25, 3.125], dtype=torch.float16)
save_test_file("float16.pt", {"tensor": float16_tensor}, "Float16 tensor [1.5, -2.25, 3.125]")
bfloat16_tensor = torch.tensor([1.5, -2.5, 3.5], dtype=torch.bfloat16)
save_test_file("bfloat16.pt", {"tensor": bfloat16_tensor}, "BFloat16 tensor [1.5, -2.5, 3.5]")
uint8_tensor = torch.tensor([0, 128, 255, 42], dtype=torch.uint8)
save_test_file("uint8.pt", {"tensor": uint8_tensor}, "UInt8 tensor [0, 128, 255, 42]")
print("\n=== Generating Multi-dimensional Tensor Tests ===")
tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)
save_test_file("tensor_2d.pt", {"tensor": tensor_2d}, "2D tensor shape (3, 2)")
torch.manual_seed(42)
tensor_3d = torch.randn(2, 3, 4) * 10
save_test_file("tensor_3d.pt", {"tensor": tensor_3d}, "3D tensor shape (2, 3, 4)")
tensor_4d = torch.randn(2, 3, 2, 2)
save_test_file("tensor_4d.pt", {"tensor": tensor_4d}, "4D tensor shape (2, 3, 2, 2)")
print("\n=== Generating State Dict Tests ===")
state_dict = {
"weight": torch.randn(3, 4),
"bias": torch.randn(3),
"running_mean": torch.zeros(3),
"running_var": torch.ones(3),
}
save_test_file("state_dict.pt", state_dict, "State dict with 4 tensors")
nested_dict = {
"layer1": {
"weight": torch.randn(2, 3),
"bias": torch.randn(2)
},
"layer2": {
"weight": torch.randn(4, 2),
"bias": torch.randn(4)
}
}
save_test_file("nested_dict.pt", nested_dict, "Nested state dict")
print("\n=== Generating Model Checkpoint Tests ===")
checkpoint = {
"model_state_dict": {
"fc1.weight": torch.randn(10, 5),
"fc1.bias": torch.randn(10),
"fc2.weight": torch.randn(3, 10),
"fc2.bias": torch.randn(3),
},
"optimizer_state_dict": {
"state": {
"0": { "momentum_buffer": torch.randn(10, 5)
}
}
},
"epoch": 42,
"loss": 0.123
}
save_test_file("checkpoint.pt", checkpoint, "Full checkpoint with model and optimizer state")
print("\n=== Generating Edge Case Tests ===")
empty_tensor = torch.zeros(0)
save_test_file("empty.pt", {"tensor": empty_tensor}, "Empty tensor")
scalar_tensor = torch.tensor(42.0)
save_test_file("scalar.pt", {"tensor": scalar_tensor}, "Scalar tensor (0-dim)")
sparse_like = torch.zeros(100, 100)
sparse_like[0, 0] = 1.0
sparse_like[50, 50] = 2.0
sparse_like[99, 99] = 3.0
save_test_file("large_shape.pt", {"tensor": sparse_like}, "Large shape (100, 100) mostly zeros")
print("\n=== Generating Mixed Type Tests ===")
mixed_types = {
"float32": torch.tensor([1.0, 2.0], dtype=torch.float32),
"int64": torch.tensor([100, 200], dtype=torch.int64),
"bool": torch.tensor([True, False], dtype=torch.bool),
"float64": torch.tensor([1.1, 2.2], dtype=torch.float64),
}
save_test_file("mixed_types.pt", mixed_types, "Dict with mixed tensor types")
print("\n=== Generating Special Value Tests ===")
special_values = torch.tensor([float('nan'), float('inf'), float('-inf'), 0.0, 1.0])
save_test_file("special_values.pt", {"tensor": special_values}, "Tensor with NaN and Inf")
extreme_values = torch.tensor([1e-30, 1e30, -1e-30, -1e30], dtype=torch.float32)
save_test_file("extreme_values.pt", {"tensor": extreme_values}, "Tensor with extreme values")
print("\n=== Generating Parameter Tests ===")
import torch.nn as nn
param = nn.Parameter(torch.randn(3, 3))
param_dict = {"param": param}
save_test_file("parameter.pt", param_dict, "nn.Parameter wrapped tensor")
print("\n=== Generating Buffer Tests ===")
buffers = {
"buffer1": torch.tensor([1, 2, 3], dtype=torch.int32),
"buffer2": torch.tensor([True, False], dtype=torch.bool),
}
save_test_file("buffers.pt", buffers, "Model buffers")
print("\n=== Generating Complex Structure Tests ===")
complex_structure = {
"metadata": {
"version": 1,
"name": "test_model"
},
"state": {
"encoder": {
"layer_0": {
"weight": torch.randn(4, 3),
"bias": torch.randn(4)
},
"layer_1": {
"weight": torch.randn(2, 4),
"bias": torch.randn(2)
}
},
"decoder": {
"weight": torch.randn(3, 2),
"bias": torch.randn(3)
}
},
"config": {
"hidden_size": 4,
"num_layers": 2
}
}
save_test_file("complex_structure.pt", complex_structure, "Complex nested structure")
print(f"\n✅ Generated {len(list(test_dir.glob('*.pt')))} test files in {test_dir}")
print("\nTest files can be used to verify PyTorch reader functionality:")
print("- Different data types (float32, int64, bool, etc.)")
print("- Multi-dimensional tensors")
print("- State dicts and nested structures")
print("- Edge cases (empty, scalar, special values)")
print("- Model checkpoints and parameters")