use ciborium::Value as CborValue;
use tensogram_encodings::simple_packing;
use crate::types::DataObjectDescriptor;
#[derive(Debug, Clone)]
pub struct DataPipeline {
pub encoding: String,
pub bits: Option<u32>,
pub filter: String,
pub compression: String,
pub compression_level: Option<i32>,
}
impl Default for DataPipeline {
fn default() -> Self {
Self {
encoding: "none".to_string(),
bits: None,
filter: "none".to_string(),
compression: "none".to_string(),
compression_level: None,
}
}
}
pub fn apply_pipeline(
desc: &mut DataObjectDescriptor,
values: Option<&[f64]>,
pipeline: &DataPipeline,
var_label: &str,
) -> Result<(), String> {
let mut applied_simple_packing = false;
match pipeline.encoding.as_str() {
"none" => {}
"simple_packing" => match values {
None => {
eprintln!(
"warning: skipping simple_packing for {var_label} \
(not a float64 payload)"
);
}
Some(values) => {
let bits = pipeline.bits.unwrap_or(16);
match simple_packing::compute_params(values, bits, 0) {
Ok(params) => {
desc.encoding = "simple_packing".to_string();
desc.params.insert(
"reference_value".to_string(),
CborValue::Float(params.reference_value),
);
desc.params.insert(
"binary_scale_factor".to_string(),
CborValue::Integer((i64::from(params.binary_scale_factor)).into()),
);
desc.params.insert(
"decimal_scale_factor".to_string(),
CborValue::Integer((i64::from(params.decimal_scale_factor)).into()),
);
desc.params.insert(
"bits_per_value".to_string(),
CborValue::Integer((i64::from(params.bits_per_value)).into()),
);
applied_simple_packing = true;
}
Err(e) => {
eprintln!("warning: skipping simple_packing for {var_label}: {e}");
}
}
}
},
other => {
return Err(format!(
"unknown encoding '{other}'; expected 'none' or 'simple_packing'"
));
}
}
match pipeline.filter.as_str() {
"none" => {}
"shuffle" => {
desc.filter = "shuffle".to_string();
let element_size = if applied_simple_packing {
let bpv = pipeline.bits.unwrap_or(16) as usize;
bpv.div_ceil(8).max(1)
} else {
desc.dtype.byte_width()
};
desc.params.insert(
"shuffle_element_size".to_string(),
CborValue::Integer((element_size as i64).into()),
);
}
other => {
return Err(format!(
"unknown filter '{other}'; expected 'none' or 'shuffle'"
));
}
}
match pipeline.compression.as_str() {
"none" => {}
"zstd" => {
desc.compression = "zstd".to_string();
let level = pipeline.compression_level.unwrap_or(3);
desc.params.insert(
"zstd_level".to_string(),
CborValue::Integer((i64::from(level)).into()),
);
}
"lz4" => {
desc.compression = "lz4".to_string();
}
"blosc2" => {
desc.compression = "blosc2".to_string();
let clevel = pipeline.compression_level.unwrap_or(5);
desc.params.insert(
"blosc2_clevel".to_string(),
CborValue::Integer((i64::from(clevel)).into()),
);
desc.params.insert(
"blosc2_codec".to_string(),
CborValue::Text("lz4".to_string()),
);
}
"szip" => {
desc.compression = "szip".to_string();
desc.params
.insert("szip_rsi".to_string(), CborValue::Integer(128.into()));
desc.params
.insert("szip_block_size".to_string(), CborValue::Integer(16.into()));
desc.params
.insert("szip_flags".to_string(), CborValue::Integer(8.into()));
}
other => {
return Err(format!(
"unknown compression '{other}'; expected one of: none, zstd, lz4, blosc2, szip"
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use super::*;
use crate::Dtype;
use crate::types::ByteOrder;
fn mk_desc() -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: 1,
shape: vec![4],
strides: vec![1],
dtype: Dtype::Float64,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
hash: None,
}
}
fn int_param(desc: &DataObjectDescriptor, key: &str) -> i64 {
match desc.params.get(key) {
Some(CborValue::Integer(i)) => {
let n: i128 = (*i).into();
n as i64
}
other => panic!("{key} not an integer: {other:?}"),
}
}
#[test]
fn default_pipeline_is_all_none() {
let p = DataPipeline::default();
assert_eq!(p.encoding, "none");
assert_eq!(p.filter, "none");
assert_eq!(p.compression, "none");
assert!(p.bits.is_none());
assert!(p.compression_level.is_none());
}
#[test]
fn default_pipeline_leaves_descriptor_unchanged() {
let mut desc = mk_desc();
let values = [1.0, 2.0, 3.0, 4.0];
apply_pipeline(&mut desc, Some(&values), &DataPipeline::default(), "x").unwrap();
assert_eq!(desc.encoding, "none");
assert_eq!(desc.filter, "none");
assert_eq!(desc.compression, "none");
assert!(desc.params.is_empty());
}
#[test]
fn simple_packing_populates_four_params() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "simple_packing".to_string(),
bits: Some(16),
..Default::default()
};
let values = [0.0_f64, 1.0, 2.0, 3.0];
apply_pipeline(&mut desc, Some(&values), &p, "test").unwrap();
assert_eq!(desc.encoding, "simple_packing");
assert_eq!(int_param(&desc, "bits_per_value"), 16);
assert_eq!(int_param(&desc, "decimal_scale_factor"), 0);
assert!(desc.params.contains_key("reference_value"));
assert!(desc.params.contains_key("binary_scale_factor"));
}
#[test]
fn simple_packing_with_no_values_skips_with_warning() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "simple_packing".to_string(),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "int_var").unwrap();
assert_eq!(desc.encoding, "none", "should skip, not set");
assert!(desc.params.is_empty(), "no params should be inserted");
}
#[test]
fn simple_packing_with_nan_values_skips_with_warning() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "simple_packing".to_string(),
..Default::default()
};
let values = [1.0_f64, f64::NAN, 3.0];
apply_pipeline(&mut desc, Some(&values), &p, "nan_var").unwrap();
assert_eq!(desc.encoding, "none", "NaN → skip");
}
#[test]
fn unknown_encoding_errors() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "magic_packing".to_string(),
..Default::default()
};
let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
assert!(err.contains("magic_packing"));
assert!(err.contains("simple_packing"));
}
#[test]
fn shuffle_on_raw_f64_uses_native_byte_width() {
let mut desc = mk_desc(); let p = DataPipeline {
filter: "shuffle".to_string(),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "x").unwrap();
assert_eq!(desc.filter, "shuffle");
assert_eq!(int_param(&desc, "shuffle_element_size"), 8);
}
#[test]
fn shuffle_on_simple_packed_uses_post_pack_byte_width() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "simple_packing".to_string(),
bits: Some(16),
filter: "shuffle".to_string(),
..Default::default()
};
let values = [0.0_f64, 1.0, 2.0, 3.0];
apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
assert_eq!(desc.filter, "shuffle");
assert_eq!(
int_param(&desc, "shuffle_element_size"),
2,
"16-bit packed → 2-byte elements"
);
}
#[test]
fn shuffle_with_24bit_packing_rounds_up() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "simple_packing".to_string(),
bits: Some(24),
filter: "shuffle".to_string(),
..Default::default()
};
let values = [0.0_f64, 1.0, 2.0, 3.0];
apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
assert_eq!(int_param(&desc, "shuffle_element_size"), 3);
}
#[test]
fn unknown_filter_errors() {
let mut desc = mk_desc();
let p = DataPipeline {
filter: "wibble".to_string(),
..Default::default()
};
let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
assert!(err.contains("wibble"));
}
#[test]
fn zstd_with_default_level() {
let mut desc = mk_desc();
let p = DataPipeline {
compression: "zstd".to_string(),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "x").unwrap();
assert_eq!(desc.compression, "zstd");
assert_eq!(int_param(&desc, "zstd_level"), 3);
}
#[test]
fn zstd_with_custom_level() {
let mut desc = mk_desc();
let p = DataPipeline {
compression: "zstd".to_string(),
compression_level: Some(9),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "x").unwrap();
assert_eq!(int_param(&desc, "zstd_level"), 9);
}
#[test]
fn lz4_has_no_params() {
let mut desc = mk_desc();
let p = DataPipeline {
compression: "lz4".to_string(),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "x").unwrap();
assert_eq!(desc.compression, "lz4");
assert!(desc.params.is_empty());
}
#[test]
fn blosc2_with_custom_level() {
let mut desc = mk_desc();
let p = DataPipeline {
compression: "blosc2".to_string(),
compression_level: Some(7),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "x").unwrap();
assert_eq!(desc.compression, "blosc2");
assert_eq!(int_param(&desc, "blosc2_clevel"), 7);
match desc.params.get("blosc2_codec") {
Some(CborValue::Text(s)) => assert_eq!(s, "lz4"),
other => panic!("blosc2_codec should be lz4: {other:?}"),
}
}
#[test]
fn szip_sets_defaults() {
let mut desc = mk_desc();
let p = DataPipeline {
compression: "szip".to_string(),
..Default::default()
};
apply_pipeline(&mut desc, None, &p, "x").unwrap();
assert_eq!(desc.compression, "szip");
assert_eq!(int_param(&desc, "szip_rsi"), 128);
assert_eq!(int_param(&desc, "szip_block_size"), 16);
assert_eq!(int_param(&desc, "szip_flags"), 8);
}
#[test]
fn unknown_compression_errors() {
let mut desc = mk_desc();
let p = DataPipeline {
compression: "bogus".to_string(),
..Default::default()
};
let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
assert!(err.contains("bogus"));
}
#[test]
fn full_pipeline_simple_packing_shuffle_zstd() {
let mut desc = mk_desc();
let p = DataPipeline {
encoding: "simple_packing".to_string(),
bits: Some(24),
filter: "shuffle".to_string(),
compression: "zstd".to_string(),
compression_level: Some(5),
};
let values = [1.0_f64, 2.0, 3.0, 4.0];
apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
assert_eq!(desc.encoding, "simple_packing");
assert_eq!(desc.filter, "shuffle");
assert_eq!(desc.compression, "zstd");
assert_eq!(int_param(&desc, "bits_per_value"), 24);
assert_eq!(int_param(&desc, "shuffle_element_size"), 3);
assert_eq!(int_param(&desc, "zstd_level"), 5);
}
}