use crate::compression::CompressionError;
use crate::pipeline::ZfpMode;
fn err(msg: impl Into<String>) -> CompressionError {
CompressionError::Zfp(msg.into())
}
pub fn zfp_compress_f64(values: &[f64], mode: &ZfpMode) -> Result<Vec<u8>, CompressionError> {
let num_values = values.len();
if num_values == 0 {
return Ok(Vec::new());
}
unsafe {
let ztype = zfp_sys_cc::zfp_type_zfp_type_double;
let field =
zfp_sys_cc::zfp_field_1d(values.as_ptr() as *mut std::ffi::c_void, ztype, num_values);
if field.is_null() {
return Err(err("zfp_field_1d failed"));
}
let zfp = zfp_sys_cc::zfp_stream_open(std::ptr::null_mut());
if zfp.is_null() {
zfp_sys_cc::zfp_field_free(field);
return Err(err("zfp_stream_open failed"));
}
set_mode(zfp, mode, ztype)?;
let bufsize = zfp_sys_cc::zfp_stream_maximum_size(zfp, field);
let mut buffer = vec![0u8; bufsize as usize];
let stream = zfp_sys_cc::stream_open(buffer.as_mut_ptr() as *mut std::ffi::c_void, bufsize);
zfp_sys_cc::zfp_stream_set_bit_stream(zfp, stream);
zfp_sys_cc::zfp_stream_rewind(zfp);
let compressed_size = zfp_sys_cc::zfp_compress(zfp, field);
if compressed_size == 0 {
zfp_sys_cc::zfp_field_free(field);
zfp_sys_cc::zfp_stream_close(zfp);
zfp_sys_cc::stream_close(stream);
return Err(err("zfp_compress returned 0"));
}
zfp_sys_cc::zfp_field_free(field);
zfp_sys_cc::zfp_stream_close(zfp);
zfp_sys_cc::stream_close(stream);
buffer.truncate(compressed_size as usize);
Ok(buffer)
}
}
pub fn zfp_decompress_f64(
compressed: &[u8],
num_values: usize,
mode: &ZfpMode,
) -> Result<Vec<f64>, CompressionError> {
match (num_values, compressed.is_empty()) {
(0, true) => return Ok(Vec::new()),
(0, false) => {
return Err(err(
"num_values=0 with non-empty compressed stream (malformed zfp descriptor)"
.to_string(),
));
}
(_, true) => {
return Err(err(format!(
"num_values={num_values} with empty compressed stream (truncated or malformed payload)"
)));
}
_ => {}
}
let mut output: Vec<f64> = Vec::new();
output.try_reserve_exact(num_values).map_err(|e| {
err(format!(
"failed to reserve {} bytes for zfp decompression: {e}",
num_values.saturating_mul(std::mem::size_of::<f64>()),
))
})?;
output.resize(num_values, 0.0);
unsafe {
let ztype = zfp_sys_cc::zfp_type_zfp_type_double;
let field = zfp_sys_cc::zfp_field_1d(
output.as_mut_ptr() as *mut std::ffi::c_void,
ztype,
num_values,
);
if field.is_null() {
return Err(err("zfp_field_1d failed"));
}
let zfp = zfp_sys_cc::zfp_stream_open(std::ptr::null_mut());
if zfp.is_null() {
zfp_sys_cc::zfp_field_free(field);
return Err(err("zfp_stream_open failed"));
}
set_mode(zfp, mode, ztype)?;
let stream = zfp_sys_cc::stream_open(
compressed.as_ptr() as *mut std::ffi::c_void,
compressed.len(),
);
zfp_sys_cc::zfp_stream_set_bit_stream(zfp, stream);
zfp_sys_cc::zfp_stream_rewind(zfp);
let ret = zfp_sys_cc::zfp_decompress(zfp, field);
if ret == 0 {
zfp_sys_cc::zfp_field_free(field);
zfp_sys_cc::zfp_stream_close(zfp);
zfp_sys_cc::stream_close(stream);
return Err(err("zfp_decompress returned 0"));
}
zfp_sys_cc::zfp_field_free(field);
zfp_sys_cc::zfp_stream_close(zfp);
zfp_sys_cc::stream_close(stream);
}
Ok(output)
}
pub fn zfp_decompress_range_f64(
compressed: &[u8],
total_values: usize,
mode: &ZfpMode,
sample_offset: usize,
sample_count: usize,
) -> Result<Vec<f64>, CompressionError> {
let all = zfp_decompress_f64(compressed, total_values, mode)?;
let end = sample_offset.checked_add(sample_count).ok_or_else(|| {
err(format!(
"range end overflow: sample_offset {sample_offset} + sample_count {sample_count}"
))
})?;
if end > all.len() {
return Err(err(format!(
"range ({sample_offset}, {sample_count}) exceeds total values {total_values}"
)));
}
let mut out: Vec<f64> = Vec::new();
out.try_reserve_exact(sample_count).map_err(|e| {
err(format!(
"failed to reserve {} bytes for zfp range output: {e}",
sample_count.saturating_mul(std::mem::size_of::<f64>()),
))
})?;
out.extend_from_slice(&all[sample_offset..end]);
Ok(out)
}
unsafe fn set_mode(
zfp: *mut zfp_sys_cc::zfp_stream,
mode: &ZfpMode,
ztype: zfp_sys_cc::zfp_type,
) -> Result<(), CompressionError> {
unsafe {
match mode {
ZfpMode::FixedRate { rate } => {
zfp_sys_cc::zfp_stream_set_rate(zfp, *rate, ztype, 1, 0);
}
ZfpMode::FixedPrecision { precision } => {
zfp_sys_cc::zfp_stream_set_precision(zfp, *precision);
}
ZfpMode::FixedAccuracy { tolerance } => {
let ret = zfp_sys_cc::zfp_stream_set_accuracy(zfp, *tolerance);
if ret == 0.0 {
return Err(err("zfp_stream_set_accuracy returned 0"));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn smooth_data(n: usize) -> Vec<f64> {
(0..n)
.map(|i| (i as f64 / n as f64 * std::f64::consts::PI).sin())
.collect()
}
#[test]
fn zfp_round_trip_fixed_rate() {
let values = smooth_data(1024);
let mode = ZfpMode::FixedRate { rate: 16.0 };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
assert!(compressed.len() < values.len() * 8);
let decompressed = zfp_decompress_f64(&compressed, values.len(), &mode).unwrap();
assert_eq!(decompressed.len(), values.len());
for (orig, dec) in values.iter().zip(decompressed.iter()) {
assert!(
(orig - dec).abs() < 0.1,
"orig={orig}, dec={dec}, diff={}",
(orig - dec).abs()
);
}
}
#[test]
fn zfp_round_trip_fixed_precision() {
let values = smooth_data(256);
let mode = ZfpMode::FixedPrecision { precision: 32 };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
let decompressed = zfp_decompress_f64(&compressed, values.len(), &mode).unwrap();
for (orig, dec) in values.iter().zip(decompressed.iter()) {
assert!((orig - dec).abs() < 0.001, "orig={orig}, dec={dec}");
}
}
#[test]
fn zfp_round_trip_fixed_accuracy() {
let values = smooth_data(256);
let tol = 1e-6;
let mode = ZfpMode::FixedAccuracy { tolerance: tol };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
let decompressed = zfp_decompress_f64(&compressed, values.len(), &mode).unwrap();
for (orig, dec) in values.iter().zip(decompressed.iter()) {
assert!(
(orig - dec).abs() <= tol,
"orig={orig}, dec={dec}, diff={}, tol={tol}",
(orig - dec).abs()
);
}
}
#[test]
fn zfp_range_decode() {
let values = smooth_data(512);
let mode = ZfpMode::FixedRate { rate: 16.0 };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
let full = zfp_decompress_f64(&compressed, values.len(), &mode).unwrap();
let partial = zfp_decompress_range_f64(&compressed, values.len(), &mode, 100, 200).unwrap();
assert_eq!(partial.len(), 200);
assert_eq!(&partial[..], &full[100..300]);
}
#[test]
fn zfp_compress_empty() {
let mode = ZfpMode::FixedRate { rate: 16.0 };
let result = zfp_compress_f64(&[], &mode).unwrap();
assert!(result.is_empty());
}
#[test]
fn zfp_decompress_empty() {
let mode = ZfpMode::FixedRate { rate: 16.0 };
let result = zfp_decompress_f64(&[], 0, &mode).unwrap();
assert!(result.is_empty());
}
#[test]
fn zfp_range_exceeds_total() {
let values = smooth_data(128);
let mode = ZfpMode::FixedRate { rate: 16.0 };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
let result = zfp_decompress_range_f64(&compressed, values.len(), &mode, 100, 100);
assert!(result.is_err());
}
#[test]
fn zfp_accuracy_mode_roundtrip() {
let values = smooth_data(256);
let mode = ZfpMode::FixedAccuracy { tolerance: 0.01 };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
let decoded = zfp_decompress_f64(&compressed, values.len(), &mode).unwrap();
assert_eq!(decoded.len(), values.len());
for (orig, dec) in values.iter().zip(decoded.iter()) {
assert!((orig - dec).abs() <= 0.01);
}
}
#[test]
fn zfp_precision_mode_roundtrip() {
let values = smooth_data(256);
let mode = ZfpMode::FixedPrecision { precision: 32 };
let compressed = zfp_compress_f64(&values, &mode).unwrap();
let decoded = zfp_decompress_f64(&compressed, values.len(), &mode).unwrap();
assert_eq!(decoded.len(), values.len());
}
#[test]
fn zfp_decompress_rejects_pathological_num_values() {
let mode = ZfpMode::FixedRate { rate: 16.0 };
let err = zfp_decompress_f64(&[1u8, 2, 3, 4], usize::MAX, &mode)
.expect_err("usize::MAX num_values must fail the capacity check");
let msg = format!("{err}");
assert!(
msg.contains("failed to reserve"),
"error should report allocation failure, got: {msg}"
);
}
}