numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/Sz3-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41// serde cannot deny unknown fields because of the flatten
42#[schemars(deny_unknown_fields)]
43/// Codec providing compression using SZ3
44pub struct Sz3Codec {
45    /// Predictor
46    #[serde(default = "default_predictor")]
47    pub predictor: Option<Sz3Predictor>,
48    /// SZ3 error bound
49    #[serde(flatten)]
50    pub error_bound: Sz3ErrorBound,
51}
52
53/// SZ3 error bound
54#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
55#[serde(tag = "eb_mode")]
56#[serde(deny_unknown_fields)]
57pub enum Sz3ErrorBound {
58    /// Errors are bounded by *both* the absolute and relative error, i.e. by
59    /// whichever bound is stricter
60    #[serde(rename = "abs-and-rel")]
61    AbsoluteAndRelative {
62        /// Absolute error bound
63        #[serde(rename = "eb_abs")]
64        abs: f64,
65        /// Relative error bound
66        #[serde(rename = "eb_rel")]
67        rel: f64,
68    },
69    /// Errors are bounded by *either* the absolute or relative error, i.e. by
70    /// whichever bound is weaker
71    #[serde(rename = "abs-or-rel")]
72    AbsoluteOrRelative {
73        /// Absolute error bound
74        #[serde(rename = "eb_abs")]
75        abs: f64,
76        /// Relative error bound
77        #[serde(rename = "eb_rel")]
78        rel: f64,
79    },
80    /// Absolute error bound
81    #[serde(rename = "abs")]
82    Absolute {
83        /// Absolute error bound
84        #[serde(rename = "eb_abs")]
85        abs: f64,
86    },
87    /// Relative error bound
88    #[serde(rename = "rel")]
89    Relative {
90        /// Relative error bound
91        #[serde(rename = "eb_rel")]
92        rel: f64,
93    },
94    /// Peak signal to noise ratio error bound
95    #[serde(rename = "psnr")]
96    PS2NR {
97        /// Peak signal to noise ratio error bound
98        #[serde(rename = "eb_psnr")]
99        psnr: f64,
100    },
101    /// Peak L2 norm error bound
102    #[serde(rename = "l2")]
103    L2Norm {
104        /// Peak L2 norm error bound
105        #[serde(rename = "eb_l2")]
106        l2: f64,
107    },
108}
109
110/// SZ3 predictor
111#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
112#[serde(deny_unknown_fields)]
113pub enum Sz3Predictor {
114    /// Linear interpolation
115    #[serde(rename = "linear-interpolation")]
116    LinearInterpolation,
117    /// Cubic interpolation
118    #[serde(rename = "cubic-interpolation")]
119    CubicInterpolation,
120    /// Linear interpolation + Lorenzo predictor
121    #[serde(rename = "linear-interpolation-lorenzo")]
122    LinearInterpolationLorenzo,
123    /// Cubic interpolation + Lorenzo predictor
124    #[serde(rename = "cubic-interpolation-lorenzo")]
125    CubicInterpolationLorenzo,
126    /// Lorenzo predictor + regression
127    #[serde(rename = "lorenzo-regression")]
128    LorenzoRegression,
129}
130
131#[expect(clippy::unnecessary_wraps)]
132const fn default_predictor() -> Option<Sz3Predictor> {
133    Some(Sz3Predictor::CubicInterpolationLorenzo)
134}
135
136impl Codec for Sz3Codec {
137    type Error = Sz3CodecError;
138
139    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
140        match data {
141            AnyCowArray::I32(data) => Ok(AnyArray::U8(
142                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
143                    .into_dyn(),
144            )),
145            AnyCowArray::I64(data) => Ok(AnyArray::U8(
146                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
147                    .into_dyn(),
148            )),
149            AnyCowArray::F32(data) => Ok(AnyArray::U8(
150                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
151                    .into_dyn(),
152            )),
153            AnyCowArray::F64(data) => Ok(AnyArray::U8(
154                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
155                    .into_dyn(),
156            )),
157            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
158        }
159    }
160
161    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
162        let AnyCowArray::U8(encoded) = encoded else {
163            return Err(Sz3CodecError::EncodedDataNotBytes {
164                dtype: encoded.dtype(),
165            });
166        };
167
168        if !matches!(encoded.shape(), [_]) {
169            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
170                shape: encoded.shape().to_vec(),
171            });
172        }
173
174        decompress(&AnyCowArray::U8(encoded).as_bytes())
175    }
176
177    fn decode_into(
178        &self,
179        encoded: AnyArrayView,
180        mut decoded: AnyArrayViewMut,
181    ) -> Result<(), Self::Error> {
182        let decoded_in = self.decode(encoded.cow())?;
183
184        Ok(decoded.assign(&decoded_in)?)
185    }
186}
187
188impl StaticCodec for Sz3Codec {
189    const CODEC_ID: &'static str = "sz3";
190
191    type Config<'de> = Self;
192
193    fn from_config(config: Self::Config<'_>) -> Self {
194        config
195    }
196
197    fn get_config(&self) -> StaticCodecConfig<Self> {
198        StaticCodecConfig::from(self)
199    }
200}
201
202#[derive(Debug, Error)]
203/// Errors that may occur when applying the [`Sz3Codec`].
204pub enum Sz3CodecError {
205    /// [`Sz3Codec`] does not support the dtype
206    #[error("Sz3 does not support the dtype {0}")]
207    UnsupportedDtype(AnyArrayDType),
208    /// [`Sz3Codec`] failed to encode the header
209    #[error("Sz3 failed to encode the header")]
210    HeaderEncodeFailed {
211        /// Opaque source error
212        source: Sz3HeaderError,
213    },
214    /// [`Sz3Codec`] cannot encode an array of `shape`
215    #[error("Sz3 cannot encode an array of shape {shape:?}")]
216    InvalidEncodeShape {
217        /// Opaque source error
218        source: Sz3CodingError,
219        /// The invalid shape of the encoded array
220        shape: Vec<usize>,
221    },
222    /// [`Sz3Codec`] failed to encode the data
223    #[error("Sz3 failed to encode the data")]
224    Sz3EncodeFailed {
225        /// Opaque source error
226        source: Sz3CodingError,
227    },
228    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
229    /// an array of a different dtype
230    #[error(
231        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
232    )]
233    EncodedDataNotBytes {
234        /// The unexpected dtype of the encoded array
235        dtype: AnyArrayDType,
236    },
237    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
238    /// an array of a different shape
239    #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
240    EncodedDataNotOneDimensional {
241        /// The unexpected shape of the encoded array
242        shape: Vec<usize>,
243    },
244    /// [`Sz3Codec`] failed to decode the header
245    #[error("Sz3 failed to decode the header")]
246    HeaderDecodeFailed {
247        /// Opaque source error
248        source: Sz3HeaderError,
249    },
250    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
251    /// the decoded data
252    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
253    DecodeInvalidShapeHeader {
254        /// Source error
255        #[from]
256        source: ShapeError,
257    },
258    /// [`Sz3Codec`] cannot decode into the provided array
259    #[error("Sz3 cannot decode into the provided array")]
260    MismatchedDecodeIntoArray {
261        /// The source of the error
262        #[from]
263        source: AnyArrayAssignError,
264    },
265}
266
267#[derive(Debug, Error)]
268#[error(transparent)]
269/// Opaque error for when encoding or decoding the header fails
270pub struct Sz3HeaderError(postcard::Error);
271
272#[derive(Debug, Error)]
273#[error(transparent)]
274/// Opaque error for when encoding or decoding with SZ3 fails
275pub struct Sz3CodingError(sz3::SZ3Error);
276
277#[expect(clippy::needless_pass_by_value)]
278/// Compresses the input `data` array using SZ3, which consists of an optional
279/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
280/// `lossless` compressor.
281///
282/// # Errors
283///
284/// Errors with
285/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
286/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
287/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
288pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
289    data: ArrayBase<S, D>,
290    predictor: Option<&Sz3Predictor>,
291    error_bound: &Sz3ErrorBound,
292) -> Result<Vec<u8>, Sz3CodecError> {
293    let mut encoded_bytes = postcard::to_extend(
294        &CompressionHeader {
295            dtype: <T as Sz3Element>::DTYPE,
296            shape: Cow::Borrowed(data.shape()),
297        },
298        Vec::new(),
299    )
300    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
301        source: Sz3HeaderError(err),
302    })?;
303
304    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
305    if data.is_empty() {
306        return Ok(encoded_bytes);
307    }
308
309    #[expect(clippy::option_if_let_else)]
310    let data_cow = if let Some(data) = data.as_slice() {
311        Cow::Borrowed(data)
312    } else {
313        Cow::Owned(data.iter().copied().collect())
314    };
315    let mut builder = sz3::DimensionedData::build(&data_cow);
316
317    for length in data.shape() {
318        // Sz3 ignores dimensions of length 1 and panics on length zero
319        // Since they carry no information for Sz3 and we already encode them
320        //  in our custom header, we just skip them here
321        if *length > 1 {
322            builder = builder
323                .dim(*length)
324                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
325                    source: Sz3CodingError(err),
326                    shape: data.shape().to_vec(),
327                })?;
328        }
329    }
330
331    if data.len() == 1 {
332        // If there is only one element, all dimensions will have been skipped,
333        //  so we explicitly encode one dimension of size 1 here
334        builder = builder
335            .dim(1)
336            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
337                source: Sz3CodingError(err),
338                shape: data.shape().to_vec(),
339            })?;
340    }
341
342    let data = builder
343        .finish()
344        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
345            source: Sz3CodingError(err),
346            shape: data.shape().to_vec(),
347        })?;
348
349    // configure the error bound
350    let error_bound = match error_bound {
351        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
352            absolute_bound: *abs,
353            relative_bound: *rel,
354        },
355        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
356            absolute_bound: *abs,
357            relative_bound: *rel,
358        },
359        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
360        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
361        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
362        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
363    };
364    let mut config = sz3::Config::new(error_bound);
365
366    // configure the interpolation mode, if necessary
367    let interpolation = match predictor {
368        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
369            Some(sz3::InterpolationAlgorithm::Linear)
370        }
371        Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
372            Some(sz3::InterpolationAlgorithm::Cubic)
373        }
374        Some(Sz3Predictor::LorenzoRegression) | None => None,
375    };
376    if let Some(interpolation) = interpolation {
377        config = config.interpolation_algorithm(interpolation);
378    }
379
380    // configure the predictor (compression algorithm)
381    let predictor = match predictor {
382        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
383            sz3::CompressionAlgorithm::Interpolation
384        }
385        Some(
386            Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
387        ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
388        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::lorenzo_regression(),
389        None => sz3::CompressionAlgorithm::NoPrediction,
390    };
391    config = config.compression_algorithm(predictor);
392
393    // TODO: avoid extra allocation here
394    let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
395        Sz3CodecError::Sz3EncodeFailed {
396            source: Sz3CodingError(err),
397        }
398    })?;
399    encoded_bytes.extend_from_slice(&compressed);
400
401    Ok(encoded_bytes)
402}
403
404/// Decompresses the `encoded` data into an array.
405///
406/// # Errors
407///
408/// Errors with
409/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
410pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
411    let (header, data) =
412        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
413            Sz3CodecError::HeaderDecodeFailed {
414                source: Sz3HeaderError(err),
415            }
416        })?;
417
418    let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
419        match header.dtype {
420            Sz3DType::I32 => {
421                AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
422            }
423            Sz3DType::I64 => {
424                AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
425            }
426            Sz3DType::F32 => {
427                AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
428            }
429            Sz3DType::F64 => {
430                AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
431            }
432        }
433    } else {
434        // TODO: avoid extra allocation here
435        match header.dtype {
436            Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
437                &*header.shape,
438                Vec::from(sz3::decompress(data).1.data()),
439            )?),
440            Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
441                &*header.shape,
442                Vec::from(sz3::decompress(data).1.data()),
443            )?),
444            Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
445                &*header.shape,
446                Vec::from(sz3::decompress(data).1.data()),
447            )?),
448            Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
449                &*header.shape,
450                Vec::from(sz3::decompress(data).1.data()),
451            )?),
452        }
453    };
454
455    Ok(decoded)
456}
457
458/// Array element types which can be compressed with SZ3.
459pub trait Sz3Element: Copy + sz3::SZ3Compressible {
460    /// The dtype representation of the type
461    const DTYPE: Sz3DType;
462}
463
464impl Sz3Element for i32 {
465    const DTYPE: Sz3DType = Sz3DType::I32;
466}
467
468impl Sz3Element for i64 {
469    const DTYPE: Sz3DType = Sz3DType::I64;
470}
471
472impl Sz3Element for f32 {
473    const DTYPE: Sz3DType = Sz3DType::F32;
474}
475
476impl Sz3Element for f64 {
477    const DTYPE: Sz3DType = Sz3DType::F64;
478}
479
480#[derive(Serialize, Deserialize)]
481struct CompressionHeader<'a> {
482    dtype: Sz3DType,
483    #[serde(borrow)]
484    shape: Cow<'a, [usize]>,
485}
486
487/// Dtypes that SZ3 can compress and decompress
488#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
489#[expect(missing_docs)]
490pub enum Sz3DType {
491    #[serde(rename = "i32", alias = "int32")]
492    I32,
493    #[serde(rename = "i64", alias = "int64")]
494    I64,
495    #[serde(rename = "f32", alias = "float32")]
496    F32,
497    #[serde(rename = "f64", alias = "float64")]
498    F64,
499}
500
501impl fmt::Display for Sz3DType {
502    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
503        fmt.write_str(match self {
504            Self::I32 => "i32",
505            Self::I64 => "i64",
506            Self::F32 => "f32",
507            Self::F64 => "f64",
508        })
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use ndarray::ArrayView1;
515
516    use super::*;
517
518    #[test]
519    fn zero_length() -> Result<(), Sz3CodecError> {
520        let encoded = compress(
521            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
522            default_predictor().as_ref(),
523            &Sz3ErrorBound::L2Norm { l2: 27.0 },
524        )?;
525        let decoded = decompress(&encoded)?;
526
527        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
528        assert!(decoded.is_empty());
529        assert_eq!(decoded.shape(), &[1, 27, 0]);
530
531        Ok(())
532    }
533
534    #[test]
535    fn one_dimension() -> Result<(), Sz3CodecError> {
536        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
537
538        let encoded = compress(
539            data.view(),
540            default_predictor().as_ref(),
541            &Sz3ErrorBound::Absolute { abs: 0.1 },
542        )?;
543        let decoded = decompress(&encoded)?;
544
545        assert_eq!(decoded, AnyArray::I32(data));
546
547        Ok(())
548    }
549
550    #[test]
551    fn small_state() -> Result<(), Sz3CodecError> {
552        for data in [
553            &[][..],
554            &[0.0],
555            &[0.0, 1.0],
556            &[0.0, 1.0, 0.0],
557            &[0.0, 1.0, 0.0, 1.0],
558        ] {
559            let encoded = compress(
560                ArrayView1::from(data),
561                default_predictor().as_ref(),
562                &Sz3ErrorBound::Absolute { abs: 0.1 },
563            )?;
564            let decoded = decompress(&encoded)?;
565
566            assert_eq!(
567                decoded,
568                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
569            );
570        }
571
572        Ok(())
573    }
574}