use std::collections::BTreeMap;
use crate::error::{Result, TensogramError};
use crate::metadata::{self, RESERVED_KEY};
use crate::types::{DataObjectDescriptor, GlobalMetadata, HashFrame, IndexFrame};
use crate::wire::{
DATA_OBJECT_FOOTER_SIZE, DataObjectFlags, FRAME_END, FRAME_HEADER_SIZE, FrameHeader, FrameType,
MAGIC, MessageFlags, POSTAMBLE_SIZE, PREAMBLE_SIZE, Postamble, Preamble,
};
fn write_frame(
out: &mut Vec<u8>,
frame_type: FrameType,
version: u16,
flags: u16,
payload: &[u8],
align: bool,
) {
let total_length = (FRAME_HEADER_SIZE + payload.len() + FRAME_END.len()) as u64;
let fh = FrameHeader {
frame_type,
version,
flags,
total_length,
};
fh.write_to(out);
out.extend_from_slice(payload);
out.extend_from_slice(FRAME_END);
if align {
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
}
}
fn read_frame(buf: &[u8]) -> Result<(FrameHeader, &[u8], usize)> {
let fh = FrameHeader::read_from(buf)?;
let frame_total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Framing(format!(
"frame total_length {} overflows usize",
fh.total_length
))
})?;
let min_frame_size = FRAME_HEADER_SIZE + FRAME_END.len();
if frame_total < min_frame_size {
return Err(TensogramError::Framing(format!(
"frame total_length {} is smaller than minimum {min_frame_size}",
frame_total
)));
}
if frame_total > buf.len() {
return Err(TensogramError::Framing(format!(
"frame total_length {} exceeds buffer: {}",
frame_total,
buf.len()
)));
}
let endf_start = frame_total - FRAME_END.len();
if &buf[endf_start..frame_total] != FRAME_END {
return Err(TensogramError::Framing(format!(
"missing ENDF marker at offset {endf_start}"
)));
}
let payload = &buf[FRAME_HEADER_SIZE..endf_start];
let mut consumed = frame_total;
let aligned = (consumed + 7) & !7;
if aligned <= buf.len() {
consumed = aligned;
}
Ok((fh, payload, consumed))
}
pub fn encode_data_object_frame(
descriptor: &DataObjectDescriptor,
payload: &[u8],
cbor_before: bool,
) -> Result<Vec<u8>> {
let cbor_bytes = metadata::object_descriptor_to_cbor(descriptor)?;
let flags = if cbor_before {
0
} else {
DataObjectFlags::CBOR_AFTER_PAYLOAD
};
let frame_body_len = cbor_bytes.len() + payload.len() + DATA_OBJECT_FOOTER_SIZE;
let total_length = (FRAME_HEADER_SIZE + frame_body_len) as u64;
let mut out = Vec::with_capacity(total_length as usize);
let fh = FrameHeader {
frame_type: FrameType::DataObject,
version: 1,
flags,
total_length,
};
fh.write_to(&mut out);
if cbor_before {
let cbor_offset = FRAME_HEADER_SIZE as u64;
out.extend_from_slice(&cbor_bytes);
out.extend_from_slice(payload);
out.extend_from_slice(&cbor_offset.to_be_bytes());
} else {
let cbor_offset = (FRAME_HEADER_SIZE + payload.len()) as u64;
out.extend_from_slice(payload);
out.extend_from_slice(&cbor_bytes);
out.extend_from_slice(&cbor_offset.to_be_bytes());
}
out.extend_from_slice(FRAME_END);
debug_assert_eq!(out.len(), total_length as usize);
Ok(out)
}
pub fn decode_data_object_frame(buf: &[u8]) -> Result<(DataObjectDescriptor, &[u8], usize)> {
let fh = FrameHeader::read_from(buf)?;
if fh.frame_type != FrameType::DataObject {
return Err(TensogramError::Framing(format!(
"expected DataObject frame, got {:?}",
fh.frame_type
)));
}
let frame_total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Framing(format!(
"data object frame total_length {} overflows usize",
fh.total_length
))
})?;
let min_frame_size = FRAME_HEADER_SIZE + DATA_OBJECT_FOOTER_SIZE;
if frame_total < min_frame_size {
return Err(TensogramError::Framing(format!(
"data object frame too small: {} < {}",
frame_total, min_frame_size
)));
}
if frame_total > buf.len() {
return Err(TensogramError::Framing(format!(
"data object frame total_length {} exceeds buffer: {}",
frame_total,
buf.len()
)));
}
let endf_start = frame_total - FRAME_END.len();
if &buf[endf_start..frame_total] != FRAME_END {
return Err(TensogramError::Framing(
"missing ENDF marker in data object frame".to_string(),
));
}
if endf_start < 8 {
return Err(TensogramError::Framing(format!(
"data object frame too small for cbor_offset: endf_start={endf_start} < 8"
)));
}
let cbor_offset_pos = endf_start - 8;
let cbor_offset_raw = crate::wire::read_u64_be(buf, cbor_offset_pos);
let cbor_offset = usize::try_from(cbor_offset_raw).map_err(|_| {
TensogramError::Framing(format!("cbor_offset {cbor_offset_raw} overflows usize"))
})?;
if cbor_offset < FRAME_HEADER_SIZE || cbor_offset > cbor_offset_pos {
return Err(TensogramError::Framing(format!(
"cbor_offset {cbor_offset} out of valid range [{FRAME_HEADER_SIZE}, {cbor_offset_pos}]"
)));
}
let cbor_after = fh.flags & DataObjectFlags::CBOR_AFTER_PAYLOAD != 0;
let (descriptor, payload_slice) = if cbor_after {
let payload_start = FRAME_HEADER_SIZE;
let cbor_start = cbor_offset;
let cbor_end = cbor_offset_pos;
let cbor_slice = &buf[cbor_start..cbor_end];
let desc = metadata::cbor_to_object_descriptor(cbor_slice)?;
(desc, &buf[payload_start..cbor_start])
} else {
let cbor_start = cbor_offset;
let region = &buf[cbor_start..cbor_offset_pos];
let mut cursor = std::io::Cursor::new(region);
let cbor_value: ciborium::Value = ciborium::from_reader(&mut cursor).map_err(|e| {
TensogramError::Metadata(format!("failed to parse object descriptor CBOR: {e}"))
})?;
let cbor_len = usize::try_from(cursor.position()).map_err(|_| {
TensogramError::Metadata("CBOR descriptor length overflows usize".to_string())
})?;
let payload_start = cbor_start + cbor_len;
let desc: DataObjectDescriptor = cbor_value.deserialized().map_err(|e| {
TensogramError::Metadata(format!("failed to deserialize descriptor: {e}"))
})?;
(desc, &buf[payload_start..cbor_offset_pos])
};
let mut consumed = frame_total;
let aligned = (consumed + 7) & !7;
if aligned <= buf.len() {
consumed = aligned;
}
Ok((descriptor, payload_slice, consumed))
}
pub struct EncodedObject {
pub descriptor: DataObjectDescriptor,
pub encoded_payload: Vec<u8>,
}
fn build_hash_frame_cbor(objects: &[EncodedObject]) -> Result<Option<Vec<u8>>> {
let has_hashes = objects.iter().any(|o| o.descriptor.hash.is_some());
if !has_hashes {
return Ok(None);
}
let hash_type = objects
.iter()
.find_map(|o| o.descriptor.hash.as_ref())
.map(|h| h.hash_type.clone())
.unwrap_or_default();
let hashes: Vec<String> = objects
.iter()
.map(|o| {
o.descriptor
.hash
.as_ref()
.map(|h| h.value.clone())
.unwrap_or_default()
})
.collect();
let hf = HashFrame {
object_count: objects.len() as u64,
hash_type,
hashes,
};
Ok(Some(metadata::hash_frame_to_cbor(&hf)?))
}
fn build_index_frame(
header_size_no_index: usize,
object_frames: &[Vec<u8>],
) -> Result<Option<Vec<u8>>> {
if object_frames.is_empty() {
return Ok(None);
}
let frame_lengths: Vec<u64> = object_frames.iter().map(|f| f.len() as u64).collect();
let dummy_idx = IndexFrame {
object_count: object_frames.len() as u64,
offsets: vec![0u64; object_frames.len()],
lengths: frame_lengths.clone(),
};
let dummy_cbor = metadata::index_to_cbor(&dummy_idx)?;
let dummy_frame_size = aligned_frame_total_size(dummy_cbor.len());
let data_cursor = header_size_no_index + dummy_frame_size;
let offsets = compute_object_offsets(data_cursor, object_frames);
let real_idx = IndexFrame {
object_count: object_frames.len() as u64,
offsets,
lengths: frame_lengths.clone(),
};
let real_cbor = metadata::index_to_cbor(&real_idx)?;
let final_cbor = if real_cbor.len() != dummy_cbor.len() {
let real_frame_size = aligned_frame_total_size(real_cbor.len());
let new_data_cursor = header_size_no_index + real_frame_size;
let new_offsets = compute_object_offsets(new_data_cursor, object_frames);
let final_idx = IndexFrame {
object_count: object_frames.len() as u64,
offsets: new_offsets,
lengths: frame_lengths,
};
let third_cbor = metadata::index_to_cbor(&final_idx)?;
if aligned_frame_total_size(third_cbor.len()) != real_frame_size {
return Err(TensogramError::Framing(
"index CBOR size changed unexpectedly on third pass".to_string(),
));
}
third_cbor
} else {
real_cbor
};
let mut idx_frame = Vec::new();
write_frame(
&mut idx_frame,
FrameType::HeaderIndex,
1,
0,
&final_cbor,
true,
);
Ok(Some(idx_frame))
}
fn compute_object_offsets(start: usize, object_frames: &[Vec<u8>]) -> Vec<u64> {
let mut offsets = Vec::with_capacity(object_frames.len());
let mut cursor = start;
for frame in object_frames {
offsets.push(cursor as u64);
cursor += frame.len();
cursor = (cursor + 7) & !7;
}
offsets
}
fn compute_message_flags(has_index: bool, has_hashes: bool) -> MessageFlags {
let mut flags = MessageFlags::default();
flags.set(MessageFlags::HEADER_METADATA);
if has_index {
flags.set(MessageFlags::HEADER_INDEX);
}
if has_hashes {
flags.set(MessageFlags::HEADER_HASHES);
}
flags
}
fn assemble_message(
flags: MessageFlags,
meta_cbor: &[u8],
index_frame_bytes: Option<&[u8]>,
hash_cbor: Option<&[u8]>,
object_frames: &[Vec<u8>],
) -> Vec<u8> {
let mut out = Vec::new();
let preamble_pos = out.len();
out.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
write_frame(&mut out, FrameType::HeaderMetadata, 1, 0, meta_cbor, true);
if let Some(idx_bytes) = index_frame_bytes {
out.extend_from_slice(idx_bytes);
}
if let Some(h_cbor) = hash_cbor {
write_frame(&mut out, FrameType::HeaderHash, 1, 0, h_cbor, true);
}
for (i, frame) in object_frames.iter().enumerate() {
out.extend_from_slice(frame);
if i + 1 < object_frames.len() {
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
}
}
let postamble_offset = out.len();
let postamble = Postamble {
first_footer_offset: postamble_offset as u64,
};
postamble.write_to(&mut out);
let total_length = out.len() as u64;
let preamble = Preamble {
version: 2,
flags,
reserved: 0,
total_length,
};
let mut preamble_bytes = Vec::new();
preamble.write_to(&mut preamble_bytes);
out[preamble_pos..preamble_pos + PREAMBLE_SIZE].copy_from_slice(&preamble_bytes);
out
}
pub fn encode_message(global_meta: &GlobalMetadata, objects: &[EncodedObject]) -> Result<Vec<u8>> {
let meta_cbor = metadata::global_metadata_to_cbor(global_meta)?;
let mut object_frames: Vec<Vec<u8>> = Vec::with_capacity(objects.len());
for obj in objects {
let frame = encode_data_object_frame(&obj.descriptor, &obj.encoded_payload, false)?;
object_frames.push(frame);
}
let hash_cbor = build_hash_frame_cbor(objects)?;
let mut header_no_index = Vec::new();
header_no_index.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
write_frame(
&mut header_no_index,
FrameType::HeaderMetadata,
1,
0,
&meta_cbor,
true,
);
if let Some(ref h_cbor) = hash_cbor {
write_frame(
&mut header_no_index,
FrameType::HeaderHash,
1,
0,
h_cbor,
true,
);
}
let index_frame_bytes = build_index_frame(header_no_index.len(), &object_frames)?;
let flags = compute_message_flags(index_frame_bytes.is_some(), hash_cbor.is_some());
Ok(assemble_message(
flags,
&meta_cbor,
index_frame_bytes.as_deref(),
hash_cbor.as_deref(),
&object_frames,
))
}
#[derive(Debug)]
pub struct DecodedMessage<'a> {
pub preamble: Preamble,
pub global_metadata: GlobalMetadata,
pub index: Option<IndexFrame>,
pub hash_frame: Option<HashFrame>,
pub objects: Vec<(DataObjectDescriptor, &'a [u8], usize)>,
pub preceder_payloads: Vec<Option<BTreeMap<String, ciborium::Value>>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum DecodePhase {
Headers = 0,
DataObjects = 1,
Footers = 2,
}
fn frame_phase(ft: FrameType) -> DecodePhase {
match ft {
FrameType::HeaderMetadata | FrameType::HeaderIndex | FrameType::HeaderHash => {
DecodePhase::Headers
}
FrameType::DataObject | FrameType::PrecederMetadata => DecodePhase::DataObjects,
FrameType::FooterHash | FrameType::FooterIndex | FrameType::FooterMetadata => {
DecodePhase::Footers
}
}
}
pub fn decode_message(buf: &[u8]) -> Result<DecodedMessage<'_>> {
let preamble = Preamble::read_from(buf)?;
if preamble.total_length > 0 {
let total_len = usize::try_from(preamble.total_length).map_err(|_| {
TensogramError::Framing(format!(
"total_length {} overflows usize",
preamble.total_length
))
})?;
if total_len > buf.len() {
return Err(TensogramError::Framing(format!(
"total_length {} exceeds buffer size {}",
preamble.total_length,
buf.len()
)));
}
let pa_offset = total_len - POSTAMBLE_SIZE;
let _postamble = Postamble::read_from(&buf[pa_offset..])?;
}
let mut pos = PREAMBLE_SIZE;
let msg_end = if preamble.total_length > 0 {
preamble.total_length as usize - POSTAMBLE_SIZE
} else {
buf.len().checked_sub(POSTAMBLE_SIZE).ok_or_else(|| {
TensogramError::Framing(format!(
"buffer too short for postamble: {} < {POSTAMBLE_SIZE}",
buf.len()
))
})?
};
let mut global_metadata: Option<GlobalMetadata> = None;
let mut index: Option<IndexFrame> = None;
let mut hash_frame: Option<HashFrame> = None;
let mut objects: Vec<(DataObjectDescriptor, &[u8], usize)> = Vec::new();
let mut preceder_payloads: Vec<Option<BTreeMap<String, ciborium::Value>>> = Vec::new();
let mut current_phase = DecodePhase::Headers;
let mut pending_preceder: Option<BTreeMap<String, ciborium::Value>> = None;
while pos < msg_end {
if pos + 2 > buf.len() {
break;
}
if &buf[pos..pos + 2] != b"FR" {
pos += 1;
continue;
}
if pos + FRAME_HEADER_SIZE > buf.len() {
break;
}
let fh = FrameHeader::read_from(&buf[pos..])?;
let phase = frame_phase(fh.frame_type);
if phase < current_phase {
return Err(TensogramError::Framing(format!(
"unexpected {:?} frame after {:?} phase — frames must appear in order: headers, data objects, footers",
fh.frame_type, current_phase
)));
}
if pending_preceder.is_some() && fh.frame_type != FrameType::DataObject {
return Err(TensogramError::Framing(format!(
"PrecederMetadata must be followed by a DataObject frame, got {:?}",
fh.frame_type
)));
}
current_phase = phase;
let frame_start = pos;
match fh.frame_type {
FrameType::HeaderMetadata | FrameType::FooterMetadata => {
let (_, payload, consumed) = read_frame(&buf[pos..])?;
let meta = metadata::cbor_to_global_metadata(payload)?;
global_metadata = Some(meta);
pos += consumed;
}
FrameType::HeaderIndex | FrameType::FooterIndex => {
let (_, payload, consumed) = read_frame(&buf[pos..])?;
let idx = metadata::cbor_to_index(payload)?;
index = Some(idx);
pos += consumed;
}
FrameType::HeaderHash | FrameType::FooterHash => {
let (_, payload, consumed) = read_frame(&buf[pos..])?;
let hf = metadata::cbor_to_hash_frame(payload)?;
hash_frame = Some(hf);
pos += consumed;
}
FrameType::PrecederMetadata => {
let (_, payload, consumed) = read_frame(&buf[pos..])?;
let preceder_meta = metadata::cbor_to_global_metadata(payload)?;
let n = preceder_meta.base.len();
if n != 1 {
return Err(TensogramError::Metadata(format!(
"PrecederMetadata base must have exactly 1 entry, got {n}"
)));
}
let mut entry = preceder_meta.base.into_iter().next().unwrap_or_default();
entry.remove(RESERVED_KEY);
pending_preceder = Some(entry);
pos += consumed;
}
FrameType::DataObject => {
let (desc, payload, consumed) = decode_data_object_frame(&buf[pos..])?;
objects.push((desc, payload, frame_start));
preceder_payloads.push(pending_preceder.take());
pos += consumed;
}
}
}
if pending_preceder.is_some() {
return Err(TensogramError::Framing(
"dangling PrecederMetadata: no DataObject frame followed".to_string(),
));
}
let mut global_metadata = global_metadata.ok_or_else(|| {
TensogramError::Metadata("no metadata frame found in message".to_string())
})?;
let obj_count = objects.len();
if global_metadata.base.len() > obj_count {
return Err(TensogramError::Metadata(format!(
"metadata base has {} entries but message contains {} objects",
global_metadata.base.len(),
obj_count
)));
}
if global_metadata.base.len() < obj_count {
global_metadata.base.resize_with(obj_count, BTreeMap::new);
}
for (i, preceder) in preceder_payloads.iter().enumerate() {
if let Some(prec_map) = preceder {
for (k, v) in prec_map {
global_metadata.base[i].insert(k.clone(), v.clone());
}
}
}
Ok(DecodedMessage {
preamble,
global_metadata,
index,
hash_frame,
objects,
preceder_payloads,
})
}
pub fn decode_metadata_only(buf: &[u8]) -> Result<GlobalMetadata> {
let preamble = Preamble::read_from(buf)?;
let mut pos = PREAMBLE_SIZE;
let msg_end = if preamble.total_length > 0 {
let total_len = usize::try_from(preamble.total_length).map_err(|_| {
TensogramError::Framing(format!(
"total_length {} overflows usize",
preamble.total_length
))
})?;
total_len.checked_sub(POSTAMBLE_SIZE).ok_or_else(|| {
TensogramError::Framing(format!(
"total_length {} too small for postamble",
preamble.total_length
))
})?
} else {
buf.len().checked_sub(POSTAMBLE_SIZE).ok_or_else(|| {
TensogramError::Framing(format!(
"buffer too short for postamble: {} < {POSTAMBLE_SIZE}",
buf.len()
))
})?
};
while pos < msg_end {
if pos + 2 > buf.len() {
break;
}
if &buf[pos..pos + 2] != b"FR" {
pos += 1;
continue;
}
if pos + FRAME_HEADER_SIZE > buf.len() {
break;
}
let fh = FrameHeader::read_from(&buf[pos..])?;
match fh.frame_type {
FrameType::HeaderMetadata | FrameType::FooterMetadata => {
let (_, payload, _) = read_frame(&buf[pos..])?;
return metadata::cbor_to_global_metadata(payload);
}
_ => {
let frame_total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Framing(format!(
"frame total_length {} overflows usize",
fh.total_length
))
})?;
pos += frame_total;
pos = (pos + 7) & !7; }
}
}
Err(TensogramError::Metadata(
"no metadata frame found in message".to_string(),
))
}
#[tracing::instrument(skip(buf), fields(buf_len = buf.len()))]
pub fn scan(buf: &[u8]) -> Vec<(usize, usize)> {
let mut messages = Vec::new();
let mut pos = 0;
while pos + PREAMBLE_SIZE + POSTAMBLE_SIZE <= buf.len() {
if &buf[pos..pos + MAGIC.len()] == MAGIC {
if let Ok(preamble) = Preamble::read_from(&buf[pos..]) {
if preamble.total_length > 0 {
let Ok(total) = usize::try_from(preamble.total_length) else {
pos += 1;
continue;
};
if pos + total <= buf.len() {
let end_magic_offset = pos + total - 8;
if &buf[end_magic_offset..end_magic_offset + 8] == crate::wire::END_MAGIC {
messages.push((pos, total));
pos += total;
continue;
}
}
} else {
let mut end_pos = pos + PREAMBLE_SIZE;
let mut found = false;
while end_pos + 8 <= buf.len() {
if &buf[end_pos..end_pos + 8] == crate::wire::END_MAGIC {
let msg_len = end_pos + 8 - pos;
messages.push((pos, msg_len));
pos = end_pos + 8;
found = true;
break;
}
end_pos += 1;
}
if found {
continue;
}
}
}
}
pos += 1;
}
messages
}
pub fn scan_file(file: &mut (impl std::io::Read + std::io::Seek)) -> Result<Vec<(usize, usize)>> {
use std::io::SeekFrom;
let file_len_u64 = file.seek(SeekFrom::End(0))?;
let file_len = usize::try_from(file_len_u64).map_err(|_| {
TensogramError::Framing(format!("file size {file_len_u64} overflows usize"))
})?;
file.seek(SeekFrom::Start(0))?;
let mut messages = Vec::new();
let mut pos: usize = 0;
let mut preamble_buf = [0u8; PREAMBLE_SIZE];
while pos + PREAMBLE_SIZE + POSTAMBLE_SIZE <= file_len {
file.seek(SeekFrom::Start(pos as u64))?;
if file.read_exact(&mut preamble_buf).is_err() {
break;
}
if &preamble_buf[..MAGIC.len()] == MAGIC
&& let Ok(preamble) = Preamble::read_from(&preamble_buf)
{
if preamble.total_length > 0 {
let Ok(total) = usize::try_from(preamble.total_length) else {
pos += 1;
continue;
};
if pos + total <= file_len {
let end_magic_offset = pos + total - 8;
file.seek(SeekFrom::Start(end_magic_offset as u64))?;
let mut end_buf = [0u8; 8];
if file.read_exact(&mut end_buf).is_ok() && &end_buf == crate::wire::END_MAGIC {
messages.push((pos, total));
pos += total;
continue;
}
}
} else {
let mut search_pos = pos + PREAMBLE_SIZE;
let mut found = false;
let chunk_size = 4096;
let mut chunk = vec![0u8; chunk_size];
while search_pos + 8 <= file_len {
file.seek(SeekFrom::Start(search_pos as u64))?;
let to_read = (file_len - search_pos).min(chunk_size);
let buf = &mut chunk[..to_read];
if file.read_exact(buf).is_err() {
break;
}
for i in 0..to_read.saturating_sub(7) {
if &buf[i..i + 8] == crate::wire::END_MAGIC {
let end_pos = search_pos + i;
let msg_len = end_pos + 8 - pos;
messages.push((pos, msg_len));
pos = end_pos + 8;
found = true;
break;
}
}
if found {
break;
}
search_pos += to_read.saturating_sub(7);
}
if found {
continue;
}
}
}
pos += 1;
}
Ok(messages)
}
fn frame_total_size(payload_len: usize) -> usize {
FRAME_HEADER_SIZE + payload_len + FRAME_END.len()
}
fn aligned_frame_total_size(payload_len: usize) -> usize {
let raw = frame_total_size(payload_len);
(raw + 7) & !7
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::Dtype;
use crate::types::ByteOrder;
use std::collections::BTreeMap;
fn make_global_meta() -> GlobalMetadata {
GlobalMetadata {
version: 2,
..Default::default()
}
}
fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
let strides = {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
}
}
#[test]
fn test_data_object_frame_round_trip_cbor_after() {
let desc = make_descriptor(vec![4]);
let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let frame = encode_data_object_frame(&desc, &payload, false).unwrap();
let (decoded_desc, decoded_payload, consumed) = decode_data_object_frame(&frame).unwrap();
assert_eq!(decoded_desc.shape, vec![4]);
assert_eq!(decoded_desc.dtype, Dtype::Float32);
assert_eq!(decoded_payload, &payload[..]);
assert!(consumed >= frame.len());
}
#[test]
fn test_data_object_frame_round_trip_cbor_before() {
let desc = make_descriptor(vec![2, 3]);
let payload = vec![0xABu8; 24];
let frame = encode_data_object_frame(&desc, &payload, true).unwrap();
let (decoded_desc, decoded_payload, _) = decode_data_object_frame(&frame).unwrap();
assert_eq!(decoded_desc.shape, vec![2, 3]);
assert_eq!(decoded_payload, &payload[..]);
}
#[test]
fn test_empty_message_round_trip() {
let meta = make_global_meta();
let msg = encode_message(&meta, &[]).unwrap();
assert_eq!(&msg[0..8], MAGIC);
assert_eq!(&msg[msg.len() - 8..], crate::wire::END_MAGIC);
let decoded = decode_message(&msg).unwrap();
assert_eq!(decoded.global_metadata.version, 2);
assert_eq!(decoded.objects.len(), 0);
assert!(decoded.index.is_none()); }
#[test]
fn test_single_object_message_round_trip() {
let meta = make_global_meta();
let desc = make_descriptor(vec![4]);
let payload = vec![42u8; 16];
let objects = vec![EncodedObject {
descriptor: desc,
encoded_payload: payload.clone(),
}];
let msg = encode_message(&meta, &objects).unwrap();
let decoded = decode_message(&msg).unwrap();
assert_eq!(decoded.global_metadata.version, 2);
assert_eq!(decoded.objects.len(), 1);
assert_eq!(decoded.objects[0].0.shape, vec![4]);
assert_eq!(decoded.objects[0].1, &payload[..]);
assert!(decoded.index.is_some());
assert_eq!(decoded.index.as_ref().unwrap().object_count, 1);
}
#[test]
fn test_multi_object_message_round_trip() {
let meta = make_global_meta();
let desc0 = make_descriptor(vec![4]);
let desc1 = make_descriptor(vec![2, 3]);
let payload0 = vec![10u8; 16];
let payload1 = vec![20u8; 24];
let objects = vec![
EncodedObject {
descriptor: desc0,
encoded_payload: payload0.clone(),
},
EncodedObject {
descriptor: desc1,
encoded_payload: payload1.clone(),
},
];
let msg = encode_message(&meta, &objects).unwrap();
let decoded = decode_message(&msg).unwrap();
assert_eq!(decoded.objects.len(), 2);
assert_eq!(decoded.objects[0].0.shape, vec![4]);
assert_eq!(decoded.objects[0].1, &payload0[..]);
assert_eq!(decoded.objects[1].0.shape, vec![2, 3]);
assert_eq!(decoded.objects[1].1, &payload1[..]);
let idx = decoded.index.as_ref().unwrap();
assert_eq!(idx.object_count, 2);
assert_eq!(idx.offsets.len(), 2);
}
#[test]
fn test_scan_multi_message() {
let meta = make_global_meta();
let msg1 = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![4]),
encoded_payload: vec![1u8; 16],
}],
)
.unwrap();
let msg2 = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![2]),
encoded_payload: vec![2u8; 8],
}],
)
.unwrap();
let mut buf = msg1.clone();
buf.extend_from_slice(&msg2);
let offsets = scan(&buf);
assert_eq!(offsets.len(), 2);
assert_eq!(offsets[0], (0, msg1.len()));
assert_eq!(offsets[1], (msg1.len(), msg2.len()));
}
#[test]
fn test_scan_with_garbage() {
let meta = make_global_meta();
let msg = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![4]),
encoded_payload: vec![1u8; 16],
}],
)
.unwrap();
let mut buf = vec![0xFF; 10];
buf.extend_from_slice(&msg);
buf.extend_from_slice(&[0xAA; 5]);
let offsets = scan(&buf);
assert_eq!(offsets.len(), 1);
assert_eq!(offsets[0], (10, msg.len()));
}
#[test]
fn test_decode_metadata_only() {
let mut meta = make_global_meta();
meta.extra.insert(
"test_key".to_string(),
ciborium::Value::Text("test_value".to_string()),
);
let msg = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![4]),
encoded_payload: vec![0u8; 16],
}],
)
.unwrap();
let decoded_meta = decode_metadata_only(&msg).unwrap();
assert_eq!(decoded_meta.version, 2);
assert!(decoded_meta.extra.contains_key("test_key"));
}
fn build_raw_message(frames: &[(&[u8], FrameType)]) -> Vec<u8> {
let meta = make_global_meta();
let meta_cbor = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let desc = make_descriptor(vec![4]);
let payload = vec![0u8; 16];
let mut out = Vec::new();
out.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
for (content, frame_type) in frames {
match frame_type {
FrameType::DataObject => {
let frame = encode_data_object_frame(&desc, &payload, false).unwrap();
out.extend_from_slice(&frame);
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
}
_ => {
let data = if content.is_empty() {
&meta_cbor
} else {
*content
};
write_frame(&mut out, *frame_type, 1, 0, data, true);
}
}
}
let postamble_offset = out.len();
let postamble = Postamble {
first_footer_offset: postamble_offset as u64,
};
postamble.write_to(&mut out);
let total_length = out.len() as u64;
let mut flags = MessageFlags::default();
flags.set(MessageFlags::HEADER_METADATA);
let preamble = Preamble {
version: 2,
flags,
reserved: 0,
total_length,
};
let mut preamble_bytes = Vec::new();
preamble.write_to(&mut preamble_bytes);
out[0..PREAMBLE_SIZE].copy_from_slice(&preamble_bytes);
out
}
#[test]
fn test_decode_rejects_header_after_data_object() {
let msg = build_raw_message(&[
(&[], FrameType::DataObject),
(&[], FrameType::HeaderMetadata),
]);
let result = decode_message(&msg);
assert!(
result.is_err(),
"header frame after data object should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("order") || err.contains("unexpected"),
"error should mention ordering: {err}"
);
}
#[test]
fn test_decode_rejects_data_object_after_footer() {
let meta = make_global_meta();
let meta_cbor = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let hf = HashFrame {
object_count: 0,
hash_type: "xxh3".to_string(),
hashes: vec![],
};
let hash_cbor = crate::metadata::hash_frame_to_cbor(&hf).unwrap();
let msg = build_raw_message(&[
(&meta_cbor, FrameType::HeaderMetadata),
(&hash_cbor, FrameType::FooterHash),
(&[], FrameType::DataObject),
]);
let result = decode_message(&msg);
assert!(
result.is_err(),
"data object after footer should be rejected"
);
}
#[test]
fn test_decode_accepts_valid_frame_order() {
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&[], FrameType::DataObject),
]);
let result = decode_message(&msg);
assert!(
result.is_ok(),
"valid frame order should be accepted: {:?}",
result.err()
);
}
#[test]
fn test_scan_file_matches_scan_buffer() {
let meta = make_global_meta();
let msg1 = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![4]),
encoded_payload: vec![1u8; 16],
}],
)
.unwrap();
let msg2 = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![2]),
encoded_payload: vec![2u8; 8],
}],
)
.unwrap();
let mut buf = msg1.clone();
buf.extend_from_slice(&msg2);
let buffer_offsets = scan(&buf);
let mut cursor = std::io::Cursor::new(&buf);
let file_offsets = scan_file(&mut cursor).unwrap();
assert_eq!(buffer_offsets, file_offsets);
}
#[test]
fn test_scan_file_with_garbage() {
let meta = make_global_meta();
let msg = encode_message(
&meta,
&[EncodedObject {
descriptor: make_descriptor(vec![4]),
encoded_payload: vec![1u8; 16],
}],
)
.unwrap();
let mut buf = vec![0xFF; 10];
buf.extend_from_slice(&msg);
buf.extend_from_slice(&[0xAA; 5]);
let mut cursor = std::io::Cursor::new(&buf);
let offsets = scan_file(&mut cursor).unwrap();
assert_eq!(offsets.len(), 1);
assert_eq!(offsets[0], (10, msg.len()));
}
#[test]
fn test_scan_file_empty() {
let buf: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&buf);
let offsets = scan_file(&mut cursor).unwrap();
assert!(offsets.is_empty());
}
#[test]
fn test_decode_accepts_footer_after_data_objects() {
let meta = make_global_meta();
let meta_cbor = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let msg = build_raw_message(&[
(&meta_cbor, FrameType::HeaderMetadata),
(&[], FrameType::DataObject),
(&meta_cbor, FrameType::FooterMetadata),
]);
let result = decode_message(&msg);
assert!(
result.is_ok(),
"footer after data objects should be accepted: {:?}",
result.err()
);
}
fn make_preceder_cbor(entries: std::collections::BTreeMap<String, ciborium::Value>) -> Vec<u8> {
let meta = GlobalMetadata {
version: 2,
base: vec![entries],
..Default::default()
};
crate::metadata::global_metadata_to_cbor(&meta).unwrap()
}
#[test]
fn test_decode_preceder_before_data_object() {
let mut entries = BTreeMap::new();
entries.insert(
"mars".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("param".to_string()),
ciborium::Value::Text("2t".to_string()),
)]),
);
let preceder_cbor = make_preceder_cbor(entries);
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&preceder_cbor, FrameType::PrecederMetadata),
(&[], FrameType::DataObject),
]);
let decoded = decode_message(&msg).unwrap();
assert_eq!(decoded.objects.len(), 1);
assert_eq!(decoded.preceder_payloads.len(), 1);
assert!(decoded.preceder_payloads[0].is_some());
assert_eq!(decoded.global_metadata.base.len(), 1);
assert!(decoded.global_metadata.base[0].contains_key("mars"));
}
#[test]
fn test_decode_consecutive_preceders_rejected() {
let entries = BTreeMap::new();
let preceder_cbor = make_preceder_cbor(entries);
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&preceder_cbor, FrameType::PrecederMetadata),
(&preceder_cbor, FrameType::PrecederMetadata),
(&[], FrameType::DataObject),
]);
let result = decode_message(&msg);
assert!(result.is_err(), "consecutive preceders should be rejected");
let err = result.unwrap_err().to_string();
assert!(
err.contains("PrecederMetadata") && err.contains("DataObject"),
"error should explain preceder must precede DataObject: {err}"
);
}
#[test]
fn test_decode_dangling_preceder_rejected() {
let entries = BTreeMap::new();
let preceder_cbor = make_preceder_cbor(entries);
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&[], FrameType::DataObject),
(&preceder_cbor, FrameType::PrecederMetadata),
]);
let result = decode_message(&msg);
assert!(result.is_err(), "dangling preceder should be rejected");
let err = result.unwrap_err().to_string();
assert!(
err.contains("dangling"),
"error should mention dangling: {err}"
);
}
#[test]
fn test_decode_preceder_with_multiple_base_entries_rejected() {
let meta = GlobalMetadata {
version: 2,
base: vec![BTreeMap::new(), BTreeMap::new()],
..Default::default()
};
let bad_cbor = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&bad_cbor, FrameType::PrecederMetadata),
(&[], FrameType::DataObject),
]);
let result = decode_message(&msg);
assert!(
result.is_err(),
"preceder with 2 payload entries should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("exactly 1"),
"error should mention 'exactly 1': {err}"
);
}
#[test]
fn test_decode_preceder_with_zero_base_entries_rejected() {
let meta = GlobalMetadata {
version: 2,
base: vec![],
..Default::default()
};
let bad_cbor = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&bad_cbor, FrameType::PrecederMetadata),
(&[], FrameType::DataObject),
]);
let result = decode_message(&msg);
assert!(
result.is_err(),
"preceder with 0 payload entries should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("exactly 1") && err.contains("got 0"),
"error should mention 'exactly 1' and 'got 0': {err}"
);
}
#[test]
fn test_decode_preceder_followed_by_footer_rejected() {
let entries = BTreeMap::new();
let preceder_cbor = make_preceder_cbor(entries);
let meta = make_global_meta();
let meta_cbor = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
let msg = build_raw_message(&[
(&meta_cbor, FrameType::HeaderMetadata),
(&preceder_cbor, FrameType::PrecederMetadata),
(&meta_cbor, FrameType::FooterMetadata),
]);
let result = decode_message(&msg);
assert!(
result.is_err(),
"preceder followed by footer should be rejected"
);
}
#[test]
fn test_decode_mixed_preceder_and_no_preceder() {
let mut entries = BTreeMap::new();
entries.insert(
"note".to_string(),
ciborium::Value::Text("from preceder".to_string()),
);
let preceder_cbor = make_preceder_cbor(entries);
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&preceder_cbor, FrameType::PrecederMetadata),
(&[], FrameType::DataObject),
(&[], FrameType::DataObject),
]);
let decoded = decode_message(&msg).unwrap();
assert_eq!(decoded.objects.len(), 2);
assert_eq!(decoded.preceder_payloads.len(), 2);
assert!(decoded.preceder_payloads[0].is_some());
assert!(decoded.preceder_payloads[1].is_none());
assert!(decoded.global_metadata.base[0].contains_key("note"));
assert!(!decoded.global_metadata.base[1].contains_key("note"));
}
#[test]
fn test_decode_preceder_wins_over_footer_payload() {
let mut prec_entries = BTreeMap::new();
prec_entries.insert(
"source".to_string(),
ciborium::Value::Text("preceder".to_string()),
);
let preceder_cbor = make_preceder_cbor(prec_entries);
let mut footer_base = BTreeMap::new();
footer_base.insert(
"source".to_string(),
ciborium::Value::Text("footer".to_string()),
);
let footer_meta = GlobalMetadata {
version: 2,
base: vec![footer_base],
..Default::default()
};
let footer_cbor = crate::metadata::global_metadata_to_cbor(&footer_meta).unwrap();
let msg = build_raw_message(&[
(&[], FrameType::HeaderMetadata),
(&preceder_cbor, FrameType::PrecederMetadata),
(&[], FrameType::DataObject),
(&footer_cbor, FrameType::FooterMetadata),
]);
let decoded = decode_message(&msg).unwrap();
let source = decoded.global_metadata.base[0]
.get("source")
.and_then(|v| match v {
ciborium::Value::Text(s) => Some(s.as_str()),
_ => None,
});
assert_eq!(source, Some("preceder"), "preceder should win over footer");
}
#[test]
fn test_decode_rejects_base_count_exceeding_objects() {
let footer_meta = GlobalMetadata {
version: 2,
base: vec![BTreeMap::new(), BTreeMap::new(), BTreeMap::new()],
..Default::default()
};
let footer_cbor = crate::metadata::global_metadata_to_cbor(&footer_meta).unwrap();
let msg = build_raw_message(&[
(&footer_cbor, FrameType::HeaderMetadata),
(&[], FrameType::DataObject),
]);
let result = decode_message(&msg);
assert!(
result.is_err(),
"base with more entries than objects should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("3") && err.contains("1"),
"error should mention counts: {err}"
);
}
}