use crate::TensorSnapshot;
use crate::burnpack::{
base::{BurnpackHeader, HEADER_SIZE},
reader::BurnpackReader,
writer::BurnpackWriter,
};
use burn_core::module::ParamId;
use burn_tensor::{BoolStore, DType, TensorData, shape};
#[test]
fn test_maximum_metadata_size() {
let large_key = "x".repeat(1000);
let large_value = "y".repeat(10000);
let mut writer = BurnpackWriter::new(vec![]);
for i in 0..100 {
writer = writer.with_metadata(&format!("{}_{}", large_key, i), &large_value);
}
let result = writer.to_bytes();
assert!(result.is_ok());
let bytes = result.unwrap();
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
assert!(header.metadata_size > 1000000); assert!(header.metadata_size < u32::MAX);
}
#[test]
fn test_zero_size_tensor_shapes() {
let test_cases = [
(vec![0], vec![]), (vec![0, 10], vec![]), (vec![10, 0], vec![]), (vec![0, 0], vec![]), (vec![5, 0, 10], vec![]), ];
let mut snapshots = vec![];
for (i, (shape, data)) in test_cases.iter().enumerate() {
let name = format!("zero_tensor_{}", i);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::F32),
vec![name.clone()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
assert_eq!(names.len(), 5);
}
#[test]
fn test_extremely_long_tensor_names() {
let long_name = "a".repeat(10000);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec![long_name.clone()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
assert_eq!(names[0].len(), 10000);
}
#[test]
fn test_unicode_in_names_and_metadata() {
let unicode_names = vec![
"测试_tensor", "тест_tensor", "テスト_tensor", "🔥_burn_tensor", "αβγδ_tensor", "한글_tensor", ];
let mut snapshots = vec![];
for name in &unicode_names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots)
.with_metadata("模型名称", "测试模型")
.with_metadata("מודל", "בדיקה")
.with_metadata("🔥", "fire");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
assert_eq!(names.len(), unicode_names.len());
assert_eq!(
reader.metadata().metadata.get("模型名称"),
Some(&"测试模型".to_string())
);
assert_eq!(
reader.metadata().metadata.get("🔥"),
Some(&"fire".to_string())
);
}
#[test]
fn test_all_supported_dtypes() {
let dtypes_with_data = [
(
DType::F32,
[
f32::MIN.to_le_bytes().to_vec(),
f32::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::F64,
[
f64::MIN.to_le_bytes().to_vec(),
f64::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::I32,
[
i32::MIN.to_le_bytes().to_vec(),
i32::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::I64,
[
i64::MIN.to_le_bytes().to_vec(),
i64::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::U32,
[
u32::MIN.to_le_bytes().to_vec(),
u32::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::U64,
[
u64::MIN.to_le_bytes().to_vec(),
u64::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(DType::U8, vec![u8::MIN, u8::MAX]),
(DType::Bool(BoolStore::Native), vec![0, 1]),
];
let mut snapshots = vec![];
for (i, (dtype, data)) in dtypes_with_data.iter().enumerate() {
let name = format!("dtype_test_{}", i);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![2], *dtype),
vec![name],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
assert_eq!(reader.tensor_names().len(), dtypes_with_data.len());
for (i, (expected_dtype, _)) in dtypes_with_data.iter().enumerate() {
let name = format!("dtype_test_{}", i);
let snapshot = reader.get_tensor_snapshot(&name).unwrap();
assert_eq!(snapshot.dtype, *expected_dtype);
}
}
#[test]
fn test_special_float_values() {
let special_values = [
f32::NAN,
f32::INFINITY,
f32::NEG_INFINITY,
0.0_f32,
-0.0_f32,
];
let data: Vec<u8> = special_values
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![5], DType::F32),
vec!["special_floats".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let tensor_data = reader.get_tensor_data("special_floats").unwrap();
assert_eq!(tensor_data, data);
}
#[test]
fn test_metadata_with_empty_values() {
let writer = BurnpackWriter::new(vec![])
.with_metadata("empty_value", "")
.with_metadata("", "empty_key")
.with_metadata("normal", "value");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let metadata = &reader.metadata().metadata;
assert_eq!(metadata.get("empty_value"), Some(&"".to_string()));
assert_eq!(metadata.get(""), Some(&"empty_key".to_string()));
assert_eq!(metadata.get("normal"), Some(&"value".to_string()));
}
#[test]
fn test_single_byte_tensor() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![42], vec![1], DType::U8),
vec!["single_byte".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let data = reader.get_tensor_data("single_byte").unwrap();
assert_eq!(data, vec![42]);
}
#[test]
fn test_high_dimensional_tensor() {
let shape = shape![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; let data = vec![1u8; 1024];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::U8),
vec!["high_dim".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let loaded_snapshot = reader.get_tensor_snapshot("high_dim").unwrap();
assert_eq!(loaded_snapshot.shape, shape);
}
#[test]
fn test_metadata_key_collision() {
let writer = BurnpackWriter::new(vec![])
.with_metadata("key", "value1")
.with_metadata("key", "value2")
.with_metadata("key", "value3");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
assert_eq!(
reader.metadata().metadata.get("key"),
Some(&"value3".to_string())
);
}
#[test]
fn test_tensor_name_with_path_separators() {
let path_like_names = vec![
"model/encoder/layer1/weights",
"model\\decoder\\layer1\\bias",
"model::module::param",
"model.submodule.weight",
];
let mut snapshots = vec![];
for name in &path_like_names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
for expected_name in &path_like_names {
assert!(names.contains(expected_name));
}
}