use std::collections::BTreeMap;
use std::io::Write;
use crate::encode::{
EncodeOptions, build_pipeline_config, populate_base_entries, populate_reserved_provenance,
validate_no_szip_offsets_for_non_szip, validate_object, validate_szip_block_offsets,
};
use crate::error::{Result, TensogramError};
use crate::framing::EncodedObject;
use crate::hash::HashAlgorithm;
use crate::metadata::{self, RESERVED_KEY};
use crate::types::{DataObjectDescriptor, GlobalMetadata, HashDescriptor, HashFrame, IndexFrame};
use crate::wire::{
FRAME_END, FRAME_HEADER_SIZE, FrameHeader, FrameType, MessageFlags, PREAMBLE_SIZE, Postamble,
Preamble,
};
use tensogram_encodings::pipeline;
pub struct StreamingEncoder<W: Write> {
writer: W,
object_offsets: Vec<u64>,
object_lengths: Vec<u64>,
hash_entries: Vec<Option<(String, String)>>,
completed_objects: Vec<EncodedObject>,
bytes_written: u64,
hash_algorithm: Option<HashAlgorithm>,
global_meta: GlobalMetadata,
pending_preceder: bool,
preceder_payloads: Vec<Option<BTreeMap<String, ciborium::Value>>>,
intra_codec_threads: u32,
parallel_threshold_bytes: Option<usize>,
}
impl<W: Write> StreamingEncoder<W> {
pub fn new(
mut writer: W,
global_meta: &GlobalMetadata,
options: &EncodeOptions,
) -> Result<Self> {
let meta_cbor = metadata::global_metadata_to_cbor(global_meta)?;
let mut flags = MessageFlags::default();
flags.set(MessageFlags::HEADER_METADATA);
flags.set(MessageFlags::FOOTER_METADATA);
flags.set(MessageFlags::FOOTER_INDEX);
flags.set(MessageFlags::PRECEDER_METADATA);
if options.hash_algorithm.is_some() {
flags.set(MessageFlags::FOOTER_HASHES);
}
let preamble = Preamble {
version: 2,
flags,
reserved: 0,
total_length: 0,
};
let preamble_bytes = preamble_to_bytes(&preamble);
writer.write_all(&preamble_bytes)?;
let mut bytes_written = PREAMBLE_SIZE as u64;
let frame_bytes = build_frame(FrameType::HeaderMetadata, 1, 0, &meta_cbor);
writer.write_all(&frame_bytes)?;
bytes_written += frame_bytes.len() as u64;
write_padding(&mut writer, &mut bytes_written)?;
let intra_codec_threads = crate::parallel::resolve_budget(options.threads);
Ok(Self {
writer,
object_offsets: Vec::new(),
object_lengths: Vec::new(),
hash_entries: Vec::new(),
completed_objects: Vec::new(),
bytes_written,
hash_algorithm: options.hash_algorithm,
global_meta: global_meta.clone(),
pending_preceder: false,
preceder_payloads: Vec::new(),
intra_codec_threads,
parallel_threshold_bytes: options.parallel_threshold_bytes,
})
}
pub fn write_preceder(&mut self, metadata: BTreeMap<String, ciborium::Value>) -> Result<()> {
if self.pending_preceder {
return Err(TensogramError::Framing(
"write_preceder called twice without an intervening write_object/write_object_pre_encoded".to_string(),
));
}
if metadata.contains_key(RESERVED_KEY) {
return Err(TensogramError::Metadata(format!(
"client code must not write '{RESERVED_KEY}' in preceder metadata; \
this field is populated by the library"
)));
}
let preceder_meta = GlobalMetadata {
version: self.global_meta.version,
base: vec![metadata.clone()],
..Default::default()
};
let cbor = crate::metadata::global_metadata_to_cbor(&preceder_meta)?;
let frame_bytes = build_frame(FrameType::PrecederMetadata, 1, 0, &cbor);
self.writer.write_all(&frame_bytes)?;
self.bytes_written += frame_bytes.len() as u64;
write_padding(&mut self.writer, &mut self.bytes_written)?;
self.pending_preceder = true;
self.preceder_payloads.push(Some(metadata));
Ok(())
}
pub fn write_object(&mut self, desc: &DataObjectDescriptor, data: &[u8]) -> Result<()> {
validate_object(desc, data.len())?;
let shape_product = desc
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
let num_elements = usize::try_from(shape_product)
.map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
let parallel = crate::parallel::should_parallelise(
self.intra_codec_threads,
data.len(),
self.parallel_threshold_bytes,
);
let intra = if parallel {
self.intra_codec_threads
} else {
0
};
let config = crate::encode::build_pipeline_config_with_backend(
desc,
num_elements,
desc.dtype,
tensogram_encodings::pipeline::CompressionBackend::default(),
intra,
)?;
let result =
crate::parallel::run_maybe_pooled(self.intra_codec_threads, parallel, intra, || {
pipeline::encode_pipeline(data, &config)
})
.map_err(|e| TensogramError::Encoding(e.to_string()))?;
let mut final_desc = desc.clone();
if let Some(offsets) = &result.block_offsets {
final_desc.params.insert(
"szip_block_offsets".to_string(),
ciborium::Value::Array(
offsets
.iter()
.map(|&o| ciborium::Value::Integer(o.into()))
.collect(),
),
);
}
self.write_object_inner(final_desc, &result.encoded_bytes)
}
#[tracing::instrument(skip(self, descriptor, pre_encoded_bytes))]
pub fn write_object_pre_encoded(
&mut self,
descriptor: &DataObjectDescriptor,
pre_encoded_bytes: &[u8],
) -> Result<()> {
validate_object(descriptor, pre_encoded_bytes.len())?;
let shape_product = descriptor
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
let num_elements = usize::try_from(shape_product)
.map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
build_pipeline_config(descriptor, num_elements, descriptor.dtype)?;
validate_no_szip_offsets_for_non_szip(descriptor)?;
if descriptor.compression == "szip" && descriptor.params.contains_key("szip_block_offsets")
{
validate_szip_block_offsets(&descriptor.params, pre_encoded_bytes.len())?;
}
self.write_object_inner(descriptor.clone(), pre_encoded_bytes)
}
fn write_object_inner(
&mut self,
mut final_desc: DataObjectDescriptor,
encoded_bytes: &[u8],
) -> Result<()> {
let start_offset = self.bytes_written;
let frame_len = write_data_object_frame_hashed(
&mut self.writer,
&mut final_desc,
encoded_bytes,
self.hash_algorithm,
)?;
self.bytes_written += frame_len;
let hash_entry = final_desc
.hash
.as_ref()
.map(|h| (h.hash_type.clone(), h.value.clone()));
self.object_offsets.push(start_offset);
self.object_lengths.push(frame_len);
self.hash_entries.push(hash_entry);
self.completed_objects.push(EncodedObject {
descriptor: final_desc,
encoded_payload: Vec::new(),
});
if self.pending_preceder {
self.pending_preceder = false;
} else {
self.preceder_payloads.push(None);
}
write_padding(&mut self.writer, &mut self.bytes_written)?;
Ok(())
}
pub fn finish(mut self) -> Result<W> {
if self.pending_preceder {
return Err(TensogramError::Framing(
"dangling PrecederMetadata: finish called without a following write_object/write_object_pre_encoded"
.to_string(),
));
}
let footer_start = self.bytes_written;
{
let mut enriched_meta = self.global_meta.clone();
populate_base_entries(&mut enriched_meta.base, &self.completed_objects);
populate_reserved_provenance(&mut enriched_meta.reserved);
if self.preceder_payloads.len() != self.completed_objects.len() {
return Err(TensogramError::Framing(format!(
"internal: preceder_payloads ({}) out of sync with completed_objects ({})",
self.preceder_payloads.len(),
self.completed_objects.len()
)));
}
for (i, prec) in self.preceder_payloads.iter().enumerate() {
if let Some(prec_map) = prec
&& i < enriched_meta.base.len()
{
for (k, v) in prec_map {
enriched_meta.base[i].insert(k.clone(), v.clone());
}
}
}
let meta_cbor = metadata::global_metadata_to_cbor(&enriched_meta)?;
let frame_bytes = build_frame(FrameType::FooterMetadata, 1, 0, &meta_cbor);
self.writer.write_all(&frame_bytes)?;
self.bytes_written += frame_bytes.len() as u64;
write_padding(&mut self.writer, &mut self.bytes_written)?;
}
let has_hashes = self.hash_entries.iter().any(|e| e.is_some());
if has_hashes {
let hash_type = self
.hash_algorithm
.map(|a| a.as_str().to_string())
.unwrap_or_default();
let hashes: Vec<String> = self
.hash_entries
.iter()
.map(|e| e.as_ref().map(|(_, v)| v.clone()).unwrap_or_default())
.collect();
let hash_frame = HashFrame {
object_count: self.object_offsets.len() as u64,
hash_type,
hashes,
};
let hash_cbor = metadata::hash_frame_to_cbor(&hash_frame)?;
let frame_bytes = build_frame(FrameType::FooterHash, 1, 0, &hash_cbor);
self.writer.write_all(&frame_bytes)?;
self.bytes_written += frame_bytes.len() as u64;
write_padding(&mut self.writer, &mut self.bytes_written)?;
}
let index = IndexFrame {
object_count: self.object_offsets.len() as u64,
offsets: self.object_offsets,
lengths: self.object_lengths,
};
let index_cbor = metadata::index_to_cbor(&index)?;
let frame_bytes = build_frame(FrameType::FooterIndex, 1, 0, &index_cbor);
self.writer.write_all(&frame_bytes)?;
self.bytes_written += frame_bytes.len() as u64;
write_padding(&mut self.writer, &mut self.bytes_written)?;
let postamble = Postamble {
first_footer_offset: footer_start,
};
let mut postamble_bytes = Vec::with_capacity(16);
postamble.write_to(&mut postamble_bytes);
self.writer.write_all(&postamble_bytes)?;
self.writer.flush()?;
Ok(self.writer)
}
pub fn object_count(&self) -> usize {
self.object_offsets.len()
}
pub fn bytes_written(&self) -> u64 {
self.bytes_written
}
}
fn preamble_to_bytes(preamble: &Preamble) -> Vec<u8> {
let mut out = Vec::with_capacity(PREAMBLE_SIZE);
preamble.write_to(&mut out);
out
}
fn build_frame(frame_type: FrameType, version: u16, flags: u16, payload: &[u8]) -> Vec<u8> {
let total_length = (FRAME_HEADER_SIZE + payload.len() + FRAME_END.len()) as u64;
let fh = FrameHeader {
frame_type,
version,
flags,
total_length,
};
let mut out = Vec::with_capacity(total_length as usize);
fh.write_to(&mut out);
out.extend_from_slice(payload);
out.extend_from_slice(FRAME_END);
out
}
const ZERO_PAD: [u8; 7] = [0; 7];
fn write_padding(writer: &mut impl Write, bytes_written: &mut u64) -> std::io::Result<()> {
let pad = (8 - (*bytes_written as usize % 8)) % 8;
if pad > 0 {
writer.write_all(&ZERO_PAD[..pad])?;
*bytes_written += pad as u64;
}
Ok(())
}
fn write_data_object_frame_hashed<W: Write>(
writer: &mut W,
descriptor: &mut DataObjectDescriptor,
payload: &[u8],
hash_algorithm: Option<HashAlgorithm>,
) -> Result<u64> {
use crate::wire::{DATA_OBJECT_FOOTER_SIZE, DataObjectFlags, FRAME_END};
let cbor_len_estimate = match hash_algorithm {
None => {
descriptor.hash = None;
metadata::object_descriptor_to_cbor(descriptor)?.len()
}
Some(alg) => {
let placeholder_len = alg.hex_digest_len();
descriptor.hash = Some(HashDescriptor {
hash_type: alg.as_str().to_string(),
value: "0".repeat(placeholder_len),
});
let len_zeros = metadata::object_descriptor_to_cbor(descriptor)?.len();
descriptor.hash = Some(HashDescriptor {
hash_type: alg.as_str().to_string(),
value: "f".repeat(placeholder_len),
});
let len_ones = metadata::object_descriptor_to_cbor(descriptor)?.len();
if len_zeros != len_ones {
return Err(TensogramError::Framing(format!(
"streaming encoder requires a hash algorithm with a \
value-independent CBOR encoding length; {} produced \
{len_zeros} bytes for an all-zero digest and \
{len_ones} bytes for an all-'f' digest of the same \
length ({placeholder_len} chars). No frame bytes \
have been written. Use the buffered encode() API \
for this hash algorithm, or extend \
write_data_object_frame_hashed to handle variable-\
length digests.",
alg.as_str()
)));
}
len_zeros
}
};
let payload_len = payload.len();
let total_length = (FRAME_HEADER_SIZE as u64)
.checked_add(cbor_len_estimate as u64)
.and_then(|n| n.checked_add(payload_len as u64))
.and_then(|n| n.checked_add(DATA_OBJECT_FOOTER_SIZE as u64))
.ok_or_else(|| {
TensogramError::Framing(format!(
"data object frame total_length overflows u64 \
(payload {payload_len} bytes, CBOR {cbor_len_estimate} bytes, \
framing {} bytes)",
FRAME_HEADER_SIZE + DATA_OBJECT_FOOTER_SIZE
))
})?;
let mut header_bytes = Vec::with_capacity(FRAME_HEADER_SIZE);
FrameHeader {
frame_type: FrameType::DataObject,
version: 1,
flags: DataObjectFlags::CBOR_AFTER_PAYLOAD,
total_length,
}
.write_to(&mut header_bytes);
writer.write_all(&header_bytes)?;
const CHUNK: usize = 64 * 1024;
let mut inline_hasher: Option<(xxhash_rust::xxh3::Xxh3Default, HashAlgorithm)> = hash_algorithm
.map(|alg| match alg {
HashAlgorithm::Xxh3 => (xxhash_rust::xxh3::Xxh3Default::new(), alg),
});
let mut offset = 0;
while offset < payload_len {
let end = (offset + CHUNK).min(payload_len);
let chunk = &payload[offset..end];
if let Some((h, _)) = &mut inline_hasher {
h.update(chunk);
}
writer.write_all(chunk)?;
offset = end;
}
descriptor.hash = inline_hasher.map(|(h, alg)| match alg {
HashAlgorithm::Xxh3 => HashDescriptor {
hash_type: alg.as_str().to_string(),
value: crate::hash::format_xxh3_digest(h.digest()),
},
});
let cbor_bytes = metadata::object_descriptor_to_cbor(descriptor)?;
debug_assert_eq!(
cbor_bytes.len(),
cbor_len_estimate,
"write_data_object_frame_hashed: final CBOR length \
({}) differs from pre-write estimate ({cbor_len_estimate}) — \
this indicates a non-deterministic CBOR serialiser, not a \
hash-length problem (that would have been caught earlier)",
cbor_bytes.len(),
);
writer.write_all(&cbor_bytes)?;
let cbor_offset = (FRAME_HEADER_SIZE + payload_len) as u64;
writer.write_all(&cbor_offset.to_be_bytes())?;
writer.write_all(FRAME_END)?;
Ok(total_length)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Dtype;
use crate::decode::{DecodeOptions, decode};
use crate::encode::{EncodeOptions, encode};
use crate::types::{ByteOrder, DataObjectDescriptor};
use std::collections::BTreeMap;
fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
let ndim = shape.len() as u64;
let mut strides = vec![0u64; shape.len()];
if !shape.is_empty() {
strides[shape.len() - 1] = 1;
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
}
}
#[test]
fn streaming_single_object_round_trip() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 2);
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].1, data);
}
#[test]
fn streaming_multi_object_round_trip() {
let meta = GlobalMetadata::default();
let desc1 = make_descriptor(vec![4]);
let desc2 = make_descriptor(vec![8]);
let data1 = vec![1u8; 4 * 4];
let data2 = vec![2u8; 8 * 4];
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_object(&desc1, &data1).unwrap();
enc.write_object(&desc2, &data2).unwrap();
assert_eq!(enc.object_count(), 2);
let result = enc.finish().unwrap();
let (_, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 2);
assert_eq!(objects[0].1, data1);
assert_eq!(objects[1].1, data2);
}
#[test]
fn streaming_matches_buffered_single_object() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![4]);
let data = vec![42u8; 4 * 4];
let options = EncodeOptions {
compression_backend: Default::default(),
hash_algorithm: Some(HashAlgorithm::Xxh3),
emit_preceders: false,
threads: 0,
parallel_threshold_bytes: None,
};
let buffered = encode(&meta, &[(&desc, &data)], &options).unwrap();
let (buf_meta, buf_objects) = decode(
&buffered,
&DecodeOptions {
verify_hash: true,
..Default::default()
},
)
.unwrap();
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
enc.write_object(&desc, &data).unwrap();
let streamed = enc.finish().unwrap();
let (str_meta, str_objects) = decode(
&streamed,
&DecodeOptions {
verify_hash: true,
..Default::default()
},
)
.unwrap();
assert_eq!(buf_meta.version, str_meta.version);
assert_eq!(buf_objects.len(), str_objects.len());
assert_eq!(buf_objects[0].0.shape, str_objects[0].0.shape);
assert_eq!(buf_objects[0].0.dtype, str_objects[0].0.dtype);
assert_eq!(buf_objects[0].1, str_objects[0].1);
assert_eq!(
buf_objects[0].0.hash.as_ref().unwrap().value,
str_objects[0].0.hash.as_ref().unwrap().value
);
}
#[test]
fn streaming_hash_verification() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![4]);
let data = vec![42u8; 4 * 4];
let options = EncodeOptions {
compression_backend: Default::default(),
hash_algorithm: Some(HashAlgorithm::Xxh3),
emit_preceders: false,
threads: 0,
parallel_threshold_bytes: None,
};
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let verify_opts = DecodeOptions {
verify_hash: true,
..Default::default()
};
let (_, objects) = decode(&result, &verify_opts).unwrap();
assert!(objects[0].0.hash.is_some());
}
#[test]
fn streaming_no_objects() {
let meta = GlobalMetadata::default();
let options = EncodeOptions {
compression_backend: Default::default(),
hash_algorithm: None,
emit_preceders: false,
threads: 0,
parallel_threshold_bytes: None,
};
let buf = Vec::new();
let enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
assert_eq!(enc.object_count(), 0);
let result = enc.finish().unwrap();
let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(decoded_meta.version, 2);
assert_eq!(objects.len(), 0);
}
#[test]
fn streaming_threads_byte_identical_transparent() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![50_000]);
let data: Vec<u8> = (0..50_000)
.flat_map(|i| (250.0f32 + (i as f32).sin() * 30.0).to_ne_bytes())
.collect();
let mk = |threads: u32| -> Vec<u8> {
let buf = Vec::new();
let opts = EncodeOptions {
threads,
parallel_threshold_bytes: Some(0), ..Default::default()
};
let mut enc = StreamingEncoder::new(buf, &meta, &opts).unwrap();
enc.write_object(&desc, &data).unwrap();
enc.finish().unwrap()
};
let payloads = |buf: &[u8]| -> Vec<Vec<u8>> {
crate::framing::decode_message(buf)
.unwrap()
.objects
.iter()
.map(|(_, p, _)| p.to_vec())
.collect()
};
let baseline = mk(0);
let payloads_baseline = payloads(&baseline);
for t in [1u32, 2, 4, 8] {
let got = mk(t);
assert_eq!(
payloads_baseline,
payloads(&got),
"streaming threads={t} payload must match sequential"
);
}
}
#[test]
fn streaming_with_metadata() {
let mut extra = BTreeMap::new();
extra.insert(
"centre".to_string(),
ciborium::Value::Text("ecmwf".to_string()),
);
let meta = GlobalMetadata {
version: 2,
extra,
..Default::default()
};
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(
decoded_meta.extra.get("centre"),
Some(&ciborium::Value::Text("ecmwf".to_string()))
);
}
#[test]
fn streaming_preceder_round_trip() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![4]);
let data = vec![42u8; 4 * 4];
let mut prec = BTreeMap::new();
prec.insert(
"mars".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("param".to_string()),
ciborium::Value::Text("2t".to_string()),
)]),
);
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(prec).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].1, data);
let mars = decoded_meta.base[0].get("mars");
assert!(mars.is_some(), "mars key should be in base[0]");
}
#[test]
fn streaming_preceder_wins_over_footer() {
let mut footer_base = BTreeMap::new();
footer_base.insert(
"source".to_string(),
ciborium::Value::Text("footer".to_string()),
);
let meta = GlobalMetadata {
version: 2,
base: vec![footer_base],
..Default::default()
};
let mut prec = BTreeMap::new();
prec.insert(
"source".to_string(),
ciborium::Value::Text("preceder".to_string()),
);
let desc = make_descriptor(vec![4]);
let data = vec![0u8; 4 * 4];
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(prec).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
let source = decoded_meta.base[0].get("source").and_then(|v| match v {
ciborium::Value::Text(s) => Some(s.as_str()),
_ => None,
});
assert_eq!(source, Some("preceder"), "preceder should win over footer");
}
#[test]
fn streaming_consecutive_preceder_error() {
let meta = GlobalMetadata::default();
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(BTreeMap::new()).unwrap();
let result = enc.write_preceder(BTreeMap::new());
assert!(
result.is_err(),
"two write_preceder calls without intervening write_object should fail"
);
}
#[test]
fn streaming_dangling_preceder_error() {
let meta = GlobalMetadata::default();
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(BTreeMap::new()).unwrap();
let result = enc.finish();
assert!(
result.is_err(),
"finish with a dangling preceder should fail"
);
}
#[test]
fn streaming_mixed_objects_with_and_without_preceders() {
let meta = GlobalMetadata::default();
let desc0 = make_descriptor(vec![4]);
let desc1 = make_descriptor(vec![8]);
let data0 = vec![1u8; 4 * 4];
let data1 = vec![2u8; 8 * 4];
let mut prec = BTreeMap::new();
prec.insert(
"note".to_string(),
ciborium::Value::Text("only for obj 0".to_string()),
);
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(prec).unwrap();
enc.write_object(&desc0, &data0).unwrap();
enc.write_object(&desc1, &data1).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 2);
assert_eq!(objects[0].1, data0);
assert_eq!(objects[1].1, data1);
assert!(decoded_meta.base[0].contains_key("note"));
assert!(!decoded_meta.base[1].contains_key("note"));
}
#[test]
fn streaming_preceder_metadata_preservation() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![2]);
let data = vec![0u8; 2 * 4];
let mut prec = BTreeMap::new();
prec.insert("units".to_string(), ciborium::Value::Text("K".to_string()));
prec.insert(
"mars".to_string(),
ciborium::Value::Map(vec![
(
ciborium::Value::Text("param".to_string()),
ciborium::Value::Text("2t".to_string()),
),
(
ciborium::Value::Text("levtype".to_string()),
ciborium::Value::Text("sfc".to_string()),
),
]),
);
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(prec).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
let p = &decoded_meta.base[0];
assert_eq!(
p.get("units"),
Some(&ciborium::Value::Text("K".to_string()))
);
assert!(p.contains_key("mars"));
assert!(p.contains_key("_reserved_"));
}
#[test]
fn streaming_preceder_with_reserved_rejected() {
let meta = GlobalMetadata::default();
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
let mut prec = BTreeMap::new();
prec.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
let result = enc.write_preceder(prec);
assert!(result.is_err(), "_reserved_ in preceder should be rejected");
let err = result.unwrap_err().to_string();
assert!(
err.contains("_reserved_"),
"error should mention _reserved_: {err}"
);
}
#[test]
fn streaming_preceder_reserved_stripped_on_decode() {
let mut prec_entry = BTreeMap::new();
prec_entry.insert(
"mars".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("param".to_string()),
ciborium::Value::Text("2t".to_string()),
)]),
);
prec_entry.insert(
"_reserved_".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("rogue".to_string()),
ciborium::Value::Text("bad".to_string()),
)]),
);
let preceder_meta = GlobalMetadata {
version: 2,
base: vec![prec_entry],
..Default::default()
};
let preceder_cbor = crate::metadata::global_metadata_to_cbor(&preceder_meta).unwrap();
let desc_for_frame = make_descriptor(vec![4]);
let payload = vec![0u8; 4 * 4];
let frame =
crate::framing::encode_data_object_frame(&desc_for_frame, &payload, false).unwrap();
let mut footer_base = BTreeMap::new();
let tensor_map = ciborium::Value::Map(vec![
(
ciborium::Value::Text("ndim".to_string()),
ciborium::Value::Integer(1.into()),
),
(
ciborium::Value::Text("shape".to_string()),
ciborium::Value::Array(vec![ciborium::Value::Integer(4.into())]),
),
(
ciborium::Value::Text("strides".to_string()),
ciborium::Value::Array(vec![ciborium::Value::Integer(1.into())]),
),
(
ciborium::Value::Text("dtype".to_string()),
ciborium::Value::Text("float32".to_string()),
),
]);
footer_base.insert(
"_reserved_".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("tensor".to_string()),
tensor_map,
)]),
);
let footer_meta = GlobalMetadata {
version: 2,
base: vec![footer_base],
..Default::default()
};
let footer_cbor = crate::metadata::global_metadata_to_cbor(&footer_meta).unwrap();
use crate::wire::*;
let header_meta_cbor =
crate::metadata::global_metadata_to_cbor(&GlobalMetadata::default()).unwrap();
let mut out = Vec::new();
out.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
let total_length = (FRAME_HEADER_SIZE + header_meta_cbor.len() + FRAME_END.len()) as u64;
let fh = FrameHeader {
frame_type: FrameType::HeaderMetadata,
version: 1,
flags: 0,
total_length,
};
fh.write_to(&mut out);
out.extend_from_slice(&header_meta_cbor);
out.extend_from_slice(FRAME_END);
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
let total_length = (FRAME_HEADER_SIZE + preceder_cbor.len() + FRAME_END.len()) as u64;
let fh = FrameHeader {
frame_type: FrameType::PrecederMetadata,
version: 1,
flags: 0,
total_length,
};
fh.write_to(&mut out);
out.extend_from_slice(&preceder_cbor);
out.extend_from_slice(FRAME_END);
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
out.extend_from_slice(&frame);
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
let total_length = (FRAME_HEADER_SIZE + footer_cbor.len() + FRAME_END.len()) as u64;
let fh = FrameHeader {
frame_type: FrameType::FooterMetadata,
version: 1,
flags: 0,
total_length,
};
fh.write_to(&mut out);
out.extend_from_slice(&footer_cbor);
out.extend_from_slice(FRAME_END);
let pad = (8 - (out.len() % 8)) % 8;
out.extend(std::iter::repeat_n(0u8, pad));
let postamble_offset = out.len();
let postamble = Postamble {
first_footer_offset: postamble_offset as u64,
};
postamble.write_to(&mut out);
let total_length = out.len() as u64;
let mut flags = MessageFlags::default();
flags.set(MessageFlags::HEADER_METADATA);
flags.set(MessageFlags::FOOTER_METADATA);
flags.set(MessageFlags::PRECEDER_METADATA);
let preamble = Preamble {
version: 2,
flags,
reserved: 0,
total_length,
};
let mut preamble_bytes = Vec::new();
preamble.write_to(&mut preamble_bytes);
out[0..PREAMBLE_SIZE].copy_from_slice(&preamble_bytes);
let decoded = crate::framing::decode_message(&out).unwrap();
let base0 = &decoded.global_metadata.base[0];
assert!(
base0.contains_key("mars"),
"mars from preceder should survive"
);
let reserved = base0.get("_reserved_");
assert!(
reserved.is_some(),
"_reserved_ from footer should be present"
);
if let Some(ciborium::Value::Map(pairs)) = reserved {
let has_tensor = pairs
.iter()
.any(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
assert!(has_tensor, "tensor key from footer should be preserved");
let has_rogue = pairs
.iter()
.any(|(k, _)| *k == ciborium::Value::Text("rogue".to_string()));
assert!(
!has_rogue,
"rogue key from preceder's _reserved_ should have been stripped"
);
}
}
#[test]
fn test_streaming_mixed_mode_pre_encoded() {
let meta = GlobalMetadata::default();
let desc0 = make_descriptor(vec![4]);
let desc2 = make_descriptor(vec![6]);
let desc1 = make_descriptor(vec![5]);
let data0 = vec![1u8; 4 * 4];
let pre_encoded1 = vec![2u8; 5 * 4]; let data2 = vec![3u8; 6 * 4];
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_object(&desc0, &data0).unwrap();
enc.write_object_pre_encoded(&desc1, &pre_encoded1).unwrap();
enc.write_object(&desc2, &data2).unwrap();
assert_eq!(enc.object_count(), 3);
let result = enc.finish().unwrap();
let (_, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 3);
assert_eq!(objects[0].1, data0, "object 0 payload mismatch");
assert_eq!(objects[1].1, pre_encoded1, "object 1 payload mismatch");
assert_eq!(objects[2].1, data2, "object 2 payload mismatch");
}
#[test]
fn test_streaming_preceder_then_pre_encoded() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![4]);
let pre_encoded = vec![42u8; 4 * 4];
let mut prec = BTreeMap::new();
prec.insert(
"mars".to_string(),
ciborium::Value::Map(vec![(
ciborium::Value::Text("param".to_string()),
ciborium::Value::Text("2t".to_string()),
)]),
);
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(prec).unwrap();
enc.write_object_pre_encoded(&desc, &pre_encoded).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].1, pre_encoded, "pre-encoded payload mismatch");
let mars = decoded_meta.base[0].get("mars");
assert!(
mars.is_some(),
"mars key from preceder should be in base[0]"
);
}
#[test]
fn streaming_finish_preserves_preceder_does_not_clobber_reserved_tensor() {
let meta = GlobalMetadata::default();
let desc = make_descriptor(vec![4]);
let data = vec![42u8; 4 * 4];
let mut prec = BTreeMap::new();
prec.insert("units".to_string(), ciborium::Value::Text("K".to_string()));
let buf = Vec::new();
let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
enc.write_preceder(prec).unwrap();
enc.write_object(&desc, &data).unwrap();
let result = enc.finish().unwrap();
let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
let base0 = &decoded_meta.base[0];
assert!(base0.contains_key("units"));
let reserved = base0.get("_reserved_").expect("_reserved_ missing");
if let ciborium::Value::Map(pairs) = reserved {
let has_tensor = pairs
.iter()
.any(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
assert!(
has_tensor,
"_reserved_.tensor should be present after preceder merge"
);
} else {
panic!("_reserved_ should be a map");
}
}
}