numcodecs_reinterpret/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.76.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-reinterpret
10//! [crates.io]: https://crates.io/crates/numcodecs-reinterpret
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-reinterpret
13//! [docs.rs]: https://docs.rs/numcodecs-reinterpret/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_reinterpret
17//!
18//! Binary reinterpret codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23    ArrayDType, Codec, StaticCodec, StaticCodecConfig,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31/// Codec to reinterpret data between different compatible types.
32///
33/// Note that no conversion happens, only the meaning of the bits changes.
34///
35/// Reinterpreting to bytes, or to a same-sized unsigned integer type, or
36/// without the changing the dtype are supported.
37pub struct ReinterpretCodec {
38    /// Dtype of the encoded data.
39    encode_dtype: AnyArrayDType,
40    /// Dtype of the decoded data
41    decode_dtype: AnyArrayDType,
42}
43
44impl ReinterpretCodec {
45    /// Try to create a [`ReinterpretCodec`] that reinterprets the input data
46    /// from `decode_dtype` to `encode_dtype` on encoding, and from
47    /// `encode_dtype` back to `decode_dtype` on decoding.
48    ///
49    /// # Errors
50    ///
51    /// Errors with [`ReinterpretCodecError::InvalidReinterpret`] if
52    /// `encode_dtype` and `decode_dtype` are incompatible.
53    pub fn try_new(
54        encode_dtype: AnyArrayDType,
55        decode_dtype: AnyArrayDType,
56    ) -> Result<Self, ReinterpretCodecError> {
57        #[allow(clippy::match_same_arms)]
58        match (decode_dtype, encode_dtype) {
59            // performing no conversion always works
60            (ty_a, ty_b) if ty_a == ty_b => (),
61            // converting to bytes always works
62            (_, AnyArrayDType::U8) => (),
63            // converting from signed / floating to same-size binary always works
64            (AnyArrayDType::I16, AnyArrayDType::U16)
65            | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
66            | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
67            (decode_dtype, encode_dtype) => {
68                return Err(ReinterpretCodecError::InvalidReinterpret {
69                    decode_dtype,
70                    encode_dtype,
71                })
72            }
73        };
74
75        Ok(Self {
76            encode_dtype,
77            decode_dtype,
78        })
79    }
80
81    #[must_use]
82    /// Create a [`ReinterpretCodec`] that does not change the `dtype`.
83    pub const fn passthrough(dtype: AnyArrayDType) -> Self {
84        Self {
85            encode_dtype: dtype,
86            decode_dtype: dtype,
87        }
88    }
89
90    #[must_use]
91    /// Create a [`ReinterpretCodec`] that reinterprets `dtype` as
92    /// [bytes][`AnyArrayDType::U8`].
93    pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
94        Self {
95            encode_dtype: AnyArrayDType::U8,
96            decode_dtype: dtype,
97        }
98    }
99
100    #[must_use]
101    /// Create a  [`ReinterpretCodec`] that reinterprets `dtype` as its
102    /// [binary][`AnyArrayDType::to_binary`] equivalent.
103    pub const fn to_binary(dtype: AnyArrayDType) -> Self {
104        Self {
105            encode_dtype: dtype.to_binary(),
106            decode_dtype: dtype,
107        }
108    }
109}
110
111impl Codec for ReinterpretCodec {
112    type Error = ReinterpretCodecError;
113
114    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
115        if data.dtype() != self.decode_dtype {
116            return Err(ReinterpretCodecError::MismatchedEncodeDType {
117                configured: self.decode_dtype,
118                provided: data.dtype(),
119            });
120        }
121
122        let encoded = match (data, self.encode_dtype) {
123            (data, dtype) if data.dtype() == dtype => data.into_owned(),
124            (data, AnyArrayDType::U8) => {
125                let mut shape = data.shape().to_vec();
126                if let Some(last) = shape.last_mut() {
127                    *last *= data.dtype().size();
128                }
129                #[allow(unsafe_code)]
130                // Safety: the shape is extended to match the expansion into bytes
131                let encoded =
132                    unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
133                AnyArray::U8(encoded)
134            }
135            (AnyCowArray::I16(data), AnyArrayDType::U16) => {
136                AnyArray::U16(reinterpret_array(data, |x| {
137                    u16::from_ne_bytes(x.to_ne_bytes())
138                }))
139            }
140            (AnyCowArray::I32(data), AnyArrayDType::U32) => {
141                AnyArray::U32(reinterpret_array(data, |x| {
142                    u32::from_ne_bytes(x.to_ne_bytes())
143                }))
144            }
145            (AnyCowArray::F32(data), AnyArrayDType::U32) => {
146                AnyArray::U32(reinterpret_array(data, f32::to_bits))
147            }
148            (AnyCowArray::I64(data), AnyArrayDType::U64) => {
149                AnyArray::U64(reinterpret_array(data, |x| {
150                    u64::from_ne_bytes(x.to_ne_bytes())
151                }))
152            }
153            (AnyCowArray::F64(data), AnyArrayDType::U64) => {
154                AnyArray::U64(reinterpret_array(data, f64::to_bits))
155            }
156            (data, dtype) => {
157                return Err(ReinterpretCodecError::InvalidReinterpret {
158                    decode_dtype: data.dtype(),
159                    encode_dtype: dtype,
160                });
161            }
162        };
163
164        Ok(encoded)
165    }
166
167    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
168        if encoded.dtype() != self.encode_dtype {
169            return Err(ReinterpretCodecError::MismatchedDecodeDType {
170                configured: self.encode_dtype,
171                provided: encoded.dtype(),
172            });
173        }
174
175        let decoded = match (encoded, self.decode_dtype) {
176            (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
177            (AnyCowArray::U8(encoded), dtype) => {
178                let mut shape = encoded.shape().to_vec();
179
180                if (encoded.len() % dtype.size()) != 0 {
181                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
182                }
183
184                if let Some(last) = shape.last_mut() {
185                    *last /= dtype.size();
186                }
187
188                let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
189                    bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
190                });
191
192                decoded
193            }
194            (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
195                AnyArray::I16(reinterpret_array(encoded, |x| {
196                    i16::from_ne_bytes(x.to_ne_bytes())
197                }))
198            }
199            (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
200                AnyArray::I32(reinterpret_array(encoded, |x| {
201                    i32::from_ne_bytes(x.to_ne_bytes())
202                }))
203            }
204            (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
205                AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
206            }
207            (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
208                AnyArray::I64(reinterpret_array(encoded, |x| {
209                    i64::from_ne_bytes(x.to_ne_bytes())
210                }))
211            }
212            (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
213                AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
214            }
215            (encoded, dtype) => {
216                return Err(ReinterpretCodecError::InvalidReinterpret {
217                    decode_dtype: dtype,
218                    encode_dtype: encoded.dtype(),
219                });
220            }
221        };
222
223        Ok(decoded)
224    }
225
226    #[allow(clippy::too_many_lines)]
227    fn decode_into(
228        &self,
229        encoded: AnyArrayView,
230        mut decoded: AnyArrayViewMut,
231    ) -> Result<(), Self::Error> {
232        if encoded.dtype() != self.encode_dtype {
233            return Err(ReinterpretCodecError::MismatchedDecodeDType {
234                configured: self.encode_dtype,
235                provided: encoded.dtype(),
236            });
237        }
238
239        match (encoded, self.decode_dtype) {
240            (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
241            (AnyArrayView::U8(encoded), dtype) => {
242                if decoded.dtype() != dtype {
243                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
244                        source: AnyArrayAssignError::DTypeMismatch {
245                            src: dtype,
246                            dst: decoded.dtype(),
247                        },
248                    });
249                }
250
251                let mut shape = encoded.shape().to_vec();
252
253                if (encoded.len() % dtype.size()) != 0 {
254                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
255                }
256
257                if let Some(last) = shape.last_mut() {
258                    *last /= dtype.size();
259                }
260
261                if decoded.shape() != shape {
262                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
263                        source: AnyArrayAssignError::ShapeMismatch {
264                            src: shape,
265                            dst: decoded.shape().to_vec(),
266                        },
267                    });
268                }
269
270                let () = decoded.with_bytes_mut(|bytes| {
271                    bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
272                });
273
274                Ok(())
275            }
276            (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
277                reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
278            }
279            (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
280                reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
281            }
282            (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
283                reinterpret_array_into(encoded, f32::from_bits, decoded)
284            }
285            (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
286                reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
287            }
288            (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
289                reinterpret_array_into(encoded, f64::from_bits, decoded)
290            }
291            (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
292                decode_dtype: dtype,
293                encode_dtype: encoded.dtype(),
294            }),
295        }?;
296
297        Ok(())
298    }
299}
300
301impl StaticCodec for ReinterpretCodec {
302    const CODEC_ID: &'static str = "reinterpret";
303
304    type Config<'de> = Self;
305
306    fn from_config(config: Self::Config<'_>) -> Self {
307        config
308    }
309
310    fn get_config(&self) -> StaticCodecConfig<Self> {
311        StaticCodecConfig::from(self)
312    }
313}
314
315impl Serialize for ReinterpretCodec {
316    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
317        ReinterpretCodecConfig {
318            encode_dtype: self.encode_dtype,
319            decode_dtype: self.decode_dtype,
320        }
321        .serialize(serializer)
322    }
323}
324
325impl<'de> Deserialize<'de> for ReinterpretCodec {
326    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
327        let config = ReinterpretCodecConfig::deserialize(deserializer)?;
328
329        Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
330    }
331}
332
333#[derive(Clone, Serialize, Deserialize)]
334#[serde(rename = "ReinterpretCodec")]
335struct ReinterpretCodecConfig {
336    encode_dtype: AnyArrayDType,
337    decode_dtype: AnyArrayDType,
338}
339
340#[derive(Debug, Error)]
341/// Errors that may occur when applying the [`ReinterpretCodec`].
342pub enum ReinterpretCodecError {
343    /// [`ReinterpretCodec`] cannot cannot bitcast the `decode_dtype` as
344    /// `encode_dtype`
345    #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
346    InvalidReinterpret {
347        /// Dtype of the configured `decode_dtype`
348        decode_dtype: AnyArrayDType,
349        /// Dtype of the configured `encode_dtype`
350        encode_dtype: AnyArrayDType,
351    },
352    /// [`ReinterpretCodec`] cannot encode the provided dtype which differs
353    /// from the configured dtype
354    #[error("Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
355    MismatchedEncodeDType {
356        /// Dtype of the `configured` `decode_dtype`
357        configured: AnyArrayDType,
358        /// Dtype of the `provided` array from which the data is to be encoded
359        provided: AnyArrayDType,
360    },
361    /// [`ReinterpretCodec`] cannot decode the provided dtype which differs
362    /// from the configured dtype
363    #[error("Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
364    MismatchedDecodeDType {
365        /// Dtype of the `configured` `encode_dtype`
366        configured: AnyArrayDType,
367        /// Dtype of the `provided` array from which the data is to be decoded
368        provided: AnyArrayDType,
369    },
370    /// [`ReinterpretCodec`] cannot decode a byte array with `shape` into an array of `dtype`s
371    #[error(
372        "Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s"
373    )]
374    InvalidEncodedShape {
375        /// Shape of the encoded array
376        shape: Vec<usize>,
377        /// Dtype of the array into which the encoded data is to be decoded
378        dtype: AnyArrayDType,
379    },
380    /// [`ReinterpretCodec`] cannot decode into the provided array
381    #[error("Reinterpret cannot decode into the provided array")]
382    MismatchedDecodeIntoArray {
383        /// The source of the error
384        #[from]
385        source: AnyArrayAssignError,
386    },
387}
388
389/// Reinterpret the data elements of the `array` using the provided `reinterpret`
390/// closure. The shape of the data is preserved.
391#[inline]
392pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
393    array: ArrayBase<S, D>,
394    reinterpret: impl Fn(T) -> U,
395) -> Array<U, D> {
396    let array = array.into_owned();
397    let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
398
399    let data = data.into_iter().map(reinterpret).collect();
400
401    #[allow(unsafe_code)]
402    // Safety: we have preserved the shape, which comes from a valid array
403    let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
404
405    array
406}
407
408#[allow(clippy::needless_pass_by_value)]
409/// Reinterpret the data elements of the `encoded` array using the provided
410/// `reinterpret` closure into the `decoded` array.
411///
412/// # Errors
413///
414/// Errors with
415/// - [`ReinterpretCodecError::MismatchedDecodeIntoArray`] if `decoded` does not
416///   contain an array with elements of type `U` or its shape does not match the
417///   `encoded` array's shape
418#[inline]
419pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
420    encoded: ArrayView<T, D>,
421    reinterpret: impl Fn(T) -> U,
422    mut decoded: AnyArrayViewMut<'a>,
423) -> Result<(), ReinterpretCodecError>
424where
425    U::RawData<ViewRepr<&'a mut ()>>: DataMut,
426{
427    let Some(decoded) = decoded.as_typed_mut::<U>() else {
428        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
429            source: AnyArrayAssignError::DTypeMismatch {
430                src: U::DTYPE,
431                dst: decoded.dtype(),
432            },
433        });
434    };
435
436    if encoded.shape() != decoded.shape() {
437        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
438            source: AnyArrayAssignError::ShapeMismatch {
439                src: encoded.shape().to_vec(),
440                dst: decoded.shape().to_vec(),
441            },
442        });
443    }
444
445    // iterate over the elements in standard order
446    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
447        *d = reinterpret(*e);
448    }
449
450    Ok(())
451}