#![allow(clippy::multiple_crate_versions)]
use std::{borrow::Cow, fmt};
use ndarray::{Array, Array1, ArrayView, Dimension, Zip};
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(test)]
use ::serde_json as _;
mod ffi;
type ZfpCodecVersion = StaticCodecVersion<0, 2, 0>;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct ZfpCodec {
#[serde(flatten)]
pub mode: ZfpCompressionMode,
#[serde(default)]
pub non_finite: ZfpNonFiniteValuesMode,
#[serde(default, rename = "_version")]
pub version: ZfpCodecVersion,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "mode")]
#[serde(deny_unknown_fields)]
pub enum ZfpCompressionMode {
#[serde(rename = "expert")]
Expert {
min_bits: u32,
max_bits: u32,
max_prec: u32,
min_exp: i32,
},
#[serde(rename = "fixed-rate")]
FixedRate {
rate: f64,
},
#[serde(rename = "fixed-precision")]
FixedPrecision {
precision: u32,
},
#[serde(rename = "fixed-accuracy")]
FixedAccuracy {
tolerance: f64,
},
#[serde(rename = "reversible")]
Reversible,
}
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub enum ZfpNonFiniteValuesMode {
#[default]
#[serde(rename = "deny")]
Deny,
#[serde(rename = "allow-unsafe")]
AllowUnsafe,
}
impl Codec for ZfpCodec {
type Error = ZfpCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
if matches!(data.dtype(), AnyArrayDType::I32 | AnyArrayDType::I64)
&& matches!(
self.mode,
ZfpCompressionMode::FixedAccuracy { tolerance: _ }
)
{
return Err(ZfpCodecError::FixedAccuracyModeIntegerData);
}
match data {
AnyCowArray::I32(data) => Ok(AnyArray::U8(
Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
)),
AnyCowArray::I64(data) => Ok(AnyArray::U8(
Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
)),
AnyCowArray::F32(data) => Ok(AnyArray::U8(
Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
)),
AnyCowArray::F64(data) => Ok(AnyArray::U8(
Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
)),
encoded => Err(ZfpCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
let AnyCowArray::U8(encoded) = encoded else {
return Err(ZfpCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(ZfpCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress(&AnyCowArray::U8(encoded).as_bytes())
}
fn decode_into(
&self,
encoded: AnyArrayView,
decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
let AnyArrayView::U8(encoded) = encoded else {
return Err(ZfpCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(ZfpCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
}
}
impl StaticCodec for ZfpCodec {
const CODEC_ID: &'static str = "zfp.rs";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<'_, Self> {
StaticCodecConfig::from(self)
}
}
#[derive(Debug, Error)]
pub enum ZfpCodecError {
#[error("Zfp does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
#[error("Zfp does not support the fixed accuracy mode for integer data")]
FixedAccuracyModeIntegerData,
#[error("Zfp only supports 1-4 dimensional data but found shape {shape:?}")]
ExcessiveDimensionality {
shape: Vec<usize>,
},
#[error("Zfp was configured with an invalid expert mode {mode:?}")]
InvalidExpertMode {
mode: ZfpCompressionMode,
},
#[error(
"Zfp does not support non-finite (infinite or NaN) floating point data in non-reversible lossy compression"
)]
NonFiniteData,
#[error("Zfp failed to encode the header")]
HeaderEncodeFailed,
#[error("Zfp failed to encode the array metadata header")]
MetaHeaderEncodeFailed {
source: ZfpHeaderError,
},
#[error("Zfp failed to encode the data")]
ZfpEncodeFailed,
#[error(
"Zfp can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
)]
EncodedDataNotBytes {
dtype: AnyArrayDType,
},
#[error(
"Zfp can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
)]
EncodedDataNotOneDimensional {
shape: Vec<usize>,
},
#[error("Zfp failed to decode the header")]
HeaderDecodeFailed,
#[error("Zfp failed to decode the array metadata header")]
MetaHeaderDecodeFailed {
source: ZfpHeaderError,
},
#[error("ZfpCodec cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
#[error("Zfp failed to decode the data")]
ZfpDecodeFailed,
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct ZfpHeaderError(postcard::Error);
pub fn compress<T: ffi::ZfpCompressible, D: Dimension>(
data: ArrayView<T, D>,
mode: &ZfpCompressionMode,
non_finite: ZfpNonFiniteValuesMode,
) -> Result<Vec<u8>, ZfpCodecError> {
if !matches!(mode, ZfpCompressionMode::Reversible)
&& !matches!(non_finite, ZfpNonFiniteValuesMode::AllowUnsafe)
&& !Zip::from(&data).all(|x| x.is_finite())
{
return Err(ZfpCodecError::NonFiniteData);
}
let mut encoded = postcard::to_extend(
&CompressionHeader {
dtype: <T as ffi::ZfpCompressible>::D_TYPE,
shape: Cow::Borrowed(data.shape()),
version: StaticCodecVersion,
},
Vec::new(),
)
.map_err(|err| ZfpCodecError::MetaHeaderEncodeFailed {
source: ZfpHeaderError(err),
})?;
if data.is_empty() {
return Ok(encoded);
}
let field = ffi::ZfpField::new(data.into_dyn().squeeze())?;
let stream = ffi::ZfpCompressionStream::new(&field, mode)?;
let stream = stream.with_bitstream(field, &mut encoded);
let stream = stream.write_header()?;
stream.compress()?;
Ok(encoded)
}
pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZfpCodecError> {
let (header, encoded) =
postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
ZfpCodecError::MetaHeaderDecodeFailed {
source: ZfpHeaderError(err),
}
})?;
if header.shape.iter().copied().product::<usize>() == 0 {
let decoded = match header.dtype {
ZfpDType::I32 => AnyArray::I32(Array::zeros(&*header.shape)),
ZfpDType::I64 => AnyArray::I64(Array::zeros(&*header.shape)),
ZfpDType::F32 => AnyArray::F32(Array::zeros(&*header.shape)),
ZfpDType::F64 => AnyArray::F64(Array::zeros(&*header.shape)),
};
return Ok(decoded);
}
let stream = ffi::ZfpDecompressionStream::new(encoded);
let stream = stream.read_header()?;
match header.dtype {
ZfpDType::I32 => {
let mut decompressed = Array::zeros(&*header.shape);
stream.decompress_into(decompressed.view_mut().squeeze())?;
Ok(AnyArray::I32(decompressed))
}
ZfpDType::I64 => {
let mut decompressed = Array::zeros(&*header.shape);
stream.decompress_into(decompressed.view_mut().squeeze())?;
Ok(AnyArray::I64(decompressed))
}
ZfpDType::F32 => {
let mut decompressed = Array::zeros(&*header.shape);
stream.decompress_into(decompressed.view_mut().squeeze())?;
Ok(AnyArray::F32(decompressed))
}
ZfpDType::F64 => {
let mut decompressed = Array::zeros(&*header.shape);
stream.decompress_into(decompressed.view_mut().squeeze())?;
Ok(AnyArray::F64(decompressed))
}
}
}
pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), ZfpCodecError> {
let (header, encoded) =
postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
ZfpCodecError::MetaHeaderDecodeFailed {
source: ZfpHeaderError(err),
}
})?;
if decoded.shape() != &*header.shape {
return Err(ZfpCodecError::MismatchedDecodeIntoArray {
source: AnyArrayAssignError::ShapeMismatch {
src: header.shape.into_owned(),
dst: decoded.shape().to_vec(),
},
});
}
if decoded.is_empty() {
return Ok(());
}
let stream = ffi::ZfpDecompressionStream::new(encoded);
let stream = stream.read_header()?;
match (decoded, header.dtype) {
(AnyArrayViewMut::I32(decoded), ZfpDType::I32) => stream.decompress_into(decoded.squeeze()),
(AnyArrayViewMut::I64(decoded), ZfpDType::I64) => stream.decompress_into(decoded.squeeze()),
(AnyArrayViewMut::F32(decoded), ZfpDType::F32) => stream.decompress_into(decoded.squeeze()),
(AnyArrayViewMut::F64(decoded), ZfpDType::F64) => stream.decompress_into(decoded.squeeze()),
(decoded, dtype) => Err(ZfpCodecError::MismatchedDecodeIntoArray {
source: AnyArrayAssignError::DTypeMismatch {
src: dtype.into_dtype(),
dst: decoded.dtype(),
},
}),
}
}
#[derive(Serialize, Deserialize)]
struct CompressionHeader<'a> {
dtype: ZfpDType,
#[serde(borrow)]
shape: Cow<'a, [usize]>,
version: ZfpCodecVersion,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
#[expect(missing_docs)]
pub enum ZfpDType {
#[serde(rename = "i32", alias = "int32")]
I32,
#[serde(rename = "i64", alias = "int64")]
I64,
#[serde(rename = "f32", alias = "float32")]
F32,
#[serde(rename = "f64", alias = "float64")]
F64,
}
impl ZfpDType {
#[must_use]
pub const fn into_dtype(self) -> AnyArrayDType {
match self {
Self::I32 => AnyArrayDType::I32,
Self::I64 => AnyArrayDType::I64,
Self::F32 => AnyArrayDType::F32,
Self::F64 => AnyArrayDType::F64,
}
}
}
impl fmt::Display for ZfpDType {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(match self {
Self::I32 => "i32",
Self::I64 => "i64",
Self::F32 => "f32",
Self::F64 => "f64",
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use ndarray::ArrayView1;
use super::*;
#[test]
fn zero_length() {
let encoded = compress(
Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])
.unwrap()
.view(),
&ZfpCompressionMode::FixedPrecision { precision: 7 },
ZfpNonFiniteValuesMode::Deny,
)
.unwrap();
let decoded = decompress(&encoded).unwrap();
assert_eq!(decoded.dtype(), AnyArrayDType::F32);
assert!(decoded.is_empty());
assert_eq!(decoded.shape(), &[1, 27, 0]);
}
#[test]
fn one_dimension() {
let data = Array::from_shape_vec(
[2_usize, 1, 2, 1, 1, 1].as_slice(),
vec![1.0, 2.0, 3.0, 4.0],
)
.unwrap();
let encoded = compress(
data.view(),
&ZfpCompressionMode::FixedAccuracy { tolerance: 0.1 },
ZfpNonFiniteValuesMode::Deny,
)
.unwrap();
let decoded = decompress(&encoded).unwrap();
assert_eq!(decoded, AnyArray::F32(data));
}
#[test]
fn small_state() {
for data in [
&[][..],
&[0.0],
&[0.0, 1.0],
&[0.0, 1.0, 0.0],
&[0.0, 1.0, 0.0, 1.0],
] {
let encoded = compress(
ArrayView1::from(data),
&ZfpCompressionMode::FixedAccuracy { tolerance: 0.1 },
ZfpNonFiniteValuesMode::Deny,
)
.unwrap();
let decoded = decompress(&encoded).unwrap();
assert_eq!(
decoded,
AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
);
}
}
}