pub mod binary;
pub mod legacy;
pub use binary::{BinarySerializer, SerializedDType, SerializedDevice};
#[cfg(feature = "compression")]
pub use legacy::{compress_bytes, decompress_bytes};
pub use legacy::{
deserialize_tensor_binary, deserialize_tensor_json, deserialize_tensor_msgpack,
load_checkpoint, load_tensor, save_checkpoint, save_tensor, serialize_tensor_binary,
serialize_tensor_json, serialize_tensor_msgpack, SerializationFormat,
TensorMetadata as LegacyTensorMetadata, MAGIC_NUMBER, SERIALIZATION_VERSION,
};
pub use binary::TensorMetadata;
#[cfg(test)]
mod tests {
use super::*;
use crate::{Device, Tensor};
use std::io::Cursor;
#[test]
fn test_roundtrip_serialization() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let original =
Tensor::from_data(data.clone(), &[3, 3]).expect("test: operation should succeed");
let mut buffer = Vec::new();
BinarySerializer::serialize(&original, &mut buffer, None)
.expect("test: serialize should succeed");
let mut cursor = Cursor::new(buffer);
let (restored, _): (Tensor<f32>, _) =
BinarySerializer::deserialize(&mut cursor).expect("test: deserialize should succeed");
assert_eq!(original.shape().dims(), restored.shape().dims());
assert_eq!(
original.as_slice().expect("tensor should be contiguous"),
restored.as_slice().expect("tensor should be contiguous")
);
}
#[test]
fn test_metadata_preservation() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let original =
Tensor::from_data(data.clone(), &[2, 2]).expect("test: operation should succeed");
let mut metadata = TensorMetadata::new();
metadata.name = Some("weight_matrix".to_string());
metadata.requires_grad = true;
metadata.add_field("layer".to_string(), "conv1".to_string());
metadata.add_field("param_type".to_string(), "weight".to_string());
let mut buffer = Vec::new();
BinarySerializer::serialize(&original, &mut buffer, Some(&metadata))
.expect("test: operation should succeed");
let mut cursor = Cursor::new(buffer);
let (restored, meta): (Tensor<f32>, _) =
BinarySerializer::deserialize(&mut cursor).expect("test: deserialize should succeed");
assert_eq!(
original.as_slice().expect("tensor should be contiguous"),
restored.as_slice().expect("tensor should be contiguous")
);
assert!(meta.is_some());
let meta = meta.expect("test: operation should succeed");
assert_eq!(meta.name, Some("weight_matrix".to_string()));
assert!(meta.requires_grad);
assert_eq!(meta.fields.get("layer"), Some(&"conv1".to_string()));
assert_eq!(meta.fields.get("param_type"), Some(&"weight".to_string()));
}
#[test]
fn test_large_tensor_serialization() {
let size = 1000;
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let original =
Tensor::from_data(data.clone(), &[size]).expect("test: operation should succeed");
let mut buffer = Vec::new();
BinarySerializer::serialize(&original, &mut buffer, None)
.expect("test: serialize should succeed");
let expected_size = std::mem::size_of::<f32>() * size + 64; assert!(buffer.len() >= expected_size);
let mut cursor = Cursor::new(buffer);
let (restored, _): (Tensor<f32>, _) =
BinarySerializer::deserialize(&mut cursor).expect("test: deserialize should succeed");
assert_eq!(
original.as_slice().expect("tensor should be contiguous"),
restored.as_slice().expect("tensor should be contiguous")
);
}
}