use std::collections::BTreeMap;
use proptest::prelude::*;
use tensogram::{
ByteOrder, DataObjectDescriptor, DecodeOptions, Dtype, EncodeOptions, GlobalMetadata, decode,
encode,
};
#[derive(Debug, Clone)]
enum Compression {
None,
Lz4,
Blosc2,
Zstd,
}
impl Compression {
fn wire_name(&self) -> &'static str {
match self {
Compression::None => "none",
Compression::Lz4 => "lz4",
Compression::Blosc2 => "blosc2",
Compression::Zstd => "zstd",
}
}
fn is_transparent(&self) -> bool {
matches!(self, Compression::None | Compression::Lz4)
}
}
fn compression_strategy() -> impl Strategy<Value = Compression> {
prop_oneof![
Just(Compression::None),
Just(Compression::Lz4),
Just(Compression::Blosc2),
Just(Compression::Zstd),
]
}
fn make_descriptor(shape: Vec<u64>, compression: &Compression) -> DataObjectDescriptor {
use ciborium::Value;
let ndim = shape.len() as u64;
let strides = {
let mut v = vec![1u64; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
v[i] = v[i + 1] * shape[i + 1];
}
v
};
let mut params = BTreeMap::new();
let wire = compression.wire_name().to_string();
match compression {
Compression::Blosc2 => {
params.insert("blosc2_clevel".to_string(), Value::Integer(3.into()));
params.insert("blosc2_codec".to_string(), Value::Text("lz4".to_string()));
}
Compression::Zstd => {
params.insert("zstd_level".to_string(), Value::Integer(3.into()));
}
_ => {}
}
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim,
shape,
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: wire,
params,
hash: None,
}
}
fn make_payload(len: usize, seed: u8) -> Vec<u8> {
(0..len)
.map(|i| ((i as u32).wrapping_mul(2654435761) ^ seed as u32) as u8)
.collect()
}
fn encoded_payloads(buf: &[u8]) -> Vec<Vec<u8>> {
tensogram::framing::decode_message(buf)
.unwrap()
.objects
.iter()
.map(|(_, p, _)| p.to_vec())
.collect()
}
fn make_object(
n_elements: usize,
compression: &Compression,
seed: u8,
) -> (DataObjectDescriptor, Vec<u8>) {
let desc = make_descriptor(vec![n_elements as u64], compression);
let payload = make_payload(n_elements * 4, seed);
(desc, payload)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(32))]
#[test]
fn roundtrip_across_threads(
n_objects in 1usize..5,
n_elements in 16usize..4096,
compression in compression_strategy(),
threads in prop_oneof![Just(0u32), Just(1), Just(2), Just(4), Just(8)],
) {
let meta = GlobalMetadata::default();
let objects: Vec<(DataObjectDescriptor, Vec<u8>)> = (0..n_objects)
.map(|i| make_object(n_elements, &compression, i as u8))
.collect();
let pairs: Vec<(&DataObjectDescriptor, &[u8])> = objects
.iter()
.map(|(d, p)| (d, p.as_slice()))
.collect();
let enc_opts = EncodeOptions {
threads,
parallel_threshold_bytes: Some(0), ..Default::default()
};
let dec_opts = DecodeOptions {
threads,
parallel_threshold_bytes: Some(0),
..Default::default()
};
let msg = encode(&meta, &pairs, &enc_opts).expect("encode");
let (_meta, decoded) = decode(&msg, &dec_opts).expect("decode");
prop_assert_eq!(decoded.len(), n_objects);
for (i, (_desc, bytes)) in decoded.iter().enumerate() {
prop_assert_eq!(bytes, &objects[i].1);
}
}
#[test]
fn transparent_codec_byte_identical(
n_objects in 1usize..5,
n_elements in 16usize..4096,
compression in compression_strategy(),
threads in prop_oneof![Just(1u32), Just(2), Just(4), Just(8)],
) {
prop_assume!(compression.is_transparent());
let meta = GlobalMetadata::default();
let objects: Vec<(DataObjectDescriptor, Vec<u8>)> = (0..n_objects)
.map(|i| make_object(n_elements, &compression, i as u8))
.collect();
let pairs: Vec<(&DataObjectDescriptor, &[u8])> = objects
.iter()
.map(|(d, p)| (d, p.as_slice()))
.collect();
let seq_opts = EncodeOptions::default();
let par_opts = EncodeOptions {
threads,
parallel_threshold_bytes: Some(0),
..Default::default()
};
let seq_msg = encode(&meta, &pairs, &seq_opts).expect("seq encode");
let par_msg = encode(&meta, &pairs, &par_opts).expect("par encode");
prop_assert_eq!(encoded_payloads(&seq_msg), encoded_payloads(&par_msg));
}
}