use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
pub const MAGIC: &[u8; 8] = b"ZTEN1000";
pub(crate) const MAGIC_V01: &[u8; 8] = b"ZTEN0001";
pub const ALIGNMENT: u64 = 64;
pub const MAX_MANIFEST_SIZE: u64 = 1_073_741_824;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DType {
F64,
F32,
F16,
#[serde(rename = "bf16")]
BF16,
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
Bool,
}
impl DType {
pub fn byte_size(&self) -> usize {
match self {
Self::F64 | Self::I64 | Self::U64 => 8,
Self::F32 | Self::I32 | Self::U32 => 4,
Self::F16 | Self::BF16 | Self::I16 | Self::U16 => 2,
Self::I8 | Self::U8 | Self::Bool => 1,
}
}
pub fn is_multi_byte(&self) -> bool {
self.byte_size() > 1
}
pub fn as_str(&self) -> &'static str {
match self {
Self::F64 => "f64",
Self::F32 => "f32",
Self::F16 => "f16",
Self::BF16 => "bf16",
Self::I64 => "i64",
Self::I32 => "i32",
Self::I16 => "i16",
Self::I8 => "i8",
Self::U64 => "u64",
Self::U32 => "u32",
Self::U16 => "u16",
Self::U8 => "u8",
Self::Bool => "bool",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Encoding {
#[default]
Raw,
Zstd,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Format {
Dense,
SparseCsr,
SparseCoo,
QuantizedGroup,
Other(String),
}
impl Format {
pub fn as_str(&self) -> &str {
match self {
Self::Dense => "dense",
Self::SparseCsr => "sparse_csr",
Self::SparseCoo => "sparse_coo",
Self::QuantizedGroup => "quantized_group",
Self::Other(s) => s,
}
}
pub fn from_str(s: &str) -> Self {
match s {
"dense" => Self::Dense,
"sparse_csr" => Self::SparseCsr,
"sparse_coo" => Self::SparseCoo,
"quantized_group" => Self::QuantizedGroup,
other => Self::Other(other.to_string()),
}
}
}
impl std::fmt::Display for Format {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl Serialize for Format {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for Format {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(Format::from_str(&s))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Component {
pub dtype: DType,
#[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub offset: u64,
pub length: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub uncompressed_length: Option<u64>,
#[serde(default, skip_serializing_if = "is_default_encoding")]
pub encoding: Encoding,
#[serde(skip_serializing_if = "Option::is_none")]
pub digest: Option<String>,
}
fn is_default_encoding(enc: &Encoding) -> bool {
*enc == Encoding::Raw
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Object {
pub shape: Vec<u64>,
pub format: Format,
#[serde(skip_serializing_if = "Option::is_none")]
pub attributes: Option<BTreeMap<String, ciborium::Value>>,
pub components: BTreeMap<String, Component>,
}
impl Object {
pub fn num_elements(&self) -> Result<u64, crate::error::Error> {
if self.shape.is_empty() {
Ok(1)
} else {
self.shape.iter().try_fold(1u64, |acc, &d| {
acc.checked_mul(d).ok_or_else(|| {
crate::error::Error::InvalidFileStructure("Shape product overflows u64".into())
})
})
}
}
pub fn data_dtype(&self) -> Result<DType, crate::error::Error> {
self.components.get("data").map(|c| c.dtype).ok_or_else(|| {
crate::error::Error::InvalidFileStructure("Missing 'data' component".to_string())
})
}
pub fn dense(shape: Vec<u64>, dtype: DType, offset: u64, length: u64) -> Self {
let component = Component {
dtype,
r#type: None,
offset,
length,
uncompressed_length: None,
encoding: Encoding::Raw,
digest: None,
};
let mut components = BTreeMap::new();
components.insert("data".to_string(), component);
Self {
shape,
format: Format::Dense,
attributes: None,
components,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Manifest {
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub attributes: Option<BTreeMap<String, ciborium::Value>>,
pub objects: BTreeMap<String, Object>,
}
impl Default for Manifest {
fn default() -> Self {
Self {
version: "1.2.0".to_string(),
attributes: None,
objects: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Checksum {
#[default]
None,
Crc32c,
Sha256,
}
#[derive(Debug, Clone)]
pub struct CooTensor<T> {
pub shape: Vec<u64>,
pub indices: Vec<Vec<u64>>,
pub values: Vec<T>,
}
#[derive(Debug, Clone)]
pub struct CsrTensor<T> {
pub shape: Vec<u64>,
pub indptr: Vec<u64>,
pub indices: Vec<u64>,
pub values: Vec<T>,
}