use crate::encode::build_pipeline_config_with_backend;
use crate::error::{Result, TensogramError};
use crate::framing;
use crate::types::{DataObjectDescriptor, DecodedObject, GlobalMetadata};
use tensogram_encodings::pipeline;
fn extract_block_offsets(
params: &std::collections::BTreeMap<String, ciborium::Value>,
) -> Result<Vec<u64>> {
match params.get("szip_block_offsets") {
Some(ciborium::Value::Array(arr)) => arr
.iter()
.map(|v| match v {
ciborium::Value::Integer(i) => {
let n: i128 = (*i).into();
u64::try_from(n).map_err(|_| {
TensogramError::Metadata("szip_block_offset out of u64 range".to_string())
})
}
_ => Err(TensogramError::Metadata(
"szip_block_offsets must contain integers".to_string(),
)),
})
.collect(),
Some(_) => Err(TensogramError::Metadata(
"szip_block_offsets must be an array".to_string(),
)),
None => Err(TensogramError::Compression(
"missing szip_block_offsets in payload metadata (required for partial range decode)"
.to_string(),
)),
}
}
#[derive(Debug, Clone)]
pub struct DecodeOptions {
pub native_byte_order: bool,
pub compression_backend: pipeline::CompressionBackend,
pub threads: u32,
pub parallel_threshold_bytes: Option<usize>,
pub restore_non_finite: bool,
pub verify_hash: bool,
}
impl Default for DecodeOptions {
fn default() -> Self {
Self {
native_byte_order: true,
compression_backend: pipeline::CompressionBackend::default(),
threads: 0,
parallel_threshold_bytes: None,
restore_non_finite: true,
verify_hash: false,
}
}
}
fn verify_data_object_frames(buf: &[u8], target_index: Option<usize>) -> Result<()> {
use crate::wire::{
FRAME_HEADER_SIZE, FRAME_MAGIC, FrameHeader, POSTAMBLE_SIZE, PREAMBLE_SIZE, Preamble,
};
if buf.len() < PREAMBLE_SIZE {
return Err(TensogramError::Framing(format!(
"buffer too short for preamble while verifying hashes: {} < {PREAMBLE_SIZE}",
buf.len()
)));
}
let preamble = Preamble::read_from(buf)?;
let msg_end = if preamble.total_length > 0 {
let total_len = usize::try_from(preamble.total_length).map_err(|_| {
TensogramError::Framing(format!(
"preamble total_length ({}) overflows usize while verifying hashes",
preamble.total_length
))
})?;
if total_len > buf.len() {
return Err(TensogramError::Framing(format!(
"preamble total_length ({total_len}) exceeds buffer size ({}) \
while verifying hashes",
buf.len()
)));
}
total_len.saturating_sub(POSTAMBLE_SIZE)
} else {
buf.len().checked_sub(POSTAMBLE_SIZE).ok_or_else(|| {
TensogramError::Framing(format!(
"buffer too short for postamble while verifying hashes: {} < {POSTAMBLE_SIZE}",
buf.len()
))
})?
};
let mut pos = PREAMBLE_SIZE;
let mut object_index: usize = 0;
while pos + FRAME_HEADER_SIZE <= msg_end {
if &buf[pos..pos + FRAME_MAGIC.len()] != FRAME_MAGIC {
pos += 1;
continue;
}
let fh = FrameHeader::read_from(&buf[pos..])?;
let frame_total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Framing(format!(
"frame total_length ({}) overflows usize on this target \
while verifying hash at offset {pos}",
fh.total_length
))
})?;
let frame_end = pos.checked_add(frame_total).ok_or_else(|| {
TensogramError::Framing(format!(
"frame total_length overflow at offset {pos} while verifying hashes"
))
})?;
if frame_end > msg_end {
return Err(TensogramError::Framing(format!(
"frame at offset {pos} runs past first-message end \
({frame_end} > {msg_end}) while verifying hashes"
)));
}
if fh.frame_type.is_data_object() {
let should_check = target_index.is_none_or(|t| t == object_index);
if should_check {
let frame_bytes = &buf[pos..frame_end];
match crate::hash::check_frame_hash(frame_bytes, fh.frame_type) {
Ok(true) => {}
Ok(false) => {
return Err(TensogramError::MissingHash { object_index });
}
Err(e) => {
return Err(crate::error::with_object_index(e, object_index));
}
}
if target_index.is_some() {
return Ok(());
}
}
object_index += 1;
}
pos = (frame_end + 7) & !7;
}
Ok(())
}
#[tracing::instrument(skip(buf, options), fields(buf_len = buf.len()))]
pub fn decode(buf: &[u8], options: &DecodeOptions) -> Result<(GlobalMetadata, Vec<DecodedObject>)> {
if options.verify_hash {
verify_data_object_frames(buf, None)?;
}
let msg = framing::decode_message(buf)?;
let budget = crate::parallel::resolve_budget(options.threads)?;
let total_bytes: usize = msg.objects.iter().map(|(_, p, _, _)| p.len()).sum();
let parallel =
crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
let any_axis_b = msg.objects.iter().any(|(d, _, _, _)| {
crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression)
});
let use_axis_a = parallel && crate::parallel::use_axis_a(msg.objects.len(), budget, any_axis_b);
let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
let decode_one = |(desc, payload_bytes, mask_region, _offset): &(
DataObjectDescriptor,
&[u8],
&[u8],
usize,
)|
-> Result<DecodedObject> {
let mut decoded = decode_single_object_with_backend(
desc,
payload_bytes,
options,
options.compression_backend,
intra_codec_threads,
)?;
if options.restore_non_finite {
crate::restore::restore_non_finite_into(
&mut decoded,
desc,
mask_region,
output_byte_order(desc, options),
)?;
}
Ok((desc.clone(), decoded))
};
let data_objects: Vec<DecodedObject> = if use_axis_a {
#[cfg(feature = "threads")]
{
use rayon::prelude::*;
crate::parallel::with_pool(budget, || {
msg.objects
.par_iter()
.map(&decode_one)
.collect::<Result<Vec<_>>>()
})?
}
#[cfg(not(feature = "threads"))]
{
msg.objects.iter().map(decode_one).collect::<Result<_>>()?
}
} else {
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
msg.objects.iter().map(decode_one).collect::<Result<_>>()
})?
};
Ok((msg.global_metadata, data_objects))
}
pub fn decode_metadata(buf: &[u8]) -> Result<GlobalMetadata> {
framing::decode_metadata_only(buf)
}
pub fn decode_with_masks(
buf: &[u8],
options: &DecodeOptions,
) -> Result<(GlobalMetadata, Vec<DecodedObjectWithMasks>)> {
let msg = framing::decode_message(buf)?;
let budget = crate::parallel::resolve_budget(options.threads)?;
let total_bytes: usize = msg.objects.iter().map(|(_, p, _, _)| p.len()).sum();
let parallel =
crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
let intra_codec_threads = if parallel { budget } else { 0 };
let mut decode_opts = options.clone();
decode_opts.restore_non_finite = false;
let decode_one = |(desc, payload_bytes, mask_region, _offset): &(
DataObjectDescriptor,
&[u8],
&[u8],
usize,
)|
-> Result<DecodedObjectWithMasks> {
let payload = decode_single_object_with_backend(
desc,
payload_bytes,
&decode_opts,
options.compression_backend,
intra_codec_threads,
)?;
let masks = crate::restore::decode_mask_set(desc, mask_region)?;
Ok(DecodedObjectWithMasks {
descriptor: desc.clone(),
payload,
masks,
})
};
let objects: Vec<DecodedObjectWithMasks> =
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
msg.objects.iter().map(decode_one).collect::<Result<_>>()
})?;
Ok((msg.global_metadata, objects))
}
pub use crate::restore::{DecodedMaskSet, DecodedObjectWithMasks};
pub fn decode_descriptors(buf: &[u8]) -> Result<(GlobalMetadata, Vec<DataObjectDescriptor>)> {
let msg = framing::decode_message(buf)?;
let descriptors = msg
.objects
.into_iter()
.map(|(desc, _, _, _)| desc)
.collect();
Ok((msg.global_metadata, descriptors))
}
pub fn decode_object(
buf: &[u8],
index: usize,
options: &DecodeOptions,
) -> Result<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)> {
if options.verify_hash {
verify_data_object_frames(buf, Some(index))?;
}
let msg = framing::decode_message(buf)?;
if index >= msg.objects.len() {
return Err(TensogramError::Object(format!(
"object index {} out of range (num_objects={})",
index,
msg.objects.len()
)));
}
let (desc, payload_bytes, mask_region, _frame_offset) = &msg.objects[index];
let budget = crate::parallel::resolve_budget(options.threads)?;
let parallel = crate::parallel::should_parallelise(
budget,
payload_bytes.len(),
options.parallel_threshold_bytes,
);
let intra_codec_threads = if parallel { budget } else { 0 };
let mut decoded =
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
decode_single_object_with_backend(
desc,
payload_bytes,
options,
options.compression_backend,
intra_codec_threads,
)
})?;
if options.restore_non_finite {
crate::restore::restore_non_finite_into(
&mut decoded,
desc,
mask_region,
output_byte_order(desc, options),
)?;
}
Ok((msg.global_metadata, desc.clone(), decoded))
}
pub fn decode_object_from_frame(
frame_bytes: &[u8],
options: &DecodeOptions,
) -> Result<(DataObjectDescriptor, Vec<u8>)> {
if options.verify_hash {
use crate::wire::{FRAME_HEADER_SIZE, FrameHeader};
let fh = FrameHeader::read_from(frame_bytes)?;
let total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Framing(format!(
"frame total_length ({}) overflows usize on this target; \
cannot verify hash on decode_object_from_frame",
fh.total_length
))
})?;
if total < FRAME_HEADER_SIZE || total > frame_bytes.len() {
return Err(TensogramError::Framing(format!(
"frame total_length ({}) outside bounds of supplied frame buffer \
({} bytes); cannot verify hash on decode_object_from_frame",
fh.total_length,
frame_bytes.len()
)));
}
if !crate::hash::check_frame_hash(&frame_bytes[..total], fh.frame_type)? {
return Err(TensogramError::MissingHash { object_index: 0 });
}
}
let (desc, payload_bytes, mask_region, _) = framing::decode_data_object_frame(frame_bytes)?;
let budget = crate::parallel::resolve_budget(options.threads)?;
let parallel = crate::parallel::should_parallelise(
budget,
payload_bytes.len(),
options.parallel_threshold_bytes,
);
let intra_codec_threads = if parallel { budget } else { 0 };
let mut decoded =
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
decode_single_object_with_backend(
&desc,
payload_bytes,
options,
options.compression_backend,
intra_codec_threads,
)
})?;
if options.restore_non_finite {
crate::restore::restore_non_finite_into(
&mut decoded,
&desc,
mask_region,
output_byte_order(&desc, options),
)?;
}
Ok((desc, decoded))
}
pub fn decode_range_from_frame(
frame_bytes: &[u8],
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<(DataObjectDescriptor, Vec<Vec<u8>>)> {
let (desc, payload_bytes, mask_region, _) = framing::decode_data_object_frame(frame_bytes)?;
let mut parts = decode_range_from_payload(&desc, payload_bytes, ranges, options)?;
if options.restore_non_finite && desc.masks.is_some() {
let mask_set = crate::restore::decode_mask_set(&desc, mask_region)?;
crate::restore::restore_non_finite_into_ranges(
&mut parts,
&desc,
ranges,
&mask_set,
output_byte_order(&desc, options),
)?;
}
Ok((desc, parts))
}
pub fn decode_range(
buf: &[u8],
object_index: usize,
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<(DataObjectDescriptor, Vec<Vec<u8>>)> {
let msg = framing::decode_message(buf)?;
if object_index >= msg.objects.len() {
return Err(TensogramError::Object(format!(
"object index {} out of range (num_objects={})",
object_index,
msg.objects.len()
)));
}
let (desc, payload_bytes, mask_region, _) = &msg.objects[object_index];
let mut parts = decode_range_from_payload(desc, payload_bytes, ranges, options)?;
if options.restore_non_finite && desc.masks.is_some() {
let mask_set = crate::restore::decode_mask_set(desc, mask_region)?;
crate::restore::restore_non_finite_into_ranges(
&mut parts,
desc,
ranges,
&mask_set,
output_byte_order(desc, options),
)?;
}
Ok((desc.clone(), parts))
}
fn output_byte_order(
desc: &DataObjectDescriptor,
options: &DecodeOptions,
) -> tensogram_encodings::ByteOrder {
if options.native_byte_order {
tensogram_encodings::ByteOrder::native()
} else {
desc.byte_order
}
}
pub fn decode_range_from_payload(
desc: &DataObjectDescriptor,
payload_bytes: &[u8],
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<Vec<Vec<u8>>> {
if desc.filter != "none" {
return Err(TensogramError::Encoding(
"decode_range is not supported when a filter (e.g. shuffle) is applied".to_string(),
));
}
if desc.dtype.byte_width() == 0 {
return Err(TensogramError::Encoding(
"partial range decode not supported for bitmask dtype".to_string(),
));
}
let num_elements = desc.num_elements()?;
let budget = crate::parallel::resolve_budget(options.threads)?;
let elem_bytes = desc.dtype.byte_width();
let total_bytes: usize = ranges
.iter()
.map(|(_, c)| (*c as usize).saturating_mul(elem_bytes))
.sum();
let parallel =
crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
let axis_b_friendly =
crate::parallel::is_axis_b_friendly(&desc.encoding, &desc.filter, &desc.compression);
let use_axis_a = parallel && crate::parallel::use_axis_a(ranges.len(), budget, axis_b_friendly);
let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
let config = build_pipeline_config_with_backend(
desc,
num_elements,
desc.dtype,
options.compression_backend,
intra_codec_threads,
)?;
let block_offsets = if desc.compression == "szip" {
extract_block_offsets(&desc.params)?
} else {
Vec::new()
};
let decode_one = |offset: u64, count: u64| -> Result<Vec<u8>> {
pipeline::decode_range_pipeline(
payload_bytes,
&config,
&block_offsets,
offset,
count,
options.native_byte_order,
)
.map_err(|e| {
TensogramError::Encoding(format!("range (offset={offset}, count={count}): {e}"))
})
};
let run_seq = || -> Result<Vec<Vec<u8>>> {
ranges
.iter()
.map(|&(offset, count)| decode_one(offset, count))
.collect()
};
let results: Vec<Vec<u8>> = if use_axis_a {
#[cfg(feature = "threads")]
{
use rayon::prelude::*;
crate::parallel::with_pool(budget, || {
ranges
.par_iter()
.map(|&(offset, count)| decode_one(offset, count))
.collect::<Result<Vec<_>>>()
})?
}
#[cfg(not(feature = "threads"))]
{
run_seq()?
}
} else {
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, run_seq)?
};
Ok(results)
}
fn decode_single_object_with_backend(
desc: &DataObjectDescriptor,
payload_bytes: &[u8],
options: &DecodeOptions,
backend: pipeline::CompressionBackend,
intra_codec_threads: u32,
) -> Result<Vec<u8>> {
let num_elements = desc.num_elements()?;
let config = build_pipeline_config_with_backend(
desc,
num_elements,
desc.dtype,
backend,
intra_codec_threads,
)?;
let decoded = pipeline::decode_pipeline(payload_bytes, &config, options.native_byte_order)
.map_err(|e| TensogramError::Encoding(e.to_string()))?;
Ok(decoded)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::Dtype;
use crate::encode::{EncodeOptions, encode};
use crate::types::ByteOrder;
use std::collections::BTreeMap;
fn make_global_meta() -> GlobalMetadata {
GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
}
}
fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
let strides = if shape.is_empty() {
vec![]
} else {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
}
}
#[test]
fn test_decode_corrupt_message_bytes() {
let garbage = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x01, 0x02, 0x03];
let result = decode(&garbage, &DecodeOptions::default());
assert!(result.is_err(), "decoding garbage should fail");
}
#[test]
fn test_decode_truncated_message() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let truncated = &encoded[..encoded.len() / 2];
let result = decode(truncated, &DecodeOptions::default());
assert!(result.is_err(), "decoding truncated message should fail");
}
#[test]
fn test_decode_corrupted_cbor_in_message() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![42u8; 16];
let mut encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let cbor_start = 40;
let corrupt_end = (cbor_start + 30).min(encoded.len());
for byte in &mut encoded[cbor_start..corrupt_end] {
*byte = 0xFF;
}
let result = decode(&encoded, &DecodeOptions::default());
assert!(result.is_err(), "decoding corrupted CBOR should fail");
}
#[test]
fn test_decode_object_index_out_of_range() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_object(&encoded, 1, &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("out of range"),
"expected 'out of range', got: {msg}"
);
let result = decode_object(&encoded, 999, &DecodeOptions::default());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("out of range"));
}
#[test]
fn test_decode_object_valid_index() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![2]);
let data0 = vec![10u8; 8];
let desc1 = make_descriptor(vec![3]);
let data1 = vec![20u8; 12];
let encoded = encode(
&meta,
&[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let (_, ret_desc, ret_data) =
decode_object(&encoded, 0, &DecodeOptions::default()).unwrap();
assert_eq!(ret_desc.shape, vec![2]);
assert_eq!(ret_data, data0);
let (_, ret_desc, ret_data) =
decode_object(&encoded, 1, &DecodeOptions::default()).unwrap();
assert_eq!(ret_desc.shape, vec![3]);
assert_eq!(ret_data, data1);
}
#[test]
fn test_decode_range_object_index_out_of_range() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 5, &[(0, 2)], &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("out of range"),
"expected 'out of range', got: {msg}"
);
}
#[test]
fn test_decode_range_exceeds_payload() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]); let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 0, &[(2, 10)], &DecodeOptions::default());
assert!(result.is_err(), "range exceeding payload should fail");
}
#[test]
fn test_decode_range_valid() {
let meta = make_global_meta();
let desc = make_descriptor(vec![8]); let data: Vec<u8> = (0..32).collect();
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (ret_desc, parts) =
decode_range(&encoded, 0, &[(0, 4)], &DecodeOptions::default()).unwrap();
assert_eq!(ret_desc.shape, vec![8]);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0].len(), 16); }
#[test]
fn test_decode_range_empty_ranges() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, parts) = decode_range(&encoded, 0, &[], &DecodeOptions::default()).unwrap();
assert!(parts.is_empty());
}
#[test]
fn test_decode_metadata_valid() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let _decoded_meta = decode_metadata(&encoded).unwrap();
}
#[test]
fn test_decode_metadata_corrupt() {
let garbage = vec![0xFF; 50];
let result = decode_metadata(&garbage);
assert!(result.is_err(), "decode_metadata on garbage should fail");
}
#[test]
fn test_decode_descriptors_valid() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![4]);
let desc1 = make_descriptor(vec![2, 3]);
let data0 = vec![0u8; 16];
let data1 = vec![0u8; 24];
let encoded = encode(
&meta,
&[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let (_decoded_meta, descs) = decode_descriptors(&encoded).unwrap();
assert_eq!(descs.len(), 2);
assert_eq!(descs[0].shape, vec![4]);
assert_eq!(descs[1].shape, vec![2, 3]);
}
#[test]
fn test_decode_range_filter_shuffle_rejected() {
let meta = make_global_meta();
let mut desc = make_descriptor(vec![100]);
desc.filter = "shuffle".to_string();
desc.params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
let data: Vec<u8> = (0..100).flat_map(|i| (i as f32).to_ne_bytes()).collect();
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("filter") || msg.contains("shuffle"),
"expected filter/shuffle error, got: {msg}"
);
}
#[test]
fn test_decode_range_bitmask_dtype_rejected() {
let meta = make_global_meta();
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![16],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0xFF; 2];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 0, &[(0, 8)], &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("bitmask"),
"expected bitmask error, got: {msg}"
);
}
#[test]
fn test_decode_options_defaults() {
let opts = DecodeOptions::default();
assert!(opts.native_byte_order);
assert!(opts.restore_non_finite);
}
#[test]
fn test_decode_unknown_encoding_in_descriptor() {
let mut desc = make_descriptor(vec![4]);
desc.encoding = "foobar".to_string();
let result = crate::encode::build_pipeline_config_with_backend(
&desc,
4,
Dtype::Float32,
pipeline::CompressionBackend::default(),
0,
);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("unknown encoding"),
"expected 'unknown encoding', got: {msg}"
);
}
#[test]
fn test_decode_unknown_compression_in_descriptor() {
let mut desc = make_descriptor(vec![4]);
desc.compression = "quantum_compress".to_string();
let result = crate::encode::build_pipeline_config_with_backend(
&desc,
4,
Dtype::Float32,
pipeline::CompressionBackend::default(),
0,
);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("unknown compression"),
"expected 'unknown compression', got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_missing() {
let params = BTreeMap::new();
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("szip_block_offsets"),
"expected szip_block_offsets error, got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_wrong_type() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Text("not an array".to_string()),
);
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("must be an array"),
"expected 'must be an array', got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_non_integer_elements() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![
ciborium::Value::Float(1.5), ]),
);
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("integers"),
"expected integers error, got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_valid() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![
ciborium::Value::Integer(0.into()),
ciborium::Value::Integer(100.into()),
ciborium::Value::Integer(200.into()),
]),
);
let result = extract_block_offsets(¶ms).unwrap();
assert_eq!(result, vec![0, 100, 200]);
}
#[test]
fn test_extract_block_offsets_negative_out_of_u64_range() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![ciborium::Value::Integer((-1).into())]),
);
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("out of u64 range"),
"expected 'out of u64 range', got: {msg}"
);
}
fn opts_verify() -> DecodeOptions {
DecodeOptions {
verify_hash: true,
..Default::default()
}
}
fn encode_one(hashing: bool) -> Vec<u8> {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![5u8; 16];
let opts = EncodeOptions {
hashing,
..Default::default()
};
encode(&meta, &[(&desc, &data)], &opts).unwrap()
}
#[test]
fn test_decode_verify_hash_succeeds_on_hashed_message() {
let msg = encode_one(true);
let (_, objs) = decode(&msg, &opts_verify()).unwrap();
assert_eq!(objs.len(), 1);
}
#[test]
fn test_decode_verify_hash_missing_hash_on_unhashed_message() {
let msg = encode_one(false);
let err = decode(&msg, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::MissingHash { object_index: 0 }),
"expected MissingHash{{0}}, got: {err:?}"
);
}
#[test]
fn test_decode_verify_hash_mismatch_on_tampered_payload() {
let mut msg = encode_one(true);
let dm = framing::decode_message(&msg).unwrap();
let frame_off = dm.objects[0].3;
msg[frame_off + crate::wire::FRAME_HEADER_SIZE] ^= 0xFF;
let err = decode(&msg, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::HashMismatch { .. }),
"expected HashMismatch, got: {err:?}"
);
}
#[test]
fn test_decode_verify_hash_buffer_too_short_for_preamble() {
let buf = vec![0u8; 10];
let err = decode(&buf, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::Framing(_)),
"expected Framing error, got: {err:?}"
);
assert!(err.to_string().contains("too short for preamble"));
}
#[test]
fn test_decode_verify_hash_total_length_exceeds_buffer() {
let msg = encode_one(true);
let truncated = &msg[..msg.len() - 16];
let err = decode(truncated, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::Framing(_)),
"expected Framing error, got: {err:?}"
);
assert!(err.to_string().contains("exceeds buffer"), "got: {}", err);
}
#[test]
fn test_decode_verify_hash_frame_runs_past_message_end() {
let mut msg = encode_one(true);
let msg_end_target = crate::wire::PREAMBLE_SIZE + crate::wire::FRAME_HEADER_SIZE;
let shrunk = (msg_end_target + crate::wire::POSTAMBLE_SIZE) as u64;
msg[16..24].copy_from_slice(&shrunk.to_be_bytes());
let err = decode(&msg, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::Framing(_)),
"expected Framing error, got: {err:?}"
);
assert!(
err.to_string().contains("runs past first-message end"),
"got: {err}"
);
}
#[test]
fn test_decode_verify_hash_skips_non_frame_bytes() {
use crate::wire::{MessageFlags, PREAMBLE_SIZE, Postamble, Preamble, WIRE_VERSION};
let mut out = Vec::new();
out.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
out.extend(std::iter::repeat_n(0u8, 32));
let postamble = Postamble {
first_footer_offset: 0,
total_length: 0,
};
postamble.write_to(&mut out);
let preamble = Preamble {
version: WIRE_VERSION,
flags: MessageFlags::default(),
reserved: 0,
total_length: 0,
};
let mut pb = Vec::new();
preamble.write_to(&mut pb);
out[0..PREAMBLE_SIZE].copy_from_slice(&pb);
assert!(verify_data_object_frames(&out, None).is_ok());
}
#[test]
fn test_decode_object_verify_hash_pre_pass() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![2]);
let desc1 = make_descriptor(vec![3]);
let data0 = vec![1u8; 8];
let data1 = vec![2u8; 12];
let opts = EncodeOptions {
hashing: true,
..Default::default()
};
let msg = encode(
&meta,
&[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
&opts,
)
.unwrap();
let (_, ret_desc, ret_data) = decode_object(&msg, 1, &opts_verify()).unwrap();
assert_eq!(ret_desc.shape, vec![3]);
assert_eq!(ret_data, data1);
}
fn encode_nan_inf_object(hashing: bool) -> (Vec<u8>, Vec<f64>) {
let values: Vec<f64> = vec![
1.0,
f64::NAN,
3.0,
f64::INFINITY,
f64::NEG_INFINITY,
6.0,
7.0,
8.0,
];
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![8],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let enc_opts = EncodeOptions {
allow_nan: true,
allow_inf: true,
hashing,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&make_global_meta(), &[(&desc, &data)], &enc_opts).unwrap();
(msg, values)
}
fn extract_object_frame(msg: &[u8], i: usize) -> Vec<u8> {
let dm = framing::decode_message(msg).unwrap();
let idx = dm.index.as_ref().expect("message must carry an index");
let off = idx.offsets[i] as usize;
let len = idx.lengths[i] as usize;
msg[off..off + len].to_vec()
}
#[test]
fn test_decode_object_restores_non_finite_masks() {
let (msg, _) = encode_nan_inf_object(false);
let (_, _desc, decoded) = decode_object(&msg, 0, &DecodeOptions::default()).unwrap();
let vals: Vec<f64> = decoded
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(vals[0], 1.0);
assert!(vals[1].is_nan(), "NaN must be restored");
assert!(vals[3].is_infinite() && vals[3] > 0.0, "+Inf restored");
assert!(vals[4].is_infinite() && vals[4] < 0.0, "-Inf restored");
}
#[test]
fn test_decode_object_from_frame_with_masks_restore() {
let (msg, _) = encode_nan_inf_object(false);
let frame = extract_object_frame(&msg, 0);
let (_desc, decoded) = decode_object_from_frame(&frame, &DecodeOptions::default()).unwrap();
let vals: Vec<f64> = decoded
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert!(vals[1].is_nan());
assert!(vals[3].is_infinite() && vals[3] > 0.0);
}
#[test]
fn test_decode_object_from_frame_verify_hash_succeeds() {
let (msg, _) = encode_nan_inf_object(true);
let frame = extract_object_frame(&msg, 0);
let (_desc, decoded) = decode_object_from_frame(&frame, &opts_verify()).unwrap();
assert!(!decoded.is_empty());
}
#[test]
fn test_decode_object_from_frame_verify_hash_missing_on_unhashed() {
let (msg, _) = encode_nan_inf_object(false);
let frame = extract_object_frame(&msg, 0);
let err = decode_object_from_frame(&frame, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::MissingHash { object_index: 0 }),
"expected MissingHash{{0}}, got: {err:?}"
);
}
#[test]
fn test_decode_non_native_byte_order_path() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let opts = DecodeOptions {
native_byte_order: false,
..Default::default()
};
let (_, objs) = decode(&encoded, &opts).unwrap();
assert_eq!(objs.len(), 1);
}
#[test]
fn test_decode_object_from_frame_verify_hash_total_out_of_bounds() {
let (msg, _) = encode_nan_inf_object(true);
let frame = extract_object_frame(&msg, 0);
let short = &frame[..frame.len() - 8];
let err = decode_object_from_frame(short, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::Framing(_)),
"expected Framing error, got: {err:?}"
);
assert!(err.to_string().contains("outside bounds"));
}
#[test]
fn test_decode_range_from_frame_restores_masks() {
let (msg, _) = encode_nan_inf_object(false);
let frame = extract_object_frame(&msg, 0);
let (ret_desc, parts) =
decode_range_from_frame(&frame, &[(0, 5)], &DecodeOptions::default()).unwrap();
assert!(ret_desc.masks.is_some());
let vals: Vec<f64> = parts[0]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert!(vals[1].is_nan());
assert!(vals[3].is_infinite() && vals[3] > 0.0);
assert!(vals[4].is_infinite() && vals[4] < 0.0);
}
#[test]
fn test_decode_range_restores_masks() {
let (msg, _) = encode_nan_inf_object(false);
let (ret_desc, parts) =
decode_range(&msg, 0, &[(0, 5)], &DecodeOptions::default()).unwrap();
assert!(ret_desc.masks.is_some());
let vals: Vec<f64> = parts[0]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert!(vals[1].is_nan());
assert!(vals[3].is_infinite() && vals[3] > 0.0);
}
#[test]
fn test_decode_axis_a_parallel_multi_object() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![4]);
let desc1 = make_descriptor(vec![4]);
let data0 = vec![11u8; 16];
let data1 = vec![22u8; 16];
let encoded = encode(
&meta,
&[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let opts = DecodeOptions {
threads: 4,
parallel_threshold_bytes: Some(0),
..Default::default()
};
let (_, objs) = decode(&encoded, &opts).unwrap();
assert_eq!(objs.len(), 2);
assert_eq!(objs[0].1, data0);
assert_eq!(objs[1].1, data1);
}
#[test]
fn test_decode_range_axis_a_parallel_multi_range() {
let meta = make_global_meta();
let desc = make_descriptor(vec![16]);
let data: Vec<u8> = (0..64).collect();
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let opts = DecodeOptions {
threads: 4,
parallel_threshold_bytes: Some(0),
..Default::default()
};
let (_, parts) = decode_range(&encoded, 0, &[(0, 4), (4, 4), (8, 4)], &opts).unwrap();
assert_eq!(parts.len(), 3);
assert!(parts.iter().all(|p| p.len() == 16));
}
#[cfg(feature = "szip")]
#[test]
fn test_decode_range_szip_extracts_block_offsets() {
let meta = make_global_meta();
let mut desc = make_descriptor(vec![256]);
desc.compression = "szip".to_string();
let data: Vec<u8> = (0..256).flat_map(|i| (i as f32).to_ne_bytes()).collect();
let encoded = match encode(&meta, &[(&desc, &data)], &EncodeOptions::default()) {
Ok(e) => e,
Err(_) => return,
};
let (_, parts) = decode_range(&encoded, 0, &[(0, 8)], &DecodeOptions::default()).unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0].len(), 32); }
#[test]
fn test_make_descriptor_scalar_empty_shape() {
let meta = make_global_meta();
let desc = make_descriptor(vec![]);
assert!(desc.strides.is_empty());
let data = vec![0u8; 4]; let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objs) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objs.len(), 1);
}
#[test]
fn test_decode_with_masks_returns_raw_masks() {
let (msg, _) = encode_nan_inf_object(false);
let (_, objs) = decode_with_masks(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objs.len(), 1);
let vals: Vec<f64> = objs[0]
.payload
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(
vals[1], 0.0,
"NaN should be 0.0 in decode_with_masks payload"
);
assert!(objs[0].masks.nan.is_some(), "NaN mask must be present");
assert!(objs[0].masks.pos_inf.is_some(), "+Inf mask must be present");
assert!(objs[0].masks.neg_inf.is_some(), "-Inf mask must be present");
}
}