numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20use std::{borrow::Cow, fmt};
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use numcodecs::{
24    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
25    Codec, StaticCodec, StaticCodecConfig,
26};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use thiserror::Error;
30
31// Only included to explicitly enable the `no_wasm_shim` feature for
32// sz3-sys/Sz3-sys
33use ::zstd_sys as _;
34
35#[cfg(test)]
36use ::serde_json as _;
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39// serde cannot deny unknown fields because of the flatten
40#[schemars(deny_unknown_fields)]
41/// Codec providing compression using SZ3
42pub struct Sz3Codec {
43    /// Predictor
44    #[serde(default = "default_predictor")]
45    pub predictor: Option<Sz3Predictor>,
46    /// SZ3 error bound
47    #[serde(flatten)]
48    pub error_bound: Sz3ErrorBound,
49    /// Encoder
50    #[serde(default = "default_encoder")]
51    pub encoder: Option<Sz3Encoder>,
52    /// Lossless compressor
53    #[serde(default = "default_lossless_compressor")]
54    pub lossless: Option<Sz3LosslessCompressor>,
55}
56
57/// SZ3 error bound
58#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
59#[serde(tag = "eb_mode")]
60#[serde(deny_unknown_fields)]
61pub enum Sz3ErrorBound {
62    /// Errors are bounded by *both* the absolute and relative error, i.e. by
63    /// whichever bound is stricter
64    #[serde(rename = "abs-and-rel")]
65    AbsoluteAndRelative {
66        /// Absolute error bound
67        #[serde(rename = "eb_abs")]
68        abs: f64,
69        /// Relative error bound
70        #[serde(rename = "eb_rel")]
71        rel: f64,
72    },
73    /// Errors are bounded by *either* the absolute or relative error, i.e. by
74    /// whichever bound is weaker
75    #[serde(rename = "abs-or-rel")]
76    AbsoluteOrRelative {
77        /// Absolute error bound
78        #[serde(rename = "eb_abs")]
79        abs: f64,
80        /// Relative error bound
81        #[serde(rename = "eb_rel")]
82        rel: f64,
83    },
84    /// Absolute error bound
85    #[serde(rename = "abs")]
86    Absolute {
87        /// Absolute error bound
88        #[serde(rename = "eb_abs")]
89        abs: f64,
90    },
91    /// Relative error bound
92    #[serde(rename = "rel")]
93    Relative {
94        /// Relative error bound
95        #[serde(rename = "eb_rel")]
96        rel: f64,
97    },
98    /// Peak signal to noise ratio error bound
99    #[serde(rename = "psnr")]
100    PS2NR {
101        /// Peak signal to noise ratio error bound
102        #[serde(rename = "eb_psnr")]
103        psnr: f64,
104    },
105    /// Peak L2 norm error bound
106    #[serde(rename = "l2")]
107    L2Norm {
108        /// Peak L2 norm error bound
109        #[serde(rename = "eb_l2")]
110        l2: f64,
111    },
112}
113
114/// SZ3 predictor
115#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
116#[serde(deny_unknown_fields)]
117pub enum Sz3Predictor {
118    /// Linear interpolation
119    #[serde(rename = "linear-interpolation")]
120    LinearInterpolation,
121    /// Cubic interpolation
122    #[serde(rename = "cubic-interpolation")]
123    CubicInterpolation,
124    /// Linear interpolation + Lorenzo predictor
125    #[serde(rename = "linear-interpolation-lorenzo")]
126    LinearInterpolationLorenzo,
127    /// Cubic interpolation + Lorenzo predictor
128    #[serde(rename = "cubic-interpolation-lorenzo")]
129    CubicInterpolationLorenzo,
130    /// Lorenzo predictor + regression
131    #[serde(rename = "lorenzo-regression")]
132    LorenzoRegression,
133}
134
135#[allow(clippy::unnecessary_wraps)]
136const fn default_predictor() -> Option<Sz3Predictor> {
137    Some(Sz3Predictor::CubicInterpolationLorenzo)
138}
139
140/// SZ3 encoder
141#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
142#[serde(deny_unknown_fields)]
143pub enum Sz3Encoder {
144    /// Huffman coding
145    #[serde(rename = "huffman")]
146    Huffman,
147    /// Arithmetic coding
148    #[serde(rename = "arithmetic")]
149    Arithmetic,
150}
151
152#[allow(clippy::unnecessary_wraps)]
153const fn default_encoder() -> Option<Sz3Encoder> {
154    Some(Sz3Encoder::Huffman)
155}
156
157/// SZ3 lossless compressor
158#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
159#[serde(deny_unknown_fields)]
160pub enum Sz3LosslessCompressor {
161    /// Zstandard
162    #[serde(rename = "zstd")]
163    Zstd,
164}
165
166#[allow(clippy::unnecessary_wraps)]
167const fn default_lossless_compressor() -> Option<Sz3LosslessCompressor> {
168    Some(Sz3LosslessCompressor::Zstd)
169}
170
171impl Codec for Sz3Codec {
172    type Error = Sz3CodecError;
173
174    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
175        match data {
176            AnyCowArray::I32(data) => Ok(AnyArray::U8(
177                Array1::from(compress(
178                    data,
179                    self.predictor.as_ref(),
180                    &self.error_bound,
181                    self.encoder.as_ref(),
182                    self.lossless.as_ref(),
183                )?)
184                .into_dyn(),
185            )),
186            AnyCowArray::I64(data) => Ok(AnyArray::U8(
187                Array1::from(compress(
188                    data,
189                    self.predictor.as_ref(),
190                    &self.error_bound,
191                    self.encoder.as_ref(),
192                    self.lossless.as_ref(),
193                )?)
194                .into_dyn(),
195            )),
196            AnyCowArray::F32(data) => Ok(AnyArray::U8(
197                Array1::from(compress(
198                    data,
199                    self.predictor.as_ref(),
200                    &self.error_bound,
201                    self.encoder.as_ref(),
202                    self.lossless.as_ref(),
203                )?)
204                .into_dyn(),
205            )),
206            AnyCowArray::F64(data) => Ok(AnyArray::U8(
207                Array1::from(compress(
208                    data,
209                    self.predictor.as_ref(),
210                    &self.error_bound,
211                    self.encoder.as_ref(),
212                    self.lossless.as_ref(),
213                )?)
214                .into_dyn(),
215            )),
216            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
217        }
218    }
219
220    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
221        let AnyCowArray::U8(encoded) = encoded else {
222            return Err(Sz3CodecError::EncodedDataNotBytes {
223                dtype: encoded.dtype(),
224            });
225        };
226
227        if !matches!(encoded.shape(), [_]) {
228            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
229                shape: encoded.shape().to_vec(),
230            });
231        }
232
233        decompress(&AnyCowArray::U8(encoded).as_bytes())
234    }
235
236    fn decode_into(
237        &self,
238        encoded: AnyArrayView,
239        mut decoded: AnyArrayViewMut,
240    ) -> Result<(), Self::Error> {
241        let decoded_in = self.decode(encoded.cow())?;
242
243        Ok(decoded.assign(&decoded_in)?)
244    }
245}
246
247impl StaticCodec for Sz3Codec {
248    const CODEC_ID: &'static str = "sz3";
249
250    type Config<'de> = Self;
251
252    fn from_config(config: Self::Config<'_>) -> Self {
253        config
254    }
255
256    fn get_config(&self) -> StaticCodecConfig<Self> {
257        StaticCodecConfig::from(self)
258    }
259}
260
261#[derive(Debug, Error)]
262/// Errors that may occur when applying the [`Sz3Codec`].
263pub enum Sz3CodecError {
264    /// [`Sz3Codec`] does not support the dtype
265    #[error("Sz3 does not support the dtype {0}")]
266    UnsupportedDtype(AnyArrayDType),
267    /// [`Sz3Codec`] failed to encode the header
268    #[error("Sz3 failed to encode the header")]
269    HeaderEncodeFailed {
270        /// Opaque source error
271        source: Sz3HeaderError,
272    },
273    /// [`Sz3Codec`] cannot encode an array of `shape`
274    #[error("Sz3 cannot encode an array of shape {shape:?}")]
275    InvalidEncodeShape {
276        /// Opaque source error
277        source: Sz3CodingError,
278        /// The invalid shape of the encoded array
279        shape: Vec<usize>,
280    },
281    /// [`Sz3Codec`] failed to encode the data
282    #[error("Sz3 failed to encode the data")]
283    Sz3EncodeFailed {
284        /// Opaque source error
285        source: Sz3CodingError,
286    },
287    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
288    /// an array of a different dtype
289    #[error(
290        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
291    )]
292    EncodedDataNotBytes {
293        /// The unexpected dtype of the encoded array
294        dtype: AnyArrayDType,
295    },
296    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
297    /// an array of a different shape
298    #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
299    EncodedDataNotOneDimensional {
300        /// The unexpected shape of the encoded array
301        shape: Vec<usize>,
302    },
303    /// [`Sz3Codec`] failed to decode the header
304    #[error("Sz3 failed to decode the header")]
305    HeaderDecodeFailed {
306        /// Opaque source error
307        source: Sz3HeaderError,
308    },
309    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
310    /// the decoded data
311    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
312    DecodeInvalidShapeHeader {
313        /// Source error
314        #[from]
315        source: ShapeError,
316    },
317    /// [`Sz3Codec`] cannot decode into the provided array
318    #[error("Sz3 cannot decode into the provided array")]
319    MismatchedDecodeIntoArray {
320        /// The source of the error
321        #[from]
322        source: AnyArrayAssignError,
323    },
324}
325
326#[derive(Debug, Error)]
327#[error(transparent)]
328/// Opaque error for when encoding or decoding the header fails
329pub struct Sz3HeaderError(postcard::Error);
330
331#[derive(Debug, Error)]
332#[error(transparent)]
333/// Opaque error for when encoding or decoding with SZ3 fails
334pub struct Sz3CodingError(sz3::SZ3Error);
335
336#[allow(clippy::needless_pass_by_value)]
337/// Compresses the input `data` array using SZ3, which consists of an optional
338/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
339/// `lossless` compressor.
340///
341/// # Errors
342///
343/// Errors with
344/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
345/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
346/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
347pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
348    data: ArrayBase<S, D>,
349    predictor: Option<&Sz3Predictor>,
350    error_bound: &Sz3ErrorBound,
351    encoder: Option<&Sz3Encoder>,
352    lossless: Option<&Sz3LosslessCompressor>,
353) -> Result<Vec<u8>, Sz3CodecError> {
354    let mut encoded_bytes = postcard::to_extend(
355        &CompressionHeader {
356            dtype: <T as Sz3Element>::DTYPE,
357            shape: Cow::Borrowed(data.shape()),
358        },
359        Vec::new(),
360    )
361    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
362        source: Sz3HeaderError(err),
363    })?;
364
365    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
366    if data.is_empty() {
367        return Ok(encoded_bytes);
368    }
369
370    #[allow(clippy::option_if_let_else)]
371    let data_cow = if let Some(data) = data.as_slice() {
372        Cow::Borrowed(data)
373    } else {
374        Cow::Owned(data.iter().copied().collect())
375    };
376    let mut builder = sz3::DimensionedData::build(&data_cow);
377
378    for length in data.shape() {
379        // Sz3 ignores dimensions of length 1 and panics on length zero
380        // Since they carry no information for Sz3 and we already encode them
381        //  in our custom header, we just skip them here
382        if *length > 1 {
383            builder = builder
384                .dim(*length)
385                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
386                    source: Sz3CodingError(err),
387                    shape: data.shape().to_vec(),
388                })?;
389        }
390    }
391
392    if data.len() == 1 {
393        // If there is only one element, all dimensions will have been skipped,
394        //  so we explicitly encode one dimension of size 1 here
395        builder = builder
396            .dim(1)
397            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
398                source: Sz3CodingError(err),
399                shape: data.shape().to_vec(),
400            })?;
401    }
402
403    let data = builder
404        .finish()
405        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
406            source: Sz3CodingError(err),
407            shape: data.shape().to_vec(),
408        })?;
409
410    // configure the error bound
411    let error_bound = match error_bound {
412        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
413            absolute_bound: *abs,
414            relative_bound: *rel,
415        },
416        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
417            absolute_bound: *abs,
418            relative_bound: *rel,
419        },
420        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
421        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
422        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
423        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
424    };
425    let mut config = sz3::Config::new(error_bound);
426
427    // configure the interpolation mode, if necessary
428    let interpolation = match predictor {
429        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
430            Some(sz3::InterpolationAlgorithm::Linear)
431        }
432        Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
433            Some(sz3::InterpolationAlgorithm::Cubic)
434        }
435        Some(Sz3Predictor::LorenzoRegression) | None => None,
436    };
437    if let Some(interpolation) = interpolation {
438        config = config.interpolation_algorithm(interpolation);
439    }
440
441    // configure the predictor (compression algorithm)
442    let predictor = match predictor {
443        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
444            sz3::CompressionAlgorithm::Interpolation
445        }
446        Some(
447            Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
448        ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
449        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::lorenzo_regression(),
450        None => sz3::CompressionAlgorithm::NoPrediction,
451    };
452    config = config.compression_algorithm(predictor);
453
454    // configure the encoder
455    let encoder = match encoder {
456        None => sz3::Encoder::SkipEncoder,
457        Some(Sz3Encoder::Huffman) => sz3::Encoder::HuffmanEncoder,
458        Some(Sz3Encoder::Arithmetic) => sz3::Encoder::ArithmeticEncoder,
459    };
460    config = config.encoder(encoder);
461
462    // configure the lossless compressor
463    let lossless = match lossless {
464        None => sz3::LossLess::LossLessBypass,
465        Some(Sz3LosslessCompressor::Zstd) => sz3::LossLess::ZSTD,
466    };
467    config = config.lossless(lossless);
468
469    // TODO: avoid extra allocation here
470    let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
471        Sz3CodecError::Sz3EncodeFailed {
472            source: Sz3CodingError(err),
473        }
474    })?;
475    encoded_bytes.extend_from_slice(&compressed);
476
477    Ok(encoded_bytes)
478}
479
480/// Decompresses the `encoded` data into an array.
481///
482/// # Errors
483///
484/// Errors with
485/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
486pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
487    let (header, data) =
488        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
489            Sz3CodecError::HeaderDecodeFailed {
490                source: Sz3HeaderError(err),
491            }
492        })?;
493
494    let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
495        match header.dtype {
496            Sz3DType::I32 => {
497                AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
498            }
499            Sz3DType::I64 => {
500                AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
501            }
502            Sz3DType::F32 => {
503                AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
504            }
505            Sz3DType::F64 => {
506                AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
507            }
508        }
509    } else {
510        // TODO: avoid extra allocation here
511        match header.dtype {
512            Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
513                &*header.shape,
514                Vec::from(sz3::decompress(data).1.data()),
515            )?),
516            Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
517                &*header.shape,
518                Vec::from(sz3::decompress(data).1.data()),
519            )?),
520            Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
521                &*header.shape,
522                Vec::from(sz3::decompress(data).1.data()),
523            )?),
524            Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
525                &*header.shape,
526                Vec::from(sz3::decompress(data).1.data()),
527            )?),
528        }
529    };
530
531    Ok(decoded)
532}
533
534/// Array element types which can be compressed with SZ3.
535pub trait Sz3Element: Copy + sz3::SZ3Compressible {
536    /// The dtype representation of the type
537    const DTYPE: Sz3DType;
538}
539
540impl Sz3Element for i32 {
541    const DTYPE: Sz3DType = Sz3DType::I32;
542}
543
544impl Sz3Element for i64 {
545    const DTYPE: Sz3DType = Sz3DType::I64;
546}
547
548impl Sz3Element for f32 {
549    const DTYPE: Sz3DType = Sz3DType::F32;
550}
551
552impl Sz3Element for f64 {
553    const DTYPE: Sz3DType = Sz3DType::F64;
554}
555
556#[derive(Serialize, Deserialize)]
557struct CompressionHeader<'a> {
558    dtype: Sz3DType,
559    #[serde(borrow)]
560    shape: Cow<'a, [usize]>,
561}
562
563/// Dtypes that SZ3 can compress and decompress
564#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
565#[allow(missing_docs)]
566pub enum Sz3DType {
567    #[serde(rename = "i32", alias = "int32")]
568    I32,
569    #[serde(rename = "i64", alias = "int64")]
570    I64,
571    #[serde(rename = "f32", alias = "float32")]
572    F32,
573    #[serde(rename = "f64", alias = "float64")]
574    F64,
575}
576
577impl fmt::Display for Sz3DType {
578    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
579        fmt.write_str(match self {
580            Self::I32 => "i32",
581            Self::I64 => "i64",
582            Self::F32 => "f32",
583            Self::F64 => "f64",
584        })
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use ndarray::ArrayView1;
591
592    use super::*;
593
594    #[test]
595    fn zero_length() -> Result<(), Sz3CodecError> {
596        let encoded = compress(
597            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
598            default_predictor().as_ref(),
599            &Sz3ErrorBound::L2Norm { l2: 27.0 },
600            default_encoder().as_ref(),
601            default_lossless_compressor().as_ref(),
602        )?;
603        let decoded = decompress(&encoded)?;
604
605        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
606        assert!(decoded.is_empty());
607        assert_eq!(decoded.shape(), &[1, 27, 0]);
608
609        Ok(())
610    }
611
612    #[test]
613    fn one_dimension() -> Result<(), Sz3CodecError> {
614        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
615
616        let encoded = compress(
617            data.view(),
618            default_predictor().as_ref(),
619            &Sz3ErrorBound::Absolute { abs: 0.1 },
620            default_encoder().as_ref(),
621            default_lossless_compressor().as_ref(),
622        )?;
623        let decoded = decompress(&encoded)?;
624
625        assert_eq!(decoded, AnyArray::I32(data));
626
627        Ok(())
628    }
629
630    #[test]
631    fn small_state() -> Result<(), Sz3CodecError> {
632        for data in [
633            &[][..],
634            &[0.0],
635            &[0.0, 1.0],
636            &[0.0, 1.0, 0.0],
637            &[0.0, 1.0, 0.0, 1.0],
638        ] {
639            let encoded = compress(
640                ArrayView1::from(data),
641                default_predictor().as_ref(),
642                &Sz3ErrorBound::Absolute { abs: 0.1 },
643                default_encoder().as_ref(),
644                default_lossless_compressor().as_ref(),
645            )?;
646            let decoded = decompress(&encoded)?;
647
648            assert_eq!(
649                decoded,
650                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
651            );
652        }
653
654        Ok(())
655    }
656}