#[cfg(feature = "serialize-arrow")]
use super::data_science;
#[cfg(feature = "serialize-onnx")]
use super::ml_formats;
use super::{
binary,
common::{SerializationFormat, SerializationOptions},
text_formats,
};
#[allow(unused_imports)]
#[cfg(feature = "serialize-hdf5")]
use super::scientific;
use crate::{Tensor, TensorElement};
use std::path::Path;
use torsh_core::error::{Result, TorshError};
#[cfg(feature = "serialize")]
impl<T: TensorElement + serde::Serialize + for<'a> serde::Deserialize<'a>> Tensor<T> {
pub fn serialize_to_bytes(
&self,
format: SerializationFormat,
options: &SerializationOptions,
) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
match format {
SerializationFormat::Binary => {
binary::serialize_binary(self, &mut buffer, options)?;
}
SerializationFormat::Json => {
text_formats::serialize_json(self, &mut buffer, options)?;
}
SerializationFormat::Numpy => {
text_formats::numpy::serialize_numpy(self, &mut buffer)?;
}
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => {
return Err(TorshError::SerializationError(
"HDF5 format requires file path, use serialize_to_file instead".to_string(),
));
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow | SerializationFormat::Parquet => {
return Err(TorshError::SerializationError(
"Arrow/Parquet format requires file path, use serialize_to_file instead"
.to_string(),
));
}
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => {
return Err(TorshError::SerializationError(
"ONNX format requires file path, use serialize_to_file instead".to_string(),
));
}
}
Ok(buffer)
}
pub fn serialize_to_file<P: AsRef<Path>>(
&self,
path: P,
format: SerializationFormat,
options: &SerializationOptions,
) -> Result<()> {
let path = path.as_ref();
match format {
SerializationFormat::Binary
| SerializationFormat::Json
| SerializationFormat::Numpy => {
let bytes = self.serialize_to_bytes(format, options)?;
std::fs::write(path, bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to write file: {}", e))
})?;
}
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => {
return Err(TorshError::SerializationError(
"HDF5 format requires hdf5::H5Type trait bound. Use binary or numpy format instead, or call scientific::hdf5::serialize_hdf5 directly with H5Type-compatible types".to_string(),
));
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow => {
data_science::arrow::serialize_arrow(self, path, options)?;
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Parquet => {
data_science::parquet::serialize_parquet(self, path, options)?;
}
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => {
ml_formats::onnx::serialize_onnx(self, path, options)?;
}
}
Ok(())
}
pub fn deserialize_from_bytes(data: &[u8], format: SerializationFormat) -> Result<Tensor<T>> {
let mut cursor = std::io::Cursor::new(data);
match format {
SerializationFormat::Binary => binary::deserialize_binary(&mut cursor),
SerializationFormat::Json => text_formats::deserialize_json(&mut cursor),
SerializationFormat::Numpy => text_formats::numpy::deserialize_numpy(&mut cursor),
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => Err(TorshError::SerializationError(
"HDF5 format requires file path, use deserialize_from_file instead".to_string(),
)),
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow | SerializationFormat::Parquet => {
Err(TorshError::SerializationError(
"Arrow/Parquet format requires file path, use deserialize_from_file instead"
.to_string(),
))
}
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => Err(TorshError::SerializationError(
"ONNX format requires file path, use deserialize_from_file instead".to_string(),
)),
}
}
pub fn deserialize_from_file<P: AsRef<Path>>(
path: P,
format: SerializationFormat,
) -> Result<Tensor<T>> {
let path = path.as_ref();
match format {
SerializationFormat::Binary
| SerializationFormat::Json
| SerializationFormat::Numpy => {
let bytes = std::fs::read(path).map_err(|e| {
TorshError::SerializationError(format!("Failed to read file: {}", e))
})?;
Self::deserialize_from_bytes(&bytes, format)
}
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => {
Err(TorshError::SerializationError(
"HDF5 format requires hdf5::H5Type trait bound. Use binary or numpy format instead, or call scientific::hdf5::deserialize_hdf5 directly with H5Type-compatible types".to_string(),
))
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow => data_science::arrow::deserialize_arrow(path),
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Parquet => data_science::parquet::deserialize_parquet(path),
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => ml_formats::onnx::deserialize_onnx(path),
}
}
pub fn save<P: AsRef<Path>>(&self, path: P, options: &SerializationOptions) -> Result<()> {
let path = path.as_ref();
let format = detect_format_from_path(path)?;
self.serialize_to_file(path, format, options)
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Tensor<T>> {
let path = path.as_ref();
let format = detect_format_from_path(path)?;
Self::deserialize_from_file(path, format)
}
}
#[cfg(not(feature = "serialize"))]
impl<T: TensorElement> Tensor<T> {
pub fn serialize_to_bytes(
&self,
format: SerializationFormat,
options: &SerializationOptions,
) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
match format {
SerializationFormat::Binary => {
binary::serialize_binary(self, &mut buffer, options)?;
}
SerializationFormat::Json => {
return Err(TorshError::SerializationError(
"JSON serialization requires the 'serialize' feature to be enabled".to_string(),
));
}
SerializationFormat::Numpy => {
text_formats::numpy::serialize_numpy(self, &mut buffer)?;
}
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => {
return Err(TorshError::SerializationError(
"HDF5 format requires file path, use serialize_to_file instead".to_string(),
));
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow | SerializationFormat::Parquet => {
return Err(TorshError::SerializationError(
"Arrow/Parquet format requires file path, use serialize_to_file instead"
.to_string(),
));
}
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => {
return Err(TorshError::SerializationError(
"ONNX format requires file path, use serialize_to_file instead".to_string(),
));
}
}
Ok(buffer)
}
pub fn serialize_to_file<P: AsRef<Path>>(
&self,
path: P,
format: SerializationFormat,
options: &SerializationOptions,
) -> Result<()> {
let path = path.as_ref();
match format {
SerializationFormat::Binary | SerializationFormat::Numpy => {
let bytes = self.serialize_to_bytes(format, options)?;
std::fs::write(path, bytes).map_err(|e| {
TorshError::SerializationError(format!("Failed to write file: {}", e))
})?;
}
SerializationFormat::Json => {
return Err(TorshError::SerializationError(
"JSON serialization requires the 'serialize' feature to be enabled".to_string(),
));
}
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => {
return Err(TorshError::SerializationError(
"HDF5 format requires hdf5::H5Type trait bound. Use binary or numpy format instead, or call scientific::hdf5::serialize_hdf5 directly with H5Type-compatible types".to_string(),
));
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow => {
data_science::arrow::serialize_arrow(self, path, options)?;
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Parquet => {
data_science::parquet::serialize_parquet(self, path, options)?;
}
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => {
ml_formats::onnx::serialize_onnx(self, path, options)?;
}
}
Ok(())
}
pub fn deserialize_from_bytes(data: &[u8], format: SerializationFormat) -> Result<Tensor<T>> {
let mut cursor = std::io::Cursor::new(data);
match format {
SerializationFormat::Binary => binary::deserialize_binary(&mut cursor),
SerializationFormat::Json => Err(TorshError::SerializationError(
"JSON deserialization requires the 'serialize' feature to be enabled".to_string(),
)),
SerializationFormat::Numpy => text_formats::numpy::deserialize_numpy(&mut cursor),
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => Err(TorshError::SerializationError(
"HDF5 format requires file path, use deserialize_from_file instead".to_string(),
)),
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow | SerializationFormat::Parquet => {
Err(TorshError::SerializationError(
"Arrow/Parquet format requires file path, use deserialize_from_file instead"
.to_string(),
))
}
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => Err(TorshError::SerializationError(
"ONNX format requires file path, use deserialize_from_file instead".to_string(),
)),
}
}
pub fn deserialize_from_file<P: AsRef<Path>>(
path: P,
format: SerializationFormat,
) -> Result<Tensor<T>> {
let path = path.as_ref();
match format {
SerializationFormat::Binary | SerializationFormat::Numpy => {
let bytes = std::fs::read(path).map_err(|e| {
TorshError::SerializationError(format!("Failed to read file: {}", e))
})?;
Self::deserialize_from_bytes(&bytes, format)
}
SerializationFormat::Json => Err(TorshError::SerializationError(
"JSON deserialization requires the 'serialize' feature to be enabled".to_string(),
)),
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => {
Err(TorshError::SerializationError(
"HDF5 format requires hdf5::H5Type trait bound. Use binary or numpy format instead, or call scientific::hdf5::deserialize_hdf5 directly with H5Type-compatible types".to_string(),
))
}
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow => data_science::arrow::deserialize_arrow(path),
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Parquet => data_science::parquet::deserialize_parquet(path),
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => ml_formats::onnx::deserialize_onnx(path),
}
}
pub fn save<P: AsRef<Path>>(&self, path: P, options: &SerializationOptions) -> Result<()> {
let path = path.as_ref();
let format = detect_format_from_path(path)?;
self.serialize_to_file(path, format, options)
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Tensor<T>> {
let path = path.as_ref();
let format = detect_format_from_path(path)?;
Self::deserialize_from_file(path, format)
}
}
fn detect_format_from_path(path: &Path) -> Result<SerializationFormat> {
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| {
TorshError::SerializationError(
"Cannot detect format: file has no extension".to_string(),
)
})?;
match extension.to_lowercase().as_str() {
"trsh" | "bin" => Ok(SerializationFormat::Binary),
"json" => Ok(SerializationFormat::Json),
"npy" => Ok(SerializationFormat::Numpy),
#[cfg(feature = "serialize-hdf5")]
"h5" | "hdf5" => Ok(SerializationFormat::Hdf5),
#[cfg(feature = "serialize-arrow")]
"arrow" => Ok(SerializationFormat::Arrow),
#[cfg(feature = "serialize-arrow")]
"parquet" => Ok(SerializationFormat::Parquet),
#[cfg(feature = "serialize-onnx")]
"onnx" => Ok(SerializationFormat::Onnx),
_ => Err(TorshError::SerializationError(format!(
"Unsupported file extension: .{}",
extension
))),
}
}
pub fn validate_format_support(format: SerializationFormat) -> Result<()> {
match format {
SerializationFormat::Binary | SerializationFormat::Numpy => {
Ok(())
}
SerializationFormat::Json => {
#[cfg(feature = "serialize")]
{
Ok(())
}
#[cfg(not(feature = "serialize"))]
{
Err(TorshError::SerializationError(
"JSON format requires the 'serialize' feature".to_string(),
))
}
}
#[cfg(feature = "serialize-hdf5")]
SerializationFormat::Hdf5 => Ok(()),
#[cfg(feature = "serialize-arrow")]
SerializationFormat::Arrow | SerializationFormat::Parquet => Ok(()),
#[cfg(feature = "serialize-onnx")]
SerializationFormat::Onnx => Ok(()),
}
}