use crate::encode::build_pipeline_config_with_backend;
use crate::error::{Result, TensogramError};
use crate::framing;
use crate::hash;
use crate::types::{DataObjectDescriptor, DecodedObject, GlobalMetadata};
use tensogram_encodings::pipeline;
fn extract_block_offsets(
params: &std::collections::BTreeMap<String, ciborium::Value>,
) -> Result<Vec<u64>> {
match params.get("szip_block_offsets") {
Some(ciborium::Value::Array(arr)) => arr
.iter()
.map(|v| match v {
ciborium::Value::Integer(i) => {
let n: i128 = (*i).into();
u64::try_from(n).map_err(|_| {
TensogramError::Metadata("szip_block_offset out of u64 range".to_string())
})
}
_ => Err(TensogramError::Metadata(
"szip_block_offsets must contain integers".to_string(),
)),
})
.collect(),
Some(_) => Err(TensogramError::Metadata(
"szip_block_offsets must be an array".to_string(),
)),
None => Err(TensogramError::Compression(
"missing szip_block_offsets in payload metadata (required for partial range decode)"
.to_string(),
)),
}
}
#[derive(Debug, Clone)]
pub struct DecodeOptions {
pub verify_hash: bool,
pub native_byte_order: bool,
pub compression_backend: pipeline::CompressionBackend,
pub threads: u32,
pub parallel_threshold_bytes: Option<usize>,
}
impl Default for DecodeOptions {
fn default() -> Self {
Self {
verify_hash: false,
native_byte_order: true,
compression_backend: pipeline::CompressionBackend::default(),
threads: 0,
parallel_threshold_bytes: None,
}
}
}
#[tracing::instrument(skip(buf, options), fields(buf_len = buf.len()))]
pub fn decode(buf: &[u8], options: &DecodeOptions) -> Result<(GlobalMetadata, Vec<DecodedObject>)> {
let msg = framing::decode_message(buf)?;
let budget = crate::parallel::resolve_budget(options.threads);
let total_bytes: usize = msg.objects.iter().map(|(_, p, _)| p.len()).sum();
let parallel =
crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
let any_axis_b = msg.objects.iter().any(|(d, _, _)| {
crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression)
});
let use_axis_a = parallel && crate::parallel::use_axis_a(msg.objects.len(), budget, any_axis_b);
let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
let decode_one = |(desc, payload_bytes, _offset): &(DataObjectDescriptor, &[u8], usize)|
-> Result<DecodedObject> {
let decoded = decode_single_object_with_backend(
desc,
payload_bytes,
options,
options.compression_backend,
intra_codec_threads,
)?;
Ok((desc.clone(), decoded))
};
let data_objects: Vec<DecodedObject> = if use_axis_a {
#[cfg(feature = "threads")]
{
use rayon::prelude::*;
crate::parallel::with_pool(budget, || {
msg.objects
.par_iter()
.map(&decode_one)
.collect::<Result<Vec<_>>>()
})?
}
#[cfg(not(feature = "threads"))]
{
msg.objects.iter().map(decode_one).collect::<Result<_>>()?
}
} else {
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
msg.objects.iter().map(decode_one).collect::<Result<_>>()
})?
};
Ok((msg.global_metadata, data_objects))
}
pub fn decode_metadata(buf: &[u8]) -> Result<GlobalMetadata> {
framing::decode_metadata_only(buf)
}
pub fn decode_descriptors(buf: &[u8]) -> Result<(GlobalMetadata, Vec<DataObjectDescriptor>)> {
let msg = framing::decode_message(buf)?;
let descriptors = msg.objects.into_iter().map(|(desc, _, _)| desc).collect();
Ok((msg.global_metadata, descriptors))
}
pub fn decode_object(
buf: &[u8],
index: usize,
options: &DecodeOptions,
) -> Result<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)> {
let msg = framing::decode_message(buf)?;
if index >= msg.objects.len() {
return Err(TensogramError::Object(format!(
"object index {} out of range (num_objects={})",
index,
msg.objects.len()
)));
}
let (desc, payload_bytes, _) = &msg.objects[index];
let budget = crate::parallel::resolve_budget(options.threads);
let parallel = crate::parallel::should_parallelise(
budget,
payload_bytes.len(),
options.parallel_threshold_bytes,
);
let intra_codec_threads = if parallel { budget } else { 0 };
let decoded = crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
decode_single_object_with_backend(
desc,
payload_bytes,
options,
options.compression_backend,
intra_codec_threads,
)
})?;
Ok((msg.global_metadata, desc.clone(), decoded))
}
pub fn decode_range(
buf: &[u8],
object_index: usize,
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<(DataObjectDescriptor, Vec<Vec<u8>>)> {
let msg = framing::decode_message(buf)?;
if object_index >= msg.objects.len() {
return Err(TensogramError::Object(format!(
"object index {} out of range (num_objects={})",
object_index,
msg.objects.len()
)));
}
let (desc, payload_bytes, _) = &msg.objects[object_index];
let parts = decode_range_from_payload(desc, payload_bytes, ranges, options)?;
Ok((desc.clone(), parts))
}
pub fn decode_range_from_payload(
desc: &DataObjectDescriptor,
payload_bytes: &[u8],
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<Vec<Vec<u8>>> {
if desc.filter != "none" {
return Err(TensogramError::Encoding(
"decode_range is not supported when a filter (e.g. shuffle) is applied".to_string(),
));
}
if desc.dtype.byte_width() == 0 {
return Err(TensogramError::Encoding(
"partial range decode not supported for bitmask dtype".to_string(),
));
}
if options.verify_hash
&& let Some(ref hash_desc) = desc.hash
{
hash::verify_hash(payload_bytes, hash_desc)?;
}
let shape_product = desc
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
let num_elements = usize::try_from(shape_product)
.map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
let budget = crate::parallel::resolve_budget(options.threads);
let elem_bytes = desc.dtype.byte_width();
let total_bytes: usize = ranges
.iter()
.map(|(_, c)| (*c as usize).saturating_mul(elem_bytes))
.sum();
let parallel =
crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
let axis_b_friendly =
crate::parallel::is_axis_b_friendly(&desc.encoding, &desc.filter, &desc.compression);
let use_axis_a = parallel && crate::parallel::use_axis_a(ranges.len(), budget, axis_b_friendly);
let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
let config = build_pipeline_config_with_backend(
desc,
num_elements,
desc.dtype,
options.compression_backend,
intra_codec_threads,
)?;
let block_offsets = if desc.compression == "szip" {
extract_block_offsets(&desc.params)?
} else {
Vec::new()
};
let decode_one = |offset: u64, count: u64| -> Result<Vec<u8>> {
pipeline::decode_range_pipeline(
payload_bytes,
&config,
&block_offsets,
offset,
count,
options.native_byte_order,
)
.map_err(|e| {
TensogramError::Encoding(format!("range (offset={offset}, count={count}): {e}"))
})
};
let run_seq = || -> Result<Vec<Vec<u8>>> {
ranges
.iter()
.map(|&(offset, count)| decode_one(offset, count))
.collect()
};
let results: Vec<Vec<u8>> = if use_axis_a {
#[cfg(feature = "threads")]
{
use rayon::prelude::*;
crate::parallel::with_pool(budget, || {
ranges
.par_iter()
.map(|&(offset, count)| decode_one(offset, count))
.collect::<Result<Vec<_>>>()
})?
}
#[cfg(not(feature = "threads"))]
{
run_seq()?
}
} else {
crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, run_seq)?
};
Ok(results)
}
#[cfg(feature = "remote")]
pub(crate) fn decode_single_object(
desc: &DataObjectDescriptor,
payload_bytes: &[u8],
options: &DecodeOptions,
) -> Result<Vec<u8>> {
decode_single_object_with_backend(desc, payload_bytes, options, options.compression_backend, 0)
}
fn decode_single_object_with_backend(
desc: &DataObjectDescriptor,
payload_bytes: &[u8],
options: &DecodeOptions,
backend: pipeline::CompressionBackend,
intra_codec_threads: u32,
) -> Result<Vec<u8>> {
if options.verify_hash
&& let Some(ref hash_desc) = desc.hash
{
hash::verify_hash(payload_bytes, hash_desc)?;
}
let shape_product = desc
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
let num_elements = usize::try_from(shape_product)
.map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
let config = build_pipeline_config_with_backend(
desc,
num_elements,
desc.dtype,
backend,
intra_codec_threads,
)?;
let decoded = pipeline::decode_pipeline(payload_bytes, &config, options.native_byte_order)
.map_err(|e| TensogramError::Encoding(e.to_string()))?;
Ok(decoded)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::Dtype;
use crate::encode::{EncodeOptions, encode};
use crate::types::ByteOrder;
use std::collections::BTreeMap;
fn make_global_meta() -> GlobalMetadata {
GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
}
}
fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
let strides = if shape.is_empty() {
vec![]
} else {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
}
}
#[test]
fn test_decode_corrupt_message_bytes() {
let garbage = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x01, 0x02, 0x03];
let result = decode(&garbage, &DecodeOptions::default());
assert!(result.is_err(), "decoding garbage should fail");
}
#[test]
fn test_decode_truncated_message() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let truncated = &encoded[..encoded.len() / 2];
let result = decode(truncated, &DecodeOptions::default());
assert!(result.is_err(), "decoding truncated message should fail");
}
#[test]
fn test_decode_corrupted_cbor_in_message() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![42u8; 16];
let mut encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let cbor_start = 40;
let corrupt_end = (cbor_start + 30).min(encoded.len());
for byte in &mut encoded[cbor_start..corrupt_end] {
*byte = 0xFF;
}
let result = decode(&encoded, &DecodeOptions::default());
assert!(result.is_err(), "decoding corrupted CBOR should fail");
}
#[test]
fn test_decode_object_index_out_of_range() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_object(&encoded, 1, &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("out of range"),
"expected 'out of range', got: {msg}"
);
let result = decode_object(&encoded, 999, &DecodeOptions::default());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("out of range"));
}
#[test]
fn test_decode_object_valid_index() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![2]);
let data0 = vec![10u8; 8];
let desc1 = make_descriptor(vec![3]);
let data1 = vec![20u8; 12];
let encoded = encode(
&meta,
&[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let (_, ret_desc, ret_data) =
decode_object(&encoded, 0, &DecodeOptions::default()).unwrap();
assert_eq!(ret_desc.shape, vec![2]);
assert_eq!(ret_data, data0);
let (_, ret_desc, ret_data) =
decode_object(&encoded, 1, &DecodeOptions::default()).unwrap();
assert_eq!(ret_desc.shape, vec![3]);
assert_eq!(ret_data, data1);
}
#[test]
fn test_decode_range_object_index_out_of_range() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 5, &[(0, 2)], &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("out of range"),
"expected 'out of range', got: {msg}"
);
}
#[test]
fn test_decode_range_exceeds_payload() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]); let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 0, &[(2, 10)], &DecodeOptions::default());
assert!(result.is_err(), "range exceeding payload should fail");
}
#[test]
fn test_decode_range_valid() {
let meta = make_global_meta();
let desc = make_descriptor(vec![8]); let data: Vec<u8> = (0..32).collect();
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (ret_desc, parts) =
decode_range(&encoded, 0, &[(0, 4)], &DecodeOptions::default()).unwrap();
assert_eq!(ret_desc.shape, vec![8]);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0].len(), 16); }
#[test]
fn test_decode_range_empty_ranges() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, parts) = decode_range(&encoded, 0, &[], &DecodeOptions::default()).unwrap();
assert!(parts.is_empty());
}
#[test]
fn test_decode_metadata_valid() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 16];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let decoded_meta = decode_metadata(&encoded).unwrap();
assert_eq!(decoded_meta.version, 2);
}
#[test]
fn test_decode_metadata_corrupt() {
let garbage = vec![0xFF; 50];
let result = decode_metadata(&garbage);
assert!(result.is_err(), "decode_metadata on garbage should fail");
}
#[test]
fn test_decode_descriptors_valid() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![4]);
let desc1 = make_descriptor(vec![2, 3]);
let data0 = vec![0u8; 16];
let data1 = vec![0u8; 24];
let encoded = encode(
&meta,
&[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let (decoded_meta, descs) = decode_descriptors(&encoded).unwrap();
assert_eq!(decoded_meta.version, 2);
assert_eq!(descs.len(), 2);
assert_eq!(descs[0].shape, vec![4]);
assert_eq!(descs[1].shape, vec![2, 3]);
}
#[test]
fn test_decode_range_filter_shuffle_rejected() {
let meta = make_global_meta();
let mut desc = make_descriptor(vec![100]);
desc.filter = "shuffle".to_string();
desc.params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
let data: Vec<u8> = (0..400).map(|i| (i % 256) as u8).collect();
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("filter") || msg.contains("shuffle"),
"expected filter/shuffle error, got: {msg}"
);
}
#[test]
fn test_decode_range_bitmask_dtype_rejected() {
let meta = make_global_meta();
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![16],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let data = vec![0xFF; 2];
let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let result = decode_range(&encoded, 0, &[(0, 8)], &DecodeOptions::default());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("bitmask"),
"expected bitmask error, got: {msg}"
);
}
#[test]
fn test_decode_options_defaults() {
let opts = DecodeOptions::default();
assert!(!opts.verify_hash);
assert!(opts.native_byte_order);
}
#[test]
fn test_decode_unknown_encoding_in_descriptor() {
let mut desc = make_descriptor(vec![4]);
desc.encoding = "foobar".to_string();
let result = crate::encode::build_pipeline_config_with_backend(
&desc,
4,
Dtype::Float32,
pipeline::CompressionBackend::default(),
0,
);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("unknown encoding"),
"expected 'unknown encoding', got: {msg}"
);
}
#[test]
fn test_decode_unknown_compression_in_descriptor() {
let mut desc = make_descriptor(vec![4]);
desc.compression = "quantum_compress".to_string();
let result = crate::encode::build_pipeline_config_with_backend(
&desc,
4,
Dtype::Float32,
pipeline::CompressionBackend::default(),
0,
);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("unknown compression"),
"expected 'unknown compression', got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_missing() {
let params = BTreeMap::new();
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("szip_block_offsets"),
"expected szip_block_offsets error, got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_wrong_type() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Text("not an array".to_string()),
);
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("must be an array"),
"expected 'must be an array', got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_non_integer_elements() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![
ciborium::Value::Float(1.5), ]),
);
let result = extract_block_offsets(¶ms);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("integers"),
"expected integers error, got: {msg}"
);
}
#[test]
fn test_extract_block_offsets_valid() {
let mut params = BTreeMap::new();
params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(vec![
ciborium::Value::Integer(0.into()),
ciborium::Value::Integer(100.into()),
ciborium::Value::Integer(200.into()),
]),
);
let result = extract_block_offsets(¶ms).unwrap();
assert_eq!(result, vec![0, 100, 200]);
}
}