#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ErrorBound {
Eps(f64),
RMSE(f64),
PSNR(f64),
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[allow(missing_docs)]
pub enum Buffer {
U8(Vec<u8>),
U16(Vec<u16>),
I32(Vec<i32>),
F32(Vec<f32>),
F64(Vec<f64>),
}
pub fn compress<T: Element>(
data: &[T],
shape: &[usize],
target: ErrorBound,
verbose: bool,
debug: bool,
) -> Result<Vec<u8>, Error> {
if shape.len() < 3 {
return Err(Error::InsufficientDimensionality);
}
if shape.iter().copied().product::<usize>() != data.len() {
return Err(Error::InvalidShape);
}
let shape = shape
.iter()
.copied()
.map(u32::try_from)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| Error::ExcessiveSize)?;
let target_value = match target {
ErrorBound::Eps(v) | ErrorBound::RMSE(v) | ErrorBound::PSNR(v) => v,
};
if target_value < 0.0 {
return Err(Error::NegativeErrorBound);
}
let mut output = std::ptr::null_mut();
let mut output_size = 0;
#[allow(unsafe_code)] unsafe {
tthresh_sys::compress_buffer(
data.as_ptr().cast::<std::ffi::c_char>(),
T::IO_TYPE,
shape.as_ptr(),
shape.len(),
std::ptr::from_mut(&mut output),
std::ptr::from_mut(&mut output_size),
match target {
ErrorBound::Eps(_) => tthresh_sys::Target_eps,
ErrorBound::RMSE(_) => tthresh_sys::Target_rmse,
ErrorBound::PSNR(_) => tthresh_sys::Target_psnr,
},
target_value,
Some(alloc),
verbose,
debug,
);
}
#[allow(unsafe_code)]
let compressed = unsafe { Vec::from_raw_parts(output, output_size, output_size) };
Ok(compressed)
}
pub fn decompress(
compressed: &[u8],
verbose: bool,
debug: bool,
) -> Result<(Buffer, Vec<usize>), Error> {
let mut shape = std::ptr::null_mut();
let mut shape_size = 0;
let mut output = std::ptr::null_mut();
let mut output_type = 0;
let mut output_length = 0;
#[allow(unsafe_code)] let ok = unsafe {
tthresh_sys::decompress_buffer(
compressed.as_ptr(),
compressed.len(),
std::ptr::from_mut(&mut output),
std::ptr::from_mut(&mut output_type),
std::ptr::from_mut(&mut output_length),
std::ptr::from_mut(&mut shape),
std::ptr::from_mut(&mut shape_size),
Some(alloc),
verbose,
debug,
)
};
if !ok {
return Err(Error::CorruptedCompressedBytes);
}
#[allow(unsafe_code)]
let shape = unsafe { Vec::from_raw_parts(shape, shape_size, shape_size) };
#[allow(unsafe_code)]
let decompressed = match output_type {
tthresh_sys::IOType_uchar_ => {
Buffer::U8(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
}
tthresh_sys::IOType_ushort_ => {
Buffer::U16(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
}
tthresh_sys::IOType_int_ => {
Buffer::I32(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
}
tthresh_sys::IOType_float_ => {
Buffer::F32(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
}
tthresh_sys::IOType_double_ => {
Buffer::F64(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
}
#[allow(clippy::unreachable)]
_ => unreachable!("tthresh decompression returned an unknown output type"),
};
let shape = shape
.into_iter()
.map(usize::try_from)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| Error::ExcessiveSize)?;
Ok((decompressed, shape))
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("data must be at least three-dimensional")]
InsufficientDimensionality,
#[error("shape does not match the provided buffer")]
InvalidShape,
#[error("data shape sizes must fit within [0; 2^32 - 1]")]
ExcessiveSize,
#[error("error bound must be non-negative")]
NegativeErrorBound,
#[error("compressed bytes have been corrupted")]
CorruptedCompressedBytes,
}
pub trait Element: sealed::Element {}
mod sealed {
pub trait Element: Copy {
const IO_TYPE: tthresh_sys::IOType;
}
}
impl Element for u8 {}
impl sealed::Element for u8 {
const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_uchar_;
}
impl Element for u16 {}
impl sealed::Element for u16 {
const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_ushort_;
}
impl Element for i32 {}
impl sealed::Element for i32 {
const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_int_;
}
impl Element for f32 {}
impl sealed::Element for f32 {
const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_float_;
}
impl Element for f64 {}
impl sealed::Element for f64 {
const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_double_;
}
extern "C" fn alloc(size: usize, align: usize) -> *mut std::ffi::c_void {
#[allow(clippy::unwrap_used)]
let layout = std::alloc::Layout::from_size_align(size, align).unwrap();
if layout.size() == 0 {
#[allow(clippy::useless_transmute)]
#[allow(unsafe_code)]
return unsafe { std::mem::transmute(align) };
}
#[allow(unsafe_code)]
unsafe { std::alloc::alloc_zeroed(layout) }.cast()
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
fn compress_decompress(target: ErrorBound) {
let data = std::fs::read("tthresh-sys/tthresh/data/3D_sphere_64_uchar.raw")
.expect("input file should not be missing");
let compressed = compress(data.as_slice(), &[64, 64, 64], target, true, true)
.expect("compression should not fail");
let (decompressed, shape) =
decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
assert!(matches!(decompressed, Buffer::U8(_)));
assert_eq!(shape, &[64, 64, 64]);
}
#[test]
fn compress_decompress_eps() {
compress_decompress(ErrorBound::Eps(0.5));
}
#[test]
fn compress_decompress_rmse() {
compress_decompress(ErrorBound::RMSE(0.1));
}
#[test]
fn compress_decompress_psnr() {
compress_decompress(ErrorBound::PSNR(30.0));
}
#[test]
fn compress_decompress_u8() {
let compressed = compress(&[42_u8], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
.expect("compression should not fail");
let (decompressed, shape) =
decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
assert_eq!(decompressed, Buffer::U8(vec![42]));
assert_eq!(shape, &[1, 1, 1]);
}
#[test]
fn compress_decompress_u16() {
let compressed = compress(&[42_u16], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
.expect("compression should not fail");
let (decompressed, shape) =
decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
assert_eq!(decompressed, Buffer::U16(vec![42]));
assert_eq!(shape, &[1, 1, 1]);
}
#[test]
fn compress_decompress_i32() {
let compressed = compress(&[42_i32], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
.expect("compression should not fail");
let (decompressed, shape) =
decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
assert_eq!(decompressed, Buffer::I32(vec![42]));
assert_eq!(shape, &[1, 1, 1]);
}
#[test]
fn compress_decompress_f32() {
let compressed = compress(&[42.0_f32], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
.expect("compression should not fail");
let (decompressed, shape) =
decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
assert_eq!(decompressed, Buffer::F32(vec![42.0]));
assert_eq!(shape, &[1, 1, 1]);
}
#[test]
fn compress_decompress_f64() {
let compressed = compress(&[42.0_f64], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
.expect("compression should not fail");
let (decompressed, shape) =
decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
assert_eq!(decompressed, Buffer::F64(vec![42.0]));
assert_eq!(shape, &[1, 1, 1]);
}
#[test]
fn decompress_empty_garbage() {
let result = decompress(&[0], true, true);
assert!(matches!(result, Err(Error::CorruptedCompressedBytes)));
}
#[test]
fn decompress_full_garbage() {
let result = decompress(vec![1; 1024].as_slice(), true, true);
assert!(matches!(result, Err(Error::CorruptedCompressedBytes)));
}
}