1use schemars::JsonSchema;
21use zstd_sys as _;
23
24use std::{borrow::Cow, io};
25
26use ndarray::Array1;
27use numcodecs::{
28 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29 Codec, StaticCodec, StaticCodecConfig,
30};
31use serde::{Deserialize, Deserializer, Serialize, Serializer};
32use thiserror::Error;
33
34#[derive(Clone, Serialize, Deserialize, JsonSchema)]
35#[serde(deny_unknown_fields)]
36pub struct ZstdCodec {
38 pub level: ZstdLevel,
42}
43
44impl Codec for ZstdCodec {
45 type Error = ZstdCodecError;
46
47 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
48 compress(data.view(), self.level)
49 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
50 }
51
52 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
53 let AnyCowArray::U8(encoded) = encoded else {
54 return Err(ZstdCodecError::EncodedDataNotBytes {
55 dtype: encoded.dtype(),
56 });
57 };
58
59 if !matches!(encoded.shape(), [_]) {
60 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
61 shape: encoded.shape().to_vec(),
62 });
63 }
64
65 decompress(&AnyCowArray::U8(encoded).as_bytes())
66 }
67
68 fn decode_into(
69 &self,
70 encoded: AnyArrayView,
71 decoded: AnyArrayViewMut,
72 ) -> Result<(), Self::Error> {
73 let AnyArrayView::U8(encoded) = encoded else {
74 return Err(ZstdCodecError::EncodedDataNotBytes {
75 dtype: encoded.dtype(),
76 });
77 };
78
79 if !matches!(encoded.shape(), [_]) {
80 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
81 shape: encoded.shape().to_vec(),
82 });
83 }
84
85 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
86 }
87}
88
89impl StaticCodec for ZstdCodec {
90 const CODEC_ID: &'static str = "zstd";
91
92 type Config<'de> = Self;
93
94 fn from_config(config: Self::Config<'_>) -> Self {
95 config
96 }
97
98 fn get_config(&self) -> StaticCodecConfig<Self> {
99 StaticCodecConfig::from(self)
100 }
101}
102
103#[derive(Clone, Copy, JsonSchema)]
104#[schemars(transparent)]
105pub struct ZstdLevel {
109 level: zstd::zstd_safe::CompressionLevel,
110}
111
112impl Serialize for ZstdLevel {
113 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
114 self.level.serialize(serializer)
115 }
116}
117
118impl<'de> Deserialize<'de> for ZstdLevel {
119 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
120 let level = Deserialize::deserialize(deserializer)?;
121
122 let level_range = zstd::compression_level_range();
123
124 if !level_range.contains(&level) {
125 return Err(serde::de::Error::custom(format!(
126 "level {level} is not in {}..={}",
127 level_range.start(),
128 level_range.end()
129 )));
130 }
131
132 Ok(Self { level })
133 }
134}
135
136#[derive(Debug, Error)]
137pub enum ZstdCodecError {
139 #[error("Zstd failed to encode the header")]
141 HeaderEncodeFailed {
142 source: ZstdHeaderError,
144 },
145 #[error("Zstd failed to decode the encoded data")]
147 ZstdEncodeFailed {
148 source: ZstdCodingError,
150 },
151 #[error(
154 "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
155 )]
156 EncodedDataNotBytes {
157 dtype: AnyArrayDType,
159 },
160 #[error("Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
163 EncodedDataNotOneDimensional {
164 shape: Vec<usize>,
166 },
167 #[error("Zstd failed to decode the header")]
169 HeaderDecodeFailed {
170 source: ZstdHeaderError,
172 },
173 #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
176 DecodeExcessiveEncodedData,
177 #[error("Zstd produced less decoded data than expected")]
179 DecodeProducedLess,
180 #[error("Zstd failed to decode the encoded data")]
182 ZstdDecodeFailed {
183 source: ZstdCodingError,
185 },
186 #[error("Zstd cannot decode into the provided array")]
188 MismatchedDecodeIntoArray {
189 #[from]
191 source: AnyArrayAssignError,
192 },
193}
194
195#[derive(Debug, Error)]
196#[error(transparent)]
197pub struct ZstdHeaderError(postcard::Error);
199
200#[derive(Debug, Error)]
201#[error(transparent)]
202pub struct ZstdCodingError(io::Error);
204
205#[allow(clippy::needless_pass_by_value)]
206pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
219 let mut encoded = postcard::to_extend(
220 &CompressionHeader {
221 dtype: array.dtype(),
222 shape: Cow::Borrowed(array.shape()),
223 },
224 Vec::new(),
225 )
226 .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
227 source: ZstdHeaderError(err),
228 })?;
229
230 zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
231 ZstdCodecError::ZstdEncodeFailed {
232 source: ZstdCodingError(err),
233 }
234 })?;
235
236 Ok(encoded)
237}
238
239pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
251 let (header, encoded) =
252 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
253 ZstdCodecError::HeaderDecodeFailed {
254 source: ZstdHeaderError(err),
255 }
256 })?;
257
258 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
259 decompress_into_bytes(encoded, decoded)
260 });
261
262 result.map(|()| decoded)
263}
264
265pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
280 let (header, encoded) =
281 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
282 ZstdCodecError::HeaderDecodeFailed {
283 source: ZstdHeaderError(err),
284 }
285 })?;
286
287 if header.dtype != decoded.dtype() {
288 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
289 source: AnyArrayAssignError::DTypeMismatch {
290 src: header.dtype,
291 dst: decoded.dtype(),
292 },
293 });
294 }
295
296 if header.shape != decoded.shape() {
297 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
298 source: AnyArrayAssignError::ShapeMismatch {
299 src: header.shape.into_owned(),
300 dst: decoded.shape().to_vec(),
301 },
302 });
303 }
304
305 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
306}
307
308fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
309 #[allow(clippy::needless_borrows_for_generic_args)]
310 zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
312 ZstdCodecError::ZstdDecodeFailed {
313 source: ZstdCodingError(err),
314 }
315 })?;
316
317 if !encoded.is_empty() {
318 return Err(ZstdCodecError::DecodeExcessiveEncodedData);
319 }
320
321 if !decoded.is_empty() {
322 return Err(ZstdCodecError::DecodeProducedLess);
323 }
324
325 Ok(())
326}
327
328#[derive(Serialize, Deserialize)]
329struct CompressionHeader<'a> {
330 dtype: AnyArrayDType,
331 #[serde(borrow)]
332 shape: Cow<'a, [usize]>,
333}