numcodecs_linear_quantize/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-linear-quantize
10//! [crates.io]: https://crates.io/crates/numcodecs-linear-quantize
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-linear-quantize
13//! [docs.rs]: https://docs.rs/numcodecs-linear-quantize/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_linear_quantize
17//!
18//! Linear Quantization codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // FIXME: twofloat -> hexf -> syn 1.0
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27    AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28    StaticCodecConfig,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{de::DeserializeOwned, Deserialize, Serialize};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36#[derive(Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields)]
38/// Lossy codec to reduce the precision of floating point data.
39///
40/// The data is quantized to unsigned integers of the best-fitting type.
41/// The range and shape of the input data is stored in-band.
42pub struct LinearQuantizeCodec {
43    /// Dtype of the decoded data
44    pub dtype: LinearQuantizeDType,
45    /// Binary precision of the encoded data where `$bits = \log_{2}(bins)$`
46    pub bits: LinearQuantizeBins,
47}
48
49/// Data types which the [`LinearQuantizeCodec`] can quantize
50#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
51#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
52#[allow(missing_docs)]
53pub enum LinearQuantizeDType {
54    #[serde(rename = "f32", alias = "float32")]
55    F32,
56    #[serde(rename = "f64", alias = "float64")]
57    F64,
58}
59
60impl fmt::Display for LinearQuantizeDType {
61    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
62        fmt.write_str(match self {
63            Self::F32 => "f32",
64            Self::F64 => "f64",
65        })
66    }
67}
68
69/// Number of bins for quantization, written in base-2 scientific notation.
70///
71/// The binary `#[repr(u8)]` value of each variant is equivalent to the binary
72/// logarithm of the number of bins, i.e. the binary precision or the number of
73/// bits used.
74#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
75#[repr(u8)]
76#[rustfmt::skip]
77#[allow(missing_docs)]
78pub enum LinearQuantizeBins {
79    _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
80    _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
81    _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
82    _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
83    _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
84    _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
85    _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
86    _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
87}
88
89impl Codec for LinearQuantizeCodec {
90    type Error = LinearQuantizeCodecError;
91
92    #[allow(clippy::too_many_lines)]
93    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
94        let encoded = match (&data, self.dtype) {
95            (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
96                bits @ ..=8 => AnyArray::U8(
97                    Array1::from_vec(quantize(data, |x| {
98                        let max = f32::from(u8::MAX >> (8 - bits));
99                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
100                        #[allow(unsafe_code)]
101                        // Safety: x is clamped beforehand
102                        unsafe {
103                            x.to_int_unchecked::<u8>()
104                        }
105                    })?)
106                    .into_dyn(),
107                ),
108                bits @ 9..=16 => AnyArray::U16(
109                    Array1::from_vec(quantize(data, |x| {
110                        let max = f32::from(u16::MAX >> (16 - bits));
111                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
112                        #[allow(unsafe_code)]
113                        // Safety: x is clamped beforehand
114                        unsafe {
115                            x.to_int_unchecked::<u16>()
116                        }
117                    })?)
118                    .into_dyn(),
119                ),
120                bits @ 17..=32 => AnyArray::U32(
121                    Array1::from_vec(quantize(data, |x| {
122                        // we need to use f64 here to have sufficient precision
123                        let max = f64::from(u32::MAX >> (32 - bits));
124                        let x = f64::from(x)
125                            .mul_add(scale_for_bits::<f64>(bits), 0.5)
126                            .clamp(0.0, max);
127                        #[allow(unsafe_code)]
128                        // Safety: x is clamped beforehand
129                        unsafe {
130                            x.to_int_unchecked::<u32>()
131                        }
132                    })?)
133                    .into_dyn(),
134                ),
135                bits @ 33.. => AnyArray::U64(
136                    Array1::from_vec(quantize(data, |x| {
137                        // we need to use TwoFloat here to have sufficient precision
138                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
139                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
140                            + TwoFloat::from(0.5))
141                        .max(TwoFloat::from(0.0))
142                        .min(max);
143                        #[allow(unsafe_code)]
144                        // Safety: x is clamped beforehand
145                        unsafe {
146                            u64::try_from(x).unwrap_unchecked()
147                        }
148                    })?)
149                    .into_dyn(),
150                ),
151            },
152            (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
153                bits @ ..=8 => AnyArray::U8(
154                    Array1::from_vec(quantize(data, |x| {
155                        let max = f64::from(u8::MAX >> (8 - bits));
156                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
157                        #[allow(unsafe_code)]
158                        // Safety: x is clamped beforehand
159                        unsafe {
160                            x.to_int_unchecked::<u8>()
161                        }
162                    })?)
163                    .into_dyn(),
164                ),
165                bits @ 9..=16 => AnyArray::U16(
166                    Array1::from_vec(quantize(data, |x| {
167                        let max = f64::from(u16::MAX >> (16 - bits));
168                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
169                        #[allow(unsafe_code)]
170                        // Safety: x is clamped beforehand
171                        unsafe {
172                            x.to_int_unchecked::<u16>()
173                        }
174                    })?)
175                    .into_dyn(),
176                ),
177                bits @ 17..=32 => AnyArray::U32(
178                    Array1::from_vec(quantize(data, |x| {
179                        let max = f64::from(u32::MAX >> (32 - bits));
180                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
181                        #[allow(unsafe_code)]
182                        // Safety: x is clamped beforehand
183                        unsafe {
184                            x.to_int_unchecked::<u32>()
185                        }
186                    })?)
187                    .into_dyn(),
188                ),
189                bits @ 33.. => AnyArray::U64(
190                    Array1::from_vec(quantize(data, |x| {
191                        // we need to use TwoFloat here to have sufficient precision
192                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
193                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
194                            + TwoFloat::from(0.5))
195                        .max(TwoFloat::from(0.0))
196                        .min(max);
197                        #[allow(unsafe_code)]
198                        // Safety: x is clamped beforehand
199                        unsafe {
200                            u64::try_from(x).unwrap_unchecked()
201                        }
202                    })?)
203                    .into_dyn(),
204                ),
205            },
206            (data, dtype) => {
207                return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
208                    configured: dtype,
209                    provided: data.dtype(),
210                });
211            }
212        };
213
214        Ok(encoded)
215    }
216
217    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
218        #[allow(clippy::option_if_let_else)]
219        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
220            array: &ArrayBase<S, D>,
221        ) -> Cow<[T]> {
222            if let Some(data) = array.as_slice() {
223                Cow::Borrowed(data)
224            } else {
225                Cow::Owned(array.iter().copied().collect())
226            }
227        }
228
229        if !matches!(encoded.shape(), [_]) {
230            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
231                shape: encoded.shape().to_vec(),
232            });
233        }
234
235        let decoded = match (&encoded, self.dtype) {
236            (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
237                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
238                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
239                })?)
240            }
241            (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
242                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244                })?)
245            }
246            (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
247                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248                    // we need to use f64 here to have sufficient precision
249                    let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
250                    #[allow(clippy::cast_possible_truncation)]
251                    let x = x as f32;
252                    x
253                })?)
254            }
255            (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
256                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
257                    // we need to use TwoFloat here to have sufficient precision
258                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
259                    f32::from(x)
260                })?)
261            }
262            (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
263                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
264                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
265                })?)
266            }
267            (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
268                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270                })?)
271            }
272            (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
273                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275                })?)
276            }
277            (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
278                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279                    // we need to use TwoFloat here to have sufficient precision
280                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
281                    f64::from(x)
282                })?)
283            }
284            (encoded, _dtype) => {
285                return Err(LinearQuantizeCodecError::InvalidEncodedDType {
286                    dtype: encoded.dtype(),
287                })
288            }
289        };
290
291        Ok(decoded)
292    }
293
294    fn decode_into(
295        &self,
296        encoded: AnyArrayView,
297        decoded: AnyArrayViewMut,
298    ) -> Result<(), Self::Error> {
299        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
300            array: &ArrayBase<S, D>,
301        ) -> Cow<[T]> {
302            #[allow(clippy::option_if_let_else)]
303            if let Some(data) = array.as_slice() {
304                Cow::Borrowed(data)
305            } else {
306                Cow::Owned(array.iter().copied().collect())
307            }
308        }
309
310        if !matches!(encoded.shape(), [_]) {
311            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
312                shape: encoded.shape().to_vec(),
313            });
314        }
315
316        match (decoded, self.dtype) {
317            (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
318                match &encoded {
319                    AnyArrayView::U8(encoded) => {
320                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
321                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
322                        })
323                    }
324                    AnyArrayView::U16(encoded) => {
325                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327                        })
328                    }
329                    AnyArrayView::U32(encoded) => {
330                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331                            // we need to use f64 here to have sufficient precision
332                            let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
333                            #[allow(clippy::cast_possible_truncation)]
334                            let x = x as f32;
335                            x
336                        })
337                    }
338                    AnyArrayView::U64(encoded) => {
339                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
340                            // we need to use TwoFloat here to have sufficient precision
341                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
342                            f32::from(x)
343                        })
344                    }
345                    encoded => {
346                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
347                            dtype: encoded.dtype(),
348                        })
349                    }
350                }
351            }
352            (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
353                match &encoded {
354                    AnyArrayView::U8(encoded) => {
355                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
356                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
357                        })
358                    }
359                    AnyArrayView::U16(encoded) => {
360                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362                        })
363                    }
364                    AnyArrayView::U32(encoded) => {
365                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367                        })
368                    }
369                    AnyArrayView::U64(encoded) => {
370                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371                            // we need to use TwoFloat here to have sufficient precision
372                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
373                            f64::from(x)
374                        })
375                    }
376                    encoded => {
377                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
378                            dtype: encoded.dtype(),
379                        })
380                    }
381                }
382            }
383            (decoded, dtype) => {
384                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
385                    configured: dtype,
386                    provided: decoded.dtype(),
387                })
388            }
389        }?;
390
391        Ok(())
392    }
393}
394
395impl StaticCodec for LinearQuantizeCodec {
396    const CODEC_ID: &'static str = "linear-quantize";
397
398    type Config<'de> = Self;
399
400    fn from_config(config: Self::Config<'_>) -> Self {
401        config
402    }
403
404    fn get_config(&self) -> StaticCodecConfig<Self> {
405        StaticCodecConfig::from(self)
406    }
407}
408
409#[derive(Debug, Error)]
410/// Errors that may occur when applying the [`LinearQuantizeCodec`].
411pub enum LinearQuantizeCodecError {
412    /// [`LinearQuantizeCodec`] cannot encode the provided dtype which differs
413    /// from the configured dtype
414    #[error("LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
415    MismatchedEncodeDType {
416        /// Dtype of the `configured` `dtype`
417        configured: LinearQuantizeDType,
418        /// Dtype of the `provided` array from which the data is to be encoded
419        provided: AnyArrayDType,
420    },
421    /// [`LinearQuantizeCodec`] does not support non-finite (infinite or NaN) floating
422    /// point data
423    #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
424    NonFiniteData,
425    /// [`LinearQuantizeCodec`] failed to encode the header
426    #[error("LinearQuantize failed to encode the header")]
427    HeaderEncodeFailed {
428        /// Opaque source error
429        source: LinearQuantizeHeaderError,
430    },
431    /// [`LinearQuantizeCodec`] can only decode one-dimensional arrays but
432    /// received an array of a different shape
433    #[error("LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}")]
434    EncodedDataNotOneDimensional {
435        /// The unexpected shape of the encoded array
436        shape: Vec<usize>,
437    },
438    /// [`LinearQuantizeCodec`] failed to decode the header
439    #[error("LinearQuantize failed to decode the header")]
440    HeaderDecodeFailed {
441        /// Opaque source error
442        source: LinearQuantizeHeaderError,
443    },
444    /// [`LinearQuantizeCodec`] decoded an invalid array shape header which does
445    /// not fit the decoded data
446    #[error(
447        "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
448    )]
449    DecodeInvalidShapeHeader {
450        /// Source error
451        #[from]
452        source: ShapeError,
453    },
454    /// [`LinearQuantizeCodec`] cannot decode the provided dtype
455    #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
456    InvalidEncodedDType {
457        /// Dtype of the provided array from which the data is to be decoded
458        dtype: AnyArrayDType,
459    },
460    /// [`LinearQuantizeCodec`] cannot decode the provided dtype which differs
461    /// from the configured dtype
462    #[error("LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
463    MismatchedDecodeIntoDtype {
464        /// Dtype of the `configured` `dtype`
465        configured: LinearQuantizeDType,
466        /// Dtype of the `provided` array into which the data is to be decoded
467        provided: AnyArrayDType,
468    },
469    /// [`LinearQuantizeCodec`] cannot decode the decoded array into the provided
470    /// array of a different shape
471    #[error("LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}")]
472    MismatchedDecodeIntoShape {
473        /// Shape of the `decoded` data
474        decoded: Vec<usize>,
475        /// Shape of the `provided` array into which the data is to be decoded
476        provided: Vec<usize>,
477    },
478}
479
480#[derive(Debug, Error)]
481#[error(transparent)]
482/// Opaque error for when encoding or decoding the header fails
483pub struct LinearQuantizeHeaderError(postcard::Error);
484
485/// Linear-quantize the elements in the `data` array using the `quantize`
486/// closure.
487///
488/// # Errors
489///
490/// Errors with
491/// - [`LinearQuantizeCodecError::NonFiniteData`] if any data element is non-
492///   finite (infinite or NaN)
493/// - [`LinearQuantizeCodecError::HeaderEncodeFailed`] if encoding the header
494///   failed
495pub fn quantize<
496    T: Float + ConstZero + ConstOne + Serialize,
497    Q: Unsigned,
498    S: Data<Elem = T>,
499    D: Dimension,
500>(
501    data: &ArrayBase<S, D>,
502    quantize: impl Fn(T) -> Q,
503) -> Result<Vec<Q>, LinearQuantizeCodecError> {
504    if !Zip::from(data).all(|x| x.is_finite()) {
505        return Err(LinearQuantizeCodecError::NonFiniteData);
506    }
507
508    let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
509        (
510            Zip::from(data).fold(*first, |a, b| a.min(*b)),
511            Zip::from(data).fold(*first, |a, b| a.max(*b)),
512        )
513    });
514
515    let header = postcard::to_extend(
516        &CompressionHeader {
517            shape: Cow::Borrowed(data.shape()),
518            minimum,
519            maximum,
520        },
521        Vec::new(),
522    )
523    .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
524        source: LinearQuantizeHeaderError(err),
525    })?;
526
527    let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
528    #[allow(unsafe_code)]
529    // Safety: encoded is at least header.len() bytes long and properly aligned for Q
530    unsafe {
531        std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
532    }
533    encoded.reserve(data.len());
534
535    if maximum == minimum {
536        encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
537    } else {
538        encoded.extend(
539            data.iter()
540                .map(|x| quantize((*x - minimum) / (maximum - minimum))),
541        );
542    }
543
544    Ok(encoded)
545}
546
547/// Reconstruct the linear-quantized `encoded` array using the `floatify`
548/// closure.
549///
550/// # Errors
551///
552/// Errors with
553/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
554///   failed
555pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
556    encoded: &[Q],
557    floatify: impl Fn(Q) -> T,
558) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
559    #[allow(unsafe_code)]
560    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
561    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
562        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
563    })
564    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
565        source: LinearQuantizeHeaderError(err),
566    })?;
567
568    let encoded = encoded
569        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
570        .unwrap_or(&[]);
571
572    let decoded = encoded
573        .iter()
574        .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
575        .map(|x| x.clamp(header.minimum, header.maximum))
576        .collect();
577
578    let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
579
580    Ok(decoded)
581}
582
583#[allow(clippy::needless_pass_by_value)]
584/// Reconstruct the linear-quantized `encoded` array using the `floatify`
585/// closure into the `decoded` array.
586///
587/// # Errors
588///
589/// Errors with
590/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
591///   failed
592/// - [`LinearQuantizeCodecError::MismatchedDecodeIntoShape`] if the `decoded`
593///   array is of the wrong shape
594pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
595    encoded: &[Q],
596    mut decoded: ArrayViewMutD<T>,
597    floatify: impl Fn(Q) -> T,
598) -> Result<(), LinearQuantizeCodecError> {
599    #[allow(unsafe_code)]
600    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
601    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
602        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
603    })
604    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
605        source: LinearQuantizeHeaderError(err),
606    })?;
607
608    let encoded = encoded
609        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
610        .unwrap_or(&[]);
611
612    if decoded.shape() != &*header.shape {
613        return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
614            decoded: header.shape.into_owned(),
615            provided: decoded.shape().to_vec(),
616        });
617    }
618
619    // iteration must occur in synchronised (standard) order
620    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
621        *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
622            .clamp(header.minimum, header.maximum);
623    }
624
625    Ok(())
626}
627
628/// Returns `${2.0}^{bits} - 1.0$`
629fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
630    <T as From<u8>>::from(bits).exp2() - T::ONE
631}
632
633/// Unsigned binary types.
634pub trait Unsigned: Copy {
635    /// `0x0`
636    const ZERO: Self;
637}
638
639impl Unsigned for u8 {
640    const ZERO: Self = 0;
641}
642
643impl Unsigned for u16 {
644    const ZERO: Self = 0;
645}
646
647impl Unsigned for u32 {
648    const ZERO: Self = 0;
649}
650
651impl Unsigned for u64 {
652    const ZERO: Self = 0;
653}
654
655#[derive(Serialize, Deserialize)]
656struct CompressionHeader<'a, T> {
657    #[serde(borrow)]
658    shape: Cow<'a, [usize]>,
659    minimum: T,
660    maximum: T,
661}
662
663#[cfg(test)]
664mod tests {
665    use ndarray::CowArray;
666
667    use super::*;
668
669    #[test]
670    fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
671        for bits in 1..=16 {
672            let codec = LinearQuantizeCodec {
673                dtype: LinearQuantizeDType::F32,
674                #[allow(unsafe_code)]
675                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
676            };
677
678            let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
679                .step_by(1 << (bits.max(8) - 8))
680                .map(f32::from)
681                .collect();
682            data.push(f32::from(u16::MAX >> (16 - bits)));
683
684            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
685            let decoded = codec.decode(encoded.cow())?;
686
687            let AnyArray::F32(decoded) = decoded else {
688                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
689                    configured: LinearQuantizeDType::F32,
690                    provided: decoded.dtype(),
691                });
692            };
693
694            for (o, d) in data.iter().zip(decoded.iter()) {
695                assert_eq!(o.to_bits(), d.to_bits());
696            }
697        }
698
699        Ok(())
700    }
701
702    #[test]
703    fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
704        for bits in 1..=64 {
705            let codec = LinearQuantizeCodec {
706                dtype: LinearQuantizeDType::F32,
707                #[allow(unsafe_code)]
708                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
709            };
710
711            #[allow(clippy::cast_precision_loss)]
712            let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
713                .step_by(1 << (bits.max(8) - 8))
714                .map(|x| x as f32)
715                .collect();
716            #[allow(clippy::cast_precision_loss)]
717            data.push((u64::MAX >> (64 - bits)) as f32);
718
719            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
720            let decoded = codec.decode(encoded.cow())?;
721
722            let AnyArray::F32(decoded) = decoded else {
723                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
724                    configured: LinearQuantizeDType::F32,
725                    provided: decoded.dtype(),
726                });
727            };
728
729            for (o, d) in data.iter().zip(decoded.iter()) {
730                assert_eq!(o.to_bits(), d.to_bits());
731            }
732        }
733
734        Ok(())
735    }
736
737    #[test]
738    fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
739        for bits in 1..=32 {
740            let codec = LinearQuantizeCodec {
741                dtype: LinearQuantizeDType::F64,
742                #[allow(unsafe_code)]
743                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
744            };
745
746            let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
747                .step_by(1 << (bits.max(8) - 8))
748                .map(f64::from)
749                .collect();
750            data.push(f64::from(u32::MAX >> (32 - bits)));
751
752            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
753            let decoded = codec.decode(encoded.cow())?;
754
755            let AnyArray::F64(decoded) = decoded else {
756                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
757                    configured: LinearQuantizeDType::F64,
758                    provided: decoded.dtype(),
759                });
760            };
761
762            for (o, d) in data.iter().zip(decoded.iter()) {
763                assert_eq!(o.to_bits(), d.to_bits());
764            }
765        }
766
767        Ok(())
768    }
769
770    #[test]
771    fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
772        for bits in 1..=64 {
773            let codec = LinearQuantizeCodec {
774                dtype: LinearQuantizeDType::F64,
775                #[allow(unsafe_code)]
776                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
777            };
778
779            #[allow(clippy::cast_precision_loss)]
780            let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
781                .step_by(1 << (bits.max(8) - 8))
782                .map(|x| x as f64)
783                .collect();
784            #[allow(clippy::cast_precision_loss)]
785            data.push((u64::MAX >> (64 - bits)) as f64);
786
787            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
788            let decoded = codec.decode(encoded.cow())?;
789
790            let AnyArray::F64(decoded) = decoded else {
791                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
792                    configured: LinearQuantizeDType::F64,
793                    provided: decoded.dtype(),
794                });
795            };
796
797            for (o, d) in data.iter().zip(decoded.iter()) {
798                assert_eq!(o.to_bits(), d.to_bits());
799            }
800        }
801
802        Ok(())
803    }
804
805    #[test]
806    fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
807        for bits in 1..=64 {
808            let data = [42.0, 42.0, 42.0, 42.0];
809
810            let codec = LinearQuantizeCodec {
811                dtype: LinearQuantizeDType::F64,
812                #[allow(unsafe_code)]
813                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
814            };
815
816            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
817            let decoded = codec.decode(encoded.cow())?;
818
819            let AnyArray::F64(decoded) = decoded else {
820                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
821                    configured: LinearQuantizeDType::F64,
822                    provided: decoded.dtype(),
823                });
824            };
825
826            for (o, d) in data.iter().zip(decoded.iter()) {
827                assert_eq!(o.to_bits(), d.to_bits());
828            }
829        }
830
831        Ok(())
832    }
833}