use std::collections::BTreeMap;
use tensogram::*;
use tensogram_encodings::simple_packing;
fn make_float32_descriptor(shape: Vec<u64>) -> (GlobalMetadata, DataObjectDescriptor) {
let strides = compute_strides(&shape);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
(global, desc)
}
fn make_mars_pair(shape: Vec<u64>, param: &str) -> (GlobalMetadata, DataObjectDescriptor) {
let strides = compute_strides(&shape);
let mut mars_global = BTreeMap::new();
mars_global.insert("class".to_string(), ciborium::Value::Text("od".to_string()));
mars_global.insert("type".to_string(), ciborium::Value::Text("fc".to_string()));
mars_global.insert(
"date".to_string(),
ciborium::Value::Text("20260401".to_string()),
);
let mut base_entry = BTreeMap::new();
base_entry.insert(
"mars".to_string(),
ciborium::Value::Map(
mars_global
.into_iter()
.map(|(k, v)| (ciborium::Value::Text(k), v))
.collect(),
),
);
let global = GlobalMetadata {
version: 3,
base: vec![base_entry],
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: {
let mut p = BTreeMap::new();
p.insert(
"mars_param".to_string(),
ciborium::Value::Text(param.to_string()),
);
p
},
masks: None,
};
(global, desc)
}
fn compute_strides(shape: &[u64]) -> Vec<u64> {
if shape.is_empty() {
return vec![];
}
let mut strides = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[test]
fn test_full_round_trip_single_object() {
let (global, desc) = make_float32_descriptor(vec![10, 20]);
let data = vec![0u8; 10 * 20 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
assert_eq!(&encoded[0..8], b"TENSOGRM");
assert_eq!(&encoded[encoded.len() - 8..], b"39277777");
let (decoded_meta, decoded_objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 3);
assert_eq!(decoded_objects.len(), 1);
assert_eq!(decoded_objects[0].1, data);
}
#[test]
fn test_multi_object_message() {
let strides1 = compute_strides(&[4, 5]);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc1 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 2,
shape: vec![4, 5],
strides: strides1,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let desc2 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![3],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data1 = vec![1u8; 4 * 5 * 4]; let data2 = vec![2u8; 3 * 8];
let encoded = encode(
&global,
&[(&desc1, &data1), (&desc2, &data2)],
&EncodeOptions::default(),
)
.unwrap();
let (meta, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 2);
assert_eq!(objects[0].1, data1);
assert_eq!(objects[1].1, data2);
let _ = meta;
}
#[test]
fn test_decode_metadata_only() {
let (global, desc) = make_mars_pair(vec![10], "2t");
let data = vec![0u8; 10 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let meta = decode_metadata(&encoded).unwrap();
assert_eq!(meta.version, 3);
}
#[test]
fn test_decode_single_object_by_index() {
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc1 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let desc2 = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![3],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data1 = vec![0xAA; 2 * 4];
let data2 = vec![0xBB; 3 * 4];
let encoded = encode(
&global,
&[(&desc1, &data1), (&desc2, &data2)],
&EncodeOptions::default(),
)
.unwrap();
let (_, _returned_desc, obj) = decode_object(&encoded, 1, &DecodeOptions::default()).unwrap();
assert_eq!(obj, data2);
}
#[test]
fn test_zero_object_message() {
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let encoded = encode(&global, &[], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 0);
}
#[test]
fn test_hash_verification_passes() {
let (global, desc) = make_float32_descriptor(vec![4]);
let data = vec![42u8; 4 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let options = DecodeOptions {
verify_hash: true,
..Default::default()
};
let (_, objects) = decode(&encoded, &options).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_hash_verification_fails_on_corruption() {
use tensogram::validate::{IssueCode, ValidateOptions, validate_message};
let (global, desc) = make_float32_descriptor(vec![100]);
let data = vec![42u8; 100 * 4];
let mut encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let data_frame_marker: &[u8] = &[b'F', b'R', 0x00, 0x09];
let frame_start = encoded
.windows(4)
.position(|w| w == data_frame_marker)
.expect("NTensorFrame not found in encoded message");
encoded[frame_start + 16] ^= 0xFF;
let report = validate_message(&encoded, &ValidateOptions::default());
assert!(
report
.issues
.iter()
.any(|i| i.code == IssueCode::HashMismatch),
"expected HashMismatch after payload tamper, got: {:?}",
report.issues
);
}
#[test]
fn test_simple_packing_round_trip() {
let values: Vec<f64> = (0..100).map(|i| 250.0 + i as f64 * 0.1).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let params = tensogram_encodings::simple_packing::compute_params(&values, 16, 0).unwrap();
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![100],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params,
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let decoded_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(decoded_values.len(), 100);
for (orig, dec) in values.iter().zip(decoded_values.iter()) {
assert!((orig - dec).abs() < 0.01, "orig={orig}, dec={dec}");
}
}
#[test]
fn test_shuffle_round_trip() {
let data: Vec<u8> = (0..40).collect();
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "none".to_string(),
params,
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_scan_multi_message_buffer() {
let (global1, desc1) = make_mars_pair(vec![4], "2t");
let (global2, desc2) = make_mars_pair(vec![8], "10u");
let data1 = vec![0u8; 4 * 4];
let data2 = vec![0u8; 8 * 4];
let msg1 = encode(&global1, &[(&desc1, &data1)], &EncodeOptions::default()).unwrap();
let msg2 = encode(&global2, &[(&desc2, &data2)], &EncodeOptions::default()).unwrap();
let mut buf = Vec::new();
buf.extend_from_slice(&msg1);
buf.extend_from_slice(&msg2);
let offsets = scan(&buf);
assert_eq!(offsets.len(), 2);
let (_, objects1) = decode(
&buf[offsets[0].0..offsets[0].0 + offsets[0].1],
&DecodeOptions::default(),
)
.unwrap();
let (_, objects2) = decode(
&buf[offsets[1].0..offsets[1].0 + offsets[1].1],
&DecodeOptions::default(),
)
.unwrap();
assert_eq!(objects1[0].0.shape, vec![4]);
assert_eq!(objects2[0].0.shape, vec![8]);
}
#[test]
fn test_partial_range_decode_uncompressed() {
let values: Vec<f32> = (0..10).map(|i| i as f32 * 1.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let (global, desc) = make_float32_descriptor(vec![10]);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, partial) =
decode_range(&encoded, 0, &[(3, 3)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial.len(), 1, "expected 1 part for 1 range");
let expected: Vec<u8> = values[3..6].iter().flat_map(|v| v.to_ne_bytes()).collect();
assert_eq!(partial[0], expected);
let joined: Vec<u8> = partial.into_iter().flatten().collect();
assert_eq!(joined, expected);
}
#[test]
fn test_decode_range_shuffle_rejected() {
let data: Vec<u8> = (0..40).collect();
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "none".to_string(),
params,
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let err = decode_range(&encoded, 0, &[(3, 3)], &DecodeOptions::default()).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("shuffle") || msg.contains("filter"), "{msg}");
}
#[test]
fn test_file_multi_message_round_trip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("multi.tgm");
let mut file = TensogramFile::create(&path).unwrap();
for i in 0..3u8 {
let (global, desc) = make_float32_descriptor(vec![4]);
let data = vec![i; 4 * 4];
file.append(&global, &[(&desc, &data)], &EncodeOptions::default())
.unwrap();
}
assert_eq!(file.message_count().unwrap(), 3);
for i in 0..3u8 {
let (_, objects) = file
.decode_message(i as usize, &DecodeOptions::default())
.unwrap();
assert_eq!(objects[0].1, vec![i; 4 * 4]);
}
}
#[test]
fn test_namespaced_metadata_round_trip() {
let (global, desc) = make_mars_pair(vec![4], "wave_spectra");
let data = vec![0u8; 4 * 4];
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let meta = decode_metadata(&encoded).unwrap();
assert!(meta.base[0].contains_key("mars"));
if let ciborium::Value::Map(entries) = &meta.base[0]["mars"] {
let class_val = entries
.iter()
.find(|(k, _)| matches!(k, ciborium::Value::Text(s) if s == "class"))
.map(|(_, v)| v);
assert!(matches!(class_val, Some(ciborium::Value::Text(s)) if s == "od"));
}
}
#[test]
fn test_validate_object_overflow() {
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 2,
shape: vec![u64::MAX, 2],
strides: vec![2, 1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 64];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
}
#[test]
fn test_cross_endian_round_trip() {
let values: Vec<f64> = (0..50).map(|i| 100.0 + i as f64 * 0.5).collect();
let params = tensogram_encodings::simple_packing::compute_params(&values, 16, 0).unwrap();
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
let be_data: Vec<u8> = values.iter().flat_map(|v| v.to_be_bytes()).collect();
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let be_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![50],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Big,
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params.clone(),
masks: None,
};
let encoded_be = encode(&global, &[(&be_desc, &be_data)], &EncodeOptions::default()).unwrap();
let (_, objects_be) = decode(&encoded_be, &DecodeOptions::default()).unwrap();
let decoded_be_values: Vec<f64> = objects_be[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values.iter().zip(decoded_be_values.iter()) {
assert!(
(orig - dec).abs() < 0.01,
"BE→native round-trip: orig={orig}, dec={dec}"
);
}
let le_data: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let le_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![50],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Little,
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params.clone(),
masks: None,
};
let encoded_le = encode(&global, &[(&le_desc, &le_data)], &EncodeOptions::default()).unwrap();
let (_, objects_le) = decode(&encoded_le, &DecodeOptions::default()).unwrap();
let decoded_le_values: Vec<f64> = objects_le[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values.iter().zip(decoded_le_values.iter()) {
assert!(
(orig - dec).abs() < 0.01,
"LE→native round-trip: orig={orig}, dec={dec}"
);
}
assert_eq!(
objects_be[0].1, objects_le[0].1,
"BE and LE should both decode to identical native-endian bytes"
);
let wire_opts = DecodeOptions {
native_byte_order: false,
..Default::default()
};
let (_, wire_be) = decode(&encoded_be, &wire_opts).unwrap();
let (_, wire_le) = decode(&encoded_le, &wire_opts).unwrap();
assert_ne!(
wire_be[0].1, wire_le[0].1,
"wire-order BE and LE bytes should differ"
);
}
#[test]
fn test_decode_range_cross_endian_native() {
let values: Vec<f32> = (0..20).map(|i| i as f32 * 1.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_be_bytes()).collect();
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![20],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::Big,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, parts) =
decode_range(&encoded, 0, &[(5, 5)], &DecodeOptions::default()).expect("decode_range");
let part_values: Vec<f32> = parts[0]
.chunks_exact(4)
.map(|c| f32::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(part_values.len(), 5);
for (orig, dec) in values[5..10].iter().zip(part_values.iter()) {
assert_eq!(*orig, *dec, "cross-endian decode_range mismatch");
}
}
#[test]
fn test_decode_range_wire_byte_order_opt_out() {
let values: Vec<f32> = (0..20).map(|i| i as f32 * 1.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_be_bytes()).collect();
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![20],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::Big,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let wire_opts = DecodeOptions {
native_byte_order: false,
..Default::default()
};
let (_, parts) = decode_range(&encoded, 0, &[(5, 5)], &wire_opts).expect("decode_range");
let wire_values: Vec<f32> = parts[0]
.chunks_exact(4)
.map(|c| f32::from_be_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values[5..10].iter().zip(wire_values.iter()) {
assert_eq!(*orig, *dec, "wire-order decode_range mismatch");
}
}
#[test]
fn test_simple_packing_rejects_non_f64() {
let values: Vec<f64> = (0..10).map(|i| i as f64).collect();
let params = tensogram_encodings::simple_packing::compute_params(&values, 16, 0).unwrap();
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params,
masks: None,
};
let data = vec![0u8; 10 * 4];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(
result.is_err(),
"expected error for simple_packing with Float32 dtype"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("float64") || msg.contains("f64"),
"expected 'float64' in error, got: {msg}"
);
}
#[test]
fn test_validate_ndim_mismatch() {
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 3,
shape: vec![4, 5],
strides: vec![5, 1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 4 * 5 * 4];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
}
#[test]
fn test_param_out_of_bounds() {
let mut packing_params = BTreeMap::new();
packing_params.insert("reference_value".to_string(), ciborium::Value::Float(0.0));
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer(i64::MAX.into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer(0.into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer(16.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: packing_params,
masks: None,
};
let data = vec![0u8; 4 * 8];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("binary_scale_factor"),
"expected 'binary_scale_factor' in error, got: {msg}"
);
}
fn make_szip_packing_pair(
num_values: u64,
params: &simple_packing::SimplePackingParams,
) -> (GlobalMetadata, DataObjectDescriptor) {
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(params.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((params.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((params.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((params.bits_per_value as i64).into()),
);
packing_params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
packing_params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
packing_params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![num_values],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params: packing_params,
masks: None,
};
(global, desc)
}
fn make_szip_raw_pair(num_values: u64, dtype: Dtype) -> (GlobalMetadata, DataObjectDescriptor) {
let mut params = BTreeMap::new();
params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![num_values],
strides: vec![1],
dtype,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
(global, desc)
}
#[test]
fn test_szip_simple_packing_round_trip() {
let values: Vec<f64> = (0..4096).map(|i| 250.0 + i as f64 * 0.1).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects_meta) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert!(objects_meta[0].0.params.contains_key("szip_block_offsets"));
let decoded_values: Vec<f64> = objects_meta[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(decoded_values.len(), 4096);
for (orig, dec) in values.iter().zip(decoded_values.iter()) {
assert!((orig - dec).abs() < 0.01, "orig={orig}, dec={dec}");
}
}
#[test]
fn test_szip_simple_packing_decode_range_vs_full() {
let values: Vec<f64> = (0..4096).map(|i| 100.0 + i as f64 * 0.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) =
decode_range(&encoded, 0, &[(100, 500)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial_parts.len(), 1, "expected 1 part for 1 range");
let partial_bytes: Vec<u8> = partial_parts.into_iter().flatten().collect();
let partial_values: Vec<f64> = partial_bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(partial_values.len(), 500);
for (full, partial) in full_values[100..600].iter().zip(partial_values.iter()) {
assert!(
(full - partial).abs() < 1e-10,
"full={full}, partial={partial}"
);
}
}
#[test]
fn test_szip_simple_packing_decode_range_first_elements() {
let values: Vec<f64> = (0..4096).map(|i| i as f64).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) =
decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial_parts.len(), 1, "expected 1 part for 1 range");
let partial_bytes: Vec<u8> = partial_parts.into_iter().flatten().collect();
let partial_values: Vec<f64> = partial_bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(partial_values.len(), 10);
for (full, partial) in full_values[..10].iter().zip(partial_values.iter()) {
assert!(
(full - partial).abs() < 1e-10,
"full={full}, partial={partial}"
);
}
}
#[test]
fn test_szip_simple_packing_decode_range_last_elements() {
let values: Vec<f64> = (0..4096).map(|i| i as f64 * 3.125).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) =
decode_range(&encoded, 0, &[(4046, 50)], &DecodeOptions::default()).expect("decode_range");
assert_eq!(partial_parts.len(), 1, "expected 1 part for 1 range");
let partial_bytes: Vec<u8> = partial_parts.into_iter().flatten().collect();
let partial_values: Vec<f64> = partial_bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(partial_values.len(), 50);
for (full, partial) in full_values[4046..].iter().zip(partial_values.iter()) {
assert!(
(full - partial).abs() < 1e-10,
"full={full}, partial={partial}"
);
}
}
#[test]
fn test_szip_raw_u8_round_trip() {
let data: Vec<u8> = (0..1024).flat_map(|i| (i as f32).to_ne_bytes()).collect();
let (global, desc) = make_szip_raw_pair(4096, Dtype::Uint8);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_szip_shuffle_round_trip() {
let data: Vec<u8> = (0..1024).flat_map(|i| (i as f32).to_ne_bytes()).collect(); let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1024],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
}
#[test]
fn test_szip_shuffle_decode_range_rejected() {
let data: Vec<u8> = (0..1024).flat_map(|i| (i as f32).to_ne_bytes()).collect();
let mut params = BTreeMap::new();
params.insert(
"shuffle_element_size".to_string(),
ciborium::Value::Integer(4.into()),
);
params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1024],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "shuffle".to_string(),
compression: "szip".to_string(),
params,
masks: None,
};
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let err = decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default()).unwrap_err();
assert!(err.to_string().contains("shuffle") || err.to_string().contains("filter"));
}
#[test]
fn test_szip_multi_object_mixed_compression() {
let values: Vec<f64> = (0..2048).map(|i| 100.0 + i as f64 * 0.1).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let raw_data = vec![0u8; 10 * 4];
let mut packing_params = BTreeMap::new();
packing_params.insert(
"reference_value".to_string(),
ciborium::Value::Float(packing.reference_value),
);
packing_params.insert(
"binary_scale_factor".to_string(),
ciborium::Value::Integer((packing.binary_scale_factor as i64).into()),
);
packing_params.insert(
"decimal_scale_factor".to_string(),
ciborium::Value::Integer((packing.decimal_scale_factor as i64).into()),
);
packing_params.insert(
"bits_per_value".to_string(),
ciborium::Value::Integer((packing.bits_per_value as i64).into()),
);
packing_params.insert("szip_rsi".to_string(), ciborium::Value::Integer(128.into()));
packing_params.insert(
"szip_block_size".to_string(),
ciborium::Value::Integer(16.into()),
);
packing_params.insert(
"szip_flags".to_string(),
ciborium::Value::Integer(8_i64.into()),
);
let packed_data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let raw_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![10],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let packed_desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2048],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "simple_packing".to_string(),
filter: "none".to_string(),
compression: "szip".to_string(),
params: packing_params,
masks: None,
};
let encoded = encode(
&global,
&[(&raw_desc, &raw_data), (&packed_desc, &packed_data)],
&EncodeOptions::default(),
)
.unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 2);
assert_eq!(objects[0].1, raw_data);
let decoded_values: Vec<f64> = objects[1]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (orig, dec) in values.iter().zip(decoded_values.iter()) {
assert!((orig - dec).abs() < 0.01);
}
}
#[test]
fn test_szip_hash_verification() {
let values: Vec<f64> = (0..2048).map(|i| i as f64).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(2048, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let options = DecodeOptions {
verify_hash: true,
..Default::default()
};
let (_, objects) = decode(&encoded, &options).unwrap();
assert_eq!(objects[0].1.len(), 2048 * 8);
}
#[test]
fn test_szip_decode_range_multiple_ranges() {
let values: Vec<f64> = (0..4096).map(|i| i as f64 * 2.5).collect();
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let packing = simple_packing::compute_params(&values, 16, 0).unwrap();
let (global, desc) = make_szip_packing_pair(4096, &packing);
let encoded = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_, objects) = decode(&encoded, &DecodeOptions::default()).unwrap();
let full_values: Vec<f64> = objects[0]
.1
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
let (_, partial_parts) = decode_range(
&encoded,
0,
&[(10, 10), (3000, 50)],
&DecodeOptions::default(),
)
.expect("decode_range");
assert_eq!(partial_parts.len(), 2, "expected 2 parts for 2 ranges");
assert_eq!(partial_parts[0].len(), 10 * 8);
assert_eq!(partial_parts[1].len(), 50 * 8);
let part0_values: Vec<f64> = partial_parts[0]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (full, partial) in full_values[10..20].iter().zip(part0_values.iter()) {
assert!((full - partial).abs() < 1e-10);
}
let part1_values: Vec<f64> = partial_parts[1]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
for (full, partial) in full_values[3000..3050].iter().zip(part1_values.iter()) {
assert!((full - partial).abs() < 1e-10);
}
let total_bytes: usize = partial_parts.iter().map(|p| p.len()).sum();
assert_eq!(total_bytes, 60 * 8);
}
#[test]
fn test_validate_empty_obj_type() {
let global = GlobalMetadata {
version: 3,
extra: BTreeMap::new(),
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 4 * 4];
let result = encode(&global, &[(&desc, &data)], &EncodeOptions::default());
assert!(result.is_err(), "expected Err but got Ok");
}
#[test]
fn test_metadata_base_reserved_round_trip() {
let (_, desc) = make_float32_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let mut base_entry = BTreeMap::new();
base_entry.insert(
"centre".to_string(),
ciborium::Value::Text("ecmwf".to_string()),
);
base_entry.insert(
"date".to_string(),
ciborium::Value::Integer(20260404.into()),
);
base_entry.insert(
"custom_key".to_string(),
ciborium::Value::Text("hello".to_string()),
);
let global = GlobalMetadata {
version: 3,
base: vec![base_entry],
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (decoded_meta, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 3);
assert_eq!(decoded_meta.base.len(), 1);
assert_eq!(
decoded_meta.base[0].get("centre"),
Some(&ciborium::Value::Text("ecmwf".to_string()))
);
assert_eq!(
decoded_meta.base[0].get("custom_key"),
Some(&ciborium::Value::Text("hello".to_string()))
);
assert!(
decoded_meta.base[0].contains_key("_reserved_"),
"encoder must auto-populate _reserved_ in base entries"
);
assert!(
decoded_meta.reserved.contains_key("encoder"),
"reserved must contain encoder provenance"
);
assert!(
decoded_meta.reserved.contains_key("time"),
"reserved must contain time provenance"
);
assert!(
decoded_meta.reserved.contains_key("uuid"),
"reserved must contain uuid provenance"
);
assert!(decoded_meta.extra.is_empty());
}
#[test]
fn test_metadata_empty_sections_not_serialized() {
let (_, desc) = make_float32_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let mut extra = BTreeMap::new();
extra.insert(
"mars".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("class".to_string()),
ciborium::Value::Text("od".to_string()),
)]),
);
let global = GlobalMetadata {
version: 3,
extra,
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (decoded_meta, _) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 3);
assert_eq!(
decoded_meta.base.len(),
1,
"encoder must create one base entry per object"
);
assert!(
decoded_meta.base[0].contains_key("_reserved_"),
"encoder must auto-populate _reserved_.tensor in base entries"
);
assert!(
decoded_meta.reserved.contains_key("encoder"),
"reserved must contain encoder provenance"
);
assert!(decoded_meta.extra.contains_key("mars"));
}
fn cbor_map_lookup<'a>(map: &'a ciborium::Value, key: &str) -> Option<&'a ciborium::Value> {
if let ciborium::Value::Map(entries) = map {
for (k, v) in entries {
if matches!(k, ciborium::Value::Text(s) if s == key) {
return Some(v);
}
}
}
None
}
#[test]
fn test_deep_nested_metadata_round_trip() {
let (_, desc) = make_float32_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let max_depth: usize = 20;
let mut value = ciborium::Value::Text(format!("leaf_at_{max_depth}"));
for d in (0..max_depth).rev() {
value = ciborium::Value::Map(vec![(
ciborium::Value::Text(format!("depth_{}", d + 1)),
value,
)]);
}
let mut base_entry = BTreeMap::new();
base_entry.insert("depth_0".to_string(), value.clone());
base_entry.insert("nested".to_string(), value);
let global = GlobalMetadata {
version: 3,
base: vec![base_entry],
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
let mut current = decoded.base[0]
.get("depth_0")
.expect("base[0] must have depth_0");
for d in 1..=max_depth {
let key = format!("depth_{d}");
current = cbor_map_lookup(current, &key)
.unwrap_or_else(|| panic!("missing key {key} at depth {d}"));
}
assert_eq!(
current,
&ciborium::Value::Text(format!("leaf_at_{max_depth}")),
"leaf value must survive round-trip at depth {max_depth}"
);
let mut current = decoded.base[0]
.get("nested")
.expect("base[0] entry must have 'nested'");
for d in 1..=max_depth {
let key = format!("depth_{d}");
current = cbor_map_lookup(current, &key)
.unwrap_or_else(|| panic!("missing key {key} in nested at depth {d}"));
}
assert_eq!(
current,
&ciborium::Value::Text(format!("leaf_at_{max_depth}")),
"nested leaf must survive round-trip at depth {max_depth}"
);
}
#[test]
fn buffered_postamble_total_length_equals_message_length() {
use tensogram::wire::POSTAMBLE_SIZE;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 16];
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let pa_start = msg.len() - POSTAMBLE_SIZE;
let pa_total = u64::from_be_bytes(msg[pa_start + 8..pa_start + 16].try_into().unwrap());
assert_eq!(pa_total, msg.len() as u64, "postamble must mirror length");
let pre_total = u64::from_be_bytes(msg[16..24].try_into().unwrap());
assert_eq!(pre_total, msg.len() as u64);
assert_eq!(pre_total, pa_total);
}
#[test]
fn streaming_non_seekable_zero_total_length_preamble_and_postamble() {
use tensogram::streaming::StreamingEncoder;
use tensogram::wire::POSTAMBLE_SIZE;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 8];
let buf: Vec<u8> = Vec::new();
let mut enc = StreamingEncoder::new(buf, &global, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let finished = enc.finish().unwrap();
let pre_total = u64::from_be_bytes(finished[16..24].try_into().unwrap());
assert_eq!(pre_total, 0, "streaming preamble total_length must be 0");
let pa_start = finished.len() - POSTAMBLE_SIZE;
let pa_total = u64::from_be_bytes(finished[pa_start + 8..pa_start + 16].try_into().unwrap());
assert_eq!(
pa_total, 0,
"streaming postamble total_length must be 0 on non-seekable sinks"
);
let (_meta, objects) = decode(&finished, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 1);
}
#[test]
fn streaming_seekable_backfill_patches_both_total_length_slots() {
use std::io::Cursor;
use tensogram::streaming::StreamingEncoder;
use tensogram::wire::POSTAMBLE_SIZE;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![3],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 12];
let cursor: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let mut enc = StreamingEncoder::new(cursor, &global, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let cursor = enc.finish_with_backfill().unwrap();
let buf = cursor.into_inner();
let pre_total = u64::from_be_bytes(buf[16..24].try_into().unwrap());
assert_eq!(pre_total, buf.len() as u64);
let pa_start = buf.len() - POSTAMBLE_SIZE;
let pa_total = u64::from_be_bytes(buf[pa_start + 8..pa_start + 16].try_into().unwrap());
assert_eq!(pa_total, buf.len() as u64);
assert_eq!(pre_total, pa_total);
let (_meta, objects) = decode(&buf, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 1);
}
#[test]
fn postamble_end_magic_always_at_last_8_bytes() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 4];
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
assert_eq!(&msg[msg.len() - 8..], b"39277777");
use tensogram::streaming::StreamingEncoder;
let mut enc =
StreamingEncoder::new(Vec::<u8>::new(), &global, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let streamed = enc.finish().unwrap();
assert_eq!(&streamed[streamed.len() - 8..], b"39277777");
use std::io::Cursor;
let cursor: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let mut enc = StreamingEncoder::new(cursor, &global, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let backfilled = enc.finish_with_backfill().unwrap().into_inner();
assert_eq!(&backfilled[backfilled.len() - 8..], b"39277777");
}
fn encode_sample(version_tag: u32) -> Vec<u8> {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![4],
dtype: Dtype::Uint32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data: Vec<u8> = (0..4)
.flat_map(|i| (version_tag + i).to_le_bytes())
.collect();
encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap()
}
#[test]
fn scan_bidirectional_matches_forward_on_10_messages() {
let mut buf = Vec::new();
for i in 0..10 {
buf.extend_from_slice(&encode_sample(i as u32 * 100));
}
let fwd_only_opts = ScanOptions {
bidirectional: false,
..ScanOptions::default()
};
let bidir_opts = ScanOptions::default();
let fwd_only = scan_with_options(&buf, &fwd_only_opts);
let bidir = scan_with_options(&buf, &bidir_opts);
assert_eq!(fwd_only.len(), 10);
assert_eq!(bidir.len(), 10);
assert_eq!(bidir, fwd_only, "bidir scan must preserve order");
}
#[test]
fn scan_default_is_bidirectional_and_equivalent_to_forward() {
let mut buf = Vec::new();
for i in 0..5 {
buf.extend_from_slice(&encode_sample(i));
}
let default = scan(&buf);
let fwd = scan_with_options(
&buf,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
},
);
assert_eq!(default, fwd);
}
#[test]
fn scan_bidirectional_falls_back_on_mid_file_streaming_marker() {
use tensogram::wire::POSTAMBLE_SIZE;
let mut msgs: Vec<Vec<u8>> = (0..5).map(encode_sample).collect();
let mid = &mut msgs[2];
let slot = mid.len() - 16;
mid[slot..slot + 8].copy_from_slice(&0u64.to_be_bytes());
let mut buf = Vec::new();
for m in &msgs {
buf.extend_from_slice(m);
}
let bidir = scan_with_options(&buf, &ScanOptions::default());
let fwd = scan_with_options(
&buf,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
},
);
assert_eq!(bidir.len(), 5, "all 5 messages must be found");
assert_eq!(bidir, fwd);
assert_eq!(POSTAMBLE_SIZE, 24);
}
#[test]
fn scan_single_message_works_in_both_modes() {
let buf = encode_sample(42);
let fwd = scan_with_options(
&buf,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
},
);
let bidir = scan_with_options(&buf, &ScanOptions::default());
assert_eq!(fwd.len(), 1);
assert_eq!(bidir, fwd);
}
#[test]
fn scan_file_bidirectional_matches_in_memory() {
use std::io::Cursor;
use tensogram::framing;
let mut buf = Vec::new();
for i in 0..8 {
buf.extend_from_slice(&encode_sample(i * 17));
}
let mut cursor = Cursor::new(buf.clone());
let file_result = framing::scan_file(&mut cursor).unwrap();
let mem_result = scan(&buf);
assert_eq!(file_result, mem_result);
assert_eq!(file_result.len(), 8);
}
fn encode_sample_hashed(create_header: bool, create_footer: bool) -> Vec<u8> {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![3],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 12];
let opts = EncodeOptions {
create_header_hashes: create_header,
create_footer_hashes: create_footer,
..EncodeOptions::default()
};
encode(&global, &[(&desc, &data)], &opts).unwrap()
}
#[test]
fn buffered_encode_default_emits_header_hash_only() {
let msg = encode_sample_hashed(true, false);
let decoded = tensogram::framing::decode_message(&msg).unwrap();
let hf = decoded
.hash_frame
.as_ref()
.expect("HeaderHash must be emitted by default");
assert_eq!(hf.algorithm, "xxh3");
assert_eq!(hf.hashes.len(), 1);
use tensogram::wire::FRAME_COMMON_FOOTER_SIZE;
let (_desc, _payload, _masks, frame_offset) = &decoded.objects[0];
let fh = tensogram::wire::FrameHeader::read_from(&msg[*frame_offset..]).unwrap();
let frame_end = *frame_offset + fh.total_length as usize;
let slot_start = frame_end - FRAME_COMMON_FOOTER_SIZE;
let inline = u64::from_be_bytes(msg[slot_start..slot_start + 8].try_into().unwrap());
assert_eq!(
format!("{inline:016x}"),
hf.hashes[0],
"aggregate hex must equal the inline slot"
);
}
#[test]
fn buffered_encode_footer_hashes_goes_to_footer() {
let msg = encode_sample_hashed(false, true);
let preamble = tensogram::wire::Preamble::read_from(&msg).unwrap();
use tensogram::wire::MessageFlags;
assert!(
preamble.flags.has(MessageFlags::FOOTER_HASHES),
"FOOTER_HASHES flag must be set"
);
assert!(
!preamble.flags.has(MessageFlags::HEADER_HASHES),
"HEADER_HASHES flag must be clear"
);
let decoded = tensogram::framing::decode_message(&msg).unwrap();
assert!(decoded.hash_frame.is_some());
}
#[test]
fn buffered_encode_both_flags_emits_both_aggregates() {
let msg = encode_sample_hashed(true, true);
let preamble = tensogram::wire::Preamble::read_from(&msg).unwrap();
use tensogram::wire::MessageFlags;
assert!(preamble.flags.has(MessageFlags::HEADER_HASHES));
assert!(preamble.flags.has(MessageFlags::FOOTER_HASHES));
}
#[test]
fn buffered_encode_without_hashing_clears_aggregate() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 4];
let opts = EncodeOptions {
hash_algorithm: None,
create_header_hashes: true,
create_footer_hashes: true,
..EncodeOptions::default()
};
let msg = encode(&global, &[(&desc, &data)], &opts).unwrap();
let preamble = tensogram::wire::Preamble::read_from(&msg).unwrap();
use tensogram::wire::MessageFlags;
assert!(!preamble.flags.has(MessageFlags::HASHES_PRESENT));
assert!(!preamble.flags.has(MessageFlags::HEADER_HASHES));
assert!(!preamble.flags.has(MessageFlags::FOOTER_HASHES));
}
fn bitmask_payload_128_bits() -> Vec<u8> {
vec![0xAAu8; 16]
}
#[test]
fn compression_rle_round_trips_bitmask() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![128],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "rle".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = bitmask_payload_128_bits();
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_meta, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
assert_eq!(objects[0].0.compression, "rle");
}
#[test]
fn compression_roaring_round_trips_bitmask() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![128],
strides: vec![1],
dtype: Dtype::Bitmask,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "roaring".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = bitmask_payload_128_bits();
let msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let (_meta, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, data);
assert_eq!(objects[0].0.compression, "roaring");
}
#[test]
fn compression_rle_rejects_non_bitmask_dtype() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "rle".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 16];
let err = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("\"rle\"") && msg.contains("dtype=bitmask"),
"expected rle-dtype error, got: {msg}"
);
}
#[test]
fn compression_roaring_rejects_non_bitmask_dtype() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![4],
dtype: Dtype::Uint32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "roaring".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 8];
let err = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("\"roaring\"") && msg.contains("dtype=bitmask"),
"expected roaring-dtype error, got: {msg}"
);
}
#[test]
fn scan_max_message_size_caps_backward_walker() {
let mut buf = Vec::new();
for i in 0..3 {
buf.extend_from_slice(&encode_sample(i * 7));
}
let big_cap = ScanOptions {
bidirectional: true,
max_message_size: 4 * 1024 * 1024 * 1024,
};
let tiny_cap = ScanOptions {
bidirectional: true,
max_message_size: 1, };
let via_big = scan_with_options(&buf, &big_cap);
let via_tiny = scan_with_options(&buf, &tiny_cap);
let via_fwd = scan_with_options(
&buf,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
},
);
assert_eq!(via_big.len(), 3);
assert_eq!(via_tiny, via_fwd);
assert_eq!(via_big, via_fwd);
}
#[test]
fn scan_bidirectional_shrugs_off_spurious_end_magic() {
let mut buf = Vec::new();
for i in 0..3 {
buf.extend_from_slice(&encode_sample(i * 11));
}
let good_end = buf.len();
buf.extend_from_slice(&[0xAB; 8]);
buf.extend_from_slice(&(good_end as u64 / 3).to_be_bytes()); buf.extend_from_slice(&(good_end as u64).to_be_bytes()); buf.extend_from_slice(b"39277777");
let bidir = scan(&buf);
let fwd = scan_with_options(
&buf,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
},
);
assert_eq!(bidir, fwd, "bidirectional must match forward on bogus tail");
assert_eq!(bidir.len(), 3);
}
#[test]
fn scan_tiny_buffer_returns_empty() {
let tiny = vec![0u8; 47];
assert!(scan(&tiny).is_empty());
assert!(
scan_with_options(
&tiny,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
}
)
.is_empty()
);
}
#[test]
fn scan_empty_buffer_returns_empty() {
let empty: Vec<u8> = Vec::new();
assert!(scan(&empty).is_empty());
assert!(
scan_with_options(
&empty,
&ScanOptions {
bidirectional: false,
..ScanOptions::default()
}
)
.is_empty()
);
}
#[test]
fn scan_max_message_size_off_by_one() {
let msg = encode_sample(1);
let exact_cap = ScanOptions {
bidirectional: true,
max_message_size: msg.len() as u64,
};
let one_short = ScanOptions {
bidirectional: true,
max_message_size: (msg.len() as u64) - 1,
};
let via_exact = scan_with_options(&msg, &exact_cap);
let via_short = scan_with_options(&msg, &one_short);
assert_eq!(via_exact.len(), 1);
assert_eq!(via_short.len(), 1);
assert_eq!(via_exact, via_short);
}
#[test]
fn validate_detects_hash_frame_aggregate_tamper() {
use tensogram::framing::{decode_message, scan};
use tensogram::validate::{IssueCode, ValidateOptions, validate_message};
use tensogram::wire::{FRAME_COMMON_FOOTER_SIZE, FrameHeader};
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0x11u8; 16];
let mut msg = encode(
&global,
&[(&desc, data.as_slice())],
&EncodeOptions {
create_header_hashes: true,
..EncodeOptions::default()
},
)
.unwrap();
let messages = scan(&msg);
let (msg_off, msg_len) = messages[0];
let msg_slice = &msg[msg_off..msg_off + msg_len];
let decoded = decode_message(msg_slice).unwrap();
let (_, _, _, frame_offset) = decoded.objects[0];
let fh = FrameHeader::read_from(&msg_slice[frame_offset..]).unwrap();
let total = fh.total_length as usize;
let slot_start = frame_offset + total - FRAME_COMMON_FOOTER_SIZE;
let inline = u64::from_be_bytes(msg_slice[slot_start..slot_start + 8].try_into().unwrap());
let inline_hex = format!("{inline:016x}");
let aggregate_pos = msg
.windows(inline_hex.len())
.position(|w| w == inline_hex.as_bytes())
.expect("aggregate hex must appear in HashFrame CBOR");
let orig = msg[aggregate_pos];
msg[aggregate_pos] = if orig == b'f' { b'0' } else { orig + 1 };
let (_meta, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 1);
let report = validate_message(&msg, &ValidateOptions::default());
let any_hash_mismatch = report
.issues
.iter()
.any(|i| i.code == IssueCode::HashMismatch);
assert!(
any_hash_mismatch,
"aggregate HashFrame tamper must trigger HashMismatch, got: {:?}",
report.issues
);
}
#[test]
fn validate_accepts_both_hash_aggregates_when_identical() {
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 8];
let opts = EncodeOptions {
create_header_hashes: true,
create_footer_hashes: true,
..EncodeOptions::default()
};
let msg = encode(&global, &[(&desc, &data)], &opts).unwrap();
use tensogram::validate::{IssueCode, ValidateOptions, validate_message};
let report = validate_message(&msg, &ValidateOptions::default());
let bad = report.issues.iter().any(|i| {
matches!(
i.code,
IssueCode::HashMismatch | IssueCode::UnknownHashAlgorithm
)
});
assert!(
!bad,
"dual aggregate HashFrames with identical hashes must not flag anything, got: {:?}",
report.issues
);
}
#[test]
fn data_object_inline_hashes_happy_path() {
use tensogram::framing::data_object_inline_hashes;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![42u8; 4];
let msg = encode(
&global,
&[(&desc, data.as_slice()), (&desc, data.as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let hashes = data_object_inline_hashes(&msg).unwrap();
assert_eq!(hashes.len(), 2);
assert!(hashes.iter().all(|h| h.is_some()));
assert_eq!(hashes[0], hashes[1]);
}
#[test]
fn data_object_inline_hashes_none_when_hashing_disabled() {
use tensogram::framing::data_object_inline_hashes;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![1],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 4];
let opts = EncodeOptions {
hash_algorithm: None,
..Default::default()
};
let msg = encode(&global, &[(&desc, &data)], &opts).unwrap();
let hashes = data_object_inline_hashes(&msg).unwrap();
assert_eq!(hashes.len(), 1);
assert!(hashes[0].is_none());
}
#[test]
fn data_object_inline_hashes_rejects_truncated_buffer() {
use tensogram::framing::data_object_inline_hashes;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![100], strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 400];
let mut msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
msg.truncate(msg.len() / 2);
let result = data_object_inline_hashes(&msg);
assert!(
result.is_err(),
"mid-frame truncation must return Err, got: {:?}",
result
);
}
#[test]
fn validate_cbor_offset_out_of_range_detected() {
use tensogram::validate::{IssueCode, ValidateOptions, validate_message};
use tensogram::wire::{
DATA_OBJECT_FOOTER_SIZE, DataObjectFlags, FRAME_END, FRAME_HEADER_SIZE, FRAME_MAGIC,
};
let (global, desc) = make_float32_descriptor(vec![2]);
let data = vec![0u8; 8];
let mut msg = encode(&global, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
let marker: &[u8] = &[b'F', b'R', 0x00, 0x09];
let frame_start = msg.windows(4).position(|w| w == marker).unwrap();
let total_length =
u64::from_be_bytes(msg[frame_start + 8..frame_start + 16].try_into().unwrap()) as usize;
let endf_offset = frame_start + total_length - FRAME_END.len();
let cbor_offset_pos = endf_offset - 16;
msg[cbor_offset_pos..cbor_offset_pos + 8].copy_from_slice(&(u64::MAX / 2).to_be_bytes());
let report = validate_message(&msg, &ValidateOptions::default());
assert!(
report.issues.iter().any(|i| matches!(
i.code,
IssueCode::CborOffsetInvalid | IssueCode::DataObjectTooSmall
)),
"expected CborOffsetInvalid or DataObjectTooSmall, got: {:?}",
report.issues
);
let _ = (
DATA_OBJECT_FOOTER_SIZE,
DataObjectFlags::CBOR_AFTER_PAYLOAD,
FRAME_HEADER_SIZE,
FRAME_MAGIC,
);
}
#[test]
fn validate_non_zero_alignment_padding_detected() {
use tensogram::validate::{IssueCode, ValidateOptions, validate_message};
let (global, desc) = make_float32_descriptor(vec![3]);
let data = vec![0u8; 12];
let mut msg = encode(
&global,
&[(&desc, &data), (&desc, &data)],
&EncodeOptions::default(),
)
.unwrap();
let fr_positions: Vec<usize> = msg
.windows(2)
.enumerate()
.filter(|(_, w)| *w == b"FR")
.map(|(i, _)| i)
.collect();
let marker: &[u8] = &[b'F', b'R', 0x00, 0x09];
let nth: Vec<usize> = msg
.windows(4)
.enumerate()
.filter(|(_, w)| *w == marker)
.map(|(i, _)| i)
.collect();
assert!(nth.len() >= 2, "expected 2 NTensorFrames");
let first_end =
nth[0] + u64::from_be_bytes(msg[nth[0] + 8..nth[0] + 16].try_into().unwrap()) as usize;
let pad_start = first_end;
let pad_end = nth[1];
if pad_end > pad_start {
msg[pad_start] = 0xFF;
}
let report = validate_message(&msg, &ValidateOptions::default());
assert!(
report
.issues
.iter()
.any(|i| i.code == IssueCode::NonZeroPadding),
"expected NonZeroPadding warning, got: {:?}",
report.issues
);
let _ = fr_positions;
}
#[test]
fn scan_file_handles_streaming_message() {
use std::io::Write;
use tensogram::framing::scan_file;
use tensogram::streaming::StreamingEncoder;
let global = GlobalMetadata {
version: 3,
..Default::default()
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![2],
strides: vec![4],
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let data = vec![0u8; 8];
let mut enc =
StreamingEncoder::new(Vec::<u8>::new(), &global, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let streamed_bytes = enc.finish().unwrap();
let path = std::env::temp_dir().join("tensogram_coverage_streaming.tgm");
let _ = std::fs::remove_file(&path);
{
let mut f = std::fs::File::create(&path).unwrap();
f.write_all(&streamed_bytes).unwrap();
}
let mut f = std::fs::File::open(&path).unwrap();
let results = scan_file(&mut f).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 0);
assert_eq!(results[0].1, streamed_bytes.len());
let _ = std::fs::remove_file(&path);
}