use super::common::{SerializationFormat, SerializationOptions, TensorMetadata};
use crate::{Tensor, TensorElement};
use std::io::{Read, Write};
use torsh_core::error::{Result, TorshError};
const TORSH_MAGIC: &[u8] = b"TRSH";
const FORMAT_VERSION: u32 = 1;
struct Crc32 {
table: [u32; 256],
}
impl Crc32 {
fn new() -> Self {
let mut table = [0u32; 256];
for i in 0..256 {
let mut crc = i as u32;
for _ in 0..8 {
if crc & 1 == 1 {
crc = (crc >> 1) ^ 0xEDB88320;
} else {
crc >>= 1;
}
}
table[i] = crc;
}
Self { table }
}
fn checksum(&self, data: &[u8]) -> u32 {
let mut crc = 0xFFFFFFFF_u32;
for &byte in data {
let index = ((crc ^ byte as u32) & 0xFF) as usize;
crc = (crc >> 8) ^ self.table[index];
}
!crc
}
}
fn calculate_crc32(data: &[u8]) -> u32 {
let crc = Crc32::new();
crc.checksum(data)
}
#[derive(Debug)]
struct BinaryHeader {
magic: [u8; 4],
version: u32,
metadata_size: u64,
data_size: u64,
checksum: u32,
}
impl BinaryHeader {
#[allow(dead_code)]
fn new(metadata_size: u64, data_size: u64) -> Self {
Self {
magic: *b"TRSH",
version: FORMAT_VERSION,
metadata_size,
data_size,
checksum: 0, }
}
fn with_checksum(metadata_size: u64, data_size: u64, checksum: u32) -> Self {
Self {
magic: *b"TRSH",
version: FORMAT_VERSION,
metadata_size,
data_size,
checksum,
}
}
fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
writer.write_all(&self.magic).map_err(|e| {
TorshError::SerializationError(format!("Failed to write magic bytes: {}", e))
})?;
writer.write_all(&self.version.to_le_bytes()).map_err(|e| {
TorshError::SerializationError(format!("Failed to write version: {}", e))
})?;
writer
.write_all(&self.metadata_size.to_le_bytes())
.map_err(|e| {
TorshError::SerializationError(format!("Failed to write metadata size: {}", e))
})?;
writer
.write_all(&self.data_size.to_le_bytes())
.map_err(|e| {
TorshError::SerializationError(format!("Failed to write data size: {}", e))
})?;
writer
.write_all(&self.checksum.to_le_bytes())
.map_err(|e| {
TorshError::SerializationError(format!("Failed to write checksum: {}", e))
})?;
Ok(())
}
fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
let mut magic = [0u8; 4];
reader.read_exact(&mut magic).map_err(|e| {
TorshError::SerializationError(format!("Failed to read magic bytes: {}", e))
})?;
if &magic != TORSH_MAGIC {
return Err(TorshError::SerializationError(format!(
"Invalid magic bytes in binary format: expected {:?}, got {:?}",
TORSH_MAGIC, magic
)));
}
let mut version_bytes = [0u8; 4];
reader.read_exact(&mut version_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read version: {}", e))
})?;
let version = u32::from_le_bytes(version_bytes);
if version != FORMAT_VERSION {
return Err(TorshError::SerializationError(format!(
"Unsupported format version: expected {}, got {}",
FORMAT_VERSION, version
)));
}
let mut metadata_size_bytes = [0u8; 8];
reader.read_exact(&mut metadata_size_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read metadata size: {}", e))
})?;
let metadata_size = u64::from_le_bytes(metadata_size_bytes);
let mut data_size_bytes = [0u8; 8];
reader.read_exact(&mut data_size_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read data size: {}", e))
})?;
let data_size = u64::from_le_bytes(data_size_bytes);
let mut checksum_bytes = [0u8; 4];
reader.read_exact(&mut checksum_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read checksum: {}", e))
})?;
let checksum = u32::from_le_bytes(checksum_bytes);
Ok(Self {
magic,
version,
metadata_size,
data_size,
checksum,
})
}
const fn size() -> usize {
4 + 4 + 8 + 8 + 4 }
}
pub fn serialize_binary<T: TensorElement, W: Write>(
tensor: &Tensor<T>,
writer: &mut W,
options: &SerializationOptions,
) -> Result<()> {
let data_size = tensor.numel() * std::mem::size_of::<T>();
let mut metadata =
TensorMetadata::from_tensor(tensor, options, SerializationFormat::Binary, data_size);
#[cfg(feature = "serialize")]
let _metadata_bytes = oxicode::serde::encode_to_vec(&metadata, oxicode::config::standard())
.map_err(|e| {
TorshError::SerializationError(format!("Failed to serialize metadata: {}", e))
})?;
#[cfg(not(feature = "serialize"))]
let metadata_bytes = {
return Err(TorshError::SerializationError(
"Serialization feature not enabled".to_string(),
));
};
let data = tensor.data()?;
let data_bytes = unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const u8,
data.len() * std::mem::size_of::<T>(),
)
};
let (final_data_bytes, compressed) = if options.compression_level > 0 {
#[cfg(feature = "serialize")]
{
let zstd_level = (options.compression_level as i32).clamp(1, 22);
let compressed_bytes = oxiarc_zstd::compress_with_level(data_bytes, zstd_level)
.map_err(|e| {
TorshError::SerializationError(format!("Failed to compress tensor data: {}", e))
})?;
(compressed_bytes, true)
}
#[cfg(not(feature = "serialize"))]
{
(data_bytes.to_vec(), false)
}
} else {
(data_bytes.to_vec(), false)
};
metadata.compressed = compressed;
metadata.data_size = final_data_bytes.len();
#[cfg(feature = "serialize")]
let final_metadata_bytes =
oxicode::serde::encode_to_vec(&metadata, oxicode::config::standard()).map_err(|e| {
TorshError::SerializationError(format!("Failed to serialize updated metadata: {}", e))
})?;
#[cfg(not(feature = "serialize"))]
let final_metadata_bytes = metadata_bytes;
let mut combined_data = Vec::new();
combined_data.extend_from_slice(&final_metadata_bytes);
combined_data.extend_from_slice(&final_data_bytes);
let checksum = calculate_crc32(&combined_data);
metadata.checksum = Some(format!("0x{:08X}", checksum));
let header = BinaryHeader::with_checksum(
final_metadata_bytes.len() as u64,
final_data_bytes.len() as u64,
checksum,
);
header.write_to(writer)?;
writer
.write_all(&final_metadata_bytes)
.map_err(|e| TorshError::SerializationError(format!("Failed to write metadata: {}", e)))?;
writer.write_all(&final_data_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to write tensor data: {}", e))
})?;
Ok(())
}
pub fn deserialize_binary<T: TensorElement, R: Read>(reader: &mut R) -> Result<Tensor<T>> {
let header = BinaryHeader::read_from(reader)?;
if header.metadata_size == 0 {
return Err(TorshError::SerializationError(
"Invalid header: metadata size cannot be zero".to_string(),
));
}
if header.data_size == 0 {
return Err(TorshError::SerializationError(
"Invalid header: data size cannot be zero".to_string(),
));
}
let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
reader
.read_exact(&mut metadata_bytes)
.map_err(|e| TorshError::SerializationError(format!("Failed to read metadata: {}", e)))?;
#[cfg(feature = "serialize")]
let (metadata, _): (TensorMetadata, usize) =
oxicode::serde::decode_from_slice(&metadata_bytes, oxicode::config::standard()).map_err(
|e| TorshError::SerializationError(format!("Failed to deserialize metadata: {}", e)),
)?;
#[cfg(not(feature = "serialize"))]
let metadata = {
return Err(TorshError::SerializationError(
"Serialization feature not enabled".to_string(),
));
};
metadata
.validate()
.map_err(|e| TorshError::SerializationError(format!("Invalid metadata: {}", e)))?;
let mut data_bytes = vec![0u8; header.data_size as usize];
reader.read_exact(&mut data_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read tensor data: {}", e))
})?;
let mut combined_data = Vec::new();
combined_data.extend_from_slice(&metadata_bytes);
combined_data.extend_from_slice(&data_bytes);
let calculated_checksum = calculate_crc32(&combined_data);
if calculated_checksum != header.checksum {
return Err(TorshError::SerializationError(format!(
"Data corruption detected: checksum mismatch (expected 0x{:08X}, got 0x{:08X})",
header.checksum, calculated_checksum
)));
}
let final_data_bytes = if metadata.compressed {
#[cfg(feature = "serialize")]
{
oxiarc_zstd::decompress(&data_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to decompress tensor data: {}", e))
})?
}
#[cfg(not(feature = "serialize"))]
{
return Err(TorshError::SerializationError(
"Cannot decompress: serialization feature not enabled".to_string(),
));
}
} else {
data_bytes
};
let expected_len = metadata.shape.numel();
let actual_len = final_data_bytes.len() / std::mem::size_of::<T>();
if expected_len != actual_len {
return Err(TorshError::SerializationError(format!(
"Data size mismatch: expected {} elements, got {} (shape: {:?}, element size: {} bytes)",
expected_len, actual_len, metadata.shape.dims(), std::mem::size_of::<T>()
)));
}
let mut typed_data = Vec::with_capacity(actual_len);
let byte_ptr = final_data_bytes.as_ptr();
for i in 0..actual_len {
unsafe {
let element_ptr = byte_ptr.add(i * std::mem::size_of::<T>()) as *const T;
typed_data.push(std::ptr::read(element_ptr));
}
}
Tensor::from_data(typed_data, metadata.shape.dims().to_vec(), metadata.device)
}
pub fn estimate_binary_size<T: TensorElement>(
tensor: &Tensor<T>,
options: &SerializationOptions,
) -> usize {
let header_size = BinaryHeader::size();
let metadata_size = 200; let data_size = tensor.numel() * std::mem::size_of::<T>();
let compressed_data_size = if options.compression_level > 0 {
let compression_ratio = match options.compression_level {
1..=3 => 0.8,
4..=6 => 0.6,
7..=9 => 0.4,
_ => 1.0,
};
(data_size as f64 * compression_ratio) as usize
} else {
data_size
};
header_size + metadata_size + compressed_data_size
}
pub fn validate_binary_format<R: Read>(reader: &mut R) -> Result<TensorMetadata> {
let header = BinaryHeader::read_from(reader)?;
let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
reader.read_exact(&mut metadata_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read metadata for validation: {}", e))
})?;
#[cfg(feature = "serialize")]
let (metadata, _): (TensorMetadata, usize) =
oxicode::serde::decode_from_slice(&metadata_bytes, oxicode::config::standard()).map_err(
|e| {
TorshError::SerializationError(format!(
"Failed to deserialize metadata for validation: {}",
e
))
},
)?;
#[cfg(not(feature = "serialize"))]
return Err(TorshError::SerializationError(
"Serialization feature not enabled".to_string(),
));
metadata.validate().map_err(|e| {
TorshError::SerializationError(format!("Invalid metadata during validation: {}", e))
})?;
Ok(metadata)
}