numcodecs_zstd/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.76.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-zstd
10//! [crates.io]: https://crates.io/crates/numcodecs-zstd
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-zstd
13//! [docs.rs]: https://docs.rs/numcodecs-zstd/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_zstd
17//!
18//! Zstandard codec implementation for the [`numcodecs`] API.
19
20use schemars::JsonSchema;
21// Only used to explicitly enable the `no_wasm_shim` feature in zstd/zstd-sys
22use zstd_sys as _;
23
24use std::{borrow::Cow, io};
25
26use ndarray::Array1;
27use numcodecs::{
28    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29    Codec, StaticCodec, StaticCodecConfig,
30};
31use serde::{Deserialize, Deserializer, Serialize, Serializer};
32use thiserror::Error;
33
34#[derive(Clone, Serialize, Deserialize, JsonSchema)]
35#[serde(deny_unknown_fields)]
36/// Codec providing compression using Zstandard
37pub struct ZstdCodec {
38    /// Zstandard compression level.
39    ///
40    /// The level ranges from small (fastest) to large (best compression).
41    pub level: ZstdLevel,
42}
43
44impl Codec for ZstdCodec {
45    type Error = ZstdCodecError;
46
47    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
48        compress(data.view(), self.level)
49            .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
50    }
51
52    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
53        let AnyCowArray::U8(encoded) = encoded else {
54            return Err(ZstdCodecError::EncodedDataNotBytes {
55                dtype: encoded.dtype(),
56            });
57        };
58
59        if !matches!(encoded.shape(), [_]) {
60            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
61                shape: encoded.shape().to_vec(),
62            });
63        }
64
65        decompress(&AnyCowArray::U8(encoded).as_bytes())
66    }
67
68    fn decode_into(
69        &self,
70        encoded: AnyArrayView,
71        decoded: AnyArrayViewMut,
72    ) -> Result<(), Self::Error> {
73        let AnyArrayView::U8(encoded) = encoded else {
74            return Err(ZstdCodecError::EncodedDataNotBytes {
75                dtype: encoded.dtype(),
76            });
77        };
78
79        if !matches!(encoded.shape(), [_]) {
80            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
81                shape: encoded.shape().to_vec(),
82            });
83        }
84
85        decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
86    }
87}
88
89impl StaticCodec for ZstdCodec {
90    const CODEC_ID: &'static str = "zstd";
91
92    type Config<'de> = Self;
93
94    fn from_config(config: Self::Config<'_>) -> Self {
95        config
96    }
97
98    fn get_config(&self) -> StaticCodecConfig<Self> {
99        StaticCodecConfig::from(self)
100    }
101}
102
103#[derive(Clone, Copy, JsonSchema)]
104#[schemars(transparent)]
105/// Zstandard compression level.
106///
107/// The level ranges from small (fastest) to large (best compression).
108pub struct ZstdLevel {
109    level: zstd::zstd_safe::CompressionLevel,
110}
111
112impl Serialize for ZstdLevel {
113    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
114        self.level.serialize(serializer)
115    }
116}
117
118impl<'de> Deserialize<'de> for ZstdLevel {
119    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
120        let level = Deserialize::deserialize(deserializer)?;
121
122        let level_range = zstd::compression_level_range();
123
124        if !level_range.contains(&level) {
125            return Err(serde::de::Error::custom(format!(
126                "level {level} is not in {}..={}",
127                level_range.start(),
128                level_range.end()
129            )));
130        }
131
132        Ok(Self { level })
133    }
134}
135
136#[derive(Debug, Error)]
137/// Errors that may occur when applying the [`ZstdCodec`].
138pub enum ZstdCodecError {
139    /// [`ZstdCodec`] failed to encode the header
140    #[error("Zstd failed to encode the header")]
141    HeaderEncodeFailed {
142        /// Opaque source error
143        source: ZstdHeaderError,
144    },
145    /// [`ZstdCodec`] failed to encode the encoded data
146    #[error("Zstd failed to decode the encoded data")]
147    ZstdEncodeFailed {
148        /// Opaque source error
149        source: ZstdCodingError,
150    },
151    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
152    /// an array of a different dtype
153    #[error(
154        "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
155    )]
156    EncodedDataNotBytes {
157        /// The unexpected dtype of the encoded array
158        dtype: AnyArrayDType,
159    },
160    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
161    /// an array of a different shape
162    #[error("Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
163    EncodedDataNotOneDimensional {
164        /// The unexpected shape of the encoded array
165        shape: Vec<usize>,
166    },
167    /// [`ZstdCodec`] failed to encode the header
168    #[error("Zstd failed to decode the header")]
169    HeaderDecodeFailed {
170        /// Opaque source error
171        source: ZstdHeaderError,
172    },
173    /// [`ZstdCodec`] decode consumed less encoded data, which contains trailing
174    /// junk
175    #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
176    DecodeExcessiveEncodedData,
177    /// [`ZstdCodec`] produced less decoded data than expected
178    #[error("Zstd produced less decoded data than expected")]
179    DecodeProducedLess,
180    /// [`ZstdCodec`] failed to decode the encoded data
181    #[error("Zstd failed to decode the encoded data")]
182    ZstdDecodeFailed {
183        /// Opaque source error
184        source: ZstdCodingError,
185    },
186    /// [`ZstdCodec`] cannot decode into the provided array
187    #[error("Zstd cannot decode into the provided array")]
188    MismatchedDecodeIntoArray {
189        /// The source of the error
190        #[from]
191        source: AnyArrayAssignError,
192    },
193}
194
195#[derive(Debug, Error)]
196#[error(transparent)]
197/// Opaque error for when encoding or decoding the header fails
198pub struct ZstdHeaderError(postcard::Error);
199
200#[derive(Debug, Error)]
201#[error(transparent)]
202/// Opaque error for when encoding or decoding with Zstandard fails
203pub struct ZstdCodingError(io::Error);
204
205#[allow(clippy::needless_pass_by_value)]
206/// Compress the `array` using Zstandard with the provided `level`.
207///
208/// # Errors
209///
210/// Errors with
211/// - [`ZstdCodecError::HeaderEncodeFailed`] if encoding the header to the
212///   output bytevec failed
213/// - [`ZstdCodecError::ZstdEncodeFailed`] if an opaque encoding error occurred
214///
215/// # Panics
216///
217/// Panics if the infallible encoding with Zstd fails.
218pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
219    let mut encoded = postcard::to_extend(
220        &CompressionHeader {
221            dtype: array.dtype(),
222            shape: Cow::Borrowed(array.shape()),
223        },
224        Vec::new(),
225    )
226    .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
227        source: ZstdHeaderError(err),
228    })?;
229
230    zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
231        ZstdCodecError::ZstdEncodeFailed {
232            source: ZstdCodingError(err),
233        }
234    })?;
235
236    Ok(encoded)
237}
238
239/// Decompress the `encoded` data into an array using Zstandard.
240///
241/// # Errors
242///
243/// Errors with
244/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
245/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
246///   contains excessive trailing data junk
247/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
248///   expected
249/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
250pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
251    let (header, encoded) =
252        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
253            ZstdCodecError::HeaderDecodeFailed {
254                source: ZstdHeaderError(err),
255            }
256        })?;
257
258    let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
259        decompress_into_bytes(encoded, decoded)
260    });
261
262    result.map(|()| decoded)
263}
264
265/// Decompress the `encoded` data into a `decoded` array using Zstandard.
266///
267/// # Errors
268///
269/// Errors with
270/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
271/// - [`ZstdCodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
272///   the wrong dtype or shape
273/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
274/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
275///   contains excessive trailing data junk
276/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
277///   expected
278/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
279pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
280    let (header, encoded) =
281        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
282            ZstdCodecError::HeaderDecodeFailed {
283                source: ZstdHeaderError(err),
284            }
285        })?;
286
287    if header.dtype != decoded.dtype() {
288        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
289            source: AnyArrayAssignError::DTypeMismatch {
290                src: header.dtype,
291                dst: decoded.dtype(),
292            },
293        });
294    }
295
296    if header.shape != decoded.shape() {
297        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
298            source: AnyArrayAssignError::ShapeMismatch {
299                src: header.shape.into_owned(),
300                dst: decoded.shape().to_vec(),
301            },
302        });
303    }
304
305    decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
306}
307
308fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
309    #[allow(clippy::needless_borrows_for_generic_args)]
310    // we want to check encoded and decoded for full consumption after the decoding
311    zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
312        ZstdCodecError::ZstdDecodeFailed {
313            source: ZstdCodingError(err),
314        }
315    })?;
316
317    if !encoded.is_empty() {
318        return Err(ZstdCodecError::DecodeExcessiveEncodedData);
319    }
320
321    if !decoded.is_empty() {
322        return Err(ZstdCodecError::DecodeProducedLess);
323    }
324
325    Ok(())
326}
327
328#[derive(Serialize, Deserialize)]
329struct CompressionHeader<'a> {
330    dtype: AnyArrayDType,
331    #[serde(borrow)]
332    shape: Cow<'a, [usize]>,
333}