use std::collections::BTreeMap;
use std::io::Cursor;
use tensogram::framing::{decode_message, scan};
use tensogram::hash::{HashAlgorithm, verify_frame_hash};
use tensogram::streaming::StreamingEncoder;
use tensogram::wire::{FRAME_COMMON_FOOTER_SIZE, FrameHeader, MessageFlags, Preamble};
use tensogram::{
ByteOrder, DataObjectDescriptor, DecodeOptions, Dtype, EncodeOptions, GlobalMetadata, decode,
encode,
};
fn simple_desc() -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
}
}
fn simple_meta() -> GlobalMetadata {
GlobalMetadata::default()
}
fn simple_data() -> Vec<u8> {
let values = [1.0f32, 2.0, 3.0, 4.0];
values.iter().flat_map(|v| v.to_ne_bytes()).collect()
}
fn first_inline_hash(wire: &[u8]) -> Option<u64> {
let messages = scan(wire);
let &(offset, len) = messages.first()?;
let msg = &wire[offset..offset + len];
let decoded = decode_message(msg).ok()?;
let (_, _, _, frame_offset) = decoded.objects.first()?;
let frame = &msg[*frame_offset..];
let fh = FrameHeader::read_from(frame).ok()?;
let total = fh.total_length as usize;
let slot_start = *frame_offset + total - FRAME_COMMON_FOOTER_SIZE;
let slot = u64::from_be_bytes(msg[slot_start..slot_start + 8].try_into().ok()?);
if slot == 0 { None } else { Some(slot) }
}
#[test]
fn buffered_encode_populates_inline_slot() {
let msg = encode(
&simple_meta(),
&[(&simple_desc(), simple_data().as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let preamble = Preamble::read_from(&msg).unwrap();
assert!(preamble.flags.has(MessageFlags::HASHES_PRESENT));
let hash = first_inline_hash(&msg).expect("inline slot must be populated");
assert_ne!(hash, 0);
}
#[test]
fn buffered_encode_inline_slot_verifies_against_body() {
let msg = encode(
&simple_meta(),
&[(&simple_desc(), simple_data().as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let messages = scan(&msg);
let (offset, len) = messages[0];
let only = &msg[offset..offset + len];
let decoded = decode_message(only).unwrap();
for (_, _, _, frame_offset) in &decoded.objects {
let frame = &only[*frame_offset..];
let fh = FrameHeader::read_from(frame).unwrap();
let frame_bytes = &frame[..fh.total_length as usize];
verify_frame_hash(frame_bytes, fh.frame_type)
.expect("buffered inline slot must verify against body");
}
}
#[test]
fn streaming_encode_matches_buffered_inline_hash() {
let buffered = encode(
&simple_meta(),
&[(&simple_desc(), simple_data().as_slice())],
&EncodeOptions::default(),
)
.unwrap();
let buffered_hash = first_inline_hash(&buffered).expect("buffered hash");
let mut enc = StreamingEncoder::new(
Cursor::new(Vec::<u8>::new()),
&simple_meta(),
&EncodeOptions::default(),
)
.unwrap();
enc.write_object(&simple_desc(), &simple_data()).unwrap();
let cursor = enc.finish_with_backfill().unwrap();
let streamed = cursor.into_inner();
let streamed_hash = first_inline_hash(&streamed).expect("streaming hash");
assert_eq!(
buffered_hash, streamed_hash,
"buffered and streaming inline hashes must match for the same input"
);
}
#[test]
fn hash_algorithm_none_clears_flag_and_zeros_slot() {
let options = EncodeOptions {
hash_algorithm: None,
..Default::default()
};
let msg = encode(
&simple_meta(),
&[(&simple_desc(), simple_data().as_slice())],
&options,
)
.unwrap();
let preamble = Preamble::read_from(&msg).unwrap();
assert!(
!preamble.flags.has(MessageFlags::HASHES_PRESENT),
"hash_algorithm = None must clear HASHES_PRESENT"
);
assert!(
first_inline_hash(&msg).is_none(),
"hash_algorithm = None must leave the inline slot at zero"
);
let (_meta, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
assert_eq!(objects[0].1, simple_data());
}
#[test]
fn hash_algorithm_enum_parses_roundtrip() {
assert_eq!(HashAlgorithm::Xxh3.as_str(), "xxh3");
assert_eq!(HashAlgorithm::parse("xxh3").unwrap(), HashAlgorithm::Xxh3);
assert!(HashAlgorithm::parse("sha256").is_err());
}