numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/Sz3-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41// serde cannot deny unknown fields because of the flatten
42#[schemars(deny_unknown_fields)]
43/// Codec providing compression using SZ3
44pub struct Sz3Codec {
45    /// Predictor
46    #[serde(default = "default_predictor")]
47    pub predictor: Option<Sz3Predictor>,
48    /// SZ3 error bound
49    #[serde(flatten)]
50    pub error_bound: Sz3ErrorBound,
51}
52
53/// SZ3 error bound
54#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
55#[serde(tag = "eb_mode")]
56#[serde(deny_unknown_fields)]
57pub enum Sz3ErrorBound {
58    /// Errors are bounded by *both* the absolute and relative error, i.e. by
59    /// whichever bound is stricter
60    #[serde(rename = "abs-and-rel")]
61    AbsoluteAndRelative {
62        /// Absolute error bound
63        #[serde(rename = "eb_abs")]
64        abs: f64,
65        /// Relative error bound
66        #[serde(rename = "eb_rel")]
67        rel: f64,
68    },
69    /// Errors are bounded by *either* the absolute or relative error, i.e. by
70    /// whichever bound is weaker
71    #[serde(rename = "abs-or-rel")]
72    AbsoluteOrRelative {
73        /// Absolute error bound
74        #[serde(rename = "eb_abs")]
75        abs: f64,
76        /// Relative error bound
77        #[serde(rename = "eb_rel")]
78        rel: f64,
79    },
80    /// Absolute error bound
81    #[serde(rename = "abs")]
82    Absolute {
83        /// Absolute error bound
84        #[serde(rename = "eb_abs")]
85        abs: f64,
86    },
87    /// Relative error bound
88    #[serde(rename = "rel")]
89    Relative {
90        /// Relative error bound
91        #[serde(rename = "eb_rel")]
92        rel: f64,
93    },
94    /// Peak signal to noise ratio error bound
95    #[serde(rename = "psnr")]
96    PS2NR {
97        /// Peak signal to noise ratio error bound
98        #[serde(rename = "eb_psnr")]
99        psnr: f64,
100    },
101    /// Peak L2 norm error bound
102    #[serde(rename = "l2")]
103    L2Norm {
104        /// Peak L2 norm error bound
105        #[serde(rename = "eb_l2")]
106        l2: f64,
107    },
108}
109
110/// SZ3 predictor
111#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
112#[serde(deny_unknown_fields)]
113pub enum Sz3Predictor {
114    /// Linear interpolation
115    #[serde(rename = "linear-interpolation")]
116    LinearInterpolation,
117    /// Cubic interpolation
118    #[serde(rename = "cubic-interpolation")]
119    CubicInterpolation,
120    /// Linear interpolation + Lorenzo predictor
121    #[serde(rename = "linear-interpolation-lorenzo")]
122    LinearInterpolationLorenzo,
123    /// Cubic interpolation + Lorenzo predictor
124    #[serde(rename = "cubic-interpolation-lorenzo")]
125    CubicInterpolationLorenzo,
126    /// 1st order regression
127    #[serde(rename = "regression")]
128    Regression,
129    /// 2nd order regression
130    #[serde(rename = "regression2")]
131    RegressionSecondOrder,
132    /// 1st+2nd order regression
133    #[serde(rename = "regression-regression2")]
134    RegressionFirstSecondOrder,
135    /// 2nd order Lorenzo predictor
136    #[serde(rename = "lorenzo2")]
137    LorenzoSecondOrder,
138    /// 2nd order Lorenzo predictor + 2nd order regression
139    #[serde(rename = "lorenzo2-regression2")]
140    LorenzoSecondOrderRegressionSecondOrder,
141    /// 2nd order Lorenzo predictor + 1st order regression
142    #[serde(rename = "lorenzo2-regression")]
143    LorenzoSecondOrderRegression,
144    /// 2nd order Lorenzo predictor + 1st order regression
145    #[serde(rename = "lorenzo2-regression-regression2")]
146    LorenzoSecondOrderRegressionFirstSecondOrder,
147    /// 1st order Lorenzo predictor
148    #[serde(rename = "lorenzo")]
149    Lorenzo,
150    /// 1st order Lorenzo predictor + 2nd order regression
151    #[serde(rename = "lorenzo-regression2")]
152    LorenzoRegressionSecondOrder,
153    /// 1st order Lorenzo predictor + 1st order regression
154    #[serde(rename = "lorenzo-regression")]
155    LorenzoRegression,
156    /// 1st order Lorenzo predictor + 1st and 2nd order regression
157    #[serde(rename = "lorenzo-regression-regression2")]
158    LorenzoRegressionFirstSecondOrder,
159    /// 1st+2nd order Lorenzo predictor
160    #[serde(rename = "lorenzo-lorenzo2")]
161    LorenzoFirstSecondOrder,
162    /// 1st+2nd order Lorenzo predictor + 2nd order regression
163    #[serde(rename = "lorenzo-lorenzo2-regression2")]
164    LorenzoFirstSecondOrderRegressionSecondOrder,
165    /// 1st+2nd order Lorenzo predictor + 1st order regression
166    #[serde(rename = "lorenzo-lorenzo2-regression")]
167    LorenzoFirstSecondOrderRegression,
168    /// 1st+2nd order Lorenzo predictor + 1st+2nd order regression
169    #[serde(rename = "lorenzo-lorenzo2-regression-regression2")]
170    LorenzoFirstSecondOrderRegressionFirstSecondOrder,
171}
172
173#[expect(clippy::unnecessary_wraps)]
174const fn default_predictor() -> Option<Sz3Predictor> {
175    Some(Sz3Predictor::CubicInterpolationLorenzo)
176}
177
178impl Codec for Sz3Codec {
179    type Error = Sz3CodecError;
180
181    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
182        match data {
183            AnyCowArray::I32(data) => Ok(AnyArray::U8(
184                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
185                    .into_dyn(),
186            )),
187            AnyCowArray::I64(data) => Ok(AnyArray::U8(
188                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
189                    .into_dyn(),
190            )),
191            AnyCowArray::F32(data) => Ok(AnyArray::U8(
192                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
193                    .into_dyn(),
194            )),
195            AnyCowArray::F64(data) => Ok(AnyArray::U8(
196                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
197                    .into_dyn(),
198            )),
199            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
200        }
201    }
202
203    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
204        let AnyCowArray::U8(encoded) = encoded else {
205            return Err(Sz3CodecError::EncodedDataNotBytes {
206                dtype: encoded.dtype(),
207            });
208        };
209
210        if !matches!(encoded.shape(), [_]) {
211            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
212                shape: encoded.shape().to_vec(),
213            });
214        }
215
216        decompress(&AnyCowArray::U8(encoded).as_bytes())
217    }
218
219    fn decode_into(
220        &self,
221        encoded: AnyArrayView,
222        mut decoded: AnyArrayViewMut,
223    ) -> Result<(), Self::Error> {
224        let decoded_in = self.decode(encoded.cow())?;
225
226        Ok(decoded.assign(&decoded_in)?)
227    }
228}
229
230impl StaticCodec for Sz3Codec {
231    const CODEC_ID: &'static str = "sz3";
232
233    type Config<'de> = Self;
234
235    fn from_config(config: Self::Config<'_>) -> Self {
236        config
237    }
238
239    fn get_config(&self) -> StaticCodecConfig<Self> {
240        StaticCodecConfig::from(self)
241    }
242}
243
244#[derive(Debug, Error)]
245/// Errors that may occur when applying the [`Sz3Codec`].
246pub enum Sz3CodecError {
247    /// [`Sz3Codec`] does not support the dtype
248    #[error("Sz3 does not support the dtype {0}")]
249    UnsupportedDtype(AnyArrayDType),
250    /// [`Sz3Codec`] failed to encode the header
251    #[error("Sz3 failed to encode the header")]
252    HeaderEncodeFailed {
253        /// Opaque source error
254        source: Sz3HeaderError,
255    },
256    /// [`Sz3Codec`] cannot encode an array of `shape`
257    #[error("Sz3 cannot encode an array of shape {shape:?}")]
258    InvalidEncodeShape {
259        /// Opaque source error
260        source: Sz3CodingError,
261        /// The invalid shape of the encoded array
262        shape: Vec<usize>,
263    },
264    /// [`Sz3Codec`] failed to encode the data
265    #[error("Sz3 failed to encode the data")]
266    Sz3EncodeFailed {
267        /// Opaque source error
268        source: Sz3CodingError,
269    },
270    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
271    /// an array of a different dtype
272    #[error(
273        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
274    )]
275    EncodedDataNotBytes {
276        /// The unexpected dtype of the encoded array
277        dtype: AnyArrayDType,
278    },
279    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
280    /// an array of a different shape
281    #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
282    EncodedDataNotOneDimensional {
283        /// The unexpected shape of the encoded array
284        shape: Vec<usize>,
285    },
286    /// [`Sz3Codec`] failed to decode the header
287    #[error("Sz3 failed to decode the header")]
288    HeaderDecodeFailed {
289        /// Opaque source error
290        source: Sz3HeaderError,
291    },
292    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
293    /// the decoded data
294    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
295    DecodeInvalidShapeHeader {
296        /// Source error
297        #[from]
298        source: ShapeError,
299    },
300    /// [`Sz3Codec`] cannot decode into the provided array
301    #[error("Sz3 cannot decode into the provided array")]
302    MismatchedDecodeIntoArray {
303        /// The source of the error
304        #[from]
305        source: AnyArrayAssignError,
306    },
307}
308
309#[derive(Debug, Error)]
310#[error(transparent)]
311/// Opaque error for when encoding or decoding the header fails
312pub struct Sz3HeaderError(postcard::Error);
313
314#[derive(Debug, Error)]
315#[error(transparent)]
316/// Opaque error for when encoding or decoding with SZ3 fails
317pub struct Sz3CodingError(sz3::SZ3Error);
318
319#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
320/// Compresses the input `data` array using SZ3, which consists of an optional
321/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
322/// `lossless` compressor.
323///
324/// # Errors
325///
326/// Errors with
327/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
328/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
329/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
330pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
331    data: ArrayBase<S, D>,
332    predictor: Option<&Sz3Predictor>,
333    error_bound: &Sz3ErrorBound,
334) -> Result<Vec<u8>, Sz3CodecError> {
335    let mut encoded_bytes = postcard::to_extend(
336        &CompressionHeader {
337            dtype: <T as Sz3Element>::DTYPE,
338            shape: Cow::Borrowed(data.shape()),
339        },
340        Vec::new(),
341    )
342    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
343        source: Sz3HeaderError(err),
344    })?;
345
346    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
347    if data.is_empty() {
348        return Ok(encoded_bytes);
349    }
350
351    #[expect(clippy::option_if_let_else)]
352    let data_cow = if let Some(data) = data.as_slice() {
353        Cow::Borrowed(data)
354    } else {
355        Cow::Owned(data.iter().copied().collect())
356    };
357    let mut builder = sz3::DimensionedData::build(&data_cow);
358
359    for length in data.shape() {
360        // Sz3 ignores dimensions of length 1 and panics on length zero
361        // Since they carry no information for Sz3 and we already encode them
362        //  in our custom header, we just skip them here
363        if *length > 1 {
364            builder = builder
365                .dim(*length)
366                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
367                    source: Sz3CodingError(err),
368                    shape: data.shape().to_vec(),
369                })?;
370        }
371    }
372
373    if data.len() == 1 {
374        // If there is only one element, all dimensions will have been skipped,
375        //  so we explicitly encode one dimension of size 1 here
376        builder = builder
377            .dim(1)
378            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
379                source: Sz3CodingError(err),
380                shape: data.shape().to_vec(),
381            })?;
382    }
383
384    let data = builder
385        .finish()
386        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
387            source: Sz3CodingError(err),
388            shape: data.shape().to_vec(),
389        })?;
390
391    // configure the error bound
392    let error_bound = match error_bound {
393        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
394            absolute_bound: *abs,
395            relative_bound: *rel,
396        },
397        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
398            absolute_bound: *abs,
399            relative_bound: *rel,
400        },
401        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
402        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
403        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
404        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
405    };
406    let mut config = sz3::Config::new(error_bound);
407
408    // configure the interpolation mode, if necessary
409    let interpolation = match predictor {
410        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
411            Some(sz3::InterpolationAlgorithm::Linear)
412        }
413        Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
414            Some(sz3::InterpolationAlgorithm::Cubic)
415        }
416        Some(
417            Sz3Predictor::Regression
418            | Sz3Predictor::RegressionSecondOrder
419            | Sz3Predictor::RegressionFirstSecondOrder
420            | Sz3Predictor::LorenzoSecondOrder
421            | Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder
422            | Sz3Predictor::LorenzoSecondOrderRegression
423            | Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder
424            | Sz3Predictor::Lorenzo
425            | Sz3Predictor::LorenzoRegressionSecondOrder
426            | Sz3Predictor::LorenzoRegression
427            | Sz3Predictor::LorenzoRegressionFirstSecondOrder
428            | Sz3Predictor::LorenzoFirstSecondOrder
429            | Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder
430            | Sz3Predictor::LorenzoFirstSecondOrderRegression
431            | Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder,
432        )
433        | None => None,
434    };
435    if let Some(interpolation) = interpolation {
436        config = config.interpolation_algorithm(interpolation);
437    }
438
439    // configure the predictor (compression algorithm)
440    let predictor = match predictor {
441        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
442            sz3::CompressionAlgorithm::Interpolation
443        }
444        Some(
445            Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
446        ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
447        Some(Sz3Predictor::RegressionSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
448            lorenzo: false,
449            lorenzo_second_order: false,
450            regression: false,
451            regression_second_order: true,
452            prediction_dimension: None,
453        },
454        Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
455            lorenzo: false,
456            lorenzo_second_order: false,
457            regression: true,
458            regression_second_order: false,
459            prediction_dimension: None,
460        },
461        Some(Sz3Predictor::RegressionFirstSecondOrder) => {
462            sz3::CompressionAlgorithm::LorenzoRegression {
463                lorenzo: false,
464                lorenzo_second_order: false,
465                regression: true,
466                regression_second_order: true,
467                prediction_dimension: None,
468            }
469        }
470        Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
471            lorenzo: false,
472            lorenzo_second_order: true,
473            regression: false,
474            regression_second_order: false,
475            prediction_dimension: None,
476        },
477        Some(Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder) => {
478            sz3::CompressionAlgorithm::LorenzoRegression {
479                lorenzo: false,
480                lorenzo_second_order: true,
481                regression: false,
482                regression_second_order: true,
483                prediction_dimension: None,
484            }
485        }
486        Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
487            sz3::CompressionAlgorithm::LorenzoRegression {
488                lorenzo: false,
489                lorenzo_second_order: true,
490                regression: true,
491                regression_second_order: false,
492                prediction_dimension: None,
493            }
494        }
495        Some(Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder) => {
496            sz3::CompressionAlgorithm::LorenzoRegression {
497                lorenzo: false,
498                lorenzo_second_order: true,
499                regression: true,
500                regression_second_order: true,
501                prediction_dimension: None,
502            }
503        }
504        Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
505            lorenzo: true,
506            lorenzo_second_order: false,
507            regression: false,
508            regression_second_order: false,
509            prediction_dimension: None,
510        },
511        Some(Sz3Predictor::LorenzoRegressionSecondOrder) => {
512            sz3::CompressionAlgorithm::LorenzoRegression {
513                lorenzo: true,
514                lorenzo_second_order: false,
515                regression: false,
516                regression_second_order: true,
517                prediction_dimension: None,
518            }
519        }
520        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
521            lorenzo: true,
522            lorenzo_second_order: false,
523            regression: true,
524            regression_second_order: false,
525            prediction_dimension: None,
526        },
527        Some(Sz3Predictor::LorenzoRegressionFirstSecondOrder) => {
528            sz3::CompressionAlgorithm::LorenzoRegression {
529                lorenzo: true,
530                lorenzo_second_order: false,
531                regression: true,
532                regression_second_order: true,
533                prediction_dimension: None,
534            }
535        }
536        Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
537            sz3::CompressionAlgorithm::LorenzoRegression {
538                lorenzo: true,
539                lorenzo_second_order: true,
540                regression: false,
541                regression_second_order: false,
542                prediction_dimension: None,
543            }
544        }
545        Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder) => {
546            sz3::CompressionAlgorithm::LorenzoRegression {
547                lorenzo: true,
548                lorenzo_second_order: true,
549                regression: false,
550                regression_second_order: true,
551                prediction_dimension: None,
552            }
553        }
554        Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
555            sz3::CompressionAlgorithm::LorenzoRegression {
556                lorenzo: true,
557                lorenzo_second_order: true,
558                regression: true,
559                regression_second_order: false,
560                prediction_dimension: None,
561            }
562        }
563        Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder) => {
564            sz3::CompressionAlgorithm::LorenzoRegression {
565                lorenzo: true,
566                lorenzo_second_order: true,
567                regression: true,
568                regression_second_order: true,
569                prediction_dimension: None,
570            }
571        }
572        None => sz3::CompressionAlgorithm::NoPrediction,
573    };
574    config = config.compression_algorithm(predictor);
575
576    // TODO: avoid extra allocation here
577    let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
578        Sz3CodecError::Sz3EncodeFailed {
579            source: Sz3CodingError(err),
580        }
581    })?;
582    encoded_bytes.extend_from_slice(&compressed);
583
584    Ok(encoded_bytes)
585}
586
587/// Decompresses the `encoded` data into an array.
588///
589/// # Errors
590///
591/// Errors with
592/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
593pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
594    let (header, data) =
595        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
596            Sz3CodecError::HeaderDecodeFailed {
597                source: Sz3HeaderError(err),
598            }
599        })?;
600
601    let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
602        match header.dtype {
603            Sz3DType::I32 => {
604                AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
605            }
606            Sz3DType::I64 => {
607                AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
608            }
609            Sz3DType::F32 => {
610                AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
611            }
612            Sz3DType::F64 => {
613                AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
614            }
615        }
616    } else {
617        // TODO: avoid extra allocation here
618        match header.dtype {
619            Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
620                &*header.shape,
621                Vec::from(sz3::decompress(data).1.data()),
622            )?),
623            Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
624                &*header.shape,
625                Vec::from(sz3::decompress(data).1.data()),
626            )?),
627            Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
628                &*header.shape,
629                Vec::from(sz3::decompress(data).1.data()),
630            )?),
631            Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
632                &*header.shape,
633                Vec::from(sz3::decompress(data).1.data()),
634            )?),
635        }
636    };
637
638    Ok(decoded)
639}
640
641/// Array element types which can be compressed with SZ3.
642pub trait Sz3Element: Copy + sz3::SZ3Compressible {
643    /// The dtype representation of the type
644    const DTYPE: Sz3DType;
645}
646
647impl Sz3Element for i32 {
648    const DTYPE: Sz3DType = Sz3DType::I32;
649}
650
651impl Sz3Element for i64 {
652    const DTYPE: Sz3DType = Sz3DType::I64;
653}
654
655impl Sz3Element for f32 {
656    const DTYPE: Sz3DType = Sz3DType::F32;
657}
658
659impl Sz3Element for f64 {
660    const DTYPE: Sz3DType = Sz3DType::F64;
661}
662
663#[derive(Serialize, Deserialize)]
664struct CompressionHeader<'a> {
665    dtype: Sz3DType,
666    #[serde(borrow)]
667    shape: Cow<'a, [usize]>,
668}
669
670/// Dtypes that SZ3 can compress and decompress
671#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
672#[expect(missing_docs)]
673pub enum Sz3DType {
674    #[serde(rename = "i32", alias = "int32")]
675    I32,
676    #[serde(rename = "i64", alias = "int64")]
677    I64,
678    #[serde(rename = "f32", alias = "float32")]
679    F32,
680    #[serde(rename = "f64", alias = "float64")]
681    F64,
682}
683
684impl fmt::Display for Sz3DType {
685    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
686        fmt.write_str(match self {
687            Self::I32 => "i32",
688            Self::I64 => "i64",
689            Self::F32 => "f32",
690            Self::F64 => "f64",
691        })
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use ndarray::ArrayView1;
698
699    use super::*;
700
701    #[test]
702    fn zero_length() -> Result<(), Sz3CodecError> {
703        let encoded = compress(
704            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
705            default_predictor().as_ref(),
706            &Sz3ErrorBound::L2Norm { l2: 27.0 },
707        )?;
708        let decoded = decompress(&encoded)?;
709
710        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
711        assert!(decoded.is_empty());
712        assert_eq!(decoded.shape(), &[1, 27, 0]);
713
714        Ok(())
715    }
716
717    #[test]
718    fn one_dimension() -> Result<(), Sz3CodecError> {
719        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
720
721        let encoded = compress(
722            data.view(),
723            default_predictor().as_ref(),
724            &Sz3ErrorBound::Absolute { abs: 0.1 },
725        )?;
726        let decoded = decompress(&encoded)?;
727
728        assert_eq!(decoded, AnyArray::I32(data));
729
730        Ok(())
731    }
732
733    #[test]
734    fn small_state() -> Result<(), Sz3CodecError> {
735        for data in [
736            &[][..],
737            &[0.0],
738            &[0.0, 1.0],
739            &[0.0, 1.0, 0.0],
740            &[0.0, 1.0, 0.0, 1.0],
741        ] {
742            let encoded = compress(
743                ArrayView1::from(data),
744                default_predictor().as_ref(),
745                &Sz3ErrorBound::Absolute { abs: 0.1 },
746            )?;
747            let decoded = decompress(&encoded)?;
748
749            assert_eq!(
750                decoded,
751                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
752            );
753        }
754
755        Ok(())
756    }
757
758    #[test]
759    fn all_predictors() -> Result<(), Sz3CodecError> {
760        let data = Array::linspace(-42.0, 42.0, 85);
761
762        for predictor in [
763            None,
764            Some(Sz3Predictor::Regression),
765            Some(Sz3Predictor::RegressionSecondOrder),
766            Some(Sz3Predictor::RegressionFirstSecondOrder),
767            Some(Sz3Predictor::LorenzoSecondOrder),
768            Some(Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder),
769            Some(Sz3Predictor::LorenzoSecondOrderRegression),
770            Some(Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder),
771            Some(Sz3Predictor::Lorenzo),
772            Some(Sz3Predictor::LorenzoRegressionSecondOrder),
773            Some(Sz3Predictor::LorenzoRegression),
774            Some(Sz3Predictor::LorenzoRegressionFirstSecondOrder),
775            Some(Sz3Predictor::LorenzoFirstSecondOrder),
776            Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder),
777            Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
778            Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder),
779        ] {
780            let encoded = compress(
781                data.view(),
782                predictor.as_ref(),
783                &Sz3ErrorBound::Absolute { abs: 0.1 },
784            )?;
785            let _decoded = decompress(&encoded)?;
786        }
787
788        Ok(())
789    }
790}