Skip to main content

numcodecs_lc/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-lc
10//! [crates.io]: https://crates.io/crates/numcodecs-lc
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-lc
13//! [docs.rs]: https://docs.rs/numcodecs-lc/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_lc
17//!
18//! LC codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::borrow::Cow;
23
24use ndarray::Array1;
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::{JsonSchema, JsonSchema_repr};
30use serde::{Deserialize, Deserializer, Serialize};
31use serde_repr::{Deserialize_repr, Serialize_repr};
32use thiserror::Error;
33
34#[cfg(test)]
35use ::serde_json as _;
36
37type LcCodecVersion = StaticCodecVersion<0, 1, 0>;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41/// Codec providing compression using LC
42pub struct LcCodec {
43    /// LC preprocessors
44    #[serde(default)]
45    pub preprocessors: Vec<LcPreprocessor>,
46    /// LC components
47    #[serde(deserialize_with = "deserialize_components")]
48    #[schemars(length(min = 1, max = lc_framework::MAX_COMPONENTS))]
49    pub components: Vec<LcComponent>,
50    /// The codec's encoding format version. Do not provide this parameter explicitly.
51    #[serde(default, rename = "_version")]
52    pub version: LcCodecVersion,
53}
54
55fn deserialize_components<'de, D: Deserializer<'de>>(
56    deserializer: D,
57) -> Result<Vec<LcComponent>, D::Error> {
58    let components = Vec::<LcComponent>::deserialize(deserializer)?;
59
60    if components.is_empty() {
61        return Err(serde::de::Error::custom("expected at least one component"));
62    }
63
64    if components.len() > lc_framework::MAX_COMPONENTS {
65        return Err(serde::de::Error::custom(format_args!(
66            "expected at most {} components",
67            lc_framework::MAX_COMPONENTS
68        )));
69    }
70
71    Ok(components)
72}
73
74#[expect(missing_docs)]
75#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
76#[serde(deny_unknown_fields)]
77#[serde(tag = "id")]
78/// LC preprocessor
79pub enum LcPreprocessor {
80    #[serde(rename = "NUL")]
81    Noop,
82    #[serde(rename = "LOR")]
83    Lorenzo1D { dtype: LcLorenzoDtype },
84    #[serde(rename = "QUANT")]
85    QuantizeErrorBound {
86        dtype: LcQuantizeDType,
87        kind: LcErrorKind,
88        error_bound: f64,
89        threshold: Option<f64>,
90        decorrelation: LcDecorrelation,
91    },
92}
93
94impl LcPreprocessor {
95    const fn into_lc(self) -> lc_framework::Preprocessor {
96        match self {
97            Self::Noop => lc_framework::Preprocessor::Noop,
98            Self::Lorenzo1D { dtype } => lc_framework::Preprocessor::Lorenzo1D {
99                dtype: dtype.into_lc(),
100            },
101            Self::QuantizeErrorBound {
102                dtype,
103                kind,
104                error_bound,
105                threshold,
106                decorrelation,
107            } => lc_framework::Preprocessor::QuantizeErrorBound {
108                dtype: dtype.into_lc(),
109                kind: kind.into_lc(),
110                error_bound,
111                threshold,
112                decorrelation: decorrelation.into_lc(),
113            },
114        }
115    }
116}
117
118#[derive(
119    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
120)]
121/// LC error bound kind
122pub enum LcErrorKind {
123    /// pointwise absolute error bound
124    #[serde(rename = "ABS")]
125    Abs,
126    /// pointwise normalised absolute / data-range-relative error bound
127    #[serde(rename = "NOA")]
128    Noa,
129    /// pointwise relative error bound
130    #[serde(rename = "REL")]
131    Rel,
132}
133
134impl LcErrorKind {
135    const fn into_lc(self) -> lc_framework::ErrorKind {
136        match self {
137            Self::Abs => lc_framework::ErrorKind::Abs,
138            Self::Noa => lc_framework::ErrorKind::Noa,
139            Self::Rel => lc_framework::ErrorKind::Rel,
140        }
141    }
142}
143
144#[expect(missing_docs)]
145#[derive(
146    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
147)]
148/// LC quantisation decorrelation mode
149pub enum LcDecorrelation {
150    #[serde(rename = "0")]
151    Zero,
152    #[serde(rename = "R")]
153    Random,
154}
155
156impl LcDecorrelation {
157    const fn into_lc(self) -> lc_framework::Decorrelation {
158        match self {
159            Self::Zero => lc_framework::Decorrelation::Zero,
160            Self::Random => lc_framework::Decorrelation::Random,
161        }
162    }
163}
164
165#[expect(missing_docs)]
166#[derive(
167    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
168)]
169/// LC Lorenzo preprocessor dtype
170pub enum LcLorenzoDtype {
171    #[serde(rename = "i32")]
172    I32,
173}
174
175impl LcLorenzoDtype {
176    const fn into_lc(self) -> lc_framework::LorenzoDtype {
177        match self {
178            Self::I32 => lc_framework::LorenzoDtype::I32,
179        }
180    }
181}
182
183#[expect(missing_docs)]
184#[derive(
185    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
186)]
187/// LC quantization dtype
188pub enum LcQuantizeDType {
189    #[serde(rename = "f32")]
190    F32,
191    #[serde(rename = "f64")]
192    F64,
193}
194
195impl LcQuantizeDType {
196    const fn into_lc(self) -> lc_framework::QuantizeDType {
197        match self {
198            Self::F32 => lc_framework::QuantizeDType::F32,
199            Self::F64 => lc_framework::QuantizeDType::F64,
200        }
201    }
202}
203
204#[expect(missing_docs)]
205#[derive(
206    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
207)]
208#[serde(deny_unknown_fields)]
209#[serde(tag = "id")]
210/// LC component
211pub enum LcComponent {
212    #[serde(rename = "NUL")]
213    Noop,
214    // mutators
215    #[serde(rename = "TCMS")]
216    TwosComplementToSignMagnitude { size: LcElemSize },
217    #[serde(rename = "TCNB")]
218    TwosComplementToNegaBinary { size: LcElemSize },
219    #[serde(rename = "DBEFS")]
220    DebiasedExponentFractionSign { size: LcFloatSize },
221    #[serde(rename = "DBESF")]
222    DebiasedExponentSignFraction { size: LcFloatSize },
223    // shufflers
224    #[serde(rename = "BIT")]
225    BitShuffle { size: LcElemSize },
226    #[serde(rename = "TUPL")]
227    Tuple { size: LcTupleSize },
228    // predictors
229    #[serde(rename = "DIFF")]
230    Delta { size: LcElemSize },
231    #[serde(rename = "DIFFMS")]
232    DeltaAsSignMagnitude { size: LcElemSize },
233    #[serde(rename = "DIFFNB")]
234    DeltaAsNegaBinary { size: LcElemSize },
235    // reducers
236    #[serde(rename = "CLOG")]
237    Clog { size: LcElemSize },
238    #[serde(rename = "HCLOG")]
239    HClog { size: LcElemSize },
240    #[serde(rename = "RARE")]
241    Rare { size: LcElemSize },
242    #[serde(rename = "RAZE")]
243    Raze { size: LcElemSize },
244    #[serde(rename = "RLE")]
245    RunLengthEncoding { size: LcElemSize },
246    #[serde(rename = "RRE")]
247    RepetitionRunBitmapEncoding { size: LcElemSize },
248    #[serde(rename = "RZE")]
249    ZeroRunBitmapEncoding { size: LcElemSize },
250}
251
252impl LcComponent {
253    const fn into_lc(self) -> lc_framework::Component {
254        match self {
255            Self::Noop => lc_framework::Component::Noop,
256            // mutators
257            Self::TwosComplementToSignMagnitude { size } => {
258                lc_framework::Component::TwosComplementToSignMagnitude {
259                    size: size.into_lc(),
260                }
261            }
262            Self::TwosComplementToNegaBinary { size } => {
263                lc_framework::Component::TwosComplementToNegaBinary {
264                    size: size.into_lc(),
265                }
266            }
267            Self::DebiasedExponentFractionSign { size } => {
268                lc_framework::Component::DebiasedExponentFractionSign {
269                    size: size.into_lc(),
270                }
271            }
272            Self::DebiasedExponentSignFraction { size } => {
273                lc_framework::Component::DebiasedExponentSignFraction {
274                    size: size.into_lc(),
275                }
276            }
277            // shufflers
278            Self::BitShuffle { size } => lc_framework::Component::BitShuffle {
279                size: size.into_lc(),
280            },
281            Self::Tuple { size } => lc_framework::Component::Tuple {
282                size: size.into_lc(),
283            },
284            // predictors
285            Self::Delta { size } => lc_framework::Component::Delta {
286                size: size.into_lc(),
287            },
288            Self::DeltaAsSignMagnitude { size } => lc_framework::Component::DeltaAsSignMagnitude {
289                size: size.into_lc(),
290            },
291            Self::DeltaAsNegaBinary { size } => lc_framework::Component::DeltaAsNegaBinary {
292                size: size.into_lc(),
293            },
294            // reducers
295            Self::Clog { size } => lc_framework::Component::Clog {
296                size: size.into_lc(),
297            },
298            Self::HClog { size } => lc_framework::Component::HClog {
299                size: size.into_lc(),
300            },
301            Self::Rare { size } => lc_framework::Component::Rare {
302                size: size.into_lc(),
303            },
304            Self::Raze { size } => lc_framework::Component::Raze {
305                size: size.into_lc(),
306            },
307            Self::RunLengthEncoding { size } => lc_framework::Component::RunLengthEncoding {
308                size: size.into_lc(),
309            },
310            Self::RepetitionRunBitmapEncoding { size } => {
311                lc_framework::Component::RepetitionRunBitmapEncoding {
312                    size: size.into_lc(),
313                }
314            }
315            Self::ZeroRunBitmapEncoding { size } => {
316                lc_framework::Component::ZeroRunBitmapEncoding {
317                    size: size.into_lc(),
318                }
319            }
320        }
321    }
322}
323
324#[expect(missing_docs)]
325#[derive(
326    Copy,
327    Clone,
328    Debug,
329    PartialEq,
330    Eq,
331    PartialOrd,
332    Ord,
333    Hash,
334    Serialize_repr,
335    Deserialize_repr,
336    JsonSchema_repr,
337)]
338/// LC component element size, in bytes
339#[repr(u8)]
340pub enum LcElemSize {
341    S1 = 1,
342    S2 = 2,
343    S4 = 4,
344    S8 = 8,
345}
346
347impl LcElemSize {
348    const fn into_lc(self) -> lc_framework::ElemSize {
349        match self {
350            Self::S1 => lc_framework::ElemSize::S1,
351            Self::S2 => lc_framework::ElemSize::S2,
352            Self::S4 => lc_framework::ElemSize::S4,
353            Self::S8 => lc_framework::ElemSize::S8,
354        }
355    }
356}
357
358#[expect(missing_docs)]
359#[derive(
360    Copy,
361    Clone,
362    Debug,
363    PartialEq,
364    Eq,
365    PartialOrd,
366    Ord,
367    Hash,
368    Serialize_repr,
369    Deserialize_repr,
370    JsonSchema_repr,
371)]
372/// LC component float element size, in bytes
373#[repr(u8)]
374pub enum LcFloatSize {
375    S4 = 4,
376    S8 = 8,
377}
378
379impl LcFloatSize {
380    const fn into_lc(self) -> lc_framework::FloatSize {
381        match self {
382            Self::S4 => lc_framework::FloatSize::S4,
383            Self::S8 => lc_framework::FloatSize::S8,
384        }
385    }
386}
387
388#[expect(missing_docs)]
389#[derive(
390    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
391)]
392/// LC tuple component element size, in bytes x tuple length
393#[schemars(description = "LC tuple component element size, in tuple length _ bytes")]
394pub enum LcTupleSize {
395    #[serde(rename = "2_1")]
396    S1x2,
397    #[serde(rename = "3_1")]
398    S1x3,
399    #[serde(rename = "4_1")]
400    S1x4,
401    #[serde(rename = "6_1")]
402    S1x6,
403    #[serde(rename = "8_1")]
404    S1x8,
405    #[serde(rename = "12_1")]
406    S1x12,
407    #[serde(rename = "2_2")]
408    S2x2,
409    #[serde(rename = "3_2")]
410    S2x3,
411    #[serde(rename = "4_2")]
412    S2x4,
413    #[serde(rename = "6_2")]
414    S2x6,
415    #[serde(rename = "2_4")]
416    S4x2,
417    #[serde(rename = "6_4")]
418    S4x6,
419    #[serde(rename = "3_8")]
420    S8x3,
421    #[serde(rename = "6_8")]
422    S8x6,
423}
424
425impl LcTupleSize {
426    const fn into_lc(self) -> lc_framework::TupleSize {
427        match self {
428            Self::S1x2 => lc_framework::TupleSize::S1x2,
429            Self::S1x3 => lc_framework::TupleSize::S1x3,
430            Self::S1x4 => lc_framework::TupleSize::S1x4,
431            Self::S1x6 => lc_framework::TupleSize::S1x6,
432            Self::S1x8 => lc_framework::TupleSize::S1x8,
433            Self::S1x12 => lc_framework::TupleSize::S1x12,
434            Self::S2x2 => lc_framework::TupleSize::S2x2,
435            Self::S2x3 => lc_framework::TupleSize::S2x3,
436            Self::S2x4 => lc_framework::TupleSize::S2x4,
437            Self::S2x6 => lc_framework::TupleSize::S2x6,
438            Self::S4x2 => lc_framework::TupleSize::S4x2,
439            Self::S4x6 => lc_framework::TupleSize::S4x6,
440            Self::S8x3 => lc_framework::TupleSize::S8x3,
441            Self::S8x6 => lc_framework::TupleSize::S8x6,
442        }
443    }
444}
445
446impl Codec for LcCodec {
447    type Error = LcCodecError;
448
449    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
450        compress(data.view(), &self.preprocessors, &self.components)
451            .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
452    }
453
454    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
455        let AnyCowArray::U8(encoded) = encoded else {
456            return Err(LcCodecError::EncodedDataNotBytes {
457                dtype: encoded.dtype(),
458            });
459        };
460
461        if !matches!(encoded.shape(), [_]) {
462            return Err(LcCodecError::EncodedDataNotOneDimensional {
463                shape: encoded.shape().to_vec(),
464            });
465        }
466
467        decompress(
468            &self.preprocessors,
469            &self.components,
470            &AnyCowArray::U8(encoded).as_bytes(),
471        )
472    }
473
474    fn decode_into(
475        &self,
476        encoded: AnyArrayView,
477        decoded: AnyArrayViewMut,
478    ) -> Result<(), Self::Error> {
479        let AnyArrayView::U8(encoded) = encoded else {
480            return Err(LcCodecError::EncodedDataNotBytes {
481                dtype: encoded.dtype(),
482            });
483        };
484
485        if !matches!(encoded.shape(), [_]) {
486            return Err(LcCodecError::EncodedDataNotOneDimensional {
487                shape: encoded.shape().to_vec(),
488            });
489        }
490
491        decompress_into(
492            &self.preprocessors,
493            &self.components,
494            &AnyArrayView::U8(encoded).as_bytes(),
495            decoded,
496        )
497    }
498}
499
500impl StaticCodec for LcCodec {
501    const CODEC_ID: &'static str = "lc.rs";
502
503    type Config<'de> = Self;
504
505    fn from_config(config: Self::Config<'_>) -> Self {
506        config
507    }
508
509    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
510        StaticCodecConfig::from(self)
511    }
512}
513
514#[derive(Debug, Error)]
515/// Errors that may occur when applying the [`LcCodec`].
516pub enum LcCodecError {
517    /// [`LcCodec`] failed to encode the header
518    #[error("Lc failed to encode the header")]
519    HeaderEncodeFailed {
520        /// Opaque source error
521        source: LcHeaderError,
522    },
523    /// [`LcCodec`] failed to encode the encoded data
524    #[error("Lc failed to decode the encoded data")]
525    LcEncodeFailed {
526        /// Opaque source error
527        source: LcCodingError,
528    },
529    /// [`LcCodec`] can only decode one-dimensional byte arrays but received
530    /// an array of a different dtype
531    #[error(
532        "Lc can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
533    )]
534    EncodedDataNotBytes {
535        /// The unexpected dtype of the encoded array
536        dtype: AnyArrayDType,
537    },
538    /// [`LcCodec`] can only decode one-dimensional byte arrays but received
539    /// an array of a different shape
540    #[error(
541        "Lc can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
542    )]
543    EncodedDataNotOneDimensional {
544        /// The unexpected shape of the encoded array
545        shape: Vec<usize>,
546    },
547    /// [`LcCodec`] failed to encode the header
548    #[error("Lc failed to decode the header")]
549    HeaderDecodeFailed {
550        /// Opaque source error
551        source: LcHeaderError,
552    },
553    /// [`LcCodec`] decode produced a different number of bytes than expected
554    #[error("Lc decode produced a different number of bytes than expected")]
555    DecodeDataLengthMismatch,
556    /// [`LcCodec`] failed to decode the encoded data
557    #[error("Lc failed to decode the encoded data")]
558    LcDecodeFailed {
559        /// Opaque source error
560        source: LcCodingError,
561    },
562    /// [`LcCodec`] cannot decode into the provided array
563    #[error("Lc cannot decode into the provided array")]
564    MismatchedDecodeIntoArray {
565        /// The source of the error
566        #[from]
567        source: AnyArrayAssignError,
568    },
569}
570
571#[derive(Debug, Error)]
572#[error(transparent)]
573/// Opaque error for when encoding or decoding the header fails
574pub struct LcHeaderError(postcard::Error);
575
576#[derive(Debug, Error)]
577#[error(transparent)]
578/// Opaque error for when encoding or decoding with LC fails
579pub struct LcCodingError(lc_framework::Error);
580
581#[expect(clippy::needless_pass_by_value)]
582/// Compress the `array` using LC with the provided `preprocessors` and
583/// `components`.
584///
585/// # Errors
586///
587/// Errors with
588/// - [`LcCodecError::HeaderEncodeFailed`] if encoding the header to the
589///   output bytevec failed
590/// - [`LcCodecError::LcEncodeFailed`] if an opaque encoding error occurred
591pub fn compress(
592    array: AnyArrayView,
593    preprocessors: &[LcPreprocessor],
594    components: &[LcComponent],
595) -> Result<Vec<u8>, LcCodecError> {
596    let mut encoded = postcard::to_extend(
597        &CompressionHeader {
598            dtype: array.dtype(),
599            shape: Cow::Borrowed(array.shape()),
600            version: StaticCodecVersion,
601        },
602        Vec::new(),
603    )
604    .map_err(|err| LcCodecError::HeaderEncodeFailed {
605        source: LcHeaderError(err),
606    })?;
607
608    // LC does not support empty input, so skip encoding
609    if array.is_empty() {
610        return Ok(encoded);
611    }
612
613    let preprocessors = preprocessors
614        .iter()
615        .cloned()
616        .map(LcPreprocessor::into_lc)
617        .collect::<Vec<_>>();
618    let components = components
619        .iter()
620        .copied()
621        .map(LcComponent::into_lc)
622        .collect::<Vec<_>>();
623
624    encoded.append(
625        &mut lc_framework::compress(&preprocessors, &components, &array.as_bytes()).map_err(
626            |err| LcCodecError::LcEncodeFailed {
627                source: LcCodingError(err),
628            },
629        )?,
630    );
631
632    Ok(encoded)
633}
634
635/// Decompress the `encoded` data into an array using LC.
636///
637/// # Errors
638///
639/// Errors with
640/// - [`LcCodecError::HeaderDecodeFailed`] if decoding the header failed
641/// - [`LcCodecError::DecodeDataLengthMismatch`] if decoding produced a
642///   different number of bytes than expected
643/// - [`LcCodecError::LcDecodeFailed`] if an opaque decoding error occurred
644pub fn decompress(
645    preprocessors: &[LcPreprocessor],
646    components: &[LcComponent],
647    encoded: &[u8],
648) -> Result<AnyArray, LcCodecError> {
649    let (header, encoded) =
650        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
651            LcCodecError::HeaderDecodeFailed {
652                source: LcHeaderError(err),
653            }
654        })?;
655
656    let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
657        decompress_into_bytes(preprocessors, components, encoded, decoded)
658    });
659
660    result.map(|()| decoded)
661}
662
663/// Decompress the `encoded` data into a `decoded` array using LC.
664///
665/// # Errors
666///
667/// Errors with
668/// - [`LcCodecError::HeaderDecodeFailed`] if decoding the header failed
669/// - [`LcCodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
670///   the wrong dtype or shape
671/// - [`LcCodecError::HeaderDecodeFailed`] if decoding the header failed
672/// - [`LcCodecError::DecodeDataLengthMismatch`] if decoding produced a
673///   different number of bytes than expected
674/// - [`LcCodecError::LcDecodeFailed`] if an opaque decoding error occurred
675pub fn decompress_into(
676    preprocessors: &[LcPreprocessor],
677    components: &[LcComponent],
678    encoded: &[u8],
679    mut decoded: AnyArrayViewMut,
680) -> Result<(), LcCodecError> {
681    let (header, encoded) =
682        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
683            LcCodecError::HeaderDecodeFailed {
684                source: LcHeaderError(err),
685            }
686        })?;
687
688    if header.dtype != decoded.dtype() {
689        return Err(LcCodecError::MismatchedDecodeIntoArray {
690            source: AnyArrayAssignError::DTypeMismatch {
691                src: header.dtype,
692                dst: decoded.dtype(),
693            },
694        });
695    }
696
697    if header.shape != decoded.shape() {
698        return Err(LcCodecError::MismatchedDecodeIntoArray {
699            source: AnyArrayAssignError::ShapeMismatch {
700                src: header.shape.into_owned(),
701                dst: decoded.shape().to_vec(),
702            },
703        });
704    }
705
706    decoded.with_bytes_mut(|decoded| {
707        decompress_into_bytes(preprocessors, components, encoded, decoded)
708    })
709}
710
711fn decompress_into_bytes(
712    preprocessors: &[LcPreprocessor],
713    components: &[LcComponent],
714    encoded: &[u8],
715    decoded: &mut [u8],
716) -> Result<(), LcCodecError> {
717    // LC does not support empty input, so skip decoding
718    if decoded.is_empty() && encoded.is_empty() {
719        return Ok(());
720    }
721
722    let preprocessors = preprocessors
723        .iter()
724        .cloned()
725        .map(LcPreprocessor::into_lc)
726        .collect::<Vec<_>>();
727    let components = components
728        .iter()
729        .copied()
730        .map(LcComponent::into_lc)
731        .collect::<Vec<_>>();
732
733    let dec = lc_framework::decompress(&preprocessors, &components, encoded).map_err(|err| {
734        LcCodecError::LcDecodeFailed {
735            source: LcCodingError(err),
736        }
737    })?;
738
739    if dec.len() != decoded.len() {
740        return Err(LcCodecError::DecodeDataLengthMismatch);
741    }
742
743    decoded.copy_from_slice(&dec);
744
745    Ok(())
746}
747
748#[derive(Serialize, Deserialize)]
749struct CompressionHeader<'a> {
750    dtype: AnyArrayDType,
751    #[serde(borrow)]
752    shape: Cow<'a, [usize]>,
753    version: LcCodecVersion,
754}
755
756#[cfg(test)]
757#[allow(clippy::unwrap_used, clippy::panic)]
758mod tests {
759    use super::*;
760
761    #[test]
762    fn lossless() {
763        let data = ndarray::linspace(0.0, std::f32::consts::PI, 100)
764            .collect::<Array1<f32>>()
765            .into_shape_with_order((10, 10))
766            .unwrap()
767            .cos();
768
769        let preprocessors = &[];
770        let components = &[
771            LcComponent::BitShuffle {
772                size: LcElemSize::S4,
773            },
774            LcComponent::RunLengthEncoding {
775                size: LcElemSize::S4,
776            },
777        ];
778
779        let compressed = compress(
780            AnyArrayView::F32(data.view().into_dyn()),
781            preprocessors,
782            components,
783        )
784        .unwrap();
785        let decompressed = decompress(preprocessors, components, &compressed).unwrap();
786
787        assert_eq!(decompressed, AnyArray::F32(data.into_dyn()));
788    }
789
790    #[test]
791    fn abs_error() {
792        let data = ndarray::linspace(0.0, std::f32::consts::PI, 100)
793            .collect::<Array1<f32>>()
794            .into_shape_with_order((10, 10))
795            .unwrap()
796            .cos();
797
798        let preprocessors = &[LcPreprocessor::QuantizeErrorBound {
799            dtype: LcQuantizeDType::F32,
800            kind: LcErrorKind::Abs,
801            error_bound: 0.1,
802            threshold: None,
803            decorrelation: LcDecorrelation::Zero,
804        }];
805        let components = &[
806            LcComponent::BitShuffle {
807                size: LcElemSize::S4,
808            },
809            LcComponent::RunLengthEncoding {
810                size: LcElemSize::S4,
811            },
812        ];
813
814        let compressed = compress(
815            AnyArrayView::F32(data.view().into_dyn()),
816            preprocessors,
817            components,
818        )
819        .unwrap();
820        let decompressed = decompress(preprocessors, components, &compressed).unwrap();
821
822        let AnyArray::F32(decompressed) = decompressed else {
823            panic!("unexpected decompressed dtype {}", decompressed.dtype());
824        };
825        assert_eq!(decompressed.shape(), data.shape());
826
827        for (o, d) in data.into_iter().zip(decompressed) {
828            assert!((o - d).abs() <= 0.1);
829        }
830    }
831}