use std::collections::BTreeMap;
use crate::dtype::Dtype;
use crate::error::{Result, TensogramError};
use crate::framing::{self, EncodedObject};
use crate::metadata::RESERVED_KEY;
use crate::substitute_and_mask::{self, MaskSet};
use crate::types::{DataObjectDescriptor, GlobalMetadata, MaskDescriptor, MasksMetadata};
pub use tensogram_encodings::bitmask::MaskMethod;
#[cfg(feature = "blosc2")]
use tensogram_encodings::pipeline::Blosc2Codec;
#[cfg(feature = "sz3")]
use tensogram_encodings::pipeline::Sz3ErrorBound;
#[cfg(feature = "zfp")]
use tensogram_encodings::pipeline::ZfpMode;
use tensogram_encodings::pipeline::{
self, ByteOrder, CompressionType, EncodingType, FilterType, PipelineConfig,
};
use tensogram_encodings::simple_packing::{self, SimplePackingParams};
#[derive(Debug, Clone)]
pub struct EncodeOptions {
pub hashing: bool,
pub compression_backend: pipeline::CompressionBackend,
pub threads: u32,
pub parallel_threshold_bytes: Option<usize>,
pub allow_nan: bool,
pub allow_inf: bool,
pub nan_mask_method: MaskMethod,
pub pos_inf_mask_method: MaskMethod,
pub neg_inf_mask_method: MaskMethod,
pub small_mask_threshold_bytes: usize,
pub aggregate_hash: AggregateHashPolicy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AggregateHashPolicy {
#[default]
Auto,
None,
Header,
Footer,
Both,
}
impl AggregateHashPolicy {
pub(crate) fn resolved_buffered(self) -> Self {
match self {
AggregateHashPolicy::Auto => AggregateHashPolicy::Header,
other => other,
}
}
pub(crate) fn resolved_streaming(self) -> Result<Self> {
match self {
AggregateHashPolicy::Auto => Ok(AggregateHashPolicy::Footer),
AggregateHashPolicy::None => Ok(AggregateHashPolicy::None),
AggregateHashPolicy::Footer => Ok(AggregateHashPolicy::Footer),
AggregateHashPolicy::Header => Err(TensogramError::Encoding(
"AggregateHashPolicy::Header is not supported in streaming mode \
— the header is written before any data object, so per-object \
hashes are not yet known. Use Auto (defaults to Footer in \
streaming) or Footer explicitly."
.to_string(),
)),
AggregateHashPolicy::Both => Err(TensogramError::Encoding(
"AggregateHashPolicy::Both is not supported in streaming mode \
— the header is written before any data object, so per-object \
hashes are not yet known. Use Auto (defaults to Footer in \
streaming) or Footer explicitly."
.to_string(),
)),
}
}
pub(crate) fn emits_header(self) -> bool {
matches!(
self,
AggregateHashPolicy::Header | AggregateHashPolicy::Both
)
}
pub(crate) fn emits_footer(self) -> bool {
matches!(
self,
AggregateHashPolicy::Footer | AggregateHashPolicy::Both
)
}
}
impl Default for EncodeOptions {
fn default() -> Self {
Self {
hashing: true,
compression_backend: pipeline::CompressionBackend::default(),
threads: 0,
parallel_threshold_bytes: None,
allow_nan: false,
allow_inf: false,
nan_mask_method: MaskMethod::default(),
pos_inf_mask_method: MaskMethod::default(),
neg_inf_mask_method: MaskMethod::default(),
small_mask_threshold_bytes: 128,
aggregate_hash: AggregateHashPolicy::Auto,
}
}
}
pub(crate) fn validate_object(desc: &DataObjectDescriptor, data_len: usize) -> Result<()> {
if desc.obj_type.is_empty() {
return Err(TensogramError::Metadata(
"obj_type must not be empty".to_string(),
));
}
if desc.ndim as usize != desc.shape.len() {
return Err(TensogramError::Metadata(format!(
"ndim {} does not match shape.len() {}",
desc.ndim,
desc.shape.len()
)));
}
if desc.strides.len() != desc.shape.len() {
return Err(TensogramError::Metadata(format!(
"strides.len() {} does not match shape.len() {}",
desc.strides.len(),
desc.shape.len()
)));
}
if desc.encoding == "none" {
let product = desc
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
if desc.dtype.byte_width() > 0 {
let expected_bytes = product
.checked_mul(desc.dtype.byte_width() as u64)
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
if expected_bytes != data_len as u64 {
return Err(TensogramError::Metadata(format!(
"data_len {data_len} does not match expected {expected_bytes} bytes from shape and dtype"
)));
}
} else if desc.dtype == crate::Dtype::Bitmask {
let expected_bytes = product.div_ceil(8);
if expected_bytes != data_len as u64 {
return Err(TensogramError::Metadata(format!(
"data_len {data_len} does not match expected {expected_bytes} bytes for bitmask (ceil({product}/8))"
)));
}
}
}
if let Some(masks) = &desc.masks {
validate_mask_params(masks)?;
}
Ok(())
}
fn mask_method_allowed_params(method: &str) -> Option<&'static [&'static str]> {
match method {
"none" | "rle" | "roaring" | "lz4" => Some(&[]),
"zstd" => Some(&["level"]),
"blosc2" => Some(&["codec", "level"]),
_ => None,
}
}
fn validate_mask_descriptor(kind: &str, md: &MaskDescriptor) -> Result<()> {
let allowed = mask_method_allowed_params(&md.method).ok_or_else(|| {
TensogramError::Metadata(format!(
"mask {kind} has unknown method {method:?}; \
expected one of: none, rle, roaring, lz4, zstd, blosc2",
kind = kind,
method = md.method,
))
})?;
for k in md.params.keys() {
if !allowed.contains(&k.as_str()) {
return Err(TensogramError::Metadata(format!(
"mask {kind} (method {method:?}) has unknown param {key:?}; \
allowed for this method: {allowed:?}",
kind = kind,
method = md.method,
key = k,
allowed = allowed,
)));
}
}
Ok(())
}
fn validate_mask_params(masks: &MasksMetadata) -> Result<()> {
if let Some(md) = &masks.nan {
validate_mask_descriptor("nan", md)?;
}
if let Some(md) = &masks.pos_inf {
validate_mask_descriptor("inf+", md)?;
}
if let Some(md) = &masks.neg_inf {
validate_mask_descriptor("inf-", md)?;
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum EncodeMode {
Raw,
PreEncoded,
}
fn encode_one_object(
desc: &DataObjectDescriptor,
data: &[u8],
mode: EncodeMode,
options: &EncodeOptions,
intra_codec_threads: u32,
) -> Result<EncodedObject> {
validate_object(desc, data.len())?;
let (pipeline_input, mask_set) = if matches!(mode, EncodeMode::Raw) {
let parallel = crate::parallel::should_parallelise(
intra_codec_threads,
data.len(),
options.parallel_threshold_bytes,
);
let (cow, masks) = substitute_and_mask::substitute_and_mask(
data,
desc.dtype,
desc.byte_order,
options.allow_nan,
options.allow_inf,
parallel,
)?;
(cow, masks)
} else {
(std::borrow::Cow::Borrowed(data), MaskSet::empty(0))
};
let num_elements = desc.num_elements()?;
let dtype = desc.dtype;
let mut final_desc = desc.clone();
if matches!(mode, EncodeMode::Raw) {
resolve_simple_packing_params(&mut final_desc, data)?;
}
let mut config = build_pipeline_config_with_backend(
&final_desc,
num_elements,
dtype,
options.compression_backend,
intra_codec_threads,
)?;
let inline_hash_requested = matches!(mode, EncodeMode::Raw) && options.hashing;
config.compute_hash = inline_hash_requested;
let (encoded_payload, inline_hash) = match mode {
EncodeMode::Raw => {
let result = pipeline::encode_pipeline(pipeline_input.as_ref(), &config)
.map_err(|e| TensogramError::Encoding(e.to_string()))?;
if let Some(offsets) = &result.block_offsets {
final_desc.params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(
offsets
.iter()
.map(|&o| ciborium::Value::Integer(o.into()))
.collect(),
),
);
}
(result.encoded_bytes, result.hash)
}
EncodeMode::PreEncoded => {
validate_no_szip_offsets_for_non_szip(desc)?;
if desc.compression == "szip" && desc.params.contains_key("szip_block_offsets") {
validate_szip_block_offsets(&desc.params, data.len())?;
}
(data.to_vec(), None)
}
};
let (payload_region, masks_metadata) = compose_payload_region(
encoded_payload,
mask_set,
&options.nan_mask_method,
&options.pos_inf_mask_method,
&options.neg_inf_mask_method,
options.small_mask_threshold_bytes,
)?;
if let Some(m) = masks_metadata {
final_desc.masks = Some(m);
}
let encoded_payload = payload_region;
let _ = (inline_hash, options);
Ok(EncodedObject {
descriptor: final_desc,
encoded_payload,
})
}
fn encode_inner(
global_metadata: &GlobalMetadata,
descriptors: &[(&DataObjectDescriptor, &[u8])],
options: &EncodeOptions,
mode: EncodeMode,
) -> Result<Vec<u8>> {
let budget = crate::parallel::resolve_budget(options.threads)?;
let total_bytes: usize = descriptors.iter().map(|(_, d)| d.len()).sum();
let parallel =
crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
let any_axis_b = descriptors
.iter()
.any(|(d, _)| crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression));
let use_axis_a = parallel && crate::parallel::use_axis_a(descriptors.len(), budget, any_axis_b);
let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
let encode_one = |(desc, data): &(&DataObjectDescriptor, &[u8])| {
encode_one_object(desc, data, mode, options, intra_codec_threads)
};
let encoded_objects: Vec<EncodedObject> = if use_axis_a {
#[cfg(feature = "threads")]
{
use rayon::prelude::*;
crate::parallel::with_pool(budget, || {
descriptors
.par_iter()
.map(&encode_one)
.collect::<Result<Vec<_>>>()
})?
}
#[cfg(not(feature = "threads"))]
{
descriptors.iter().map(encode_one).collect::<Result<_>>()?
}
} else {
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
descriptors.iter().map(encode_one).collect::<Result<_>>()
})?
};
validate_no_client_reserved(global_metadata)?;
if global_metadata.base.len() > descriptors.len() {
return Err(TensogramError::Metadata(format!(
"metadata base has {} entries but only {} descriptors provided; \
extra base entries would be discarded",
global_metadata.base.len(),
descriptors.len()
)));
}
let mut enriched_meta = global_metadata.clone();
populate_base_entries(&mut enriched_meta.base, &encoded_objects);
populate_reserved_provenance(&mut enriched_meta.reserved);
let resolved = options.aggregate_hash.resolved_buffered();
let hash_policy = framing::HashFramePolicy {
header: resolved.emits_header(),
footer: resolved.emits_footer(),
};
framing::encode_message(
&enriched_meta,
&encoded_objects,
options.hashing,
hash_policy,
)
}
#[tracing::instrument(skip(global_metadata, descriptors, options), fields(objects = descriptors.len()))]
pub fn encode(
global_metadata: &GlobalMetadata,
descriptors: &[(&DataObjectDescriptor, &[u8])],
options: &EncodeOptions,
) -> Result<Vec<u8>> {
encode_inner(global_metadata, descriptors, options, EncodeMode::Raw)
}
#[tracing::instrument(name = "encode_pre_encoded", skip_all, fields(num_objects = descriptors.len()))]
pub fn encode_pre_encoded(
global_metadata: &GlobalMetadata,
descriptors: &[(&DataObjectDescriptor, &[u8])],
options: &EncodeOptions,
) -> Result<Vec<u8>> {
encode_inner(
global_metadata,
descriptors,
options,
EncodeMode::PreEncoded,
)
}
fn validate_no_client_reserved(meta: &GlobalMetadata) -> Result<()> {
if !meta.reserved.is_empty() {
return Err(TensogramError::Metadata(format!(
"client code must not write to '{RESERVED_KEY}' at message level; \
this field is populated by the library"
)));
}
for (i, entry) in meta.base.iter().enumerate() {
if entry.contains_key(RESERVED_KEY) {
return Err(TensogramError::Metadata(format!(
"client code must not write to '{RESERVED_KEY}' in base[{i}]; \
this field is populated by the library"
)));
}
}
Ok(())
}
pub(crate) fn populate_base_entries(
base: &mut Vec<BTreeMap<String, ciborium::Value>>,
encoded_objects: &[crate::framing::EncodedObject],
) {
use ciborium::Value;
base.resize_with(encoded_objects.len(), BTreeMap::new);
for (entry, obj) in base.iter_mut().zip(encoded_objects.iter()) {
let desc = &obj.descriptor;
let tensor_map = Value::Map(vec![
(
Value::Text("ndim".to_string()),
Value::Integer(desc.ndim.into()),
),
(
Value::Text("shape".to_string()),
Value::Array(
desc.shape
.iter()
.map(|&d| Value::Integer(d.into()))
.collect(),
),
),
(
Value::Text("strides".to_string()),
Value::Array(
desc.strides
.iter()
.map(|&s| Value::Integer(s.into()))
.collect(),
),
),
(
Value::Text("dtype".to_string()),
Value::Text(desc.dtype.to_string()),
),
]);
let reserved_map = Value::Map(vec![(Value::Text("tensor".to_string()), tensor_map)]);
entry.insert(RESERVED_KEY.to_string(), reserved_map);
}
}
pub(crate) fn populate_reserved_provenance(reserved: &mut BTreeMap<String, ciborium::Value>) {
use ciborium::Value;
#[cfg(not(target_arch = "wasm32"))]
use std::time::SystemTime;
let version_str = env!("CARGO_PKG_VERSION");
let encoder_map = Value::Map(vec![
(
Value::Text("name".to_string()),
Value::Text("tensogram".to_string()),
),
(
Value::Text("version".to_string()),
Value::Text(version_str.to_string()),
),
]);
reserved.insert("encoder".to_string(), encoder_map);
#[cfg(not(target_arch = "wasm32"))]
{
let secs = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let days = secs / 86400;
let day_secs = secs % 86400;
let hours = day_secs / 3600;
let minutes = (day_secs % 3600) / 60;
let seconds = day_secs % 60;
let (y, m, d) = civil_from_days(days as i64);
let timestamp = format!("{y:04}-{m:02}-{d:02}T{hours:02}:{minutes:02}:{seconds:02}Z");
reserved.insert("time".to_string(), Value::Text(timestamp));
}
let id = uuid::Uuid::new_v4();
reserved.insert("uuid".to_string(), Value::Text(id.to_string()));
}
#[cfg(not(target_arch = "wasm32"))]
fn civil_from_days(days: i64) -> (i64, u32, u32) {
let z = days + 719468;
let era = if z >= 0 { z } else { z - 146096 } / 146097;
let doe = (z - era * 146097) as u32;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
let y = yoe as i64 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y + 1 } else { y };
(y, m, d)
}
pub(crate) fn build_pipeline_config(
desc: &DataObjectDescriptor,
num_values: usize,
dtype: Dtype,
) -> Result<PipelineConfig> {
build_pipeline_config_with_backend(
desc,
num_values,
dtype,
pipeline::CompressionBackend::default(),
0,
)
}
fn resolve_encoding(desc: &DataObjectDescriptor, dtype: Dtype) -> Result<EncodingType> {
match desc.encoding.as_str() {
"none" => Ok(EncodingType::None),
"simple_packing" => {
if dtype != Dtype::Float64 {
return Err(TensogramError::Encoding(format!(
"simple_packing only supports float64 dtype; got {dtype:?}"
)));
}
let params = extract_simple_packing_params(&desc.params)?;
Ok(EncodingType::SimplePacking(params))
}
other => Err(TensogramError::Encoding(format!(
"unknown encoding: {other}"
))),
}
}
fn resolve_filter(desc: &DataObjectDescriptor) -> Result<FilterType> {
match desc.filter.as_str() {
"none" => Ok(FilterType::None),
"shuffle" => {
let element_size = usize::try_from(get_u64_param(
&desc.params,
"shuffle_element_size",
)?)
.map_err(|_| {
TensogramError::Metadata("shuffle_element_size out of usize range".to_string())
})?;
Ok(FilterType::Shuffle { element_size })
}
other => Err(TensogramError::Encoding(format!("unknown filter: {other}"))),
}
}
#[cfg_attr(
not(any(feature = "szip", feature = "szip-pure", feature = "blosc2")),
allow(unused_variables)
)]
fn resolve_compression(
desc: &DataObjectDescriptor,
dtype: Dtype,
encoding: &EncodingType,
filter: &FilterType,
) -> Result<CompressionType> {
match desc.compression.as_str() {
"none" => Ok(CompressionType::None),
#[cfg(any(feature = "szip", feature = "szip-pure"))]
"szip" => {
let rsi = u32::try_from(get_u64_param(&desc.params, "szip_rsi")?)
.map_err(|_| TensogramError::Metadata("szip_rsi out of u32 range".to_string()))?;
let block_size = u32::try_from(get_u64_param(&desc.params, "szip_block_size")?)
.map_err(|_| {
TensogramError::Metadata("szip_block_size out of u32 range".to_string())
})?;
let flags = u32::try_from(get_u64_param(&desc.params, "szip_flags")?)
.map_err(|_| TensogramError::Metadata("szip_flags out of u32 range".to_string()))?;
let bits_per_sample = match (encoding, filter) {
(EncodingType::SimplePacking(params), _) => params.bits_per_value,
(EncodingType::None, FilterType::Shuffle { .. }) => 8,
(EncodingType::None, FilterType::None) => (dtype.byte_width() * 8) as u32,
};
Ok(CompressionType::Szip {
rsi,
block_size,
flags,
bits_per_sample,
})
}
#[cfg(any(feature = "zstd", feature = "zstd-pure"))]
"zstd" => {
let level_i64 = get_i64_param_or_default(&desc.params, "zstd_level", 3)?;
let level = i32::try_from(level_i64).map_err(|_| {
TensogramError::Metadata(format!("zstd_level value {level_i64} out of i32 range"))
})?;
Ok(CompressionType::Zstd { level })
}
#[cfg(feature = "lz4")]
"lz4" => Ok(CompressionType::Lz4),
#[cfg(feature = "blosc2")]
"blosc2" => {
let codec_str = get_text_param_or_default(&desc.params, "blosc2_codec", "lz4")?;
let codec = match codec_str {
"blosclz" => Blosc2Codec::Blosclz,
"lz4" => Blosc2Codec::Lz4,
"lz4hc" => Blosc2Codec::Lz4hc,
"zlib" => Blosc2Codec::Zlib,
"zstd" => Blosc2Codec::Zstd,
other => {
return Err(TensogramError::Encoding(format!(
"unknown blosc2 codec: {other}"
)));
}
};
let clevel_i64 = get_i64_param_or_default(&desc.params, "blosc2_clevel", 5)?;
let clevel = i32::try_from(clevel_i64).map_err(|_| {
TensogramError::Metadata(format!(
"blosc2_clevel value {clevel_i64} out of i32 range"
))
})?;
let typesize = match (encoding, filter) {
(EncodingType::SimplePacking(params), _) => {
(params.bits_per_value as usize).div_ceil(8)
}
(EncodingType::None, FilterType::Shuffle { .. }) => 1,
(EncodingType::None, FilterType::None) => dtype.byte_width(),
};
Ok(CompressionType::Blosc2 {
codec,
clevel,
typesize,
})
}
#[cfg(feature = "zfp")]
"zfp" => {
let mode_str = match desc.params.get("zfp_mode") {
Some(ciborium::Value::Text(s)) => s.clone(),
_ => {
return Err(TensogramError::Metadata(
"missing required parameter: zfp_mode".to_string(),
));
}
};
let mode = match mode_str.as_str() {
"fixed_rate" => {
let rate = get_f64_param(&desc.params, "zfp_rate")?;
ZfpMode::FixedRate { rate }
}
"fixed_precision" => {
let precision = u32::try_from(get_u64_param(&desc.params, "zfp_precision")?)
.map_err(|_| {
TensogramError::Metadata("zfp_precision out of u32 range".to_string())
})?;
ZfpMode::FixedPrecision { precision }
}
"fixed_accuracy" => {
let tolerance = get_f64_param(&desc.params, "zfp_tolerance")?;
ZfpMode::FixedAccuracy { tolerance }
}
other => {
return Err(TensogramError::Encoding(format!(
"unknown zfp_mode: {other}"
)));
}
};
Ok(CompressionType::Zfp { mode })
}
#[cfg(feature = "sz3")]
"sz3" => {
let mode_str = match desc.params.get("sz3_error_bound_mode") {
Some(ciborium::Value::Text(s)) => s.clone(),
_ => {
return Err(TensogramError::Metadata(
"missing required parameter: sz3_error_bound_mode".to_string(),
));
}
};
let bound_val = get_f64_param(&desc.params, "sz3_error_bound")?;
let error_bound = match mode_str.as_str() {
"abs" => Sz3ErrorBound::Absolute(bound_val),
"rel" => Sz3ErrorBound::Relative(bound_val),
"psnr" => Sz3ErrorBound::Psnr(bound_val),
other => {
return Err(TensogramError::Encoding(format!(
"unknown sz3_error_bound_mode: {other}"
)));
}
};
Ok(CompressionType::Sz3 { error_bound })
}
"rle" => {
if dtype != Dtype::Bitmask {
return Err(TensogramError::Encoding(format!(
"compression \"rle\" only supports dtype=bitmask, got dtype={:?}",
dtype
)));
}
Ok(CompressionType::Rle)
}
"roaring" => {
if dtype != Dtype::Bitmask {
return Err(TensogramError::Encoding(format!(
"compression \"roaring\" only supports dtype=bitmask, got dtype={:?}",
dtype
)));
}
Ok(CompressionType::Roaring)
}
other => Err(TensogramError::Encoding(format!(
"unknown compression: {other}"
))),
}
}
pub(crate) fn build_pipeline_config_with_backend(
desc: &DataObjectDescriptor,
num_values: usize,
dtype: Dtype,
compression_backend: pipeline::CompressionBackend,
intra_codec_threads: u32,
) -> Result<PipelineConfig> {
let encoding = resolve_encoding(desc, dtype)?;
let filter = resolve_filter(desc)?;
let compression = resolve_compression(desc, dtype, &encoding, &filter)?;
Ok(PipelineConfig {
encoding,
filter,
compression,
num_values,
byte_order: desc.byte_order,
dtype_byte_width: dtype.byte_width(),
swap_unit_size: dtype.swap_unit_size(),
compression_backend,
intra_codec_threads,
compute_hash: false,
})
}
fn extract_simple_packing_params(
params: &BTreeMap<String, ciborium::Value>,
) -> Result<SimplePackingParams> {
let reference_value = get_f64_param(params, "sp_reference_value")?;
if reference_value.is_nan() || reference_value.is_infinite() {
return Err(TensogramError::Metadata(format!(
"sp_reference_value must be finite, got {reference_value}"
)));
}
Ok(SimplePackingParams {
reference_value,
binary_scale_factor: i32::try_from(get_i64_param(params, "sp_binary_scale_factor")?)
.map_err(|_| {
TensogramError::Metadata("sp_binary_scale_factor out of i32 range".to_string())
})?,
decimal_scale_factor: i32::try_from(get_i64_param(params, "sp_decimal_scale_factor")?)
.map_err(|_| {
TensogramError::Metadata("sp_decimal_scale_factor out of i32 range".to_string())
})?,
bits_per_value: u32::try_from(get_u64_param(params, "sp_bits_per_value")?).map_err(
|_| TensogramError::Metadata("sp_bits_per_value out of u32 range".to_string()),
)?,
})
}
pub(crate) fn resolve_simple_packing_params(
desc: &mut DataObjectDescriptor,
data_bytes: &[u8],
) -> Result<()> {
if desc.encoding != "simple_packing" {
return Ok(());
}
if desc.dtype != Dtype::Float64 {
return Err(TensogramError::Encoding(format!(
"simple_packing only supports float64 dtype; got {:?}",
desc.dtype
)));
}
if !desc.params.contains_key("sp_bits_per_value") {
return Err(TensogramError::Metadata(
"simple_packing requires sp_bits_per_value (the encoder can \
auto-compute sp_reference_value + sp_binary_scale_factor from \
the data, but the bit-width and decimal scale are the user \
knobs). Provide at least sp_bits_per_value, or the full \
explicit 4-key set."
.to_string(),
));
}
let has_ref = desc.params.contains_key("sp_reference_value");
let has_bsf = desc.params.contains_key("sp_binary_scale_factor");
if has_ref ^ has_bsf {
let (set, missing) = if has_ref {
("sp_reference_value", "sp_binary_scale_factor")
} else {
("sp_binary_scale_factor", "sp_reference_value")
};
return Err(TensogramError::Metadata(format!(
"simple_packing: descriptor sets {set} but not {missing}. \
Provide both for explicit-params encoding, or neither to \
let the encoder auto-compute them from the data."
)));
}
if has_ref && has_bsf {
desc.params
.entry("sp_decimal_scale_factor".to_string())
.or_insert(ciborium::Value::Integer(0i64.into()));
return Ok(());
}
let bits_per_value = u32::try_from(get_u64_param(&desc.params, "sp_bits_per_value")?)
.map_err(|_| TensogramError::Metadata("sp_bits_per_value out of u32 range".to_string()))?;
let decimal_scale_factor = i32::try_from(get_i64_param_or_default(
&desc.params,
"sp_decimal_scale_factor",
0,
)?)
.map_err(|_| {
TensogramError::Metadata("sp_decimal_scale_factor out of i32 range".to_string())
})?;
let values = bytes_as_f64_vec(data_bytes, desc.byte_order)?;
let params = simple_packing::compute_params(&values, bits_per_value, decimal_scale_factor)
.map_err(|e| TensogramError::Encoding(e.to_string()))?;
desc.params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
desc.params.insert(
"sp_binary_scale_factor".to_string(),
ciborium::Value::Integer(i64::from(params.binary_scale_factor).into()),
);
desc.params.insert(
"sp_decimal_scale_factor".to_string(),
ciborium::Value::Integer(i64::from(params.decimal_scale_factor).into()),
);
desc.params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(i64::from(params.bits_per_value).into()),
);
Ok(())
}
fn bytes_as_f64_vec(bytes: &[u8], byte_order: ByteOrder) -> Result<Vec<f64>> {
if !bytes.len().is_multiple_of(8) {
return Err(TensogramError::Metadata(format!(
"simple_packing: input byte length {} is not a multiple of 8 (float64)",
bytes.len()
)));
}
let n = bytes.len() / 8;
let mut out: Vec<f64> = Vec::new();
out.try_reserve_exact(n).map_err(|e| {
TensogramError::Encoding(format!(
"simple_packing: failed to reserve {} bytes for byte-to-f64 \
conversion: {e}",
n.saturating_mul(std::mem::size_of::<f64>()),
))
})?;
for chunk in bytes.chunks_exact(8) {
let mut buf = [0u8; 8];
buf.copy_from_slice(chunk);
out.push(match byte_order {
ByteOrder::Big => f64::from_be_bytes(buf),
ByteOrder::Little => f64::from_le_bytes(buf),
});
}
Ok(out)
}
const F64_EXACT_INT_BOUND: i128 = 1 << 53;
pub(crate) fn get_f64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<f64> {
match params.get(key) {
Some(ciborium::Value::Float(f)) => Ok(*f),
Some(ciborium::Value::Integer(i)) => {
let n: i128 = (*i).into();
if n.abs() > F64_EXACT_INT_BOUND {
return Err(TensogramError::Metadata(format!(
"{key}: integer value {n} is outside the f64 \
exact-representable range [-2^53, 2^53]; \
converting to f64 would silently lose precision. \
Supply a float literal or pick a parameter that \
accepts integers up to i64::MAX."
)));
}
Ok(n as f64)
}
Some(other) => Err(TensogramError::Metadata(format!(
"expected number for {key}, got {kind}",
kind = crate::metadata::cbor_value_kind(other),
))),
None => Err(TensogramError::Metadata(format!(
"missing required parameter: {key}"
))),
}
}
pub(crate) fn get_i64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<i64> {
match params.get(key) {
Some(ciborium::Value::Integer(i)) => {
let n: i128 = (*i).into();
i64::try_from(n).map_err(|_| {
TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
})
}
Some(other) => Err(TensogramError::Metadata(format!(
"expected integer for {key}, got {kind}",
kind = crate::metadata::cbor_value_kind(other),
))),
None => Err(TensogramError::Metadata(format!(
"missing required parameter: {key}"
))),
}
}
pub(crate) fn get_i64_param_or_default(
params: &BTreeMap<String, ciborium::Value>,
key: &str,
default: i64,
) -> Result<i64> {
match params.get(key) {
Some(ciborium::Value::Integer(i)) => {
let n: i128 = (*i).into();
i64::try_from(n).map_err(|_| {
TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
})
}
Some(other) => Err(TensogramError::Metadata(format!(
"expected integer for {key}, got {kind}; \
if you meant to use the default ({default}), omit the key",
kind = crate::metadata::cbor_value_kind(other),
))),
None => Ok(default),
}
}
pub(crate) fn get_u64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<u64> {
match params.get(key) {
Some(ciborium::Value::Integer(i)) => {
let n: i128 = (*i).into();
u64::try_from(n).map_err(|_| {
TensogramError::Metadata(format!("integer value {n} out of u64 range for {key}"))
})
}
Some(other) => Err(TensogramError::Metadata(format!(
"expected integer for {key}, got {kind}",
kind = crate::metadata::cbor_value_kind(other),
))),
None => Err(TensogramError::Metadata(format!(
"missing required parameter: {key}"
))),
}
}
#[cfg(any(feature = "blosc2", test))]
pub(crate) fn get_text_param_or_default<'a>(
params: &'a BTreeMap<String, ciborium::Value>,
key: &str,
default: &'a str,
) -> Result<&'a str> {
match params.get(key) {
Some(ciborium::Value::Text(s)) => Ok(s.as_str()),
Some(other) => Err(TensogramError::Metadata(format!(
"expected text for {key}, got {kind}; \
if you meant to use the default ({default:?}), omit the key",
kind = crate::metadata::cbor_value_kind(other),
))),
None => Ok(default),
}
}
pub(crate) fn validate_szip_block_offsets(
params: &BTreeMap<String, ciborium::Value>,
encoded_bytes_len: usize,
) -> Result<()> {
let value = params.get("szip_block_offsets").ok_or_else(|| {
TensogramError::Metadata(
"missing required parameter: szip_block_offsets for szip compression".to_string(),
)
})?;
let offsets = match value {
ciborium::Value::Array(arr) => arr,
other => {
return Err(TensogramError::Metadata(format!(
"szip_block_offsets must be an array, got {other:?}"
)));
}
};
if offsets.is_empty() {
return Err(TensogramError::Metadata(
"szip_block_offsets must not be empty; first offset must be 0".to_string(),
));
}
let bit_bound = encoded_bytes_len.checked_mul(8).ok_or_else(|| {
TensogramError::Metadata(format!(
"encoded byte length {encoded_bytes_len} overflows bit-bound calculation"
))
})?;
let bit_bound_u64 = u64::try_from(bit_bound).map_err(|_| {
TensogramError::Metadata(format!(
"bit-bound {bit_bound} derived from {encoded_bytes_len} bytes does not fit in u64"
))
})?;
let mut parsed_offsets = Vec::with_capacity(offsets.len());
for (idx, item) in offsets.iter().enumerate() {
let offset = match item {
ciborium::Value::Integer(i) => {
let n: i128 = (*i).into();
u64::try_from(n).map_err(|_| {
TensogramError::Metadata(format!(
"szip_block_offsets[{idx}] = {n} out of u64 range"
))
})?
}
other => {
return Err(TensogramError::Metadata(format!(
"szip_block_offsets[{idx}] must be an integer, got {other:?}"
)));
}
};
if offset > bit_bound_u64 {
return Err(TensogramError::Metadata(format!(
"szip_block_offsets[{idx}] = {offset} exceeds bit bound {bit_bound_u64} (encoded_bytes_len = {encoded_bytes_len} bytes, {bit_bound_u64} bits)"
)));
}
if idx == 0 {
if offset != 0 {
return Err(TensogramError::Metadata(format!(
"szip_block_offsets[0] must be 0, got {offset}"
)));
}
} else {
let prev = parsed_offsets[idx - 1];
if offset <= prev {
return Err(TensogramError::Metadata(format!(
"szip_block_offsets must be strictly increasing: szip_block_offsets[{}] = {}, szip_block_offsets[{idx}] = {offset}",
idx - 1,
prev
)));
}
}
parsed_offsets.push(offset);
}
Ok(())
}
pub(crate) fn validate_no_szip_offsets_for_non_szip(desc: &DataObjectDescriptor) -> Result<()> {
if desc.compression != "szip" && desc.params.contains_key("szip_block_offsets") {
return Err(TensogramError::Metadata(format!(
"szip_block_offsets provided but compression is '{}', not 'szip'",
desc.compression
)));
}
Ok(())
}
pub(crate) fn compose_payload_region(
mut encoded_payload: Vec<u8>,
masks: MaskSet,
nan_method: &MaskMethod,
pos_inf_method: &MaskMethod,
neg_inf_method: &MaskMethod,
small_threshold: usize,
) -> Result<(Vec<u8>, Option<MasksMetadata>)> {
if masks.is_empty() {
return Ok((encoded_payload, None));
}
let mut metadata = MasksMetadata::default();
let mut region_cursor = encoded_payload.len() as u64;
let mut append_one =
|bits_opt: Option<&Vec<bool>>, method: &MaskMethod| -> Result<Option<MaskDescriptor>> {
let Some(bits) = bits_opt else {
return Ok(None);
};
let (blob, used_method) = encode_one_mask(bits, method.clone(), small_threshold)?;
let desc = MaskDescriptor {
method: used_method.name().to_string(),
offset: region_cursor,
length: blob.len() as u64,
params: mask_params_cbor(&used_method),
};
region_cursor += blob.len() as u64;
encoded_payload.extend_from_slice(&blob);
Ok(Some(desc))
};
metadata.nan = append_one(masks.nan.as_ref(), nan_method)?;
metadata.pos_inf = append_one(masks.pos_inf.as_ref(), pos_inf_method)?;
metadata.neg_inf = append_one(masks.neg_inf.as_ref(), neg_inf_method)?;
Ok((encoded_payload, Some(metadata)))
}
fn encode_one_mask(
bits: &[bool],
requested: MaskMethod,
small_threshold: usize,
) -> Result<(Vec<u8>, MaskMethod)> {
use tensogram_encodings::bitmask;
let uncompressed_bytes = bits.len().div_ceil(8);
let method = if small_threshold > 0 && uncompressed_bytes <= small_threshold {
MaskMethod::None
} else {
requested
};
let blob = match &method {
MaskMethod::None => bitmask::codecs::encode_none(bits)
.map_err(|e| TensogramError::Encoding(format!("bitmask pack: {e}")))?,
MaskMethod::Rle => bitmask::rle::encode(bits),
MaskMethod::Roaring => bitmask::roaring::encode(bits)
.map_err(|e| TensogramError::Encoding(format!("roaring mask encode: {e}")))?,
MaskMethod::Lz4 => bitmask::codecs::encode_lz4(bits)
.map_err(|e| TensogramError::Encoding(format!("lz4 mask encode: {e}")))?,
MaskMethod::Zstd { level } => bitmask::codecs::encode_zstd(bits, *level)
.map_err(|e| TensogramError::Encoding(format!("zstd mask encode: {e}")))?,
#[cfg(feature = "blosc2")]
MaskMethod::Blosc2 { codec, level } => bitmask::codecs::encode_blosc2(bits, *codec, *level)
.map_err(|e| TensogramError::Encoding(format!("blosc2 mask encode: {e}")))?,
};
Ok((blob, method))
}
fn mask_params_cbor(method: &MaskMethod) -> BTreeMap<String, ciborium::Value> {
let mut params = BTreeMap::new();
match method {
MaskMethod::None | MaskMethod::Rle | MaskMethod::Roaring | MaskMethod::Lz4 => {}
MaskMethod::Zstd { level } => {
if let Some(l) = level {
params.insert(
"level".to_string(),
ciborium::Value::Integer((*l as i64).into()),
);
}
}
#[cfg(feature = "blosc2")]
MaskMethod::Blosc2 { codec, level } => {
let codec_str = match codec {
Blosc2Codec::Blosclz => "blosclz",
Blosc2Codec::Lz4 => "lz4",
Blosc2Codec::Lz4hc => "lz4hc",
Blosc2Codec::Zlib => "zlib",
Blosc2Codec::Zstd => "zstd",
};
params.insert(
"codec".to_string(),
ciborium::Value::Text(codec_str.to_string()),
);
params.insert(
"level".to_string(),
ciborium::Value::Integer((*level as i64).into()),
);
}
}
params
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decode::{DecodeOptions, decode};
use crate::types::{ByteOrder, GlobalMetadata};
use std::collections::BTreeMap;
fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
let strides = {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
}
}
#[test]
fn test_base_more_entries_than_descriptors_rejected() {
let meta = GlobalMetadata {
base: vec![
BTreeMap::new(),
BTreeMap::new(),
BTreeMap::new(),
BTreeMap::new(),
BTreeMap::new(),
],
..Default::default()
};
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let result = encode(
&meta,
&[(&desc, data.as_slice()), (&desc, data.as_slice())],
&options,
);
assert!(
result.is_err(),
"5 base entries with 2 descriptors should fail"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("5") && err.contains("2"),
"error should mention counts: {err}"
);
}
#[test]
fn test_base_fewer_entries_than_descriptors_auto_extended() {
let meta = GlobalMetadata {
base: vec![],
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(
&meta,
&[
(&desc, data.as_slice()),
(&desc, data.as_slice()),
(&desc, data.as_slice()),
],
&options,
)
.unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(decoded.base.len(), 3);
for entry in &decoded.base {
assert!(
entry.contains_key("_reserved_"),
"auto-extended base entry should have _reserved_"
);
}
}
#[test]
fn test_base_entry_with_top_level_key_names_no_collision() {
let mut entry = BTreeMap::new();
entry.insert(
"version".to_string(),
ciborium::Value::Text("my-version".to_string()),
);
entry.insert(
"base".to_string(),
ciborium::Value::Text("not-the-real-base".to_string()),
);
let meta = GlobalMetadata {
base: vec![entry],
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(
decoded.base[0].get("version"),
Some(&ciborium::Value::Text("my-version".to_string()))
);
assert_eq!(
decoded.base[0].get("base"),
Some(&ciborium::Value::Text("not-the-real-base".to_string()))
);
}
#[test]
fn test_base_entry_with_deeply_nested_reserved_allowed() {
let nested = ciborium::Value::Map(vec![(
ciborium::Value::Text("_reserved_".to_string()),
ciborium::Value::Text("nested-is-ok".to_string()),
)]);
let mut entry = BTreeMap::new();
entry.insert("foo".to_string(), nested);
let meta = GlobalMetadata {
base: vec![entry],
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
let foo = decoded.base[0].get("foo").unwrap();
if let ciborium::Value::Map(pairs) = foo {
assert_eq!(pairs.len(), 1);
} else {
panic!("expected map for foo");
}
}
#[test]
fn test_reserved_rejected_at_message_level() {
let mut reserved = BTreeMap::new();
reserved.insert(
"rogue".to_string(),
ciborium::Value::Text("bad".to_string()),
);
let meta = GlobalMetadata {
reserved,
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let result = encode(
&meta,
&[(&desc, data.as_slice())],
&EncodeOptions::default(),
);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("_reserved_") && err.contains("message level"),
"error: {err}"
);
}
#[test]
fn test_reserved_rejected_in_base_entry() {
let mut entry = BTreeMap::new();
entry.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
let meta = GlobalMetadata {
base: vec![entry],
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let result = encode(
&meta,
&[(&desc, data.as_slice())],
&EncodeOptions::default(),
);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("_reserved_") && err.contains("base[0]"),
"error: {err}"
);
}
#[test]
fn test_reserved_tensor_has_four_keys_after_encode() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![3, 4]);
let data = vec![0u8; 3 * 4 * 4]; let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
let reserved = decoded.base[0]
.get("_reserved_")
.expect("_reserved_ missing");
if let ciborium::Value::Map(pairs) = reserved {
let tensor_entry = pairs
.iter()
.find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
assert!(tensor_entry.is_some(), "missing tensor key in _reserved_");
if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
let keys: Vec<String> = tensor_pairs
.iter()
.filter_map(|(k, _)| {
if let ciborium::Value::Text(s) = k {
Some(s.clone())
} else {
None
}
})
.collect();
assert_eq!(keys.len(), 4, "tensor should have 4 keys, got: {keys:?}");
assert!(keys.contains(&"ndim".to_string()));
assert!(keys.contains(&"shape".to_string()));
assert!(keys.contains(&"strides".to_string()));
assert!(keys.contains(&"dtype".to_string()));
} else {
panic!("tensor is not a map");
}
} else {
panic!("_reserved_ is not a map");
}
}
#[test]
fn test_reserved_tensor_scalar_ndim_zero() {
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 0,
shape: vec![],
strides: vec![],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 4]; let meta = GlobalMetadata::default();
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
let reserved = decoded.base[0]
.get("_reserved_")
.expect("_reserved_ missing");
if let ciborium::Value::Map(pairs) = reserved {
let tensor_entry = pairs
.iter()
.find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
let ndim = tensor_pairs
.iter()
.find(|(k, _)| *k == ciborium::Value::Text("ndim".to_string()));
assert!(
matches!(ndim, Some((_, ciborium::Value::Integer(i))) if i128::from(*i) == 0),
"ndim should be 0 for scalar"
);
let shape = tensor_pairs
.iter()
.find(|(k, _)| *k == ciborium::Value::Text("shape".to_string()));
assert!(
matches!(shape, Some((_, ciborium::Value::Array(a))) if a.is_empty()),
"shape should be [] for scalar"
);
} else {
panic!("tensor missing or not a map");
}
} else {
panic!("_reserved_ is not a map");
}
}
#[test]
fn test_extra_with_keys_colliding_with_base_entry_keys() {
let mut entry = BTreeMap::new();
entry.insert(
"mars".to_string(),
ciborium::Value::Text("base-mars".to_string()),
);
let mut extra = BTreeMap::new();
extra.insert(
"mars".to_string(),
ciborium::Value::Text("extra-mars".to_string()),
);
let meta = GlobalMetadata {
base: vec![entry],
extra,
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(
decoded.base[0].get("mars"),
Some(&ciborium::Value::Text("base-mars".to_string()))
);
assert_eq!(
decoded.extra.get("mars"),
Some(&ciborium::Value::Text("extra-mars".to_string()))
);
}
#[test]
fn test_empty_extra_omitted_from_cbor() {
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert!(decoded.extra.is_empty());
}
#[test]
fn test_extra_with_nested_maps_round_trips() {
let nested = ciborium::Value::Map(vec![
(
ciborium::Value::Text("key1".to_string()),
ciborium::Value::Integer(42.into()),
),
(
ciborium::Value::Text("key2".to_string()),
ciborium::Value::Map(vec![(
ciborium::Value::Text("deep".to_string()),
ciborium::Value::Bool(true),
)]),
),
]);
let mut extra = BTreeMap::new();
extra.insert("nested".to_string(), nested.clone());
let meta = GlobalMetadata {
extra,
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert!(decoded.extra.contains_key("nested"));
}
#[test]
fn test_legacy_top_level_keys_routed_to_extra() {
use ciborium::Value;
let cbor = Value::Map(vec![
(Value::Text("common".to_string()), Value::Map(vec![])),
(Value::Text("payload".to_string()), Value::Array(vec![])),
(Value::Text("version".to_string()), Value::Integer(3.into())),
]);
let mut bytes = Vec::new();
ciborium::into_writer(&cbor, &mut bytes).unwrap();
let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
assert!(decoded.base.is_empty());
assert!(decoded.reserved.is_empty());
assert!(decoded.extra.contains_key("common"));
assert!(decoded.extra.contains_key("payload"));
assert_eq!(
decoded.extra.get("version"),
Some(&Value::Integer(3.into()))
);
}
#[test]
fn test_old_reserved_key_name_routed_to_extra() {
use ciborium::Value;
let cbor = Value::Map(vec![(
Value::Text("reserved".to_string()),
Value::Map(vec![(
Value::Text("rogue".to_string()),
Value::Text("value".to_string()),
)]),
)]);
let mut bytes = Vec::new();
ciborium::into_writer(&cbor, &mut bytes).unwrap();
let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
assert!(
decoded.reserved.is_empty(),
"legacy 'reserved' must NOT bleed into library-managed `_reserved_`"
);
assert!(
decoded.extra.contains_key("reserved"),
"legacy 'reserved' key must land in `_extra_`"
);
}
#[test]
fn test_reserved_rejected_in_second_base_entry_only() {
let mut entry0 = BTreeMap::new();
entry0.insert("clean".to_string(), ciborium::Value::Text("ok".to_string()));
let mut entry1 = BTreeMap::new();
entry1.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
let meta = GlobalMetadata {
base: vec![entry0, entry1],
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let result = encode(
&meta,
&[(&desc, data.as_slice()), (&desc, data.as_slice())],
&EncodeOptions::default(),
);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("base[1]"),
"error should mention base[1]: {err}"
);
}
#[test]
fn test_reserved_accepted_when_all_base_entries_clean() {
let mut e0 = BTreeMap::new();
e0.insert(
"key0".to_string(),
ciborium::Value::Text("val0".to_string()),
);
let mut e1 = BTreeMap::new();
e1.insert(
"key1".to_string(),
ciborium::Value::Text("val1".to_string()),
);
let meta = GlobalMetadata {
base: vec![e0, e1],
..Default::default()
};
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(
&meta,
&[(&desc, data.as_slice()), (&desc, data.as_slice())],
&options,
)
.unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(decoded.base.len(), 2);
assert!(decoded.base[0].contains_key("key0"));
assert!(decoded.base[1].contains_key("key1"));
}
#[test]
fn test_reserved_tensor_dtype_strings_for_all_dtypes() {
let dtypes_and_expected = [
(Dtype::Float16, "float16"),
(Dtype::Bfloat16, "bfloat16"),
(Dtype::Float32, "float32"),
(Dtype::Float64, "float64"),
(Dtype::Complex64, "complex64"),
(Dtype::Complex128, "complex128"),
(Dtype::Int8, "int8"),
(Dtype::Int16, "int16"),
(Dtype::Int32, "int32"),
(Dtype::Int64, "int64"),
(Dtype::Uint8, "uint8"),
(Dtype::Uint16, "uint16"),
(Dtype::Uint32, "uint32"),
(Dtype::Uint64, "uint64"),
];
for (dtype, expected_str) in dtypes_and_expected {
let byte_width = dtype.byte_width();
let num_elements: u64 = 4;
let data_len = num_elements as usize * byte_width;
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![num_elements],
strides: vec![1],
dtype,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; data_len];
let meta = GlobalMetadata::default();
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
let reserved = decoded.base[0]
.get("_reserved_")
.unwrap_or_else(|| panic!("_reserved_ missing for dtype {dtype}"));
if let ciborium::Value::Map(pairs) = reserved {
let tensor_entry = pairs
.iter()
.find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
let dtype_val = tensor_pairs
.iter()
.find(|(k, _)| *k == ciborium::Value::Text("dtype".to_string()));
assert!(
matches!(
dtype_val,
Some((_, ciborium::Value::Text(s))) if s == expected_str
),
"dtype for {dtype} should be '{expected_str}', got: {dtype_val:?}"
);
} else {
panic!("tensor missing or not a map for dtype {dtype}");
}
} else {
panic!("_reserved_ is not a map for dtype {dtype}");
}
}
}
#[test]
fn test_global_metadata_serde_all_fields_populated() {
use ciborium::Value;
let mut base_entry = BTreeMap::new();
base_entry.insert("key".to_string(), Value::Text("base_val".to_string()));
let mut reserved = BTreeMap::new();
reserved.insert("encoder".to_string(), Value::Text("test".to_string()));
let mut extra = BTreeMap::new();
extra.insert("custom".to_string(), Value::Integer(42.into()));
let meta = GlobalMetadata {
base: vec![base_entry],
reserved,
extra,
};
let cbor_bytes = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let decoded: GlobalMetadata =
crate::metadata::cbor_to_global_metadata(&cbor_bytes).unwrap();
assert_eq!(decoded.base.len(), 1);
assert_eq!(
decoded.base[0].get("key"),
Some(&Value::Text("base_val".to_string()))
);
assert!(decoded.reserved.contains_key("encoder"));
assert_eq!(
decoded.extra.get("custom"),
Some(&Value::Integer(42.into()))
);
}
#[test]
fn test_provenance_fields_present_after_encode() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert!(decoded.reserved.contains_key("encoder"));
assert!(decoded.reserved.contains_key("time"));
assert!(decoded.reserved.contains_key("uuid"));
if let ciborium::Value::Map(pairs) = decoded.reserved.get("encoder").unwrap() {
let has_name = pairs
.iter()
.any(|(k, _)| *k == ciborium::Value::Text("name".to_string()));
let has_version = pairs
.iter()
.any(|(k, _)| *k == ciborium::Value::Text("version".to_string()));
assert!(has_name, "encoder map should have 'name' key");
assert!(has_version, "encoder map should have 'version' key");
} else {
panic!("encoder should be a map");
}
if let ciborium::Value::Text(uuid_str) = decoded.reserved.get("uuid").unwrap() {
assert_eq!(uuid_str.len(), 36, "UUID should be 36 chars: {uuid_str}");
assert_eq!(
uuid_str.chars().filter(|c| *c == '-').count(),
4,
"UUID should have 4 hyphens: {uuid_str}"
);
} else {
panic!("uuid should be a text");
}
if let ciborium::Value::Text(time_str) = decoded.reserved.get("time").unwrap() {
assert!(
time_str.ends_with('Z'),
"time should end with Z: {time_str}"
);
assert!(
time_str.contains('T'),
"time should contain T separator: {time_str}"
);
} else {
panic!("time should be a text");
}
}
#[test]
fn test_both_reserved_and_reserved_underscore_only_new_captured() {
use ciborium::Value;
let cbor = Value::Map(vec![
(
Value::Text("_reserved_".to_string()),
Value::Map(vec![(
Value::Text("encoder".to_string()),
Value::Text("tensogram".to_string()),
)]),
),
(
Value::Text("reserved".to_string()),
Value::Map(vec![(
Value::Text("old".to_string()),
Value::Text("ignored".to_string()),
)]),
),
(Value::Text("version".to_string()), Value::Integer(3.into())),
]);
let mut bytes = Vec::new();
ciborium::into_writer(&cbor, &mut bytes).unwrap();
let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
assert!(decoded.reserved.contains_key("encoder"));
assert!(!decoded.reserved.contains_key("old"));
}
#[test]
fn test_encode_pre_encoded_roundtrip_simple_packing() {
let desc = make_descriptor(vec![4]);
let raw_data: Vec<u8> = vec![0u8; 4 * 4];
let meta = GlobalMetadata::default();
let options = EncodeOptions::default();
let msg1 = encode(&meta, &[(&desc, raw_data.as_slice())], &options).unwrap();
let (_, objects1) = decode(&msg1, &DecodeOptions::default()).unwrap();
let (decoded_desc1, decoded_payload1) = &objects1[0];
let msg2 = encode_pre_encoded(
&meta,
&[(&decoded_desc1.clone(), decoded_payload1.as_slice())],
&options,
)
.unwrap();
let (_, objects2) = decode(&msg2, &DecodeOptions::default()).unwrap();
let (_, decoded_payload2) = &objects2[0];
assert_eq!(
decoded_payload1, decoded_payload2,
"decoded payloads should be equal after encode/re-encode roundtrip"
);
}
#[test]
fn test_encode_pre_encoded_populates_inline_hash_slot() {
use crate::framing::{decode_message, scan};
use crate::hash::verify_frame_hash;
use crate::wire::{FrameHeader, MessageFlags, Preamble};
let desc = make_descriptor(vec![2]);
let data = vec![0xABu8; 8];
let meta = GlobalMetadata::default();
let options = EncodeOptions::default();
let msg = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options).unwrap();
let preamble = Preamble::read_from(&msg).unwrap();
assert!(preamble.flags.has(MessageFlags::HASHES_PRESENT));
let messages = scan(&msg);
assert_eq!(messages.len(), 1);
let (offset, len) = messages[0];
let only_msg = &msg[offset..offset + len];
let decoded = decode_message(only_msg).unwrap();
for (_, _, _, frame_offset) in &decoded.objects {
let frame = &only_msg[*frame_offset..];
let fh = FrameHeader::read_from(frame).unwrap();
let frame_bytes = &frame[..fh.total_length as usize];
verify_frame_hash(frame_bytes, fh.frame_type, None)
.expect("inline hash slot must verify against body");
}
}
#[test]
fn test_validate_szip_block_offsets_happy_path() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![0u64, 100, 200].into_iter().map(|n| n.into()).collect()),
);
assert!(validate_szip_block_offsets(¶ms, 100).is_ok());
}
#[test]
fn test_validate_szip_block_offsets_missing_key() {
let params = BTreeMap::new();
let err = validate_szip_block_offsets(¶ms, 100)
.unwrap_err()
.to_string();
assert!(
err.contains("missing") && err.contains("szip_block_offsets"),
"error: {err}"
);
}
#[test]
fn test_validate_szip_block_offsets_not_array() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Integer(0.into()),
);
let err = validate_szip_block_offsets(¶ms, 100)
.unwrap_err()
.to_string();
assert!(
err.contains("array") && err.contains("szip_block_offsets"),
"error: {err}"
);
}
#[test]
fn test_validate_szip_block_offsets_non_integer_element() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![
ciborium::Value::Integer(0.into()),
ciborium::Value::Text("x".to_string()),
]),
);
let err = validate_szip_block_offsets(¶ms, 100)
.unwrap_err()
.to_string();
assert!(
err.contains("integer") && err.contains("szip_block_offsets"),
"error: {err}"
);
}
#[test]
fn test_validate_szip_block_offsets_nonzero_first() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![5u64, 100, 200].into_iter().map(|n| n.into()).collect()),
);
let err = validate_szip_block_offsets(¶ms, 100)
.unwrap_err()
.to_string();
assert!(
err.contains("must be 0") && err.contains("got 5"),
"error: {err}"
);
}
#[test]
fn test_validate_szip_block_offsets_non_monotonic() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![0u64, 100, 50].into_iter().map(|n| n.into()).collect()),
);
let err = validate_szip_block_offsets(¶ms, 100)
.unwrap_err()
.to_string();
assert!(
err.contains("increasing") || err.contains("monotonic"),
"error: {err}"
);
}
#[test]
fn test_validate_szip_block_offsets_offset_beyond_bound() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![0u64, 100, 801].into_iter().map(|n| n.into()).collect()),
);
let err = validate_szip_block_offsets(¶ms, 100)
.unwrap_err()
.to_string();
assert!(err.contains("800") && err.contains("801"), "error: {err}");
}
#[test]
fn test_validate_no_szip_offsets_for_non_szip_rejects() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "zstd".to_string(),
params,
masks: None,
};
let err = validate_no_szip_offsets_for_non_szip(&desc)
.unwrap_err()
.to_string();
assert!(
err.contains("szip_block_offsets") && err.contains("zstd"),
"error: {err}"
);
}
#[test]
fn test_validate_no_szip_offsets_for_non_szip_allows_szip() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
assert!(validate_no_szip_offsets_for_non_szip(&desc).is_ok());
}
#[test]
fn aggregate_hash_policy_default_is_auto() {
assert_eq!(AggregateHashPolicy::default(), AggregateHashPolicy::Auto);
}
#[test]
fn aggregate_hash_policy_buffered_resolves_auto_to_header() {
assert_eq!(
AggregateHashPolicy::Auto.resolved_buffered(),
AggregateHashPolicy::Header
);
assert_eq!(
AggregateHashPolicy::None.resolved_buffered(),
AggregateHashPolicy::None
);
assert_eq!(
AggregateHashPolicy::Footer.resolved_buffered(),
AggregateHashPolicy::Footer
);
assert_eq!(
AggregateHashPolicy::Both.resolved_buffered(),
AggregateHashPolicy::Both
);
}
#[test]
fn aggregate_hash_policy_streaming_rejects_header() {
let err = AggregateHashPolicy::Header
.resolved_streaming()
.unwrap_err();
assert!(matches!(err, TensogramError::Encoding(_)));
}
#[test]
fn aggregate_hash_policy_streaming_rejects_both() {
let err = AggregateHashPolicy::Both.resolved_streaming().unwrap_err();
assert!(matches!(err, TensogramError::Encoding(_)));
}
#[test]
fn aggregate_hash_policy_streaming_resolves_auto_to_footer() {
assert_eq!(
AggregateHashPolicy::Auto.resolved_streaming().unwrap(),
AggregateHashPolicy::Footer
);
}
#[test]
fn aggregate_hash_policy_streaming_accepts_explicit_footer_and_none() {
assert_eq!(
AggregateHashPolicy::Footer.resolved_streaming().unwrap(),
AggregateHashPolicy::Footer
);
assert_eq!(
AggregateHashPolicy::None.resolved_streaming().unwrap(),
AggregateHashPolicy::None
);
}
#[test]
fn aggregate_hash_policy_emits_flags() {
assert!(AggregateHashPolicy::Header.emits_header());
assert!(!AggregateHashPolicy::Header.emits_footer());
assert!(!AggregateHashPolicy::Footer.emits_header());
assert!(AggregateHashPolicy::Footer.emits_footer());
assert!(AggregateHashPolicy::Both.emits_header());
assert!(AggregateHashPolicy::Both.emits_footer());
assert!(!AggregateHashPolicy::None.emits_header());
assert!(!AggregateHashPolicy::None.emits_footer());
}
#[test]
fn get_i64_param_or_default_returns_default_on_absent() {
let params = BTreeMap::new();
assert_eq!(
get_i64_param_or_default(¶ms, "zstd_level", 3).unwrap(),
3
);
}
#[test]
fn get_i64_param_or_default_returns_present_value() {
let mut params = BTreeMap::new();
params.insert(
"zstd_level".to_string(),
ciborium::Value::Integer(7i64.into()),
);
assert_eq!(
get_i64_param_or_default(¶ms, "zstd_level", 3).unwrap(),
7
);
}
#[test]
fn get_i64_param_or_default_rejects_wrong_type() {
let mut params = BTreeMap::new();
params.insert(
"zstd_level".to_string(),
ciborium::Value::Text("high".to_string()),
);
let err = get_i64_param_or_default(¶ms, "zstd_level", 3).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("expected integer"), "msg: {msg}");
assert!(msg.contains("zstd_level"), "msg: {msg}");
assert!(msg.contains("default"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_text_param_or_default_returns_default_on_absent() {
let params = BTreeMap::new();
assert_eq!(
get_text_param_or_default(¶ms, "blosc2_codec", "lz4").unwrap(),
"lz4"
);
}
#[test]
fn get_text_param_or_default_returns_present_value() {
let mut params = BTreeMap::new();
params.insert(
"blosc2_codec".to_string(),
ciborium::Value::Text("zstd".to_string()),
);
assert_eq!(
get_text_param_or_default(¶ms, "blosc2_codec", "lz4").unwrap(),
"zstd"
);
}
#[test]
fn get_text_param_or_default_rejects_wrong_type() {
let mut params = BTreeMap::new();
params.insert(
"blosc2_codec".to_string(),
ciborium::Value::Integer(5i64.into()),
);
let err = get_text_param_or_default(¶ms, "blosc2_codec", "lz4").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("expected text"), "msg: {msg}");
assert!(msg.contains("blosc2_codec"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
fn make_mask_desc(method: &str, params: BTreeMap<String, ciborium::Value>) -> MaskDescriptor {
MaskDescriptor {
method: method.to_string(),
offset: 0,
length: 1,
params,
}
}
#[test]
fn validate_mask_params_accepts_empty_for_paramless_methods() {
for m in ["none", "rle", "roaring", "lz4"] {
let masks = MasksMetadata {
nan: Some(make_mask_desc(m, BTreeMap::new())),
..Default::default()
};
assert!(
validate_mask_params(&masks).is_ok(),
"method {m} must accept empty params"
);
}
}
#[test]
fn validate_mask_params_accepts_zstd_level() {
let mut params = BTreeMap::new();
params.insert("level".to_string(), ciborium::Value::Integer(3i64.into()));
let masks = MasksMetadata {
nan: Some(make_mask_desc("zstd", params)),
..Default::default()
};
assert!(validate_mask_params(&masks).is_ok());
}
#[test]
fn validate_mask_params_rejects_unknown_method() {
let masks = MasksMetadata {
nan: Some(make_mask_desc("snappy", BTreeMap::new())),
..Default::default()
};
let err = validate_mask_params(&masks).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("unknown method"), "msg: {msg}");
assert!(msg.contains("snappy"), "msg: {msg}");
assert!(msg.contains("expected one of"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_mask_params_rejects_unknown_param_for_paramless_method() {
let mut params = BTreeMap::new();
params.insert("level".to_string(), ciborium::Value::Integer(5i64.into()));
let masks = MasksMetadata {
pos_inf: Some(make_mask_desc("rle", params)),
..Default::default()
};
let err = validate_mask_params(&masks).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("unknown param"), "msg: {msg}");
assert!(msg.contains("level"), "msg: {msg}");
assert!(msg.contains("rle"), "msg: {msg}");
assert!(msg.contains("inf+"), "kind tag missing: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_f64_param_accepts_integer_within_exact_range() {
let mut params = BTreeMap::new();
params.insert(
"tol".to_string(),
ciborium::Value::Integer((1i64 << 53).into()),
);
assert_eq!(get_f64_param(¶ms, "tol").unwrap(), (1u64 << 53) as f64);
}
#[test]
fn get_f64_param_rejects_integer_beyond_exact_range() {
let mut params = BTreeMap::new();
let too_big = i64::from((1u32 << 30) - 1) << 24; params.insert("tol".to_string(), ciborium::Value::Integer(too_big.into()));
let err = get_f64_param(¶ms, "tol").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("exact-representable"), "msg: {msg}");
assert!(msg.contains("tol"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_f64_param_accepts_negative_integer_within_range() {
let mut params = BTreeMap::new();
params.insert(
"tol".to_string(),
ciborium::Value::Integer((-(1i64 << 53)).into()),
);
assert_eq!(
get_f64_param(¶ms, "tol").unwrap(),
-((1u64 << 53) as f64)
);
}
#[test]
fn get_f64_param_rejects_large_negative_integer() {
let mut params = BTreeMap::new();
let too_neg = -(i64::from((1u32 << 30) - 1) << 24);
params.insert("tol".to_string(), ciborium::Value::Integer(too_neg.into()));
let err = get_f64_param(¶ms, "tol").unwrap_err();
assert!(matches!(err, TensogramError::Metadata(_)));
}
#[test]
fn validate_mask_params_rejects_typo_param() {
let mut params = BTreeMap::new();
params.insert("levle".to_string(), ciborium::Value::Integer(3i64.into()));
let masks = MasksMetadata {
neg_inf: Some(make_mask_desc("zstd", params)),
..Default::default()
};
let err = validate_mask_params(&masks).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("unknown param"), "msg: {msg}");
assert!(msg.contains("levle"), "msg: {msg}");
assert!(msg.contains("zstd"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_object_bitmask_data_len_mismatch_rejected() {
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![20],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let err = validate_object(&desc, 2).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("bitmask"), "msg: {msg}");
assert!(msg.contains('3'), "expected ceil(20/8)=3: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_object_bitmask_data_len_match_ok() {
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![20],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
assert!(validate_object(&desc, 3).is_ok());
}
#[test]
fn validate_object_rejects_invalid_mask_descriptor() {
let masks = MasksMetadata {
nan: Some(make_mask_desc("bogus", BTreeMap::new())),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: Some(masks),
};
let err = validate_object(&desc, 8).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("unknown method"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
fn float64_desc(
encoding: &str,
params: BTreeMap<String, ciborium::Value>,
) -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: encoding.to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params,
masks: None,
}
}
#[test]
fn resolve_encoding_simple_packing_rejects_non_float64() {
let mut desc = float64_desc("simple_packing", BTreeMap::new());
desc.dtype = Dtype::Float32;
let err = resolve_encoding(&desc, Dtype::Float32).unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("simple_packing"), "msg: {msg}");
assert!(msg.contains("float64"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn resolve_encoding_rejects_unknown_encoding() {
let desc = float64_desc("totally_unknown", BTreeMap::new());
let err = resolve_encoding(&desc, Dtype::Float64).unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("unknown encoding"), "msg: {msg}");
assert!(msg.contains("totally_unknown"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn resolve_filter_rejects_unknown_filter() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.filter = "rot13".to_string();
let err = resolve_filter(&desc).unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("unknown filter"), "msg: {msg}");
assert!(msg.contains("rot13"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn resolve_filter_shuffle_missing_element_size_rejected() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.filter = "shuffle".to_string();
let err = resolve_filter(&desc).unwrap_err();
assert!(matches!(err, TensogramError::Metadata(_)));
}
#[test]
fn resolve_filter_shuffle_happy_path() {
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4i64.into()),
);
let mut desc = float64_desc("none", params);
desc.filter = "shuffle".to_string();
let f = resolve_filter(&desc).unwrap();
assert!(matches!(f, FilterType::Shuffle { element_size: 4 }));
}
#[test]
fn resolve_compression_rejects_unknown_compression() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.compression = "magic".to_string();
let enc = EncodingType::None;
let filt = FilterType::None;
let err = resolve_compression(&desc, Dtype::Float64, &enc, &filt).unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("unknown compression"), "msg: {msg}");
assert!(msg.contains("magic"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn resolve_compression_rle_rejects_non_bitmask() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.compression = "rle".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("rle"), "msg: {msg}");
assert!(msg.contains("bitmask"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn resolve_compression_roaring_rejects_non_bitmask() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.compression = "roaring".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("roaring"), "msg: {msg}");
assert!(msg.contains("bitmask"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn resolve_compression_szip_rsi_out_of_u32_range() {
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer((u64::from(u32::MAX) + 1).into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "szip".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("szip_rsi"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn resolve_compression_szip_block_size_out_of_u32_range() {
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer((u64::from(u32::MAX) + 1).into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "szip".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("szip_block_size"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn resolve_compression_szip_flags_out_of_u32_range() {
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer((u64::from(u32::MAX) + 1).into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "szip".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("szip_flags"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(any(feature = "zstd", feature = "zstd-pure"))]
#[test]
fn resolve_compression_zstd_level_out_of_i32_range() {
let mut params = BTreeMap::new();
params.insert(
"zstd_level".to_string(),
ciborium::Value::Integer((i64::from(i32::MAX) + 1).into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "zstd".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("zstd_level"), "msg: {msg}");
assert!(msg.contains("i32"), "msg: {msg}");
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(feature = "blosc2")]
#[test]
fn resolve_compression_blosc2_clevel_out_of_i32_range() {
let mut params = BTreeMap::new();
params.insert(
"blosc2_codec".to_string(),
ciborium::Value::Text("lz4".to_string()),
);
params.insert(
"blosc2_clevel".to_string(),
ciborium::Value::Integer((i64::from(i32::MAX) + 1).into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "blosc2".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("blosc2_clevel"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(feature = "blosc2")]
#[test]
fn resolve_compression_blosc2_unknown_codec_rejected() {
let mut params = BTreeMap::new();
params.insert(
"blosc2_codec".to_string(),
ciborium::Value::Text("snappy".to_string()),
);
let mut desc = float64_desc("none", params);
desc.compression = "blosc2".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("unknown blosc2 codec"), "msg: {msg}");
assert!(msg.contains("snappy"), "msg: {msg}");
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[cfg(feature = "zfp")]
#[test]
fn resolve_compression_zfp_missing_mode_rejected() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.compression = "zfp".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("zfp_mode"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(feature = "zfp")]
#[test]
fn resolve_compression_zfp_precision_out_of_u32_range() {
let mut params = BTreeMap::new();
params.insert(
"zfp_mode".to_string(),
ciborium::Value::Text("fixed_precision".to_string()),
);
params.insert(
"zfp_precision".to_string(),
ciborium::Value::Integer((u64::from(u32::MAX) + 1).into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "zfp".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("zfp_precision"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(feature = "zfp")]
#[test]
fn resolve_compression_zfp_unknown_mode_rejected() {
let mut params = BTreeMap::new();
params.insert(
"zfp_mode".to_string(),
ciborium::Value::Text("wobble".to_string()),
);
let mut desc = float64_desc("none", params);
desc.compression = "zfp".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("unknown zfp_mode"), "msg: {msg}")
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[cfg(feature = "sz3")]
#[test]
fn resolve_compression_sz3_missing_mode_rejected() {
let mut desc = float64_desc("none", BTreeMap::new());
desc.compression = "sz3".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("sz3_error_bound_mode"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[cfg(feature = "sz3")]
#[test]
fn resolve_compression_sz3_unknown_mode_rejected() {
let mut params = BTreeMap::new();
params.insert(
"sz3_error_bound_mode".to_string(),
ciborium::Value::Text("bogus".to_string()),
);
params.insert("sz3_error_bound".to_string(), ciborium::Value::Float(1e-3));
let mut desc = float64_desc("none", params);
desc.compression = "sz3".to_string();
let err = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(msg.contains("unknown sz3_error_bound_mode"), "msg: {msg}")
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn extract_simple_packing_params_rejects_nan_reference() {
let mut params = BTreeMap::new();
params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(f64::NAN),
);
params.insert(
"sp_binary_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_decimal_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let err = extract_simple_packing_params(¶ms).unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("finite"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn extract_simple_packing_params_rejects_decimal_scale_out_of_i32() {
let mut params = BTreeMap::new();
params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(0.0),
);
params.insert(
"sp_binary_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_decimal_scale_factor".to_string(),
ciborium::Value::Integer((i64::from(i32::MAX) + 1).into()),
);
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let err = extract_simple_packing_params(¶ms).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("sp_decimal_scale_factor"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn extract_simple_packing_params_rejects_binary_scale_out_of_i32() {
let mut params = BTreeMap::new();
params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(0.0),
);
params.insert(
"sp_binary_scale_factor".to_string(),
ciborium::Value::Integer((i64::from(i32::MAX) + 1).into()),
);
params.insert(
"sp_decimal_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let err = extract_simple_packing_params(¶ms).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("sp_binary_scale_factor"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn extract_simple_packing_params_rejects_bits_per_value_out_of_u32() {
let mut params = BTreeMap::new();
params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(0.0),
);
params.insert(
"sp_binary_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_decimal_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer((u64::from(u32::MAX) + 1).into()),
);
let err = extract_simple_packing_params(¶ms).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("sp_bits_per_value"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn resolve_simple_packing_params_noop_for_non_simple_packing() {
let mut desc = float64_desc("none", BTreeMap::new());
resolve_simple_packing_params(&mut desc, &[0u8; 8]).unwrap();
assert!(desc.params.is_empty());
}
#[test]
fn resolve_simple_packing_params_rejects_decimal_scale_out_of_i32() {
let mut params = BTreeMap::new();
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
params.insert(
"sp_decimal_scale_factor".to_string(),
ciborium::Value::Integer((i64::from(i32::MAX) + 1).into()),
);
let mut desc = float64_desc("simple_packing", params);
let data: Vec<u8> = (0..4).flat_map(|i| (i as f64).to_le_bytes()).collect();
let err = resolve_simple_packing_params(&mut desc, &data).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("sp_decimal_scale_factor"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn resolve_simple_packing_params_auto_computes_and_stamps_four_keys() {
let mut params = BTreeMap::new();
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let mut desc = float64_desc("simple_packing", params);
let data: Vec<u8> = (0..8).flat_map(|i| (i as f64).to_le_bytes()).collect();
resolve_simple_packing_params(&mut desc, &data).unwrap();
assert!(desc.params.contains_key("sp_reference_value"));
assert!(desc.params.contains_key("sp_binary_scale_factor"));
assert!(desc.params.contains_key("sp_decimal_scale_factor"));
assert!(desc.params.contains_key("sp_bits_per_value"));
}
#[test]
fn resolve_simple_packing_params_explicit_pair_defaults_decimal() {
let mut params = BTreeMap::new();
params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(0.0),
);
params.insert(
"sp_binary_scale_factor".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let mut desc = float64_desc("simple_packing", params);
resolve_simple_packing_params(&mut desc, &[0u8; 8]).unwrap();
assert_eq!(
desc.params.get("sp_decimal_scale_factor"),
Some(&ciborium::Value::Integer(0i64.into()))
);
}
#[test]
fn resolve_simple_packing_params_partial_pair_rejected() {
let mut params = BTreeMap::new();
params.insert(
"sp_reference_value".to_string(),
ciborium::Value::Float(0.0),
);
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let mut desc = float64_desc("simple_packing", params);
let err = resolve_simple_packing_params(&mut desc, &[0u8; 8]).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("simple_packing"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn bytes_as_f64_vec_rejects_non_multiple_of_8() {
let err = bytes_as_f64_vec(&[0u8; 7], ByteOrder::Little).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("multiple of 8"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn bytes_as_f64_vec_big_endian_roundtrip() {
let v = bytes_as_f64_vec(&1.5f64.to_be_bytes(), ByteOrder::Big).unwrap();
assert_eq!(v, vec![1.5]);
}
#[test]
fn get_f64_param_rejects_wrong_type() {
let mut params = BTreeMap::new();
params.insert("tol".to_string(), ciborium::Value::Text("x".to_string()));
let err = get_f64_param(¶ms, "tol").unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("expected number"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_f64_param_missing_key_rejected() {
let params = BTreeMap::new();
let err = get_f64_param(¶ms, "tol").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("missing required parameter"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_i64_param_rejects_wrong_type() {
let mut params = BTreeMap::new();
params.insert("n".to_string(), ciborium::Value::Float(1.5));
let err = get_i64_param(¶ms, "n").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("expected integer"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_i64_param_missing_key_rejected() {
let params = BTreeMap::new();
let err = get_i64_param(¶ms, "n").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("missing required parameter"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_i64_param_or_default_rejects_out_of_i64_range() {
let mut params = BTreeMap::new();
params.insert("n".to_string(), ciborium::Value::Integer((u64::MAX).into()));
let err = get_i64_param_or_default(¶ms, "n", 0).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("out of i64 range"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_i64_param_rejects_out_of_i64_range() {
let mut params = BTreeMap::new();
params.insert("n".to_string(), ciborium::Value::Integer((u64::MAX).into()));
let err = get_i64_param(¶ms, "n").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("out of i64 range"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_u64_param_rejects_negative() {
let mut params = BTreeMap::new();
params.insert("n".to_string(), ciborium::Value::Integer((-1i64).into()));
let err = get_u64_param(¶ms, "n").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("out of u64 range"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_u64_param_rejects_wrong_type() {
let mut params = BTreeMap::new();
params.insert("n".to_string(), ciborium::Value::Float(1.0));
let err = get_u64_param(¶ms, "n").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("expected integer"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn get_u64_param_missing_key_rejected() {
let params = BTreeMap::new();
let err = get_u64_param(¶ms, "n").unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("missing required parameter"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_szip_block_offsets_empty_rejected() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![]),
);
let err = validate_szip_block_offsets(¶ms, 100).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("must not be empty"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_szip_block_offsets_negative_element_rejected() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![ciborium::Value::Integer((-1i64).into())]),
);
let err = validate_szip_block_offsets(¶ms, 100).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("out of u64 range"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn compose_payload_region_empty_masks_is_passthrough() {
let payload = vec![1u8, 2, 3, 4];
let (region, meta) = compose_payload_region(
payload.clone(),
MaskSet::empty(0),
&MaskMethod::Roaring,
&MaskMethod::Roaring,
&MaskMethod::Roaring,
128,
)
.unwrap();
assert_eq!(region, payload);
assert!(meta.is_none());
}
#[test]
fn encode_nan_without_allow_nan_is_hard_error() {
let desc = make_descriptor(vec![2]);
let mut data = Vec::new();
data.extend_from_slice(&f32::NAN.to_le_bytes());
data.extend_from_slice(&1.0f32.to_le_bytes());
let options = EncodeOptions {
hashing: false,
allow_nan: false,
allow_inf: false,
..Default::default()
};
let err = encode(&meta_default(), &[(&desc, data.as_slice())], &options);
assert!(err.is_err(), "NaN without allow_nan must error");
}
#[test]
fn encode_inf_without_allow_inf_is_hard_error() {
let desc = make_descriptor(vec![2]);
let mut data = Vec::new();
data.extend_from_slice(&f32::INFINITY.to_le_bytes());
data.extend_from_slice(&1.0f32.to_le_bytes());
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let err = encode(&meta_default(), &[(&desc, data.as_slice())], &options);
assert!(err.is_err(), "Inf without allow_inf must error");
}
#[test]
fn encode_nan_with_allow_nan_produces_mask_and_roundtrips() {
let desc = make_descriptor(vec![2048]);
let mut data = Vec::with_capacity(2048 * 4);
for i in 0..2048u32 {
if i % 3 == 0 {
data.extend_from_slice(&f32::NAN.to_le_bytes());
} else {
data.extend_from_slice(&(i as f32).to_le_bytes());
}
}
let options = EncodeOptions {
hashing: false,
allow_nan: true,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap();
let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 1);
let (decoded_desc, _) = &objects[0];
assert!(
decoded_desc.masks.is_some(),
"decoded descriptor should carry a NaN mask"
);
}
fn meta_default() -> GlobalMetadata {
GlobalMetadata::default()
}
#[test]
fn encode_one_mask_none_method_packs_raw() {
let bits = vec![true, false, true, false];
let (blob, used) = encode_one_mask(&bits, MaskMethod::Roaring, 128).unwrap();
assert!(matches!(used, MaskMethod::None));
assert!(!blob.is_empty());
}
#[test]
fn encode_one_mask_rle_method() {
let bits = vec![true; 2048];
let (blob, used) = encode_one_mask(&bits, MaskMethod::Rle, 0).unwrap();
assert!(matches!(used, MaskMethod::Rle));
assert!(!blob.is_empty());
}
#[test]
fn encode_one_mask_roaring_method() {
let mut bits = vec![false; 2048];
bits[5] = true;
bits[1000] = true;
let (blob, used) = encode_one_mask(&bits, MaskMethod::Roaring, 0).unwrap();
assert!(matches!(used, MaskMethod::Roaring));
assert!(!blob.is_empty());
}
#[cfg(feature = "lz4")]
#[test]
fn encode_one_mask_lz4_method() {
let bits = vec![true; 2048];
let (blob, used) = encode_one_mask(&bits, MaskMethod::Lz4, 0).unwrap();
assert!(matches!(used, MaskMethod::Lz4));
assert!(!blob.is_empty());
}
#[test]
fn encode_one_mask_zstd_method() {
let bits = vec![true; 2048];
let (blob, used) = encode_one_mask(&bits, MaskMethod::Zstd { level: Some(3) }, 0).unwrap();
assert!(matches!(used, MaskMethod::Zstd { .. }));
assert!(!blob.is_empty());
}
#[cfg(feature = "blosc2")]
#[test]
fn encode_one_mask_blosc2_method() {
use tensogram_encodings::pipeline::Blosc2Codec;
let bits = vec![true; 2048];
let (blob, used) = encode_one_mask(
&bits,
MaskMethod::Blosc2 {
codec: Blosc2Codec::Lz4,
level: 5,
},
0,
)
.unwrap();
assert!(matches!(used, MaskMethod::Blosc2 { .. }));
assert!(!blob.is_empty());
}
#[test]
fn mask_params_cbor_paramless_methods_empty() {
for m in [
MaskMethod::None,
MaskMethod::Rle,
MaskMethod::Roaring,
MaskMethod::Lz4,
] {
assert!(
mask_params_cbor(&m).is_empty(),
"method {m:?} must be paramless"
);
}
}
#[test]
fn mask_params_cbor_zstd_with_level() {
let params = mask_params_cbor(&MaskMethod::Zstd { level: Some(7) });
assert_eq!(
params.get("level"),
Some(&ciborium::Value::Integer(7i64.into()))
);
}
#[test]
fn mask_params_cbor_zstd_without_level_empty() {
let params = mask_params_cbor(&MaskMethod::Zstd { level: None });
assert!(params.is_empty());
}
#[cfg(feature = "blosc2")]
#[test]
fn mask_params_cbor_blosc2_all_codecs() {
use tensogram_encodings::pipeline::Blosc2Codec;
let codecs = [
(Blosc2Codec::Blosclz, "blosclz"),
(Blosc2Codec::Lz4, "lz4"),
(Blosc2Codec::Lz4hc, "lz4hc"),
(Blosc2Codec::Zlib, "zlib"),
(Blosc2Codec::Zstd, "zstd"),
];
for (codec, name) in codecs {
let params = mask_params_cbor(&MaskMethod::Blosc2 { codec, level: 4 });
assert_eq!(
params.get("codec"),
Some(&ciborium::Value::Text(name.to_string()))
);
assert_eq!(
params.get("level"),
Some(&ciborium::Value::Integer(4i64.into()))
);
}
}
#[test]
fn compose_payload_region_all_three_masks_with_distinct_methods() {
let mk = |seed: usize| -> Vec<bool> {
(0..2048).map(|i| (i + seed).is_multiple_of(7)).collect()
};
let masks = MaskSet {
nan: Some(mk(0)),
pos_inf: Some(mk(1)),
neg_inf: Some(mk(2)),
n_elements: 2048,
};
let payload = vec![0xAAu8; 64];
let (region, meta) = compose_payload_region(
payload.clone(),
masks,
&MaskMethod::Roaring,
&MaskMethod::Rle,
&MaskMethod::Zstd { level: Some(1) },
0,
)
.unwrap();
let meta = meta.expect("masks present");
assert!(region.len() > payload.len());
let nan = meta.nan.expect("nan");
let pos = meta.pos_inf.expect("pos_inf");
let neg = meta.neg_inf.expect("neg_inf");
assert_eq!(nan.method, "roaring");
assert_eq!(pos.method, "rle");
assert_eq!(neg.method, "zstd");
assert_eq!(nan.offset, payload.len() as u64);
assert!(pos.offset > nan.offset);
assert!(neg.offset > pos.offset);
}
#[test]
fn validate_object_rejects_empty_obj_type() {
let mut desc = make_descriptor(vec![2]);
desc.obj_type = String::new();
let err = validate_object(&desc, 8).unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("obj_type"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_object_rejects_ndim_shape_mismatch() {
let mut desc = make_descriptor(vec![2, 3]);
desc.ndim = 3; let err = validate_object(&desc, 24).unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("ndim"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_object_rejects_strides_shape_mismatch() {
let mut desc = make_descriptor(vec![2, 3]);
desc.strides = vec![1]; let err = validate_object(&desc, 24).unwrap_err();
match err {
TensogramError::Metadata(msg) => assert!(msg.contains("strides.len()"), "msg: {msg}"),
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn validate_object_rejects_data_len_mismatch() {
let desc = make_descriptor(vec![4]);
let err = validate_object(&desc, 8).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("does not match expected"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn resolve_simple_packing_params_rejects_non_float64_dtype() {
let mut params = BTreeMap::new();
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let mut desc = float64_desc("simple_packing", params);
desc.dtype = Dtype::Float32;
let err = resolve_simple_packing_params(&mut desc, &[0u8; 4]).unwrap_err();
match err {
TensogramError::Encoding(msg) => {
assert!(
msg.contains("simple_packing only supports float64"),
"msg: {msg}"
)
}
other => panic!("expected Encoding error, got: {other:?}"),
}
}
#[test]
fn resolve_simple_packing_params_rejects_missing_bits_per_value() {
let desc_params = BTreeMap::new();
let mut desc = float64_desc("simple_packing", desc_params);
let err = resolve_simple_packing_params(&mut desc, &[0u8; 8]).unwrap_err();
match err {
TensogramError::Metadata(msg) => {
assert!(msg.contains("sp_bits_per_value"), "msg: {msg}")
}
other => panic!("expected Metadata error, got: {other:?}"),
}
}
#[test]
fn resolve_simple_packing_params_rejects_non_finite_data() {
let mut params = BTreeMap::new();
params.insert(
"sp_bits_per_value".to_string(),
ciborium::Value::Integer(16i64.into()),
);
let mut desc = float64_desc("simple_packing", params);
let mut data = Vec::new();
data.extend_from_slice(&f64::NAN.to_le_bytes());
data.extend_from_slice(&1.0f64.to_le_bytes());
let err = resolve_simple_packing_params(&mut desc, &data).unwrap_err();
assert!(matches!(err, TensogramError::Encoding(_)));
}
fn bitmask_desc(compression: &str, n_bits: u64) -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![n_bits],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: compression.to_string(),
params: BTreeMap::new(),
masks: None,
}
}
#[test]
fn encode_bitmask_rle_round_trips() {
let desc = bitmask_desc("rle", 64);
let data = vec![0xFFu8; 8];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap();
let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn encode_bitmask_roaring_round_trips() {
let desc = bitmask_desc("roaring", 64);
let mut data = vec![0u8; 8];
data[0] = 0b1010_1010;
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap();
let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn encode_szip_with_shuffle_filter_round_trips() {
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4i64.into()),
);
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![256],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
let data: Vec<u8> = (0..256u32).flat_map(|i| (i as f32).to_le_bytes()).collect();
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap();
let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[cfg(feature = "blosc2")]
#[test]
fn encode_blosc2_with_shuffle_filter_round_trips() {
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4i64.into()),
);
params.insert(
"blosc2_codec".to_string(),
ciborium::Value::Text("lz4".to_string()),
);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![256],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "blosc2".to_string(),
params,
masks: None,
};
let data: Vec<u8> = (0..256u32).flat_map(|i| (i as f32).to_le_bytes()).collect();
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg = encode(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap();
let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn resolve_compression_szip_none_none_uses_dtype_bit_width() {
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "szip".to_string();
let c = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap();
match c {
CompressionType::Szip {
bits_per_sample, ..
} => {
assert_eq!(bits_per_sample, 64); }
other => panic!("expected Szip, got: {other:?}"),
}
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn resolve_compression_szip_simple_packing_uses_pack_bits() {
use tensogram_encodings::simple_packing::SimplePackingParams;
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
let mut desc = float64_desc("none", params);
desc.compression = "szip".to_string();
let enc = EncodingType::SimplePacking(SimplePackingParams {
reference_value: 0.0,
binary_scale_factor: 0,
decimal_scale_factor: 0,
bits_per_value: 12,
});
let c = resolve_compression(&desc, Dtype::Float64, &enc, &FilterType::None).unwrap();
match c {
CompressionType::Szip {
bits_per_sample, ..
} => assert_eq!(bits_per_sample, 12),
other => panic!("expected Szip, got: {other:?}"),
}
}
#[cfg(feature = "blosc2")]
#[test]
fn resolve_compression_blosc2_simple_packing_typesize_rounds_up() {
use tensogram_encodings::simple_packing::SimplePackingParams;
let mut params = BTreeMap::new();
params.insert(
"blosc2_codec".to_string(),
ciborium::Value::Text("lz4".to_string()),
);
let mut desc = float64_desc("none", params);
desc.compression = "blosc2".to_string();
let enc = EncodingType::SimplePacking(SimplePackingParams {
reference_value: 0.0,
binary_scale_factor: 0,
decimal_scale_factor: 0,
bits_per_value: 12,
});
let c = resolve_compression(&desc, Dtype::Float64, &enc, &FilterType::None).unwrap();
match c {
CompressionType::Blosc2 { typesize, .. } => assert_eq!(typesize, 2),
other => panic!("expected Blosc2, got: {other:?}"),
}
}
#[cfg(feature = "zfp")]
#[test]
fn resolve_compression_zfp_fixed_rate_and_accuracy() {
for (mode, key) in [
("fixed_rate", "zfp_rate"),
("fixed_accuracy", "zfp_tolerance"),
] {
let mut params = BTreeMap::new();
params.insert(
"zfp_mode".to_string(),
ciborium::Value::Text(mode.to_string()),
);
params.insert(key.to_string(), ciborium::Value::Float(8.0));
let mut desc = float64_desc("none", params);
desc.compression = "zfp".to_string();
let c = resolve_compression(
&desc,
Dtype::Float64,
&EncodingType::None,
&FilterType::None,
)
.unwrap();
assert!(matches!(c, CompressionType::Zfp { .. }), "mode {mode}");
}
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn encode_pre_encoded_szip_validates_block_offsets() {
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![ciborium::Value::Integer(0i64.into())]),
);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
let data = vec![0u8; 16];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let msg =
encode_pre_encoded(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap();
assert!(!msg.is_empty());
}
#[cfg(any(feature = "szip", feature = "szip-pure"))]
#[test]
fn encode_pre_encoded_szip_rejects_offset_beyond_bound() {
let mut params = BTreeMap::new();
params.insert(
"szip_rsi".to_string(),
ciborium::Value::Integer(128i64.into()),
);
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(32i64.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(0i64.into()),
);
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![
ciborium::Value::Integer(0i64.into()),
ciborium::Value::Integer(9999i64.into()),
]),
);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
let data = vec![0u8; 16];
let options = EncodeOptions {
hashing: false,
..Default::default()
};
let err =
encode_pre_encoded(&meta_default(), &[(&desc, data.as_slice())], &options).unwrap_err();
assert!(matches!(err, TensogramError::Metadata(_)));
}
}