msgpack_numpy/
serde.rs

1use crate::core::{CowNDArray, NDArray, Scalar};
2use half::f16;
3use serde::de::{self, Visitor};
4use serde::ser::SerializeMap;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use serde_bytes::{ByteBuf, Bytes};
7use std::borrow::Cow;
8use std::fmt;
9
10// DType
11
12enum DType {
13    String(String),
14    #[allow(dead_code)]
15    Array(Vec<(String, String)>),
16}
17
18impl<'de> Deserialize<'de> for DType {
19    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
20    where
21        D: Deserializer<'de>,
22    {
23        struct DTypeVisitor;
24
25        impl<'de> Visitor<'de> for DTypeVisitor {
26            type Value = DType;
27
28            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
29                formatter.write_str("a string or an array of tuples")
30            }
31
32            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
33            where
34                E: de::Error,
35            {
36                Ok(DType::String(value.to_string()))
37            }
38
39            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
40            where
41                A: de::SeqAccess<'de>,
42            {
43                let mut vec = Vec::new();
44                while let Some((name, dtype)) = seq.next_element()? {
45                    vec.push((name, dtype));
46                }
47                Ok(DType::Array(vec))
48            }
49        }
50
51        deserializer.deserialize_any(DTypeVisitor)
52    }
53}
54
55/***********************************************************************************************/
56// Scalar
57
58// impl Deserialize
59
60impl<'de> Deserialize<'de> for Scalar {
61    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62    where
63        D: Deserializer<'de>,
64    {
65        struct ScalarVisitor;
66
67        impl<'de> Visitor<'de> for ScalarVisitor {
68            type Value = Scalar;
69
70            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
71                formatter.write_str("a numpy scaler in msgpack format")
72            }
73
74            // additional compatibility in case msgpack-python short-circuits during serialization
75            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
76            where
77                E: de::Error,
78            {
79                Ok(Scalar::Bool(v))
80            }
81
82            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
83            where
84                E: de::Error,
85            {
86                Ok(Scalar::I64(v))
87            }
88
89            // msgpack-python indeed short-circuits this during serialization
90            fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
91            where
92                E: de::Error,
93            {
94                Ok(Scalar::F64(v))
95            }
96
97            // for NumPy's 'U' type
98            fn visit_str<E>(self, _v: &str) -> Result<Self::Value, E>
99            where
100                E: de::Error,
101            {
102                Ok(Scalar::Unsupported)
103            }
104
105            // for NumPy's 'S' type
106            fn visit_bytes<E>(self, _v: &[u8]) -> Result<Self::Value, E>
107            where
108                E: de::Error,
109            {
110                Ok(Scalar::Unsupported)
111            }
112
113            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
114            where
115                A: de::MapAccess<'de>,
116            {
117                let mut nd: Option<bool> = None;
118                let mut numpy_dtype: Option<DType> = None;
119                let mut data: Option<ByteBuf> = None;
120
121                while let Some(key) = map.next_key()? {
122                    match key {
123                        "nd" => nd = Some(map.next_value()?),
124                        "type" => numpy_dtype = Some(map.next_value()?),
125                        "data" => data = Some(map.next_value()?),
126                        _ => return Err(de::Error::unknown_field(key, &["nd", "type", "data"])),
127                    }
128                }
129
130                let nd = nd.ok_or_else(|| de::Error::missing_field("nd"))?;
131                let numpy_dtype = numpy_dtype.ok_or_else(|| de::Error::missing_field("type"))?;
132                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
133
134                if nd {
135                    return Err(de::Error::custom("nd should be false for numpy scalars"));
136                }
137
138                // we only support primitive numeric types for now
139
140                match numpy_dtype {
141                    DType::String(dtype) => {
142                        match dtype.as_str() {
143                            // convert through u8 to conform to NumPy's serialization behavior of booleans
144                            "|b1" => TryInto::<[u8; 1]>::try_into(data.into_vec())
145                                .map(|bytes| Scalar::Bool(bytes[0] != 0))
146                                .map_err(|_| de::Error::custom("Invalid data for bool")),
147                            "|u1" => TryInto::<[u8; 1]>::try_into(data.into_vec())
148                                .map(|bytes| Scalar::U8(bytes[0]))
149                                .map_err(|_| de::Error::custom("Invalid data for u8")),
150                            "|i1" => data
151                                .into_vec()
152                                .try_into()
153                                .map(|bytes| Scalar::I8(i8::from_le_bytes(bytes)))
154                                .map_err(|_| de::Error::custom("Invalid data for i8")),
155                            "<u2" => data
156                                .into_vec()
157                                .try_into()
158                                .map(|bytes| Scalar::U16(u16::from_le_bytes(bytes)))
159                                .map_err(|_| de::Error::custom("Invalid data for u16")),
160                            "<i2" => data
161                                .into_vec()
162                                .try_into()
163                                .map(|bytes| Scalar::I16(i16::from_le_bytes(bytes)))
164                                .map_err(|_| de::Error::custom("Invalid data for i16")),
165                            "<f2" => data
166                                .into_vec()
167                                .try_into()
168                                .map(|bytes| Scalar::F16(f16::from_le_bytes(bytes)))
169                                .map_err(|_| de::Error::custom("Invalid data for f16")),
170                            "<u4" => data
171                                .into_vec()
172                                .try_into()
173                                .map(|bytes| Scalar::U32(u32::from_le_bytes(bytes)))
174                                .map_err(|_| de::Error::custom("Invalid data for u32")),
175                            "<i4" => data
176                                .into_vec()
177                                .try_into()
178                                .map(|bytes| Scalar::I32(i32::from_le_bytes(bytes)))
179                                .map_err(|_| de::Error::custom("Invalid data for i32")),
180                            "<f4" => data
181                                .into_vec()
182                                .try_into()
183                                .map(|bytes| Scalar::F32(f32::from_le_bytes(bytes)))
184                                .map_err(|_| de::Error::custom("Invalid data for f32")),
185                            "<u8" => data
186                                .into_vec()
187                                .try_into()
188                                .map(|bytes| Scalar::U64(u64::from_le_bytes(bytes)))
189                                .map_err(|_| de::Error::custom("Invalid data for u64")),
190                            "<i8" => data
191                                .into_vec()
192                                .try_into()
193                                .map(|bytes| Scalar::I64(i64::from_le_bytes(bytes)))
194                                .map_err(|_| de::Error::custom("Invalid data for i64")),
195                            "<f8" => data
196                                .into_vec()
197                                .try_into()
198                                .map(|bytes| Scalar::F64(f64::from_le_bytes(bytes)))
199                                .map_err(|_| de::Error::custom("Invalid data for f64")),
200                            _ => Ok(Scalar::Unsupported),
201                        }
202                    }
203                    DType::Array(_) => Ok(Scalar::Unsupported),
204                }
205            }
206        }
207
208        deserializer.deserialize_map(ScalarVisitor)
209    }
210}
211
212// impl Serialize
213
214impl Serialize for Scalar {
215    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
216    where
217        S: Serializer,
218    {
219        let mut state = serializer.serialize_map(Some(3))?;
220
221        state.serialize_entry(Bytes::new(b"nd"), &false)?;
222
223        match self {
224            // convert through u8 to conform to NumPy's serialization behavior of booleans
225            Scalar::Bool(val) => serialize_value(&mut state, "|b1", &[*val as u8]),
226            Scalar::U8(val) => serialize_value(&mut state, "|u1", &[*val]),
227            Scalar::I8(val) => serialize_value(&mut state, "|i1", &val.to_le_bytes()),
228            Scalar::U16(val) => serialize_value(&mut state, "<u2", &val.to_le_bytes()),
229            Scalar::I16(val) => serialize_value(&mut state, "<i2", &val.to_le_bytes()),
230            Scalar::F16(val) => serialize_value(&mut state, "<f2", &val.to_le_bytes()),
231            Scalar::U32(val) => serialize_value(&mut state, "<u4", &val.to_le_bytes()),
232            Scalar::I32(val) => serialize_value(&mut state, "<i4", &val.to_le_bytes()),
233            Scalar::F32(val) => serialize_value(&mut state, "<f4", &val.to_le_bytes()),
234            Scalar::U64(val) => serialize_value(&mut state, "<u8", &val.to_le_bytes()),
235            Scalar::I64(val) => serialize_value(&mut state, "<i8", &val.to_le_bytes()),
236            Scalar::F64(val) => serialize_value(&mut state, "<f8", &val.to_le_bytes()),
237            Scalar::Unsupported => {
238                return Err(serde::ser::Error::custom("Unsupported numpy dtype"));
239            }
240        }?;
241
242        state.end()
243    }
244}
245
246fn serialize_value<S>(state: &mut S, type_str: &str, val: &[u8]) -> Result<(), S::Error>
247where
248    S: SerializeMap,
249{
250    state.serialize_entry(Bytes::new(b"type"), type_str)?;
251    state.serialize_entry(Bytes::new(b"data"), Bytes::new(val))
252}
253
254/***********************************************************************************************/
255// NDArray
256
257use ndarray::{Array, ArrayBase, IxDyn};
258use std::mem;
259
260#[derive(thiserror::Error, Debug)]
261enum NDArrayError {
262    #[error("InvalidDataLength: {0}")]
263    InvalidDataLength(String),
264
265    #[error("ArrayShapeError: {0}")]
266    ArrayShapeError(ndarray::ShapeError),
267}
268
269// impl Deserialize
270
271impl<'de> Deserialize<'de> for NDArray {
272    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273    where
274        D: Deserializer<'de>,
275    {
276        struct NDArrayVisitor;
277
278        impl<'de> Visitor<'de> for NDArrayVisitor {
279            type Value = NDArray;
280
281            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
282                formatter.write_str("a numpy array in msgpack format")
283            }
284
285            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
286            where
287                A: de::MapAccess<'de>,
288            {
289                let mut nd: Option<bool> = None;
290                let mut numpy_dtype: Option<DType> = None;
291                let mut kind: Option<ByteBuf> = None;
292                let mut shape: Option<Vec<usize>> = None;
293                let mut data: Option<ByteBuf> = None;
294
295                while let Some(key) = map.next_key()? {
296                    match key {
297                        "nd" => nd = Some(map.next_value()?),
298                        "type" => numpy_dtype = Some(map.next_value()?),
299                        "kind" => kind = Some(map.next_value()?),
300                        "shape" => shape = Some(map.next_value()?),
301                        "data" => data = Some(map.next_value()?),
302                        _ => {
303                            return Err(de::Error::unknown_field(
304                                key,
305                                &["nd", "type", "kind", "shape", "data"],
306                            ))
307                        }
308                    }
309                }
310
311                let nd = nd.ok_or_else(|| de::Error::missing_field("nd"))?;
312                let numpy_dtype = numpy_dtype.ok_or_else(|| de::Error::missing_field("type"))?;
313                let _kind = kind.ok_or_else(|| de::Error::missing_field("kind"))?;
314                let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
315                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
316
317                if !nd {
318                    return Err(de::Error::custom("nd should be true for numpy arrays"));
319                }
320
321                let shape = IxDyn(&shape);
322
323                // we only support primitive numeric types for now
324
325                match numpy_dtype {
326                    DType::String(dtype) => {
327                        match dtype.as_str() {
328                            // convert through u8 to conform to NumPy's serialization behavior of booleans
329                            "|b1" => Array::from_shape_vec(
330                                shape,
331                                data.into_iter().map(|v| v != 0).collect(),
332                            )
333                            .map(NDArray::Bool)
334                            .map_err(de::Error::custom),
335                            "|u1" => Array::from_shape_vec(shape, data.into_vec())
336                                .map(NDArray::U8)
337                                .map_err(de::Error::custom),
338                            "|i1" => create_ndarray_from_transmution::<i8>(data.into_vec(), shape)
339                                .map(NDArray::I8)
340                                .map_err(de::Error::custom),
341                            "<u2" => create_ndarray_from_transmution::<u16>(data.into_vec(), shape)
342                                .map(NDArray::U16)
343                                .map_err(de::Error::custom),
344                            "<i2" => create_ndarray_from_transmution::<i16>(data.into_vec(), shape)
345                                .map(NDArray::I16)
346                                .map_err(de::Error::custom),
347                            "<f2" => create_ndarray_from_transmution::<f16>(data.into_vec(), shape)
348                                .map(NDArray::F16)
349                                .map_err(de::Error::custom),
350                            "<u4" => create_ndarray_from_transmution::<u32>(data.into_vec(), shape)
351                                .map(NDArray::U32)
352                                .map_err(de::Error::custom),
353                            "<i4" => create_ndarray_from_transmution::<i32>(data.into_vec(), shape)
354                                .map(NDArray::I32)
355                                .map_err(de::Error::custom),
356                            "<f4" => create_ndarray_from_transmution::<f32>(data.into_vec(), shape)
357                                .map(NDArray::F32)
358                                .map_err(de::Error::custom),
359                            "<u8" => create_ndarray_from_transmution::<u64>(data.into_vec(), shape)
360                                .map(NDArray::U64)
361                                .map_err(de::Error::custom),
362                            "<i8" => create_ndarray_from_transmution::<i64>(data.into_vec(), shape)
363                                .map(NDArray::I64)
364                                .map_err(de::Error::custom),
365                            "<f8" => create_ndarray_from_transmution::<f64>(data.into_vec(), shape)
366                                .map(NDArray::F64)
367                                .map_err(de::Error::custom),
368                            _ => Ok(NDArray::Unsupported),
369                        }
370                    }
371                    DType::Array(_) => Ok(NDArray::Unsupported),
372                }
373            }
374        }
375
376        deserializer.deserialize_map(NDArrayVisitor)
377    }
378}
379
380/// Creates an n-dimensional array from raw byte data by transmuting it to the specified type.
381///
382/// # Type Parameters
383///
384/// * `T`: The target numeric type for transmutation (e.g., f32, i64).
385///
386/// # Arguments
387///
388/// * `data`: Raw bytes to be transmuted and reshaped.
389/// * `shape`: The desired shape of the output array.
390///
391/// # Returns
392///
393/// An n-dimensional array of type `T` with the specified shape, or an error.
394///
395/// # Errors
396///
397/// Returns an error if:
398/// * Transmutation fails (e.g., data length isn't a multiple of `size_of::<T>()`).
399/// * Specified shape doesn't match the transmuted data length.
400///
401/// # Safety
402///
403/// Caller must ensure:
404/// * Input data represents valid values of type `T`.
405/// * Data length is a multiple of `size_of::<T>()`.
406/// * Memory layout of `T` is compatible with the original data.
407fn create_ndarray_from_transmution<T>(
408    data: Vec<u8>,
409    shape: IxDyn,
410) -> Result<Array<T, IxDyn>, NDArrayError> {
411    let transmuted = unsafe { transmute_vec(data) }.ok_or_else(|| {
412        NDArrayError::InvalidDataLength(format!(
413            "Invalid data length for {} transmutation",
414            std::any::type_name::<T>()
415        ))
416    })?;
417
418    Array::from_shape_vec(shape, transmuted).map_err(|e| NDArrayError::ArrayShapeError(e))
419}
420
421/// Transmutes a `Vec<u8>` into a `Vec<T>`.
422///
423/// We could use vec_into_raw_parts when that stabilizes.
424/// e.g. let (ptr, len, cap) = data.into_raw_parts();
425///
426/// # Safety
427///
428/// This function is unsafe because it assumes that:
429/// - The input `data` is correctly formatted and aligned for type `T`.
430///
431/// # Type Parameters
432///
433/// * `T`: The target numeric type for transmutation (e.g., f32, i64).
434///
435/// # Arguments
436///
437/// * `data` - A `Vec<u8>` containing the raw byte data.
438///
439/// # Returns
440///
441/// Returns `Some(Vec<T>)` containing the transmuted data if successful,
442/// or `None` if the input length is not a multiple of `size_of::<T>()`.
443unsafe fn transmute_vec<T>(mut data: Vec<u8>) -> Option<Vec<T>> {
444    let size_of_t = mem::size_of::<T>();
445    if data.len() % size_of_t != 0 {
446        return None;
447    }
448
449    let ptr = data.as_mut_ptr() as *mut T;
450    let len = data.len() / size_of_t;
451    let capacity = data.capacity() / size_of_t;
452
453    // Ensure we don't drop the original vector's memory
454    mem::forget(data);
455
456    Some(Vec::from_raw_parts(ptr, len, capacity))
457}
458
459// impl Serialize
460
461impl Serialize for NDArray {
462    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
463    where
464        S: Serializer,
465    {
466        let mut state = serializer.serialize_map(Some(5))?;
467
468        state.serialize_entry(Bytes::new(b"nd"), &true)?;
469
470        match self {
471            // convert through u8 to conform to NumPy's serialization behavior of booleans
472            NDArray::Bool(arr) => serialize_ndarray(&mut state, "|b1", &arr.mapv(|v| v as u8)),
473            NDArray::U8(arr) => serialize_ndarray(&mut state, "|u1", arr),
474            NDArray::I8(arr) => serialize_ndarray(&mut state, "|i1", arr),
475            NDArray::U16(arr) => serialize_ndarray(&mut state, "<u2", arr),
476            NDArray::I16(arr) => serialize_ndarray(&mut state, "<i2", arr),
477            NDArray::F16(arr) => serialize_ndarray(&mut state, "<f2", arr),
478            NDArray::U32(arr) => serialize_ndarray(&mut state, "<u4", arr),
479            NDArray::I32(arr) => serialize_ndarray(&mut state, "<i4", arr),
480            NDArray::F32(arr) => serialize_ndarray(&mut state, "<f4", arr),
481            NDArray::U64(arr) => serialize_ndarray(&mut state, "<u8", arr),
482            NDArray::I64(arr) => serialize_ndarray(&mut state, "<i8", arr),
483            NDArray::F64(arr) => serialize_ndarray(&mut state, "<f8", arr),
484            NDArray::Unsupported => {
485                return Err(serde::ser::Error::custom("Unsupported numpy dtype"));
486            }
487        }?;
488
489        state.end()
490    }
491}
492
493fn serialize_ndarray<S, A, T>(
494    state: &mut S,
495    type_str: &str,
496    arr: &ArrayBase<A, IxDyn>,
497) -> Result<(), S::Error>
498where
499    S: SerializeMap,
500    A: ndarray::RawData<Elem = T>,
501{
502    state.serialize_entry(Bytes::new(b"type"), type_str)?;
503    state.serialize_entry(Bytes::new(b"kind"), Bytes::new(b""))?;
504    state.serialize_entry(Bytes::new(b"shape"), &arr.shape())?;
505
506    let data = unsafe { transmute_array_to_slice(arr) };
507    state.serialize_entry(Bytes::new(b"data"), Bytes::new(data))
508}
509
510/// Converts an n-dimensional array to a byte slice without copying.
511///
512/// # Safety
513///
514/// This function is unsafe because:
515/// - It assumes the memory layout of the array buffer is contiguous with no padding.
516///
517/// # Type Parameters
518///
519/// * `A`: ndarray::RawData, having an associate type `T`.
520/// * `T`: The target numeric type for transmutation (e.g., f32, i64).
521///
522/// # Arguments
523///
524/// * `arr` - A reference to an n-dimensional array of type `T`.
525///
526/// # Returns
527///
528/// A byte slice (`&[u8]`) representing the raw memory of the input array.
529unsafe fn transmute_array_to_slice<A: ndarray::RawData<Elem = T>, T>(
530    arr: &ArrayBase<A, IxDyn>,
531) -> &[u8] {
532    let ptr = arr.as_ptr() as *const u8;
533    let len = arr.len() * mem::size_of::<T>();
534    std::slice::from_raw_parts(ptr, len)
535}
536
537/***********************************************************************************************/
538// CowNDArray
539
540use ndarray::{ArrayView, CowArray};
541
542// impl Deserialize
543
544impl<'de: 'a, 'a> Deserialize<'de> for CowNDArray<'a> {
545    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
546    where
547        D: Deserializer<'de>,
548    {
549        struct NDArrayVisitor<'a>(std::marker::PhantomData<&'a ()>);
550
551        impl<'de: 'a, 'a> Visitor<'de> for NDArrayVisitor<'a> {
552            type Value = CowNDArray<'a>;
553
554            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
555                formatter.write_str("a numpy array in msgpack format")
556            }
557
558            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
559            where
560                A: de::MapAccess<'de>,
561            {
562                let mut nd: Option<bool> = None;
563                let mut numpy_dtype: Option<DType> = None;
564                let mut kind: Option<&'a Bytes> = None;
565                let mut shape: Option<Vec<usize>> = None;
566                let mut data: Option<&'a Bytes> = None;
567
568                while let Some(key) = map.next_key()? {
569                    match key {
570                        "nd" => nd = Some(map.next_value()?),
571                        "type" => numpy_dtype = Some(map.next_value()?),
572                        "kind" => kind = Some(map.next_value()?),
573                        "shape" => shape = Some(map.next_value()?),
574                        "data" => data = Some(map.next_value()?),
575                        _ => {
576                            return Err(de::Error::unknown_field(
577                                key,
578                                &["nd", "type", "kind", "shape", "data"],
579                            ))
580                        }
581                    }
582                }
583
584                let nd = nd.ok_or_else(|| de::Error::missing_field("nd"))?;
585                let numpy_dtype = numpy_dtype.ok_or_else(|| de::Error::missing_field("type"))?;
586                let _kind = kind.ok_or_else(|| de::Error::missing_field("kind"))?;
587                let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
588                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
589
590                if !nd {
591                    return Err(de::Error::custom("nd should be true for numpy arrays"));
592                }
593
594                let shape = IxDyn(&shape);
595
596                // we only support primitive numeric types for now
597
598                match numpy_dtype {
599                    DType::String(dtype) => {
600                        match dtype.as_str() {
601                            // convert through u8 to conform to NumPy's serialization behavior of booleans
602                            "|b1" => Array::from_shape_vec(
603                                shape,
604                                data.into_iter().map(|v| *v != 0).collect(),
605                            )
606                            .map(CowArray::from)
607                            .map(CowNDArray::Bool)
608                            .map_err(de::Error::custom),
609                            "|u1" => ArrayView::from_shape(shape, data)
610                                .map(CowArray::from)
611                                .map(CowNDArray::U8)
612                                .map_err(de::Error::custom),
613                            "|i1" => create_cowndarray_from_transmution::<i8>(data, shape)
614                                .map(CowNDArray::I8)
615                                .map_err(de::Error::custom),
616                            "<u2" => create_cowndarray_from_transmution::<u16>(data, shape)
617                                .map(CowNDArray::U16)
618                                .map_err(de::Error::custom),
619                            "<i2" => create_cowndarray_from_transmution::<i16>(data, shape)
620                                .map(CowNDArray::I16)
621                                .map_err(de::Error::custom),
622                            "<f2" => create_cowndarray_from_transmution::<f16>(data, shape)
623                                .map(CowNDArray::F16)
624                                .map_err(de::Error::custom),
625                            "<u4" => create_cowndarray_from_transmution::<u32>(data, shape)
626                                .map(CowNDArray::U32)
627                                .map_err(de::Error::custom),
628                            "<i4" => create_cowndarray_from_transmution::<i32>(data, shape)
629                                .map(CowNDArray::I32)
630                                .map_err(de::Error::custom),
631                            "<f4" => create_cowndarray_from_transmution::<f32>(data, shape)
632                                .map(CowNDArray::F32)
633                                .map_err(de::Error::custom),
634                            "<u8" => create_cowndarray_from_transmution::<u64>(data, shape)
635                                .map(CowNDArray::U64)
636                                .map_err(de::Error::custom),
637                            "<i8" => create_cowndarray_from_transmution::<i64>(data, shape)
638                                .map(CowNDArray::I64)
639                                .map_err(de::Error::custom),
640                            "<f8" => create_cowndarray_from_transmution::<f64>(data, shape)
641                                .map(CowNDArray::F64)
642                                .map_err(de::Error::custom),
643                            _ => Ok(CowNDArray::Unsupported),
644                        }
645                    }
646                    DType::Array(_) => Ok(CowNDArray::Unsupported),
647                }
648            }
649        }
650
651        deserializer.deserialize_map(NDArrayVisitor(std::marker::PhantomData))
652    }
653}
654
655fn create_cowndarray_from_transmution<T: Clone>(
656    data: &[u8],
657    shape: IxDyn,
658) -> Result<CowArray<T, IxDyn>, NDArrayError> {
659    let transmuted = unsafe { transmute_slice(data) }.ok_or_else(|| {
660        NDArrayError::InvalidDataLength(format!(
661            "Invalid data length for {} transmutation",
662            std::any::type_name::<T>()
663        ))
664    })?;
665
666    match transmuted {
667        Cow::Borrowed(slice) => ArrayView::from_shape(shape, slice).map(CowArray::from),
668        Cow::Owned(vec) => Array::from_shape_vec(shape, vec).map(CowArray::from),
669    }
670    .map_err(|e| NDArrayError::ArrayShapeError(e))
671}
672
673unsafe fn transmute_slice<T: Clone>(data: &[u8]) -> Option<Cow<[T]>> {
674    let size_of_t = mem::size_of::<T>();
675    // Ensure the data length is a multiple of T's size
676    if data.len() % size_of_t != 0 {
677        return None;
678    }
679
680    // Calculate the misalignment
681    let misalignment = (data.as_ptr() as usize) % mem::align_of::<T>();
682
683    if misalignment == 0 {
684        // println!("Data is already aligned");
685        // The data is already aligned, we can transmute directly
686        let ptr = data.as_ptr() as *const T;
687        let len = data.len() / size_of_t;
688        Some(Cow::Borrowed(std::slice::from_raw_parts(ptr, len)))
689    } else {
690        // println!("Data is not aligned");
691        // The data is not aligned, we need to copy it to an aligned buffer
692        let mut aligned_vec: Vec<T> = Vec::with_capacity(data.len() / size_of_t);
693        std::ptr::copy_nonoverlapping(
694            data.as_ptr(),
695            aligned_vec.as_mut_ptr() as *mut u8,
696            data.len(),
697        );
698        aligned_vec.set_len(data.len() / size_of_t);
699        Some(Cow::Owned(aligned_vec))
700    }
701}
702
703// impl Serialize
704
705impl<'a> Serialize for CowNDArray<'a> {
706    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
707    where
708        S: Serializer,
709    {
710        let mut state = serializer.serialize_map(Some(5))?;
711
712        state.serialize_entry(Bytes::new(b"nd"), &true)?;
713
714        match self {
715            // convert through u8 to conform to NumPy's serialization behavior of booleans
716            CowNDArray::Bool(arr) => serialize_ndarray(&mut state, "|b1", &arr.mapv(|v| v as u8)),
717            CowNDArray::U8(arr) => serialize_ndarray(&mut state, "|u1", arr),
718            CowNDArray::I8(arr) => serialize_ndarray(&mut state, "|i1", arr),
719            CowNDArray::U16(arr) => serialize_ndarray(&mut state, "<u2", arr),
720            CowNDArray::I16(arr) => serialize_ndarray(&mut state, "<i2", arr),
721            CowNDArray::F16(arr) => serialize_ndarray(&mut state, "<f2", arr),
722            CowNDArray::U32(arr) => serialize_ndarray(&mut state, "<u4", arr),
723            CowNDArray::I32(arr) => serialize_ndarray(&mut state, "<i4", arr),
724            CowNDArray::F32(arr) => serialize_ndarray(&mut state, "<f4", arr),
725            CowNDArray::U64(arr) => serialize_ndarray(&mut state, "<u8", arr),
726            CowNDArray::I64(arr) => serialize_ndarray(&mut state, "<i8", arr),
727            CowNDArray::F64(arr) => serialize_ndarray(&mut state, "<f8", arr),
728            CowNDArray::Unsupported => {
729                return Err(serde::ser::Error::custom("Unsupported numpy dtype"));
730            }
731        }?;
732
733        state.end()
734    }
735}
736
737/*********************************************************************************/
738// tests
739
740#[cfg(test)]
741mod tests {
742    // use super::*;
743    use crate::core::{CowNDArray, NDArray, Scalar};
744    use half::f16;
745    use ndarray::Array;
746
747    #[test]
748    fn test_scalar_serialization() {
749        let cases = vec![
750            Scalar::Bool(true),
751            Scalar::U8(255),
752            Scalar::I8(-128),
753            Scalar::U16(65535),
754            Scalar::I16(-32768),
755            Scalar::F16(f16::from_f32(1.0)),
756            Scalar::U32(4294967295),
757            Scalar::I32(-2147483648),
758            Scalar::F32(1.0),
759            Scalar::U64(18446744073709551615),
760            Scalar::I64(-9223372036854775808),
761            Scalar::F64(1.0),
762        ];
763
764        for scalar in cases {
765            let serialized = rmp_serde::to_vec_named(&scalar).unwrap();
766            let deserialized: Scalar = rmp_serde::from_slice(&serialized).unwrap();
767            assert_eq!(deserialized, scalar);
768        }
769    }
770
771    #[test]
772    #[rustfmt::skip]
773    fn test_ndarray_serialization() {
774        let cases = vec![
775            NDArray::Bool(Array::from_vec(vec![true, false]).into_dyn().into()),
776            NDArray::U8(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
777            NDArray::I8(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
778            NDArray::U16(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
779            NDArray::I16(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
780            NDArray::F16(Array::from_vec(vec![1.0, 2.0]).into_dyn().mapv(f16::from_f32).into()),
781            NDArray::U32(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
782            NDArray::I32(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
783            NDArray::F32(Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn().into()),
784            NDArray::U64(Array::from_vec(vec![1, 2]).into_dyn().into()),
785            NDArray::I64(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
786            NDArray::F64(Array::from_vec(vec![1.0, 2.0]).into_dyn().into()),
787        ];
788
789        for ndarray in cases {
790            let serialized = rmp_serde::to_vec_named(&ndarray).unwrap();
791            let deserialized: NDArray = rmp_serde::from_slice(&serialized).unwrap();
792
793            assert_eq!(deserialized, ndarray);
794        }
795    }
796
797    #[test]
798    #[rustfmt::skip]
799    fn test_cowndarray_serialization() {
800        fn assert_float_eq<T>(a: T, b: T)
801        where
802            T: num_traits::Float + std::fmt::Debug,
803        {
804            if a.is_nan() && b.is_nan() {
805                return; // Both are NaN, consider them equal
806            }
807            if a.is_infinite() && b.is_infinite() {
808                assert_eq!(
809                    a.signum(),
810                    b.signum(),
811                    "Infinite values have different signs"
812                );
813                return;
814            }
815            assert_eq!(a, b);
816        }
817        let cases = vec![
818            CowNDArray::Bool(Array::from_vec(vec![true, false]).into_dyn().into()),
819            CowNDArray::U8(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
820            CowNDArray::I8(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
821            CowNDArray::U16(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
822            CowNDArray::I16(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
823            CowNDArray::F16(Array::from_vec(vec![1.0, 2.0]).into_dyn().mapv(f16::from_f32).into()),
824            CowNDArray::U32(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
825            CowNDArray::I32(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
826            CowNDArray::F32(Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn().into()),
827            CowNDArray::U64(Array::from_vec(vec![1, 2]).into_dyn().into()),
828            CowNDArray::I64(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
829            CowNDArray::F64(Array::from_vec(vec![1.0, 2.0]).into_dyn().into()),
830        ];
831
832        for ndarray in cases {
833            let serialized = rmp_serde::to_vec_named(&ndarray).unwrap();
834            let deserialized: CowNDArray = rmp_serde::from_slice(&serialized).unwrap();
835
836            match (deserialized, ndarray) {
837                (CowNDArray::Bool(a), CowNDArray::Bool(b)) => assert_eq!(a, b),
838                (CowNDArray::U8(a), CowNDArray::U8(b)) => assert_eq!(a, b),
839                (CowNDArray::U16(a), CowNDArray::U16(b)) => assert_eq!(a, b),
840                (CowNDArray::U32(a), CowNDArray::U32(b)) => assert_eq!(a, b),
841                (CowNDArray::U64(a), CowNDArray::U64(b)) => assert_eq!(a, b),
842                (CowNDArray::I8(a), CowNDArray::I8(b)) => assert_eq!(a, b),
843                (CowNDArray::I16(a), CowNDArray::I16(b)) => assert_eq!(a, b),
844                (CowNDArray::I32(a), CowNDArray::I32(b)) => assert_eq!(a, b),
845                (CowNDArray::I64(a), CowNDArray::I64(b)) => assert_eq!(a, b),
846                (CowNDArray::F16(a), CowNDArray::F16(b)) => {
847                    assert_eq!(a.shape(), b.shape());
848                    a.iter().zip(b.iter()).for_each(|(x, y)| {
849                        assert_float_eq(x.to_f32(), y.to_f32());
850                    });
851                }
852                (CowNDArray::F32(a), CowNDArray::F32(b)) => {
853                    assert_eq!(a.shape(), b.shape());
854                    a.iter().zip(b.iter()).for_each(|(x, y)| {
855                        assert_float_eq(*x, *y);
856                    });
857                }
858                (CowNDArray::F64(a), CowNDArray::F64(b)) => {
859                    assert_eq!(a.shape(), b.shape());
860                    a.iter().zip(b.iter()).for_each(|(x, y)| {
861                        assert_float_eq(*x, *y);
862                    });
863                }
864                (CowNDArray::Unsupported, CowNDArray::Unsupported) => (),
865                _ => panic!("Mismatched types"),
866            }
867        }
868    }
869}