1use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23 ArrayDType, Codec, StaticCodec, StaticCodecConfig,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct ReinterpretCodec {
38 encode_dtype: AnyArrayDType,
40 decode_dtype: AnyArrayDType,
42}
43
44impl ReinterpretCodec {
45 pub fn try_new(
54 encode_dtype: AnyArrayDType,
55 decode_dtype: AnyArrayDType,
56 ) -> Result<Self, ReinterpretCodecError> {
57 #[allow(clippy::match_same_arms)]
58 match (decode_dtype, encode_dtype) {
59 (ty_a, ty_b) if ty_a == ty_b => (),
61 (_, AnyArrayDType::U8) => (),
63 (AnyArrayDType::I16, AnyArrayDType::U16)
65 | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
66 | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
67 (decode_dtype, encode_dtype) => {
68 return Err(ReinterpretCodecError::InvalidReinterpret {
69 decode_dtype,
70 encode_dtype,
71 })
72 }
73 };
74
75 Ok(Self {
76 encode_dtype,
77 decode_dtype,
78 })
79 }
80
81 #[must_use]
82 pub const fn passthrough(dtype: AnyArrayDType) -> Self {
84 Self {
85 encode_dtype: dtype,
86 decode_dtype: dtype,
87 }
88 }
89
90 #[must_use]
91 pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
94 Self {
95 encode_dtype: AnyArrayDType::U8,
96 decode_dtype: dtype,
97 }
98 }
99
100 #[must_use]
101 pub const fn to_binary(dtype: AnyArrayDType) -> Self {
104 Self {
105 encode_dtype: dtype.to_binary(),
106 decode_dtype: dtype,
107 }
108 }
109}
110
111impl Codec for ReinterpretCodec {
112 type Error = ReinterpretCodecError;
113
114 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
115 if data.dtype() != self.decode_dtype {
116 return Err(ReinterpretCodecError::MismatchedEncodeDType {
117 configured: self.decode_dtype,
118 provided: data.dtype(),
119 });
120 }
121
122 let encoded = match (data, self.encode_dtype) {
123 (data, dtype) if data.dtype() == dtype => data.into_owned(),
124 (data, AnyArrayDType::U8) => {
125 let mut shape = data.shape().to_vec();
126 if let Some(last) = shape.last_mut() {
127 *last *= data.dtype().size();
128 }
129 #[allow(unsafe_code)]
130 let encoded =
132 unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
133 AnyArray::U8(encoded)
134 }
135 (AnyCowArray::I16(data), AnyArrayDType::U16) => {
136 AnyArray::U16(reinterpret_array(data, |x| {
137 u16::from_ne_bytes(x.to_ne_bytes())
138 }))
139 }
140 (AnyCowArray::I32(data), AnyArrayDType::U32) => {
141 AnyArray::U32(reinterpret_array(data, |x| {
142 u32::from_ne_bytes(x.to_ne_bytes())
143 }))
144 }
145 (AnyCowArray::F32(data), AnyArrayDType::U32) => {
146 AnyArray::U32(reinterpret_array(data, f32::to_bits))
147 }
148 (AnyCowArray::I64(data), AnyArrayDType::U64) => {
149 AnyArray::U64(reinterpret_array(data, |x| {
150 u64::from_ne_bytes(x.to_ne_bytes())
151 }))
152 }
153 (AnyCowArray::F64(data), AnyArrayDType::U64) => {
154 AnyArray::U64(reinterpret_array(data, f64::to_bits))
155 }
156 (data, dtype) => {
157 return Err(ReinterpretCodecError::InvalidReinterpret {
158 decode_dtype: data.dtype(),
159 encode_dtype: dtype,
160 });
161 }
162 };
163
164 Ok(encoded)
165 }
166
167 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
168 if encoded.dtype() != self.encode_dtype {
169 return Err(ReinterpretCodecError::MismatchedDecodeDType {
170 configured: self.encode_dtype,
171 provided: encoded.dtype(),
172 });
173 }
174
175 let decoded = match (encoded, self.decode_dtype) {
176 (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
177 (AnyCowArray::U8(encoded), dtype) => {
178 let mut shape = encoded.shape().to_vec();
179
180 if (encoded.len() % dtype.size()) != 0 {
181 return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
182 }
183
184 if let Some(last) = shape.last_mut() {
185 *last /= dtype.size();
186 }
187
188 let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
189 bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
190 });
191
192 decoded
193 }
194 (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
195 AnyArray::I16(reinterpret_array(encoded, |x| {
196 i16::from_ne_bytes(x.to_ne_bytes())
197 }))
198 }
199 (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
200 AnyArray::I32(reinterpret_array(encoded, |x| {
201 i32::from_ne_bytes(x.to_ne_bytes())
202 }))
203 }
204 (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
205 AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
206 }
207 (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
208 AnyArray::I64(reinterpret_array(encoded, |x| {
209 i64::from_ne_bytes(x.to_ne_bytes())
210 }))
211 }
212 (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
213 AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
214 }
215 (encoded, dtype) => {
216 return Err(ReinterpretCodecError::InvalidReinterpret {
217 decode_dtype: dtype,
218 encode_dtype: encoded.dtype(),
219 });
220 }
221 };
222
223 Ok(decoded)
224 }
225
226 #[allow(clippy::too_many_lines)]
227 fn decode_into(
228 &self,
229 encoded: AnyArrayView,
230 mut decoded: AnyArrayViewMut,
231 ) -> Result<(), Self::Error> {
232 if encoded.dtype() != self.encode_dtype {
233 return Err(ReinterpretCodecError::MismatchedDecodeDType {
234 configured: self.encode_dtype,
235 provided: encoded.dtype(),
236 });
237 }
238
239 match (encoded, self.decode_dtype) {
240 (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
241 (AnyArrayView::U8(encoded), dtype) => {
242 if decoded.dtype() != dtype {
243 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
244 source: AnyArrayAssignError::DTypeMismatch {
245 src: dtype,
246 dst: decoded.dtype(),
247 },
248 });
249 }
250
251 let mut shape = encoded.shape().to_vec();
252
253 if (encoded.len() % dtype.size()) != 0 {
254 return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
255 }
256
257 if let Some(last) = shape.last_mut() {
258 *last /= dtype.size();
259 }
260
261 if decoded.shape() != shape {
262 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
263 source: AnyArrayAssignError::ShapeMismatch {
264 src: shape,
265 dst: decoded.shape().to_vec(),
266 },
267 });
268 }
269
270 let () = decoded.with_bytes_mut(|bytes| {
271 bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
272 });
273
274 Ok(())
275 }
276 (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
277 reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
278 }
279 (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
280 reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
281 }
282 (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
283 reinterpret_array_into(encoded, f32::from_bits, decoded)
284 }
285 (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
286 reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
287 }
288 (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
289 reinterpret_array_into(encoded, f64::from_bits, decoded)
290 }
291 (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
292 decode_dtype: dtype,
293 encode_dtype: encoded.dtype(),
294 }),
295 }?;
296
297 Ok(())
298 }
299}
300
301impl StaticCodec for ReinterpretCodec {
302 const CODEC_ID: &'static str = "reinterpret";
303
304 type Config<'de> = Self;
305
306 fn from_config(config: Self::Config<'_>) -> Self {
307 config
308 }
309
310 fn get_config(&self) -> StaticCodecConfig<Self> {
311 StaticCodecConfig::from(self)
312 }
313}
314
315impl Serialize for ReinterpretCodec {
316 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
317 ReinterpretCodecConfig {
318 encode_dtype: self.encode_dtype,
319 decode_dtype: self.decode_dtype,
320 }
321 .serialize(serializer)
322 }
323}
324
325impl<'de> Deserialize<'de> for ReinterpretCodec {
326 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
327 let config = ReinterpretCodecConfig::deserialize(deserializer)?;
328
329 Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
330 }
331}
332
333#[derive(Clone, Serialize, Deserialize)]
334#[serde(rename = "ReinterpretCodec")]
335struct ReinterpretCodecConfig {
336 encode_dtype: AnyArrayDType,
337 decode_dtype: AnyArrayDType,
338}
339
340#[derive(Debug, Error)]
341pub enum ReinterpretCodecError {
343 #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
346 InvalidReinterpret {
347 decode_dtype: AnyArrayDType,
349 encode_dtype: AnyArrayDType,
351 },
352 #[error("Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
355 MismatchedEncodeDType {
356 configured: AnyArrayDType,
358 provided: AnyArrayDType,
360 },
361 #[error("Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
364 MismatchedDecodeDType {
365 configured: AnyArrayDType,
367 provided: AnyArrayDType,
369 },
370 #[error(
372 "Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s"
373 )]
374 InvalidEncodedShape {
375 shape: Vec<usize>,
377 dtype: AnyArrayDType,
379 },
380 #[error("Reinterpret cannot decode into the provided array")]
382 MismatchedDecodeIntoArray {
383 #[from]
385 source: AnyArrayAssignError,
386 },
387}
388
389#[inline]
392pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
393 array: ArrayBase<S, D>,
394 reinterpret: impl Fn(T) -> U,
395) -> Array<U, D> {
396 let array = array.into_owned();
397 let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
398
399 let data = data.into_iter().map(reinterpret).collect();
400
401 #[allow(unsafe_code)]
402 let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
404
405 array
406}
407
408#[allow(clippy::needless_pass_by_value)]
409#[inline]
419pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
420 encoded: ArrayView<T, D>,
421 reinterpret: impl Fn(T) -> U,
422 mut decoded: AnyArrayViewMut<'a>,
423) -> Result<(), ReinterpretCodecError>
424where
425 U::RawData<ViewRepr<&'a mut ()>>: DataMut,
426{
427 let Some(decoded) = decoded.as_typed_mut::<U>() else {
428 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
429 source: AnyArrayAssignError::DTypeMismatch {
430 src: U::DTYPE,
431 dst: decoded.dtype(),
432 },
433 });
434 };
435
436 if encoded.shape() != decoded.shape() {
437 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
438 source: AnyArrayAssignError::ShapeMismatch {
439 src: encoded.shape().to_vec(),
440 dst: decoded.shape().to_vec(),
441 },
442 });
443 }
444
445 for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
447 *d = reinterpret(*e);
448 }
449
450 Ok(())
451}