use super::common::{SerializationFormat, SerializationOptions, TensorMetadata};
use crate::{Tensor, TensorElement};
use std::io::{Read, Write};
use torsh_core::{
device::DeviceType,
error::{Result, TorshError},
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
struct SerializableTensor<T> {
metadata: TensorMetadata,
data: Vec<T>,
}
#[cfg(feature = "serialize")]
pub fn serialize_json<T: TensorElement + serde::Serialize, W: Write>(
tensor: &Tensor<T>,
writer: &mut W,
options: &SerializationOptions,
) -> Result<()> {
let data = tensor.data()?.clone();
let metadata = TensorMetadata::from_tensor(
tensor,
options,
SerializationFormat::Json,
data.len() * std::mem::size_of::<T>(),
);
let serializable = SerializableTensor { metadata, data };
let json_data = if options.compression_level == 0 {
serde_json::to_vec_pretty(&serializable)
} else {
serde_json::to_vec(&serializable)
}
.map_err(|e| TorshError::SerializationError(format!("JSON serialization failed: {}", e)))?;
writer
.write_all(&json_data)
.map_err(|e| TorshError::SerializationError(format!("Failed to write JSON data: {}", e)))?;
Ok(())
}
#[cfg(not(feature = "serialize"))]
pub fn serialize_json<T: TensorElement, W: Write>(
_tensor: &Tensor<T>,
_writer: &mut W,
_options: &SerializationOptions,
) -> Result<()> {
Err(TorshError::SerializationError(
"JSON serialization requires the 'serialize' feature to be enabled".to_string(),
))
}
#[cfg(feature = "serialize")]
pub fn deserialize_json<T: TensorElement + for<'a> serde::Deserialize<'a>, R: Read>(
reader: &mut R,
) -> Result<Tensor<T>> {
let mut json_data = Vec::new();
reader
.read_to_end(&mut json_data)
.map_err(|e| TorshError::SerializationError(format!("Failed to read JSON data: {}", e)))?;
let serializable: SerializableTensor<T> = serde_json::from_slice(&json_data).map_err(|e| {
TorshError::SerializationError(format!("JSON deserialization failed: {}", e))
})?;
serializable
.metadata
.validate()
.map_err(|e| TorshError::SerializationError(format!("Invalid metadata in JSON: {}", e)))?;
Tensor::from_data(
serializable.data,
serializable.metadata.shape.dims().to_vec(),
serializable.metadata.device,
)
}
#[cfg(not(feature = "serialize"))]
pub fn deserialize_json<T: TensorElement, R: Read>(_reader: &mut R) -> Result<Tensor<T>> {
Err(TorshError::SerializationError(
"JSON deserialization requires the 'serialize' feature to be enabled".to_string(),
))
}
pub mod numpy {
use super::*;
#[derive(Debug, Clone)]
struct NumpyHeader {
dtype: String,
fortran_order: bool,
shape: Vec<usize>,
}
impl NumpyHeader {
fn new<T: TensorElement>(shape: &[usize]) -> Self {
let dtype = match std::any::TypeId::of::<T>() {
id if id == std::any::TypeId::of::<f32>() => "<f4".to_string(),
id if id == std::any::TypeId::of::<f64>() => "<f8".to_string(),
id if id == std::any::TypeId::of::<i8>() => "<i1".to_string(),
id if id == std::any::TypeId::of::<i16>() => "<i2".to_string(),
id if id == std::any::TypeId::of::<i32>() => "<i4".to_string(),
id if id == std::any::TypeId::of::<i64>() => "<i8".to_string(),
id if id == std::any::TypeId::of::<u8>() => "<u1".to_string(),
id if id == std::any::TypeId::of::<u16>() => "<u2".to_string(),
id if id == std::any::TypeId::of::<u32>() => "<u4".to_string(),
id if id == std::any::TypeId::of::<u64>() => "<u8".to_string(),
_ => "<f4".to_string(), };
Self {
dtype,
fortran_order: false, shape: shape.to_vec(),
}
}
fn to_string(&self) -> String {
format!(
"{{'descr': '{}', 'fortran_order': {}, 'shape': {:?}}}",
self.dtype, self.fortran_order, self.shape
)
}
fn from_string(s: &str) -> Result<Self> {
let s = s.trim().trim_start_matches('{').trim_end_matches('}');
let mut dtype = String::new();
let mut fortran_order = false;
let mut shape = Vec::new();
let parts: Vec<&str> = s.split(',').collect();
for part in parts {
let part = part.trim();
if part.starts_with("'descr'") || part.starts_with("\"descr\"") {
if let Some(start) = part.find(':') {
let value_part = &part[start + 1..].trim();
if let Some(quote_start) = value_part.find(['\'', '"']) {
let quote_char = value_part
.chars()
.nth(quote_start)
.expect("char at found index should exist");
if let Some(quote_end) = value_part[quote_start + 1..].find(quote_char)
{
dtype = value_part[quote_start + 1..quote_start + 1 + quote_end]
.to_string();
}
}
}
} else if part.starts_with("'fortran_order'")
|| part.starts_with("\"fortran_order\"")
{
fortran_order = part.contains("True");
} else if part.starts_with("'shape'") || part.starts_with("\"shape\"") {
if let Some(tuple_start) = part.find('(') {
if let Some(tuple_end) = part.rfind(')') {
let tuple_content = &part[tuple_start + 1..tuple_end];
if !tuple_content.trim().is_empty() {
for dim in tuple_content.split(',') {
let dim = dim.trim();
if !dim.is_empty() {
shape.push(dim.parse().map_err(|_| {
TorshError::SerializationError(format!(
"Invalid shape dimension: '{}'",
dim
))
})?);
}
}
}
}
}
}
}
if dtype.is_empty() {
return Err(TorshError::SerializationError(
"Missing or invalid dtype in NumPy header".to_string(),
));
}
Ok(Self {
dtype,
fortran_order,
shape,
})
}
#[allow(dead_code)]
fn element_size(&self) -> Result<usize> {
let size_str = &self.dtype[2..]; size_str.parse().map_err(|_| {
TorshError::SerializationError(format!("Invalid dtype format: {}", self.dtype))
})
}
}
pub fn serialize_numpy<T: TensorElement, W: Write>(
tensor: &Tensor<T>,
writer: &mut W,
) -> Result<()> {
writer.write_all(b"\x93NUMPY").map_err(|e| {
TorshError::SerializationError(format!("Failed to write NumPy magic: {}", e))
})?;
writer.write_all(&[0x01, 0x00]).map_err(|e| {
TorshError::SerializationError(format!("Failed to write NumPy version: {}", e))
})?;
let header = NumpyHeader::new::<T>(tensor.shape().dims());
let header_str = header.to_string();
let base_len = 6 + 2 + 2 + header_str.len() + 1; let padding = if base_len % 64 == 0 {
0
} else {
64 - (base_len % 64)
};
let padded_header = format!("{}{}\n", header_str, " ".repeat(padding));
let header_len = padded_header.len() as u16;
writer.write_all(&header_len.to_le_bytes()).map_err(|e| {
TorshError::SerializationError(format!("Failed to write header length: {}", e))
})?;
writer.write_all(padded_header.as_bytes()).map_err(|e| {
TorshError::SerializationError(format!("Failed to write header: {}", e))
})?;
let data = tensor.data()?;
let data_bytes = unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const u8,
data.len() * std::mem::size_of::<T>(),
)
};
writer.write_all(data_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to write tensor data: {}", e))
})?;
Ok(())
}
pub fn deserialize_numpy<T: TensorElement, R: Read>(reader: &mut R) -> Result<Tensor<T>> {
let mut magic = [0u8; 6];
reader.read_exact(&mut magic).map_err(|e| {
TorshError::SerializationError(format!("Failed to read NumPy magic: {}", e))
})?;
if &magic != b"\x93NUMPY" {
return Err(TorshError::SerializationError(format!(
"Invalid NumPy magic string: expected b\"\\x93NUMPY\", got {:?}",
magic
)));
}
let mut version = [0u8; 2];
reader.read_exact(&mut version).map_err(|e| {
TorshError::SerializationError(format!("Failed to read NumPy version: {}", e))
})?;
if version[0] != 1 || version[1] != 0 {
return Err(TorshError::SerializationError(format!(
"Unsupported NumPy version: {}.{} (only 1.0 is supported)",
version[0], version[1]
)));
}
let mut header_len_bytes = [0u8; 2];
reader.read_exact(&mut header_len_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read header length: {}", e))
})?;
let header_len = u16::from_le_bytes(header_len_bytes) as usize;
let mut header_bytes = vec![0u8; header_len];
reader
.read_exact(&mut header_bytes)
.map_err(|e| TorshError::SerializationError(format!("Failed to read header: {}", e)))?;
let header_str = String::from_utf8(header_bytes).map_err(|e| {
TorshError::SerializationError(format!("Invalid UTF-8 in NumPy header: {}", e))
})?;
let header = NumpyHeader::from_string(&header_str)?;
let expected_header = NumpyHeader::new::<T>(&header.shape);
if header.dtype != expected_header.dtype {
return Err(TorshError::SerializationError(format!(
"NumPy dtype mismatch: file contains '{}', expected '{}' for type {}",
header.dtype,
expected_header.dtype,
std::any::type_name::<T>()
)));
}
if header.fortran_order {
return Err(TorshError::SerializationError(
"Fortran order arrays are not currently supported".to_string(),
));
}
let numel = header.shape.iter().product::<usize>();
let expected_data_size = numel * std::mem::size_of::<T>();
if numel == 0 {
return Err(TorshError::SerializationError(
"Cannot deserialize array with zero elements".to_string(),
));
}
let mut data_bytes = vec![0u8; expected_data_size];
reader.read_exact(&mut data_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read tensor data: {}", e))
})?;
let mut typed_data = Vec::with_capacity(numel);
let byte_ptr = data_bytes.as_ptr();
for i in 0..numel {
unsafe {
let element_ptr = byte_ptr.add(i * std::mem::size_of::<T>()) as *const T;
typed_data.push(std::ptr::read(element_ptr));
}
}
Tensor::from_data(typed_data, header.shape, DeviceType::Cpu)
}
pub fn validate_numpy_format<R: Read>(reader: &mut R) -> Result<(Vec<usize>, String)> {
let mut magic = [0u8; 6];
reader.read_exact(&mut magic).map_err(|e| {
TorshError::SerializationError(format!("Failed to read NumPy magic: {}", e))
})?;
if &magic != b"\x93NUMPY" {
return Err(TorshError::SerializationError(
"Invalid NumPy magic string".to_string(),
));
}
let mut version = [0u8; 2];
reader.read_exact(&mut version).map_err(|e| {
TorshError::SerializationError(format!("Failed to read NumPy version: {}", e))
})?;
if version[0] != 1 || version[1] != 0 {
return Err(TorshError::SerializationError(format!(
"Unsupported NumPy version: {}.{}",
version[0], version[1]
)));
}
let mut header_len_bytes = [0u8; 2];
reader.read_exact(&mut header_len_bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to read header length: {}", e))
})?;
let header_len = u16::from_le_bytes(header_len_bytes) as usize;
let mut header_bytes = vec![0u8; header_len];
reader
.read_exact(&mut header_bytes)
.map_err(|e| TorshError::SerializationError(format!("Failed to read header: {}", e)))?;
let header_str = String::from_utf8(header_bytes).map_err(|e| {
TorshError::SerializationError(format!("Invalid UTF-8 in header: {}", e))
})?;
let header = NumpyHeader::from_string(&header_str)?;
Ok((header.shape, header.dtype))
}
}