#![expect(clippy::multiple_crate_versions)]
use std::borrow::Cow;
use ndarray::Array1;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
};
use schemars::{JsonSchema, JsonSchema_repr};
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use thiserror::Error;
type ZlibCodecVersion = StaticCodecVersion<0, 1, 0>;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ZlibCodec {
pub level: ZlibLevel,
#[serde(default, rename = "_version")]
pub version: ZlibCodecVersion,
}
#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
#[repr(u8)]
#[expect(missing_docs)]
pub enum ZlibLevel {
ZNoCompression = 0,
ZBestSpeed = 1,
ZLevel2 = 2,
ZLevel3 = 3,
ZLevel4 = 4,
ZLevel5 = 5,
ZLevel6 = 6,
ZLevel7 = 7,
ZLevel8 = 8,
ZBestCompression = 9,
}
impl Codec for ZlibCodec {
type Error = ZlibCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
compress(data.view(), self.level)
.map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
let AnyCowArray::U8(encoded) = encoded else {
return Err(ZlibCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(ZlibCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress(&AnyCowArray::U8(encoded).as_bytes())
}
fn decode_into(
&self,
encoded: AnyArrayView,
decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
let AnyArrayView::U8(encoded) = encoded else {
return Err(ZlibCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(ZlibCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
}
}
impl StaticCodec for ZlibCodec {
const CODEC_ID: &'static str = "zlib.rs";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<'_, Self> {
StaticCodecConfig::from(self)
}
}
#[derive(Debug, Error)]
pub enum ZlibCodecError {
#[error("Zlib failed to encode the header")]
HeaderEncodeFailed {
source: ZlibHeaderError,
},
#[error(
"Zlib can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
)]
EncodedDataNotBytes {
dtype: AnyArrayDType,
},
#[error(
"Zlib can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
)]
EncodedDataNotOneDimensional {
shape: Vec<usize>,
},
#[error("Zlib failed to decode the header")]
HeaderDecodeFailed {
source: ZlibHeaderError,
},
#[error("Zlib decode consumed less encoded data, which contains trailing junk")]
DecodeExcessiveEncodedData,
#[error("Zlib produced less decoded data than expected")]
DecodeProducedLess,
#[error("Zlib failed to decode the encoded data")]
ZlibDecodeFailed {
source: ZlibDecodeError,
},
#[error("Zlib cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct ZlibHeaderError(postcard::Error);
#[derive(Debug, Error)]
#[error(transparent)]
pub struct ZlibDecodeError(miniz_oxide::inflate::DecompressError);
#[expect(clippy::needless_pass_by_value)]
pub fn compress(array: AnyArrayView, level: ZlibLevel) -> Result<Vec<u8>, ZlibCodecError> {
let data = array.as_bytes();
let mut encoded = postcard::to_extend(
&CompressionHeader {
dtype: array.dtype(),
shape: Cow::Borrowed(array.shape()),
version: StaticCodecVersion,
},
Vec::new(),
)
.map_err(|err| ZlibCodecError::HeaderEncodeFailed {
source: ZlibHeaderError(err),
})?;
let mut in_pos = 0;
let mut out_pos = encoded.len();
let flags =
miniz_oxide::deflate::core::create_comp_flags_from_zip_params((level as u8).into(), 1, 0);
let mut compressor = miniz_oxide::deflate::core::CompressorOxide::new(flags);
encoded.resize(encoded.len() + (data.len() / 2).max(2), 0);
loop {
let (Some(data_left), Some(encoded_left)) =
(data.get(in_pos..), encoded.get_mut(out_pos..))
else {
#[expect(clippy::panic)] {
panic!("Zlib encode bug: input or output is out of bounds")
}
};
let (status, bytes_in, bytes_out) = miniz_oxide::deflate::core::compress(
&mut compressor,
data_left,
encoded_left,
miniz_oxide::deflate::core::TDEFLFlush::Finish,
);
out_pos += bytes_out;
in_pos += bytes_in;
match status {
miniz_oxide::deflate::core::TDEFLStatus::Okay => {
if encoded.len().saturating_sub(out_pos) < 30 {
encoded.resize(encoded.len() * 2, 0);
}
}
miniz_oxide::deflate::core::TDEFLStatus::Done => {
encoded.truncate(out_pos);
assert!(
in_pos == data.len(),
"Zlib encode bug: consumed less input than expected"
);
return Ok(encoded);
}
#[expect(clippy::panic)] err => panic!("Zlib encode bug: {err:?}"),
}
}
}
pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZlibCodecError> {
let (header, encoded) =
postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
ZlibCodecError::HeaderDecodeFailed {
source: ZlibHeaderError(err),
}
})?;
let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
decompress_into_bytes(encoded, decoded)
});
result.map(|()| decoded)
}
pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZlibCodecError> {
let (header, encoded) =
postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
ZlibCodecError::HeaderDecodeFailed {
source: ZlibHeaderError(err),
}
})?;
if header.dtype != decoded.dtype() {
return Err(ZlibCodecError::MismatchedDecodeIntoArray {
source: AnyArrayAssignError::DTypeMismatch {
src: header.dtype,
dst: decoded.dtype(),
},
});
}
if header.shape != decoded.shape() {
return Err(ZlibCodecError::MismatchedDecodeIntoArray {
source: AnyArrayAssignError::ShapeMismatch {
src: header.shape.into_owned(),
dst: decoded.shape().to_vec(),
},
});
}
decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
}
fn decompress_into_bytes(encoded: &[u8], decoded: &mut [u8]) -> Result<(), ZlibCodecError> {
let flags = miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_PARSE_ZLIB_HEADER
| miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF;
let mut decomp = Box::<miniz_oxide::inflate::core::DecompressorOxide>::default();
let (status, in_consumed, out_consumed) =
miniz_oxide::inflate::core::decompress(&mut decomp, encoded, decoded, 0, flags);
match status {
miniz_oxide::inflate::TINFLStatus::Done => {
if in_consumed != encoded.len() {
Err(ZlibCodecError::DecodeExcessiveEncodedData)
} else if out_consumed == decoded.len() {
Ok(())
} else {
Err(ZlibCodecError::DecodeProducedLess)
}
}
status => Err(ZlibCodecError::ZlibDecodeFailed {
source: ZlibDecodeError(miniz_oxide::inflate::DecompressError {
status,
output: Vec::new(),
}),
}),
}
}
#[derive(Serialize, Deserialize)]
struct CompressionHeader<'a> {
dtype: AnyArrayDType,
#[serde(borrow)]
shape: Cow<'a, [usize]>,
version: ZlibCodecVersion,
}