use std::collections::BTreeMap;
use tensogram::*;
use tensogram_encodings::simple_packing;
fn make_float32_descriptor(shape: Vec<u64>) -> (GlobalMetadata, DataObjectDescriptor) {
let strides = compute_strides(&shape);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
(global, desc)
}
fn make_mars_pair(shape: Vec<u64>, param: &str) -> (GlobalMetadata, DataObjectDescriptor) {
let strides = compute_strides(&shape);
let mut mars_global = BTreeMap::new();
mars_global.insert("class".to_string(), ciborium::Value::Text("od".to_string()));
mars_global.insert("type".to_string(), ciborium::Value::Text("fc".to_string()));
mars_global.insert(
"date".to_string(),
ciborium::Value::Text("20260401".to_string()),
);
let mut base_entry = BTreeMap::new();
base_entry.insert(
"mars".to_string(),
ciborium::Value::Map(
mars_global
.into_iter()
.map(|(k, v)| (ciborium::Value::Text(k), v))
.collect(),
),
);
let global = GlobalMetadata {
version: 2,
base: vec![base_entry],
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: {
let mut p = BTreeMap::new();
p.insert(
"mars_param".to_string(),
ciborium::Value::Text(param.to_string()),
);
p
},
hash: None,
};
(global, desc)
}
fn compute_strides(shape: &[u64]) -> Vec<u64> {
if shape.is_empty() {
return vec![];
}
let mut strides = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[test]
fn test_full_round_trip_single_object() {
let (global, desc) = make_float32_descriptor(vec![10, 20]);
let data = vec![0u8; 10 * 20 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
assert_eq!(&encoded[0..8], b"TENSOGRM");
assert_eq!(&encoded[encoded.len() - 8..], b"39277777");
let (decoded_meta, decoded_objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 2);
assert_eq!(decoded_objects.len(), 1);
assert_eq!(decoded_objects[0].1, data);
}
#[test]
fn test_multi_object_message() {
let strides1 = compute_strides(&[4, 5]);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc1 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 2,
shape: vec![4, 5],
strides: strides1,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let desc2 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![3],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let data1 = vec![1u8; 4 * 5 * 4]; let data2 = vec![2u8; 3 * 8];
let encoded = encode(
&global,
&[(&desc1, &data1), (&desc2, &data2)],
&EncodeOptions::default(),
)
.unwrap();
let (meta, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 2);
assert_eq!(objects[0].1, data1);
assert_eq!(objects[1].1, data2);
let _ = meta;
}
#[test]
fn test_decode_metadata_only() {
let (global, desc) = make_mars_pair(vec![10], "2t");
let data = vec![0u8; 10 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let meta = decode_metadata(&encoded).unwrap();
assert_eq!(meta.version, 2);
}
#[test]
fn test_decode_single_object_by_index() {
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc1 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let desc2 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![3],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let data1 = vec![0xAA; 2 * 4];
let data2 = vec![0xBB; 3 * 4];
let encoded = encode(
&global,
&[(&desc1, &data1), (&desc2, &data2)],
&EncodeOptions::default(),
)
.unwrap();
let (_, _returned_desc, obj) = decode_object(&encoded, 1, &DecodeOptions::default()).unwrap();
assert_eq!(obj, data2);
}
#[test]
fn test_zero_object_message() {
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let encoded = encode(&global, &[], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 0);
}
#[test]
fn test_hash_verification_passes() {
let (global, desc) = make_float32_descriptor(vec![4]);
let data = vec![42u8; 4 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let options = DecodeOptions {
verify_hash: true,
..Default::default()
};
let (_, objects) = decode(&encoded, &options).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_hash_verification_fails_on_corruption() {
let (global, desc) = make_float32_descriptor(vec![100]);
let data = vec![42u8; 100 * 4];
let mut encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let data_frame_marker: &[u8] = &[b'F', b'R', 0x00, 0x04];
let frame_start = encoded
.windows(4)
.position(|w| w == data_frame_marker)
.expect("DataObject frame not found in encoded message");
let payload_byte = frame_start + 16;
encoded[payload_byte] ^= 0xFF;
let options = DecodeOptions {
verify_hash: true,
..Default::default()
};
let result = decode(&encoded, &options);
assert!(
result.is_err(),
"expected hash verification failure after corruption"
);
}
#[test]
fn test_simple_packing_round_trip() {
let values: Vec<f64> = (0..100).map(|i| 250.0 + i as f64 * 0.1).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let params = tensogram_encodings::simple_packing::compute_params(&values, 16, 0).unwrap();
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![100],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params,
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let decoded_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(decoded_values.len(), 100);
for (orig, dec) in values.iter().zip(decoded_values.iter()) {
assert!((orig - dec).abs() < 0.01, "orig={orig}, dec={dec}");
}
}
#[test]
fn test_shuffle_round_trip() {
let data: Vec<u8> = (0..40).collect();
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "none".to_string(),
params,
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_scan_multi_message_buffer() {
let (global1, desc1) = make_mars_pair(vec![4], "2t");
let (global2, desc2) = make_mars_pair(vec![8], "10u");
let data1 = vec![0u8; 4 * 4];
let data2 = vec![0u8; 8 * 4];
let msg1 = encode(&global1, &[(&desc1, &data1)], &EncodeOptions::default()).unwrap();
let msg2 = encode(&global2, &[(&desc2, &data2)], &EncodeOptions::default()).unwrap();
let mut buf = Vec::new();
buf.extend_from_slice(&msg1);
buf.extend_from_slice(&msg2);
let offsets = scan(&buf);
assert_eq!(offsets.len(), 2);
let (_, objects1) = decode(
&buf[offsets[0].0..offsets[0].0 + offsets[0].1],
&DecodeOptions::default(),
)
.unwrap();
let (_, objects2) = decode(
&buf[offsets[1].0..offsets[1].0 + offsets[1].1],
&DecodeOptions::default(),
)
.unwrap();
assert_eq!(objects1[0].0.shape, vec![4]);
assert_eq!(objects2[0].0.shape, vec![8]);
}
#[test]
fn test_partial_range_decode_uncompressed() {
let values: Vec<f32> = (0..10).map(|i| i as f32 * 1.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let (global, desc) = make_float32_descriptor(vec![10]);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, partial) =
decode_range(&encoded, 0, &[(3, 3)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial.len(), 1, "expected 1 part for 1 range");
let expected: Vec<u8> = values[3..6].iter().flat_map(|v| v.to_ne_bytes()).collect();
assert_eq!(partial[0], expected);
let joined: Vec<u8> = partial.into_iter().flatten().collect();
assert_eq!(joined, expected);
}
#[test]
fn test_decode_range_shuffle_rejected() {
let data: Vec<u8> = (0..40).collect();
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "none".to_string(),
params,
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let err = decode_range(&encoded, 0, &[(3, 3)], &DecodeOptions::default()).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("shuffle") || msg.contains("filter"), "{msg}");
}
#[test]
fn test_file_multi_message_round_trip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("multi.tgm");
let mut file = TensogramFile::create(&path).unwrap();
for i in 0..3u8 {
let (global, desc) = make_float32_descriptor(vec![4]);
let data = vec![i; 4 * 4];
file.append(&global, &[(&desc, &data)], &EncodeOptions::default())
.unwrap();
}
assert_eq!(file.message_count().unwrap(), 3);
for i in 0..3u8 {
let (_, objects) = file
.decode_message(i as usize, &DecodeOptions::default())
.unwrap();
assert_eq!(objects[0].1, vec![i; 4 * 4]);
}
}
#[test]
fn test_namespaced_metadata_round_trip() {
let (global, desc) = make_mars_pair(vec![4], "wave_spectra");
let data = vec![0u8; 4 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let meta = decode_metadata(&encoded).unwrap();
assert!(meta.base[0].contains_key("mars"));
if let ciborium::Value::Map(entries) = &meta.base[0]["mars"] {
let class_val = entries
.iter()
.find(|(k, _)| matches!(k, ciborium::Value::Text(s) if s == "class"))
.map(|(_, v)| v);
assert!(matches!(class_val, Some(ciborium::Value::Text(s)) if s == "od"));
}
}
#[test]
fn test_validate_object_overflow() {
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 2,
shape: vec![u64::MAX, 2],
strides: vec![2, 1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let data = vec![0u8; 64];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
}
#[test]
fn test_cross_endian_round_trip() {
let values: Vec<f64> = (0..50).map(|i| 100.0 + i as f64 * 0.5).collect();
let params = tensogram_encodings::simple_packing::compute_params(&values, 16, 0).unwrap();
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
let be_data: Vec<u8> = values.iter().flat_map(|v| v.to_be_bytes()).collect();
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let be_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![50],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Big,
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params.clone(),
hash: None,
};
let encoded_be = encode(&global, &[(&be_desc, &be_data)], &EncodeOptions::default()).unwrap();
let (_, objects_be) = decode(&encoded_be, &DecodeOptions::default()).unwrap();
let decoded_be_values: Vec<f64> = objects_be[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values.iter().zip(decoded_be_values.iter()) {
assert!(
(orig - dec).abs() < 0.01,
"BE→native round-trip: orig={orig}, dec={dec}"
);
}
let le_data: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let le_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![50],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Little,
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params.clone(),
hash: None,
};
let encoded_le = encode(&global, &[(&le_desc, &le_data)], &EncodeOptions::default()).unwrap();
let (_, objects_le) = decode(&encoded_le, &DecodeOptions::default()).unwrap();
let decoded_le_values: Vec<f64> = objects_le[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values.iter().zip(decoded_le_values.iter()) {
assert!(
(orig - dec).abs() < 0.01,
"LE→native round-trip: orig={orig}, dec={dec}"
);
}
assert_eq!(
objects_be[0].1, objects_le[0].1,
"BE and LE should both decode to identical native-endian bytes"
);
let wire_opts = DecodeOptions {
native_byte_order: false,
..Default::default()
};
let (_, wire_be) = decode(&encoded_be, &wire_opts).unwrap();
let (_, wire_le) = decode(&encoded_le, &wire_opts).unwrap();
assert_ne!(
wire_be[0].1, wire_le[0].1,
"wire-order BE and LE bytes should differ"
);
}
#[test]
fn test_decode_range_cross_endian_native() {
let values: Vec<f32> = (0..20).map(|i| i as f32 * 1.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_be_bytes()).collect();
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![20],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::Big,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, parts) =
decode_range(&encoded, 0, &[(5, 5)], &DecodeOptions::default()).expect("decode_range");
let part_values: Vec<f32> = parts[0]
.chunks_exact(4)
.map(|c| f32::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(part_values.len(), 5);
for (orig, dec) in values[5..10].iter().zip(part_values.iter()) {
assert_eq!(*orig, *dec, "cross-endian decode_range mismatch");
}
}
#[test]
fn test_decode_range_wire_byte_order_opt_out() {
let values: Vec<f32> = (0..20).map(|i| i as f32 * 1.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_be_bytes()).collect();
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![20],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::Big,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let wire_opts = DecodeOptions {
native_byte_order: false,
..Default::default()
};
let (_, parts) = decode_range(&encoded, 0, &[(5, 5)], &wire_opts).expect("decode_range");
let wire_values: Vec<f32> = parts[0]
.chunks_exact(4)
.map(|c| f32::from_be_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values[5..10].iter().zip(wire_values.iter()) {
assert_eq!(*orig, *dec, "wire-order decode_range mismatch");
}
}
#[test]
fn test_simple_packing_rejects_non_f64() {
let values: Vec<f64> = (0..10).map(|i| i as f64).collect();
let params = tensogram_encodings::simple_packing::compute_params(&values, 16, 0).unwrap();
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params,
hash: None,
};
let data = vec![0u8; 10 * 4];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(
result.is_err(),
"expected error for simple_packing with Float32 dtype"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("float64") || msg.contains("f64"),
"expected 'float64' in error, got: {msg}"
);
}
#[test]
fn test_validate_ndim_mismatch() {
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 3,
shape: vec![4, 5],
strides: vec![5, 1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let data = vec![0u8; 4 * 5 * 4];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
}
#[test]
fn test_param_out_of_bounds() {
let mut packing_params = BTreeMap::new();
packing_params.insert("reference_value".to_string(), ciborium::Value::Float(0.0));
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer(i64::MAX.into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer(0.into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer(16.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params,
hash: None,
};
let data = vec![0u8; 4 * 8];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("binary_scale_factor"),
"expected 'binary_scale_factor' in error, got: {msg}"
);
}
fn make_szip_packing_pair(
num_values: u64,
params: &simple_packing::SimplePackingParams,
) -> (GlobalMetadata, DataObjectDescriptor) {
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
packing_params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
packing_params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
packing_params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![num_values],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params: packing_params,
hash: None,
};
(global, desc)
}
fn make_szip_raw_pair(num_values: u64, dtype: Dtype) -> (GlobalMetadata, DataObjectDescriptor) {
let mut params = BTreeMap::new();
params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![num_values],
strides: vec![1],
dtype,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params,
hash: None,
};
(global, desc)
}
#[test]
fn test_szip_simple_packing_round_trip() {
let values: Vec<f64> = (0..4096).map(|i| 250.0 + i as f64 * 0.1).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects_meta) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert!(objects_meta[0].0.params.contains_key("szip_block_offsets"));
let decoded_values: Vec<f64> = objects_meta[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(decoded_values.len(), 4096);
for (orig, dec) in values.iter().zip(decoded_values.iter()) {
assert!((orig - dec).abs() < 0.01, "orig={orig}, dec={dec}");
}
}
#[test]
fn test_szip_simple_packing_decode_range_vs_full() {
let values: Vec<f64> = (0..4096).map(|i| 100.0 + i as f64 * 0.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) =
decode_range(&encoded, 0, &[(100, 500)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial_parts.len(), 1, "expected 1 part for 1 range");
let partial_bytes: Vec<u8> = partial_parts.into_iter().flatten().collect();
let partial_values: Vec<f64> = partial_bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(partial_values.len(), 500);
for (full, partial) in full_values[100..600].iter().zip(partial_values.iter()) {
assert!(
(full - partial).abs() < 1e-10,
"full={full}, partial={partial}"
);
}
}
#[test]
fn test_szip_simple_packing_decode_range_first_elements() {
let values: Vec<f64> = (0..4096).map(|i| i as f64).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) =
decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial_parts.len(), 1, "expected 1 part for 1 range");
let partial_bytes: Vec<u8> = partial_parts.into_iter().flatten().collect();
let partial_values: Vec<f64> = partial_bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(partial_values.len(), 10);
for (full, partial) in full_values[..10].iter().zip(partial_values.iter()) {
assert!(
(full - partial).abs() < 1e-10,
"full={full}, partial={partial}"
);
}
}
#[test]
fn test_szip_simple_packing_decode_range_last_elements() {
let values: Vec<f64> = (0..4096).map(|i| i as f64 * 3.125).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) =
decode_range(&encoded, 0, &[(4046, 50)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial_parts.len(), 1, "expected 1 part for 1 range");
let partial_bytes: Vec<u8> = partial_parts.into_iter().flatten().collect();
let partial_values: Vec<f64> = partial_bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(partial_values.len(), 50);
for (full, partial) in full_values[4046..].iter().zip(partial_values.iter()) {
assert!(
(full - partial).abs() < 1e-10,
"full={full}, partial={partial}"
);
}
}
#[test]
fn test_szip_raw_u8_round_trip() {
let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
let (global, desc) = make_szip_raw_pair(4096, Dtype::Uint8);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_szip_shuffle_round_trip() {
let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect(); let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1024],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "szip".to_string(),
params,
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_szip_shuffle_decode_range_rejected() {
let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1024],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "szip".to_string(),
params,
hash: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let err = decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default()).unwrap_err();
assert!(err.to_string().contains("shuffle") || err.to_string().contains("filter"));
}
#[test]
fn test_szip_multi_object_mixed_compression() {
let values: Vec<f64> = (0..2048).map(|i| 100.0 + i as f64 * 0.1).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let raw_data = vec![0u8; 10 * 4];
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(packing.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((packing.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((packing.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((packing.bits_per_value as i64).into()),
);
packing_params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
packing_params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
packing_params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let packed_data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let raw_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let packed_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2048],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params: packing_params,
hash: None,
};
let encoded = encode(
&global,
&[(&raw_desc, &raw_data), (&packed_desc, &packed_data)],
&EncodeOptions::default(),
)
.unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 2);
assert_eq!(objects[0].1, raw_data);
let decoded_values: Vec<f64> = objects[1]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values.iter().zip(decoded_values.iter()) {
assert!((orig - dec).abs() < 0.01);
}
}
#[test]
fn test_szip_hash_verification() {
let values: Vec<f64> = (0..2048).map(|i| i as f64).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(2048, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let options = DecodeOptions {
verify_hash: true,
..Default::default()
};
let (_, objects) = decode(&encoded, &options).unwrap();
assert_eq!(objects[0].1.len(), 2048 * 8);
}
#[test]
fn test_szip_decode_range_multiple_ranges() {
let values: Vec<f64> = (0..4096).map(|i| i as f64 * 2.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) = decode_range(
&encoded,
0,
&[(10, 10), (3000, 50)],
&DecodeOptions::default(),
)
.expect("decode_range");
assert_eq!(partial_parts.len(), 2, "expected 2 parts for 2 ranges");
assert_eq!(partial_parts[0].len(), 10 * 8);
assert_eq!(partial_parts[1].len(), 50 * 8);
let part0_values: Vec<f64> = partial_parts[0]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (full, partial) in full_values[10..20].iter().zip(part0_values.iter()) {
assert!((full - partial).abs() < 1e-10);
}
let part1_values: Vec<f64> = partial_parts[1]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (full, partial) in full_values[3000..3050].iter().zip(part1_values.iter()) {
assert!((full - partial).abs() < 1e-10);
}
let total_bytes: usize = partial_parts.iter().map(|p| p.len()).sum();
assert_eq!(total_bytes, 60 * 8);
}
#[test]
fn test_validate_empty_obj_type() {
let global = GlobalMetadata {
version: 2,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
};
let data = vec![0u8; 4 * 4];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
}
#[test]
fn test_metadata_base_reserved_round_trip() {
let (_, desc) = make_float32_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let mut base_entry = BTreeMap::new();
base_entry.insert(
"centre".to_string(),
ciborium::Value::Text("ecmwf".to_string()),
);
base_entry.insert(
"date".to_string(),
ciborium::Value::Integer(20260404.into()),
);
base_entry.insert(
"custom_key".to_string(),
ciborium::Value::Text("hello".to_string()),
);
let global = GlobalMetadata {
version: 2,
base: vec![base_entry],
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (decoded_meta, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 2);
assert_eq!(decoded_meta.base.len(), 1);
assert_eq!(
decoded_meta.base[0].get("centre"),
Some(&ciborium::Value::Text("ecmwf".to_string()))
);
assert_eq!(
decoded_meta.base[0].get("custom_key"),
Some(&ciborium::Value::Text("hello".to_string()))
);
assert!(
decoded_meta.base[0].contains_key("_reserved_"),
"encoder must auto-populate _reserved_ in base entries"
);
assert!(
decoded_meta.reserved.contains_key("encoder"),
"reserved must contain encoder provenance"
);
assert!(
decoded_meta.reserved.contains_key("time"),
"reserved must contain time provenance"
);
assert!(
decoded_meta.reserved.contains_key("uuid"),
"reserved must contain uuid provenance"
);
assert!(decoded_meta.extra.is_empty());
}
#[test]
fn test_metadata_empty_sections_not_serialized() {
let (_, desc) = make_float32_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let mut extra = BTreeMap::new();
extra.insert(
"mars".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("class".to_string()),
ciborium::Value::Text("od".to_string()),
)]),
);
let global = GlobalMetadata {
version: 2,
extra,
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (decoded_meta, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 2);
assert_eq!(
decoded_meta.base.len(),
1,
"encoder must create one base entry per object"
);
assert!(
decoded_meta.base[0].contains_key("_reserved_"),
"encoder must auto-populate _reserved_.tensor in base entries"
);
assert!(
decoded_meta.reserved.contains_key("encoder"),
"reserved must contain encoder provenance"
);
assert!(decoded_meta.extra.contains_key("mars"));
}
fn cbor_map_lookup<'a>(map: &'a ciborium::Value, key: &str) -> Option<&'a ciborium::Value> {
if let ciborium::Value::Map(entries) = map {
for (k, v) in entries {
if matches!(k, ciborium::Value::Text(s) if s == key) {
return Some(v);
}
}
}
None
}
#[test]
fn test_deep_nested_metadata_round_trip() {
let (_, desc) = make_float32_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let max_depth: usize = 20;
let mut value = ciborium::Value::Text(format!("leaf_at_{max_depth}"));
for d in (0..max_depth).rev() {
value = ciborium::Value::Map(vec![(
ciborium::Value::Text(format!("depth_{}", d + 1)),
value,
)]);
}
let mut base_entry = BTreeMap::new();
base_entry.insert("depth_0".to_string(), value.clone());
base_entry.insert("nested".to_string(), value);
let global = GlobalMetadata {
version: 2,
base: vec![base_entry],
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
let mut current = decoded.base[0]
.get("depth_0")
.expect("base[0] must have depth_0");
for d in 1..=max_depth {
let key = format!("depth_{d}");
current = cbor_map_lookup(current, &key)
.unwrap_or_else(|| panic!("missing key {key} at depth {d}"));
}
assert_eq!(
current,
&ciborium::Value::Text(format!("leaf_at_{max_depth}")),
"leaf value must survive round-trip at depth {max_depth}"
);
let mut current = decoded.base[0]
.get("nested")
.expect("base[0] entry must have 'nested'");
for d in 1..=max_depth {
let key = format!("depth_{d}");
current = cbor_map_lookup(current, &key)
.unwrap_or_else(|| panic!("missing key {key} in nested at depth {d}"));
}
assert_eq!(
current,
&ciborium::Value::Text(format!("leaf_at_{max_depth}")),
"nested leaf must survive round-trip at depth {max_depth}"
);
}