msgpack_numpy/
core.rs

1use half::f16;
2use num_traits::{NumCast, ToPrimitive};
3
4/*********************************************************************************/
5// Scalar
6
7/// De-/serialization target for a NumPy scalar
8#[derive(Debug, Clone, PartialEq)]
9pub enum Scalar {
10    Bool(bool),
11    U8(u8),
12    I8(i8),
13    U16(u16),
14    I16(i16),
15    F16(f16),
16    U32(u32),
17    I32(i32),
18    F32(f32),
19    U64(u64),
20    I64(i64),
21    F64(f64),
22    Unsupported,
23}
24
25impl Scalar {
26    pub fn to_bool(&self) -> Option<bool> {
27        match self {
28            Scalar::Bool(v) => Some(*v),
29            _ => None,
30        }
31    }
32
33    pub fn to_u8(&self) -> Option<u8> {
34        self.to()
35    }
36
37    pub fn to_i8(&self) -> Option<i8> {
38        self.to()
39    }
40
41    pub fn to_u16(&self) -> Option<u16> {
42        self.to()
43    }
44
45    pub fn to_i16(&self) -> Option<i16> {
46        self.to()
47    }
48
49    pub fn to_f16(&self) -> Option<f16> {
50        self.to()
51    }
52
53    pub fn to_u32(&self) -> Option<u32> {
54        self.to()
55    }
56
57    pub fn to_i32(&self) -> Option<i32> {
58        self.to()
59    }
60
61    pub fn to_f32(&self) -> Option<f32> {
62        self.to()
63    }
64
65    pub fn to_u64(&self) -> Option<u64> {
66        self.to()
67    }
68
69    pub fn to_i64(&self) -> Option<i64> {
70        self.to()
71    }
72
73    pub fn to_f64(&self) -> Option<f64> {
74        self.to()
75    }
76
77    fn to<T: NumCast>(&self) -> Option<T> {
78        match self {
79            // bool doesn't implement ToPrimitive, so we need to convert it to u8 first
80            Scalar::Bool(v) => NumCast::from(*v as u8),
81            Scalar::U8(v) => NumCast::from(*v),
82            Scalar::I8(v) => NumCast::from(*v),
83            Scalar::U16(v) => NumCast::from(*v),
84            Scalar::I16(v) => NumCast::from(*v),
85            Scalar::F16(v) => NumCast::from(*v),
86            Scalar::U32(v) => NumCast::from(*v),
87            Scalar::I32(v) => NumCast::from(*v),
88            Scalar::F32(v) => NumCast::from(*v),
89            Scalar::U64(v) => NumCast::from(*v),
90            Scalar::I64(v) => NumCast::from(*v),
91            Scalar::F64(v) => NumCast::from(*v),
92            Scalar::Unsupported => None,
93        }
94    }
95}
96
97/*********************************************************************************/
98// NDArray
99
100use ndarray::{Array, IxDyn};
101
102/// De-/serialization target for a NumPy array that uses owned Array for deserialization
103#[derive(Debug, Clone, PartialEq)]
104pub enum NDArray {
105    Bool(Array<bool, IxDyn>),
106    U8(Array<u8, IxDyn>),
107    I8(Array<i8, IxDyn>),
108    U16(Array<u16, IxDyn>),
109    I16(Array<i16, IxDyn>),
110    F16(Array<f16, IxDyn>),
111    U32(Array<u32, IxDyn>),
112    I32(Array<i32, IxDyn>),
113    F32(Array<f32, IxDyn>),
114    U64(Array<u64, IxDyn>),
115    I64(Array<i64, IxDyn>),
116    F64(Array<f64, IxDyn>),
117    Unsupported,
118}
119
120impl NDArray {
121    pub fn into_bool_array(self) -> Option<Array<bool, IxDyn>> {
122        match self {
123            NDArray::Bool(arr) => Some(arr),
124            _ => None,
125        }
126    }
127
128    pub fn into_u8_array(self) -> Option<Array<u8, IxDyn>> {
129        match self {
130            NDArray::U8(arr) => Some(arr),
131            _ => self.convert_into::<u8>(),
132        }
133    }
134
135    pub fn into_i8_array(self) -> Option<Array<i8, IxDyn>> {
136        match self {
137            NDArray::I8(arr) => Some(arr),
138            _ => self.convert_into::<i8>(),
139        }
140    }
141
142    pub fn into_u16_array(self) -> Option<Array<u16, IxDyn>> {
143        match self {
144            NDArray::U16(arr) => Some(arr),
145            _ => self.convert_into::<u16>(),
146        }
147    }
148
149    pub fn into_i16_array(self) -> Option<Array<i16, IxDyn>> {
150        match self {
151            NDArray::I16(arr) => Some(arr),
152            _ => self.convert_into::<i16>(),
153        }
154    }
155
156    pub fn into_f16_array(self) -> Option<Array<f16, IxDyn>> {
157        match self {
158            NDArray::F16(arr) => Some(arr),
159            _ => self.convert_into::<f16>(),
160        }
161    }
162
163    pub fn into_u32_array(self) -> Option<Array<u32, IxDyn>> {
164        match self {
165            NDArray::U32(arr) => Some(arr),
166            _ => self.convert_into::<u32>(),
167        }
168    }
169
170    pub fn into_i32_array(self) -> Option<Array<i32, IxDyn>> {
171        match self {
172            NDArray::I32(arr) => Some(arr),
173            _ => self.convert_into::<i32>(),
174        }
175    }
176
177    pub fn into_f32_array(self) -> Option<Array<f32, IxDyn>> {
178        match self {
179            NDArray::F32(arr) => Some(arr),
180            _ => self.convert_into::<f32>(),
181        }
182    }
183
184    pub fn into_u64_array(self) -> Option<Array<u64, IxDyn>> {
185        match self {
186            NDArray::U64(arr) => Some(arr),
187            _ => self.convert_into::<u64>(),
188        }
189    }
190
191    pub fn into_i64_array(self) -> Option<Array<i64, IxDyn>> {
192        match self {
193            NDArray::I64(arr) => Some(arr),
194            _ => self.convert_into::<i64>(),
195        }
196    }
197
198    pub fn into_f64_array(self) -> Option<Array<f64, IxDyn>> {
199        match self {
200            NDArray::F64(arr) => Some(arr),
201            _ => self.convert_into::<f64>(),
202        }
203    }
204
205    fn convert_into<T: NumCast + Copy>(self) -> Option<Array<T, IxDyn>> {
206        match self {
207            NDArray::Bool(arr) => Self::convert_bool_array(arr),
208            NDArray::U8(arr) => Self::convert_array(arr),
209            NDArray::I8(arr) => Self::convert_array(arr),
210            NDArray::U16(arr) => Self::convert_array(arr),
211            NDArray::I16(arr) => Self::convert_array(arr),
212            NDArray::F16(arr) => Self::convert_array(arr),
213            NDArray::U32(arr) => Self::convert_array(arr),
214            NDArray::I32(arr) => Self::convert_array(arr),
215            NDArray::F32(arr) => Self::convert_array(arr),
216            NDArray::U64(arr) => Self::convert_array(arr),
217            NDArray::I64(arr) => Self::convert_array(arr),
218            NDArray::F64(arr) => Self::convert_array(arr),
219            NDArray::Unsupported => None,
220        }
221    }
222
223    fn convert_array<S: Copy + ToPrimitive, T: NumCast>(
224        arr: Array<S, IxDyn>,
225    ) -> Option<Array<T, IxDyn>> {
226        let raw_dim = arr.raw_dim();
227        arr.into_iter()
228            .map(|v| NumCast::from(v).ok_or(()))
229            .collect::<Result<Vec<_>, _>>()
230            .ok()
231            .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap())
232    }
233
234    fn convert_bool_array<T: NumCast>(arr: Array<bool, IxDyn>) -> Option<Array<T, IxDyn>> {
235        let raw_dim = arr.raw_dim();
236        arr.into_iter()
237            .map(|v| NumCast::from(v as u8).ok_or(()))
238            .collect::<Result<Vec<_>, _>>()
239            .ok()
240            .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap())
241    }
242}
243
244/*********************************************************************************/
245// CowNDArray
246
247use ndarray::CowArray;
248
249/// De-/serialization target for a NumPy array that uses CowArray for zero-copy deserialization (when array buffer alignment is good)
250#[derive(Debug, Clone, PartialEq)]
251pub enum CowNDArray<'a> {
252    Bool(CowArray<'a, bool, IxDyn>),
253    U8(CowArray<'a, u8, IxDyn>),
254    I8(CowArray<'a, i8, IxDyn>),
255    U16(CowArray<'a, u16, IxDyn>),
256    I16(CowArray<'a, i16, IxDyn>),
257    F16(CowArray<'a, f16, IxDyn>),
258    U32(CowArray<'a, u32, IxDyn>),
259    I32(CowArray<'a, i32, IxDyn>),
260    F32(CowArray<'a, f32, IxDyn>),
261    U64(CowArray<'a, u64, IxDyn>),
262    I64(CowArray<'a, i64, IxDyn>),
263    F64(CowArray<'a, f64, IxDyn>),
264    Unsupported,
265}
266
267impl<'a> CowNDArray<'a> {
268    pub fn into_bool_array(self) -> Option<CowArray<'a, bool, IxDyn>> {
269        match self {
270            CowNDArray::Bool(arr) => Some(arr),
271            _ => None,
272        }
273    }
274
275    pub fn into_u8_array(self) -> Option<CowArray<'a, u8, IxDyn>> {
276        match self {
277            CowNDArray::U8(arr) => Some(arr),
278            _ => self.convert_into::<u8>(),
279        }
280    }
281
282    pub fn into_i8_array(self) -> Option<CowArray<'a, i8, IxDyn>> {
283        match self {
284            CowNDArray::I8(arr) => Some(arr),
285            _ => self.convert_into::<i8>(),
286        }
287    }
288
289    pub fn into_u16_array(self) -> Option<CowArray<'a, u16, IxDyn>> {
290        match self {
291            CowNDArray::U16(arr) => Some(arr),
292            _ => self.convert_into::<u16>(),
293        }
294    }
295
296    pub fn into_i16_array(self) -> Option<CowArray<'a, i16, IxDyn>> {
297        match self {
298            CowNDArray::I16(arr) => Some(arr),
299            _ => self.convert_into::<i16>(),
300        }
301    }
302
303    pub fn into_f16_array(self) -> Option<CowArray<'a, f16, IxDyn>> {
304        match self {
305            CowNDArray::F16(arr) => Some(arr),
306            // round trip through f32 if not already f16
307            _ => self.convert_into::<f16>(),
308        }
309    }
310
311    pub fn into_u32_array(self) -> Option<CowArray<'a, u32, IxDyn>> {
312        match self {
313            CowNDArray::U32(arr) => Some(arr),
314            _ => self.convert_into::<u32>(),
315        }
316    }
317
318    pub fn into_i32_array(self) -> Option<CowArray<'a, i32, IxDyn>> {
319        match self {
320            CowNDArray::I32(arr) => Some(arr),
321            _ => self.convert_into::<i32>(),
322        }
323    }
324
325    pub fn into_f32_array(self) -> Option<CowArray<'a, f32, IxDyn>> {
326        match self {
327            CowNDArray::F32(arr) => Some(arr),
328            _ => self.convert_into::<f32>(),
329        }
330    }
331
332    pub fn into_u64_array(self) -> Option<CowArray<'a, u64, IxDyn>> {
333        match self {
334            CowNDArray::U64(arr) => Some(arr),
335            _ => self.convert_into::<u64>(),
336        }
337    }
338
339    pub fn into_i64_array(self) -> Option<CowArray<'a, i64, IxDyn>> {
340        match self {
341            CowNDArray::I64(arr) => Some(arr),
342            _ => self.convert_into::<i64>(),
343        }
344    }
345
346    pub fn into_f64_array(self) -> Option<CowArray<'a, f64, IxDyn>> {
347        match self {
348            CowNDArray::F64(arr) => Some(arr),
349            _ => self.convert_into::<f64>(),
350        }
351    }
352
353    fn convert_into<T: NumCast + Copy>(self) -> Option<CowArray<'a, T, IxDyn>> {
354        match self {
355            CowNDArray::Bool(arr) => Self::convert_bool_array(arr),
356            CowNDArray::U8(arr) => Self::convert_array(arr),
357            CowNDArray::I8(arr) => Self::convert_array(arr),
358            CowNDArray::U16(arr) => Self::convert_array(arr),
359            CowNDArray::I16(arr) => Self::convert_array(arr),
360            CowNDArray::F16(arr) => Self::convert_array(arr),
361            CowNDArray::U32(arr) => Self::convert_array(arr),
362            CowNDArray::I32(arr) => Self::convert_array(arr),
363            CowNDArray::F32(arr) => Self::convert_array(arr),
364            CowNDArray::U64(arr) => Self::convert_array(arr),
365            CowNDArray::I64(arr) => Self::convert_array(arr),
366            CowNDArray::F64(arr) => Self::convert_array(arr),
367            CowNDArray::Unsupported => None,
368        }
369    }
370
371    fn convert_array<S: Copy + ToPrimitive, T: NumCast>(
372        arr: CowArray<S, IxDyn>,
373    ) -> Option<CowArray<T, IxDyn>> {
374        let raw_dim = arr.raw_dim();
375        arr.into_iter()
376            .map(|v| NumCast::from(v).ok_or(()))
377            .collect::<Result<Vec<_>, _>>()
378            .ok()
379            .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap().into())
380    }
381
382    fn convert_bool_array<T: NumCast>(arr: CowArray<bool, IxDyn>) -> Option<CowArray<T, IxDyn>> {
383        let raw_dim = arr.raw_dim();
384        arr.into_iter()
385            .map(|v| NumCast::from(v as u8).ok_or(()))
386            .collect::<Result<Vec<_>, _>>()
387            .ok()
388            .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap().into())
389    }
390}