use std::collections::BTreeMap;
use std::path::PathBuf;
use tensogram::decode::{
DecodeOptions, decode, decode_object, decode_object_from_frame, decode_range,
decode_range_from_frame,
};
use tensogram::encode::{EncodeOptions, encode};
use tensogram::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
use tensogram::{Dtype, TensogramError, framing, wire};
fn golden_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/golden")
}
fn read_golden(name: &str) -> Vec<u8> {
std::fs::read(golden_dir().join(name)).unwrap_or_else(|e| panic!("{name}: {e}"))
}
fn opts_no_verify() -> DecodeOptions {
DecodeOptions {
verify_hash: false,
..DecodeOptions::default()
}
}
fn opts_verify() -> DecodeOptions {
DecodeOptions {
verify_hash: true,
..DecodeOptions::default()
}
}
fn build_unhashed_single_object_message() -> Vec<u8> {
let meta = GlobalMetadata::default();
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::Big,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let mut payload = Vec::new();
for v in [1.0f32, 2.0, 3.0, 4.0] {
payload.extend_from_slice(&v.to_be_bytes());
}
let opts = EncodeOptions {
hashing: false,
..Default::default()
};
encode(&meta, &[(&desc, &payload)], &opts).unwrap()
}
fn locate_object_frame(buf: &[u8], object_index: usize) -> (usize, usize) {
let messages = framing::scan(buf);
assert_eq!(
messages.len(),
1,
"this test helper assumes single-message buffers"
);
let (msg_offset, msg_len) = messages[0];
let msg = &buf[msg_offset..msg_offset + msg_len];
let mut pos = 24; let mut found_objects = 0;
while pos + wire::FRAME_HEADER_SIZE <= msg.len() {
if &msg[pos..pos + 2] != b"FR" {
pos += 1;
continue;
}
let fh = wire::FrameHeader::read_from(&msg[pos..]).unwrap();
if fh.frame_type.is_data_object() {
if found_objects == object_index {
let frame_start = msg_offset + pos;
let frame_len = fh.total_length as usize;
let slot_offset_within_frame = frame_len - wire::FRAME_COMMON_FOOTER_SIZE;
return (frame_start, slot_offset_within_frame);
}
found_objects += 1;
}
let aligned = (pos + fh.total_length as usize + 7) & !7;
pos = aligned.min(msg.len());
}
panic!(
"data-object frame index {object_index} not found in buffer (have {found_objects} objects)"
);
}
fn locate_payload_start(buf: &[u8], object_index: usize) -> usize {
let (frame_start, _) = locate_object_frame(buf, object_index);
frame_start + wire::FRAME_HEADER_SIZE
}
#[test]
fn cell_a_decode_no_verify_succeeds_on_hashed_message() {
let data = read_golden("hash_xxh3.tgm");
let (_, objects) = decode(&data, &opts_no_verify()).unwrap();
assert_eq!(objects.len(), 1);
}
#[test]
fn cell_b_decode_with_verify_succeeds_on_hashed_message() {
let data = read_golden("hash_xxh3.tgm");
let (_, objects) = decode(&data, &opts_verify()).unwrap();
assert_eq!(objects.len(), 1);
}
#[test]
fn cell_b_decode_object_with_verify_succeeds_on_hashed_message() {
let data = read_golden("hash_xxh3.tgm");
let (_, _, _) = decode_object(&data, 0, &opts_verify()).unwrap();
}
#[test]
fn cell_c_decode_verify_on_unhashed_message_returns_missing_hash() {
let data = build_unhashed_single_object_message();
let err = decode(&data, &opts_verify()).unwrap_err();
match err {
TensogramError::MissingHash { object_index } => {
assert_eq!(
object_index, 0,
"single-object message must surface index 0"
);
}
other => panic!("expected MissingHash, got: {other:?}"),
}
}
#[test]
fn cell_c_decode_object_verify_on_unhashed_message_returns_missing_hash() {
let data = build_unhashed_single_object_message();
let err = decode_object(&data, 0, &opts_verify()).unwrap_err();
match err {
TensogramError::MissingHash { object_index } => {
assert_eq!(object_index, 0);
}
other => panic!("expected MissingHash, got: {other:?}"),
}
}
#[test]
fn cell_c_no_verify_silently_decodes_unhashed_message() {
let data = build_unhashed_single_object_message();
let (_, objects) = decode(&data, &opts_no_verify()).unwrap();
assert_eq!(objects.len(), 1);
}
#[test]
fn cell_d_decode_verify_reports_hash_mismatch_on_tampered_slot() {
let mut data = read_golden("hash_xxh3.tgm");
let (frame_start, slot_offset) = locate_object_frame(&data, 0);
data[frame_start + slot_offset] ^= 0xFF;
let err = decode(&data, &opts_verify()).unwrap_err();
match err {
TensogramError::HashMismatch {
object_index,
expected,
actual,
} => {
assert_eq!(object_index, Some(0));
assert_ne!(expected, actual);
}
other => panic!("expected HashMismatch, got: {other:?}"),
}
}
#[test]
fn cell_d_decode_object_verify_reports_hash_mismatch_on_tampered_slot() {
let mut data = read_golden("hash_xxh3.tgm");
let (frame_start, slot_offset) = locate_object_frame(&data, 0);
data[frame_start + slot_offset] ^= 0xFF;
let err = decode_object(&data, 0, &opts_verify()).unwrap_err();
match err {
TensogramError::HashMismatch {
object_index,
expected,
actual,
} => {
assert_eq!(object_index, Some(0));
assert_ne!(expected, actual);
}
other => panic!("expected HashMismatch, got: {other:?}"),
}
}
#[test]
fn cell_e_decode_verify_reports_hash_mismatch_on_tampered_payload() {
let mut data = read_golden("hash_xxh3.tgm");
let payload_start = locate_payload_start(&data, 0);
data[payload_start] ^= 0xFF;
let err = decode(&data, &opts_verify()).unwrap_err();
match err {
TensogramError::HashMismatch {
object_index,
expected,
actual,
} => {
assert_eq!(object_index, Some(0));
assert_ne!(expected, actual);
}
other => panic!("expected HashMismatch, got: {other:?}"),
}
}
#[test]
fn cell_f_decode_verify_reports_correct_object_index_on_multi_object() {
let mut data = read_golden("multi_object_xxh3.tgm");
let payload_start = locate_payload_start(&data, 1);
data[payload_start] ^= 0xFF;
let err = decode(&data, &opts_verify()).unwrap_err();
match err {
TensogramError::HashMismatch {
object_index,
expected,
actual,
} => {
assert_eq!(
object_index,
Some(1),
"must surface the *tampered* object's index, not 0"
);
assert_ne!(expected, actual);
}
other => panic!("expected HashMismatch on object 1, got: {other:?}"),
}
}
#[test]
fn cell_f_decode_object_verify_targets_specific_object() {
let mut data = read_golden("multi_object_xxh3.tgm");
let payload_start = locate_payload_start(&data, 1);
data[payload_start] ^= 0xFF;
decode_object(&data, 0, &opts_verify()).unwrap();
decode_object(&data, 2, &opts_verify()).unwrap();
let err = decode_object(&data, 1, &opts_verify()).unwrap_err();
match err {
TensogramError::HashMismatch { object_index, .. } => {
assert_eq!(object_index, Some(1));
}
other => panic!("expected HashMismatch, got: {other:?}"),
}
}
#[test]
fn cell_f_clearing_per_frame_flag_yields_missing_hash_with_correct_index() {
let mut data = read_golden("multi_object_xxh3.tgm");
let (frame_start, _) = locate_object_frame(&data, 1);
let flags_offset = frame_start + 6;
let mut flags = u16::from_be_bytes(data[flags_offset..flags_offset + 2].try_into().unwrap());
flags &= !wire::FrameFlags::HASH_PRESENT;
data[flags_offset..flags_offset + 2].copy_from_slice(&flags.to_be_bytes());
decode_object(&data, 0, &opts_verify()).unwrap();
decode_object(&data, 2, &opts_verify()).unwrap();
let err = decode(&data, &opts_verify()).unwrap_err();
match err {
TensogramError::MissingHash { object_index } => {
assert_eq!(object_index, 1);
}
other => panic!("expected MissingHash on object 1, got: {other:?}"),
}
}
#[test]
fn decode_range_ignores_verify_hash_on_unhashed_message() {
let data = read_golden("simple_f32.tgm");
let (_, parts) = decode_range(&data, 0, &[(0, 4)], &opts_verify()).unwrap();
assert_eq!(parts.len(), 1);
}
#[test]
fn decode_range_ignores_verify_hash_on_tampered_payload() {
let mut data = read_golden("hash_xxh3.tgm");
let payload_start = locate_payload_start(&data, 0);
data[payload_start] ^= 0xFF;
let result = decode_range(&data, 0, &[(0, 1)], &opts_verify());
match result {
Ok(_) => (),
Err(TensogramError::HashMismatch { .. }) | Err(TensogramError::MissingHash { .. }) => {
panic!(
"decode_range must not surface verify_hash errors — verify_hash is documented as ignored"
);
}
Err(_) => (),
}
}
#[test]
fn cell_e_variant_decode_verify_reports_hash_mismatch_on_tampered_cbor() {
let mut data = read_golden("hash_xxh3.tgm");
let (frame_start, slot_offset_within_frame) = locate_object_frame(&data, 0);
let frame_end =
frame_start + slot_offset_within_frame + tensogram::wire::FRAME_COMMON_FOOTER_SIZE;
let cbor_byte = frame_end - 32;
assert!(
cbor_byte > frame_start + tensogram::wire::FRAME_HEADER_SIZE + 16,
"tamper offset must land past the 16-byte encoded f32 payload region"
);
data[cbor_byte] ^= 0xff;
let err = decode(&data, &opts_verify()).unwrap_err();
match err {
TensogramError::HashMismatch { object_index, .. } => {
assert_eq!(object_index, Some(0));
}
other => panic!(
"expected HashMismatch (verify-first ordering) on tampered \
CBOR byte, got: {other:?}"
),
}
}
#[test]
fn decode_verify_respects_first_message_total_length_on_concat() {
let hashed = read_golden("hash_xxh3.tgm");
let unhashed = build_unhashed_single_object_message();
decode(&hashed, &opts_verify()).expect("hashed message verifies");
let unhashed_err = decode(&unhashed, &opts_verify()).unwrap_err();
assert!(
matches!(
unhashed_err,
TensogramError::MissingHash { object_index: 0 }
),
"baseline: unhashed message must fail verify with MissingHash, got {unhashed_err:?}"
);
let mut concat = hashed.clone();
concat.extend_from_slice(&unhashed);
let (_meta, objects) = decode(&concat, &opts_verify()).unwrap_or_else(|e| {
panic!(
"decode(concat[hashed,unhashed], verify_hash=true) must \
not leak into msg2: {e:?}"
)
});
assert_eq!(
objects.len(),
1,
"decode is bounded to first message; helper must match"
);
}
#[test]
fn decode_object_verify_respects_first_message_total_length_on_concat() {
let hashed = read_golden("hash_xxh3.tgm");
let unhashed = build_unhashed_single_object_message();
let mut concat = hashed.clone();
concat.extend_from_slice(&unhashed);
decode_object(&concat, 0, &opts_verify())
.expect("decode_object(0) on concat must verify msg1's first object");
}
#[test]
fn verify_hash_on_buffer_shorter_than_preamble_returns_framing_error() {
let tiny = vec![0u8; 10];
let err = decode(&tiny, &opts_verify()).unwrap_err();
match &err {
TensogramError::Framing(msg) => {
assert!(
msg.contains("preamble") || msg.contains("too short"),
"expected preamble/too-short error, got: {msg}"
);
}
other => panic!("expected Framing error, got: {other:?}"),
}
}
#[test]
fn verify_hash_on_empty_buffer_returns_framing_error() {
let empty: Vec<u8> = vec![];
let err = decode(&empty, &opts_verify()).unwrap_err();
assert!(
matches!(&err, TensogramError::Framing(_)),
"expected Framing error on empty buffer, got: {err:?}"
);
}
fn extract_raw_object_frame(buf: &[u8], object_index: usize) -> Vec<u8> {
let (frame_start, slot_offset) = locate_object_frame(buf, object_index);
let frame_len = slot_offset + wire::FRAME_COMMON_FOOTER_SIZE;
buf[frame_start..frame_start + frame_len].to_vec()
}
#[test]
fn decode_object_from_frame_verify_hash_succeeds_on_hashed_frame() {
let data = read_golden("hash_xxh3.tgm");
let frame = extract_raw_object_frame(&data, 0);
let (desc, decoded) = decode_object_from_frame(&frame, &opts_verify())
.expect("verify_hash on a valid hashed frame must succeed");
assert_eq!(desc.shape, vec![4]);
assert!(!decoded.is_empty());
}
#[test]
fn decode_object_from_frame_verify_hash_rejects_unhashed_frame() {
let data = build_unhashed_single_object_message();
let frame = extract_raw_object_frame(&data, 0);
let err = decode_object_from_frame(&frame, &opts_verify()).unwrap_err();
match err {
TensogramError::MissingHash { object_index } => {
assert_eq!(object_index, 0);
}
other => panic!("expected MissingHash, got: {other:?}"),
}
}
#[test]
fn decode_object_from_frame_verify_hash_rejects_tampered_payload() {
let data = read_golden("hash_xxh3.tgm");
let mut frame = extract_raw_object_frame(&data, 0);
frame[wire::FRAME_HEADER_SIZE] ^= 0xFF;
let err = decode_object_from_frame(&frame, &opts_verify()).unwrap_err();
assert!(
matches!(err, TensogramError::HashMismatch { .. }),
"expected HashMismatch, got: {err:?}"
);
}
#[test]
fn decode_object_from_frame_no_verify_succeeds_on_unhashed_frame() {
let data = build_unhashed_single_object_message();
let frame = extract_raw_object_frame(&data, 0);
let (desc, decoded) = decode_object_from_frame(&frame, &opts_no_verify())
.expect("no-verify on unhashed frame must succeed");
assert_eq!(desc.shape, vec![4]);
assert!(!decoded.is_empty());
}
#[test]
fn decode_range_from_frame_restores_non_finite_masks() {
let values: Vec<f64> = vec![1.0, f64::NAN, 3.0, f64::INFINITY, 5.0, 6.0, 7.0, 8.0];
let data: Vec<u8> = values.iter().flat_map(|v| v.to_ne_bytes()).collect();
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![8],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: std::collections::BTreeMap::new(),
masks: None,
};
let enc_opts = EncodeOptions {
allow_nan: true,
allow_inf: true,
hashing: true,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&GlobalMetadata::default(), &[(&desc, &data)], &enc_opts).unwrap();
let frame = extract_raw_object_frame(&msg, 0);
let opts_restore = DecodeOptions {
restore_non_finite: true,
..Default::default()
};
let (ret_desc, parts) = decode_range_from_frame(&frame, &[(0, 4)], &opts_restore).unwrap();
assert_eq!(parts.len(), 1);
let decoded: Vec<f64> = parts[0]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(decoded[0], 1.0);
assert!(decoded[1].is_nan(), "NaN must be restored");
assert_eq!(decoded[2], 3.0);
assert!(
decoded[3].is_infinite() && decoded[3] > 0.0,
"+Inf must be restored"
);
let opts_no_restore = DecodeOptions {
restore_non_finite: false,
..Default::default()
};
let (_, parts_no) = decode_range_from_frame(&frame, &[(0, 4)], &opts_no_restore).unwrap();
let decoded_no: Vec<f64> = parts_no[0]
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(decoded_no[0], 1.0);
assert_eq!(decoded_no[1], 0.0, "NaN should be 0.0 when restore is off");
assert_eq!(decoded_no[2], 3.0);
assert_eq!(decoded_no[3], 0.0, "Inf should be 0.0 when restore is off");
assert!(
ret_desc.masks.is_some(),
"descriptor must carry masks for NaN/Inf data"
);
}