use crate::{Tensor, TensorElement};
use std::collections::HashMap;
use torsh_core::{device::DeviceType, shape::Shape};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SerializationFormat {
Binary,
Json,
Numpy,
#[cfg(feature = "serialize-hdf5")]
Hdf5,
#[cfg(feature = "serialize-arrow")]
Arrow,
#[cfg(feature = "serialize-arrow")]
Parquet,
#[cfg(feature = "serialize-onnx")]
Onnx,
}
impl SerializationFormat {
pub fn supports_streaming(self) -> bool {
match self {
Self::Binary | Self::Numpy => true,
Self::Json => false, #[cfg(feature = "serialize-hdf5")]
Self::Hdf5 => true,
#[cfg(feature = "serialize-arrow")]
Self::Arrow | Self::Parquet => true,
#[cfg(feature = "serialize-onnx")]
Self::Onnx => false, }
}
pub fn requires_file_path(self) -> bool {
match self {
Self::Binary | Self::Json | Self::Numpy => false,
#[cfg(feature = "serialize-hdf5")]
Self::Hdf5 => true,
#[cfg(feature = "serialize-arrow")]
Self::Arrow | Self::Parquet => true,
#[cfg(feature = "serialize-onnx")]
Self::Onnx => true,
}
}
pub fn supports_compression(self) -> bool {
match self {
Self::Binary => true, Self::Json | Self::Numpy => false,
#[cfg(feature = "serialize-hdf5")]
Self::Hdf5 => true,
#[cfg(feature = "serialize-arrow")]
Self::Arrow | Self::Parquet => true,
#[cfg(feature = "serialize-onnx")]
Self::Onnx => false,
}
}
pub fn file_extension(self) -> &'static str {
match self {
Self::Binary => "trsh",
Self::Json => "json",
Self::Numpy => "npy",
#[cfg(feature = "serialize-hdf5")]
Self::Hdf5 => "h5",
#[cfg(feature = "serialize-arrow")]
Self::Arrow => "arrow",
#[cfg(feature = "serialize-arrow")]
Self::Parquet => "parquet",
#[cfg(feature = "serialize-onnx")]
Self::Onnx => "onnx",
}
}
pub fn mime_type(self) -> &'static str {
match self {
Self::Binary => "application/octet-stream",
Self::Json => "application/json",
Self::Numpy => "application/octet-stream",
#[cfg(feature = "serialize-hdf5")]
Self::Hdf5 => "application/x-hdf5",
#[cfg(feature = "serialize-arrow")]
Self::Arrow => "application/vnd.apache.arrow.file",
#[cfg(feature = "serialize-arrow")]
Self::Parquet => "application/vnd.apache.parquet",
#[cfg(feature = "serialize-onnx")]
Self::Onnx => "application/onnx",
}
}
}
#[derive(Debug, Clone)]
pub struct SerializationOptions {
pub include_gradients: bool,
pub include_operations: bool,
pub compression_level: u8,
pub metadata: HashMap<String, String>,
pub chunk_size: Option<usize>,
pub validate_data: bool,
pub preserve_precision: bool,
}
impl Default for SerializationOptions {
fn default() -> Self {
Self {
include_gradients: false,
include_operations: false,
compression_level: 0,
metadata: HashMap::new(),
chunk_size: None,
validate_data: true,
preserve_precision: true,
}
}
}
impl SerializationOptions {
pub fn fast() -> Self {
Self {
include_gradients: false,
include_operations: false,
compression_level: 0,
metadata: HashMap::new(),
chunk_size: Some(64 * 1024 * 1024), validate_data: false,
preserve_precision: false,
}
}
pub fn compact() -> Self {
Self {
include_gradients: false,
include_operations: false,
compression_level: 9,
metadata: HashMap::new(),
chunk_size: Some(1024 * 1024), validate_data: true,
preserve_precision: false,
}
}
pub fn debug() -> Self {
Self {
include_gradients: true,
include_operations: true,
compression_level: 1, metadata: HashMap::new(),
chunk_size: None,
validate_data: true,
preserve_precision: true,
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn with_compression(mut self, level: u8) -> Self {
self.compression_level = level.min(9);
self
}
pub fn with_gradients(mut self, include: bool) -> Self {
self.include_gradients = include;
self
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct TensorMetadata {
pub shape: Shape,
pub device: DeviceType,
pub requires_grad: bool,
pub dtype_name: String,
pub version: String,
pub timestamp: u64,
pub custom_metadata: HashMap<String, String>,
pub format: String,
pub data_size: usize,
pub compressed: bool,
pub checksum: Option<String>,
}
impl TensorMetadata {
pub fn from_tensor<T: TensorElement>(
tensor: &Tensor<T>,
options: &SerializationOptions,
format: SerializationFormat,
data_size: usize,
) -> Self {
Self {
shape: tensor.shape(),
device: tensor.device(),
requires_grad: tensor.requires_grad(),
dtype_name: std::any::type_name::<T>().to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
custom_metadata: options.metadata.clone(),
format: format!("{:?}", format),
data_size,
compressed: options.compression_level > 0,
checksum: None, }
}
pub fn validate(&self) -> Result<(), String> {
if self.shape.numel() == 0 {
return Err("Invalid shape: tensor cannot have zero size".to_string());
}
if self.dtype_name.is_empty() {
return Err("Invalid dtype: type name cannot be empty".to_string());
}
if self.version.is_empty() {
return Err("Invalid version: version string cannot be empty".to_string());
}
if self.data_size == 0 {
return Err("Invalid data size: cannot be zero".to_string());
}
Ok(())
}
pub fn estimated_memory_usage(&self) -> usize {
let base_size = self.data_size;
let overhead = base_size / 10; base_size + overhead
}
pub fn has_gradients(&self) -> bool {
self.requires_grad
}
pub fn size_description(&self) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
let mut size = self.data_size as f64;
let mut unit_idx = 0;
while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
size /= 1024.0;
unit_idx += 1;
}
if size.fract() == 0.0 {
format!("{:.0} {}", size, UNITS[unit_idx])
} else {
format!("{:.1} {}", size, UNITS[unit_idx])
}
}
}