#![allow(clippy::multiple_crate_versions)]
use std::{borrow::Cow, io};
use ndarray::Array1;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
use zstd_sys as _;
type ZstdCodecVersion = StaticCodecVersion<0, 1, 0>;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ZstdCodec {
pub level: ZstdLevel,
#[serde(default, rename = "_version")]
pub version: ZstdCodecVersion,
}
impl Codec for ZstdCodec {
type Error = ZstdCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
compress(data.view(), self.level)
.map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
let AnyCowArray::U8(encoded) = encoded else {
return Err(ZstdCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(ZstdCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress(&AnyCowArray::U8(encoded).as_bytes())
}
fn decode_into(
&self,
encoded: AnyArrayView,
decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
let AnyArrayView::U8(encoded) = encoded else {
return Err(ZstdCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(ZstdCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
}
}
impl StaticCodec for ZstdCodec {
const CODEC_ID: &'static str = "zstd.rs";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<'_, Self> {
StaticCodecConfig::from(self)
}
}
#[derive(Clone, Copy, JsonSchema)]
#[schemars(transparent)]
pub struct ZstdLevel {
level: zstd::zstd_safe::CompressionLevel,
}
impl Serialize for ZstdLevel {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.level.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for ZstdLevel {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let level = Deserialize::deserialize(deserializer)?;
let level_range = zstd::compression_level_range();
if !level_range.contains(&level) {
return Err(serde::de::Error::custom(format!(
"level {level} is not in {}..={}",
level_range.start(),
level_range.end()
)));
}
Ok(Self { level })
}
}
#[derive(Debug, Error)]
pub enum ZstdCodecError {
#[error("Zstd failed to encode the header")]
HeaderEncodeFailed {
source: ZstdHeaderError,
},
#[error("Zstd failed to decode the encoded data")]
ZstdEncodeFailed {
source: ZstdCodingError,
},
#[error(
"Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
)]
EncodedDataNotBytes {
dtype: AnyArrayDType,
},
#[error(
"Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
)]
EncodedDataNotOneDimensional {
shape: Vec<usize>,
},
#[error("Zstd failed to decode the header")]
HeaderDecodeFailed {
source: ZstdHeaderError,
},
#[error("Zstd decode consumed less encoded data, which contains trailing junk")]
DecodeExcessiveEncodedData,
#[error("Zstd produced less decoded data than expected")]
DecodeProducedLess,
#[error("Zstd failed to decode the encoded data")]
ZstdDecodeFailed {
source: ZstdCodingError,
},
#[error("Zstd cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct ZstdHeaderError(postcard::Error);
#[derive(Debug, Error)]
#[error(transparent)]
pub struct ZstdCodingError(io::Error);
#[expect(clippy::needless_pass_by_value)]
pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
let mut encoded = postcard::to_extend(
&CompressionHeader {
dtype: array.dtype(),
shape: Cow::Borrowed(array.shape()),
version: StaticCodecVersion,
},
Vec::new(),
)
.map_err(|err| ZstdCodecError::HeaderEncodeFailed {
source: ZstdHeaderError(err),
})?;
zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
ZstdCodecError::ZstdEncodeFailed {
source: ZstdCodingError(err),
}
})?;
Ok(encoded)
}
pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
let (header, encoded) =
postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
ZstdCodecError::HeaderDecodeFailed {
source: ZstdHeaderError(err),
}
})?;
let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
decompress_into_bytes(encoded, decoded)
});
result.map(|()| decoded)
}
pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
let (header, encoded) =
postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
ZstdCodecError::HeaderDecodeFailed {
source: ZstdHeaderError(err),
}
})?;
if header.dtype != decoded.dtype() {
return Err(ZstdCodecError::MismatchedDecodeIntoArray {
source: AnyArrayAssignError::DTypeMismatch {
src: header.dtype,
dst: decoded.dtype(),
},
});
}
if header.shape != decoded.shape() {
return Err(ZstdCodecError::MismatchedDecodeIntoArray {
source: AnyArrayAssignError::ShapeMismatch {
src: header.shape.into_owned(),
dst: decoded.shape().to_vec(),
},
});
}
decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
}
fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
ZstdCodecError::ZstdDecodeFailed {
source: ZstdCodingError(err),
}
})?;
if !encoded.is_empty() {
return Err(ZstdCodecError::DecodeExcessiveEncodedData);
}
if !decoded.is_empty() {
return Err(ZstdCodecError::DecodeProducedLess);
}
Ok(())
}
#[derive(Serialize, Deserialize)]
struct CompressionHeader<'a> {
dtype: AnyArrayDType,
#[serde(borrow)]
shape: Cow<'a, [usize]>,
version: ZstdCodecVersion,
}