use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use crate::dtype::Dtype;
use crate::error::Result;
use crate::error::TensogramError;
pub use tensogram_encodings::ByteOrder;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MaskDescriptor {
pub method: String,
pub offset: u64,
pub length: u64,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub params: BTreeMap<String, ciborium::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct MasksMetadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub nan: Option<MaskDescriptor>,
#[serde(rename = "inf+", default, skip_serializing_if = "Option::is_none")]
pub pos_inf: Option<MaskDescriptor>,
#[serde(rename = "inf-", default, skip_serializing_if = "Option::is_none")]
pub neg_inf: Option<MaskDescriptor>,
}
impl MasksMetadata {
pub fn is_empty(&self) -> bool {
self.nan.is_none() && self.pos_inf.is_none() && self.neg_inf.is_none()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataObjectDescriptor {
#[serde(rename = "type")]
pub obj_type: String,
pub ndim: u64,
pub shape: Vec<u64>,
pub strides: Vec<u64>,
pub dtype: Dtype,
#[serde(default = "ByteOrder::native")]
pub byte_order: ByteOrder,
pub encoding: String,
pub filter: String,
pub compression: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub masks: Option<MasksMetadata>,
#[serde(flatten)]
pub params: BTreeMap<String, ciborium::Value>,
}
impl DataObjectDescriptor {
#[inline]
pub fn num_elements(&self) -> Result<usize> {
let shape_product = self
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
usize::try_from(shape_product)
.map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GlobalMetadata {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub base: Vec<BTreeMap<String, ciborium::Value>>,
#[serde(
rename = "_reserved_",
default,
skip_serializing_if = "BTreeMap::is_empty"
)]
pub reserved: BTreeMap<String, ciborium::Value>,
#[serde(
rename = "_extra_",
default,
skip_serializing_if = "BTreeMap::is_empty"
)]
pub extra: BTreeMap<String, ciborium::Value>,
}
#[derive(Debug, Clone, Default)]
pub struct IndexFrame {
pub offsets: Vec<u64>,
pub lengths: Vec<u64>,
}
#[derive(Debug, Clone)]
pub struct HashFrame {
pub algorithm: String,
pub hashes: Vec<String>,
}
pub type DecodedObject = (DataObjectDescriptor, Vec<u8>);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn masks_metadata_is_empty_detects_every_kind_absent() {
let empty = MasksMetadata::default();
assert!(empty.is_empty());
}
#[test]
fn masks_metadata_is_empty_false_when_any_kind_present() {
let any_mask = MaskDescriptor {
method: "roaring".to_string(),
offset: 0,
length: 1,
params: BTreeMap::new(),
};
let nan_only = MasksMetadata {
nan: Some(any_mask.clone()),
..MasksMetadata::default()
};
let pos_only = MasksMetadata {
pos_inf: Some(any_mask.clone()),
..MasksMetadata::default()
};
let neg_only = MasksMetadata {
neg_inf: Some(any_mask),
..MasksMetadata::default()
};
assert!(!nan_only.is_empty());
assert!(!pos_only.is_empty());
assert!(!neg_only.is_empty());
}
#[test]
fn descriptor_deserialize_defaults_byte_order_to_native() {
let json = r#"{
"type": "ntensor",
"ndim": 1,
"shape": [4],
"strides": [1],
"dtype": "float32",
"encoding": "none",
"filter": "none",
"compression": "none"
}"#;
let desc: DataObjectDescriptor =
serde_json::from_str(json).expect("deserialize should succeed without byte_order");
assert_eq!(desc.byte_order, ByteOrder::native());
}
#[test]
fn descriptor_deserialize_honours_explicit_byte_order() {
for (literal, expected) in [("little", ByteOrder::Little), ("big", ByteOrder::Big)] {
let json = format!(
r#"{{
"type": "ntensor", "ndim": 1, "shape": [4], "strides": [1],
"dtype": "float32", "byte_order": "{literal}",
"encoding": "none", "filter": "none", "compression": "none"
}}"#
);
let desc: DataObjectDescriptor =
serde_json::from_str(&json).expect("deserialize should accept explicit byte_order");
assert_eq!(desc.byte_order, expected);
}
}
fn descriptor_with_shape(shape: Vec<u64>) -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides: Vec::new(),
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
masks: None,
params: BTreeMap::new(),
}
}
#[test]
fn num_elements_empty_shape_is_one() {
let desc = descriptor_with_shape(vec![]);
assert_eq!(desc.num_elements().unwrap(), 1);
}
#[test]
fn num_elements_single_dim() {
let desc = descriptor_with_shape(vec![100]);
assert_eq!(desc.num_elements().unwrap(), 100);
}
#[test]
fn num_elements_multi_dim() {
let desc = descriptor_with_shape(vec![3, 4, 5]);
assert_eq!(desc.num_elements().unwrap(), 60);
}
#[test]
fn num_elements_zero_dim_yields_zero() {
let desc = descriptor_with_shape(vec![10, 0, 5]);
assert_eq!(desc.num_elements().unwrap(), 0);
}
#[test]
fn num_elements_u64_overflow_is_metadata_error() {
let desc = descriptor_with_shape(vec![u64::MAX, 2]);
let err = desc.num_elements().unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(
msg.contains("shape product overflow"),
"unexpected message: {msg}"
);
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(target_pointer_width = "32")]
#[test]
fn num_elements_usize_overflow_is_metadata_error_on_32bit() {
let desc = descriptor_with_shape(vec![(usize::MAX as u64) + 1]);
let err = desc.num_elements().unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(
msg.contains("element count overflows usize"),
"unexpected message: {msg}"
);
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
}