1use crate::core::{CowNDArray, NDArray, Scalar};
2use half::f16;
3use serde::de::{self, Visitor};
4use serde::ser::SerializeMap;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use serde_bytes::{ByteBuf, Bytes};
7use std::borrow::Cow;
8use std::fmt;
9
10enum DType {
13 String(String),
14 #[allow(dead_code)]
15 Array(Vec<(String, String)>),
16}
17
18impl<'de> Deserialize<'de> for DType {
19 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
20 where
21 D: Deserializer<'de>,
22 {
23 struct DTypeVisitor;
24
25 impl<'de> Visitor<'de> for DTypeVisitor {
26 type Value = DType;
27
28 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
29 formatter.write_str("a string or an array of tuples")
30 }
31
32 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
33 where
34 E: de::Error,
35 {
36 Ok(DType::String(value.to_string()))
37 }
38
39 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
40 where
41 A: de::SeqAccess<'de>,
42 {
43 let mut vec = Vec::new();
44 while let Some((name, dtype)) = seq.next_element()? {
45 vec.push((name, dtype));
46 }
47 Ok(DType::Array(vec))
48 }
49 }
50
51 deserializer.deserialize_any(DTypeVisitor)
52 }
53}
54
55impl<'de> Deserialize<'de> for Scalar {
61 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62 where
63 D: Deserializer<'de>,
64 {
65 struct ScalarVisitor;
66
67 impl<'de> Visitor<'de> for ScalarVisitor {
68 type Value = Scalar;
69
70 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
71 formatter.write_str("a numpy scaler in msgpack format")
72 }
73
74 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
76 where
77 E: de::Error,
78 {
79 Ok(Scalar::Bool(v))
80 }
81
82 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
83 where
84 E: de::Error,
85 {
86 Ok(Scalar::I64(v))
87 }
88
89 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
91 where
92 E: de::Error,
93 {
94 Ok(Scalar::F64(v))
95 }
96
97 fn visit_str<E>(self, _v: &str) -> Result<Self::Value, E>
99 where
100 E: de::Error,
101 {
102 Ok(Scalar::Unsupported)
103 }
104
105 fn visit_bytes<E>(self, _v: &[u8]) -> Result<Self::Value, E>
107 where
108 E: de::Error,
109 {
110 Ok(Scalar::Unsupported)
111 }
112
113 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
114 where
115 A: de::MapAccess<'de>,
116 {
117 let mut nd: Option<bool> = None;
118 let mut numpy_dtype: Option<DType> = None;
119 let mut data: Option<ByteBuf> = None;
120
121 while let Some(key) = map.next_key()? {
122 match key {
123 "nd" => nd = Some(map.next_value()?),
124 "type" => numpy_dtype = Some(map.next_value()?),
125 "data" => data = Some(map.next_value()?),
126 _ => return Err(de::Error::unknown_field(key, &["nd", "type", "data"])),
127 }
128 }
129
130 let nd = nd.ok_or_else(|| de::Error::missing_field("nd"))?;
131 let numpy_dtype = numpy_dtype.ok_or_else(|| de::Error::missing_field("type"))?;
132 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
133
134 if nd {
135 return Err(de::Error::custom("nd should be false for numpy scalars"));
136 }
137
138 match numpy_dtype {
141 DType::String(dtype) => {
142 match dtype.as_str() {
143 "|b1" => TryInto::<[u8; 1]>::try_into(data.into_vec())
145 .map(|bytes| Scalar::Bool(bytes[0] != 0))
146 .map_err(|_| de::Error::custom("Invalid data for bool")),
147 "|u1" => TryInto::<[u8; 1]>::try_into(data.into_vec())
148 .map(|bytes| Scalar::U8(bytes[0]))
149 .map_err(|_| de::Error::custom("Invalid data for u8")),
150 "|i1" => data
151 .into_vec()
152 .try_into()
153 .map(|bytes| Scalar::I8(i8::from_le_bytes(bytes)))
154 .map_err(|_| de::Error::custom("Invalid data for i8")),
155 "<u2" => data
156 .into_vec()
157 .try_into()
158 .map(|bytes| Scalar::U16(u16::from_le_bytes(bytes)))
159 .map_err(|_| de::Error::custom("Invalid data for u16")),
160 "<i2" => data
161 .into_vec()
162 .try_into()
163 .map(|bytes| Scalar::I16(i16::from_le_bytes(bytes)))
164 .map_err(|_| de::Error::custom("Invalid data for i16")),
165 "<f2" => data
166 .into_vec()
167 .try_into()
168 .map(|bytes| Scalar::F16(f16::from_le_bytes(bytes)))
169 .map_err(|_| de::Error::custom("Invalid data for f16")),
170 "<u4" => data
171 .into_vec()
172 .try_into()
173 .map(|bytes| Scalar::U32(u32::from_le_bytes(bytes)))
174 .map_err(|_| de::Error::custom("Invalid data for u32")),
175 "<i4" => data
176 .into_vec()
177 .try_into()
178 .map(|bytes| Scalar::I32(i32::from_le_bytes(bytes)))
179 .map_err(|_| de::Error::custom("Invalid data for i32")),
180 "<f4" => data
181 .into_vec()
182 .try_into()
183 .map(|bytes| Scalar::F32(f32::from_le_bytes(bytes)))
184 .map_err(|_| de::Error::custom("Invalid data for f32")),
185 "<u8" => data
186 .into_vec()
187 .try_into()
188 .map(|bytes| Scalar::U64(u64::from_le_bytes(bytes)))
189 .map_err(|_| de::Error::custom("Invalid data for u64")),
190 "<i8" => data
191 .into_vec()
192 .try_into()
193 .map(|bytes| Scalar::I64(i64::from_le_bytes(bytes)))
194 .map_err(|_| de::Error::custom("Invalid data for i64")),
195 "<f8" => data
196 .into_vec()
197 .try_into()
198 .map(|bytes| Scalar::F64(f64::from_le_bytes(bytes)))
199 .map_err(|_| de::Error::custom("Invalid data for f64")),
200 _ => Ok(Scalar::Unsupported),
201 }
202 }
203 DType::Array(_) => Ok(Scalar::Unsupported),
204 }
205 }
206 }
207
208 deserializer.deserialize_map(ScalarVisitor)
209 }
210}
211
212impl Serialize for Scalar {
215 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
216 where
217 S: Serializer,
218 {
219 let mut state = serializer.serialize_map(Some(3))?;
220
221 state.serialize_entry(Bytes::new(b"nd"), &false)?;
222
223 match self {
224 Scalar::Bool(val) => serialize_value(&mut state, "|b1", &[*val as u8]),
226 Scalar::U8(val) => serialize_value(&mut state, "|u1", &[*val]),
227 Scalar::I8(val) => serialize_value(&mut state, "|i1", &val.to_le_bytes()),
228 Scalar::U16(val) => serialize_value(&mut state, "<u2", &val.to_le_bytes()),
229 Scalar::I16(val) => serialize_value(&mut state, "<i2", &val.to_le_bytes()),
230 Scalar::F16(val) => serialize_value(&mut state, "<f2", &val.to_le_bytes()),
231 Scalar::U32(val) => serialize_value(&mut state, "<u4", &val.to_le_bytes()),
232 Scalar::I32(val) => serialize_value(&mut state, "<i4", &val.to_le_bytes()),
233 Scalar::F32(val) => serialize_value(&mut state, "<f4", &val.to_le_bytes()),
234 Scalar::U64(val) => serialize_value(&mut state, "<u8", &val.to_le_bytes()),
235 Scalar::I64(val) => serialize_value(&mut state, "<i8", &val.to_le_bytes()),
236 Scalar::F64(val) => serialize_value(&mut state, "<f8", &val.to_le_bytes()),
237 Scalar::Unsupported => {
238 return Err(serde::ser::Error::custom("Unsupported numpy dtype"));
239 }
240 }?;
241
242 state.end()
243 }
244}
245
246fn serialize_value<S>(state: &mut S, type_str: &str, val: &[u8]) -> Result<(), S::Error>
247where
248 S: SerializeMap,
249{
250 state.serialize_entry(Bytes::new(b"type"), type_str)?;
251 state.serialize_entry(Bytes::new(b"data"), Bytes::new(val))
252}
253
254use ndarray::{Array, ArrayBase, IxDyn};
258use std::mem;
259
260#[derive(thiserror::Error, Debug)]
261enum NDArrayError {
262 #[error("InvalidDataLength: {0}")]
263 InvalidDataLength(String),
264
265 #[error("ArrayShapeError: {0}")]
266 ArrayShapeError(ndarray::ShapeError),
267}
268
269impl<'de> Deserialize<'de> for NDArray {
272 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273 where
274 D: Deserializer<'de>,
275 {
276 struct NDArrayVisitor;
277
278 impl<'de> Visitor<'de> for NDArrayVisitor {
279 type Value = NDArray;
280
281 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
282 formatter.write_str("a numpy array in msgpack format")
283 }
284
285 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
286 where
287 A: de::MapAccess<'de>,
288 {
289 let mut nd: Option<bool> = None;
290 let mut numpy_dtype: Option<DType> = None;
291 let mut kind: Option<ByteBuf> = None;
292 let mut shape: Option<Vec<usize>> = None;
293 let mut data: Option<ByteBuf> = None;
294
295 while let Some(key) = map.next_key()? {
296 match key {
297 "nd" => nd = Some(map.next_value()?),
298 "type" => numpy_dtype = Some(map.next_value()?),
299 "kind" => kind = Some(map.next_value()?),
300 "shape" => shape = Some(map.next_value()?),
301 "data" => data = Some(map.next_value()?),
302 _ => {
303 return Err(de::Error::unknown_field(
304 key,
305 &["nd", "type", "kind", "shape", "data"],
306 ))
307 }
308 }
309 }
310
311 let nd = nd.ok_or_else(|| de::Error::missing_field("nd"))?;
312 let numpy_dtype = numpy_dtype.ok_or_else(|| de::Error::missing_field("type"))?;
313 let _kind = kind.ok_or_else(|| de::Error::missing_field("kind"))?;
314 let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
315 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
316
317 if !nd {
318 return Err(de::Error::custom("nd should be true for numpy arrays"));
319 }
320
321 let shape = IxDyn(&shape);
322
323 match numpy_dtype {
326 DType::String(dtype) => {
327 match dtype.as_str() {
328 "|b1" => Array::from_shape_vec(
330 shape,
331 data.into_iter().map(|v| v != 0).collect(),
332 )
333 .map(NDArray::Bool)
334 .map_err(de::Error::custom),
335 "|u1" => Array::from_shape_vec(shape, data.into_vec())
336 .map(NDArray::U8)
337 .map_err(de::Error::custom),
338 "|i1" => create_ndarray_from_transmution::<i8>(data.into_vec(), shape)
339 .map(NDArray::I8)
340 .map_err(de::Error::custom),
341 "<u2" => create_ndarray_from_transmution::<u16>(data.into_vec(), shape)
342 .map(NDArray::U16)
343 .map_err(de::Error::custom),
344 "<i2" => create_ndarray_from_transmution::<i16>(data.into_vec(), shape)
345 .map(NDArray::I16)
346 .map_err(de::Error::custom),
347 "<f2" => create_ndarray_from_transmution::<f16>(data.into_vec(), shape)
348 .map(NDArray::F16)
349 .map_err(de::Error::custom),
350 "<u4" => create_ndarray_from_transmution::<u32>(data.into_vec(), shape)
351 .map(NDArray::U32)
352 .map_err(de::Error::custom),
353 "<i4" => create_ndarray_from_transmution::<i32>(data.into_vec(), shape)
354 .map(NDArray::I32)
355 .map_err(de::Error::custom),
356 "<f4" => create_ndarray_from_transmution::<f32>(data.into_vec(), shape)
357 .map(NDArray::F32)
358 .map_err(de::Error::custom),
359 "<u8" => create_ndarray_from_transmution::<u64>(data.into_vec(), shape)
360 .map(NDArray::U64)
361 .map_err(de::Error::custom),
362 "<i8" => create_ndarray_from_transmution::<i64>(data.into_vec(), shape)
363 .map(NDArray::I64)
364 .map_err(de::Error::custom),
365 "<f8" => create_ndarray_from_transmution::<f64>(data.into_vec(), shape)
366 .map(NDArray::F64)
367 .map_err(de::Error::custom),
368 _ => Ok(NDArray::Unsupported),
369 }
370 }
371 DType::Array(_) => Ok(NDArray::Unsupported),
372 }
373 }
374 }
375
376 deserializer.deserialize_map(NDArrayVisitor)
377 }
378}
379
380fn create_ndarray_from_transmution<T>(
408 data: Vec<u8>,
409 shape: IxDyn,
410) -> Result<Array<T, IxDyn>, NDArrayError> {
411 let transmuted = unsafe { transmute_vec(data) }.ok_or_else(|| {
412 NDArrayError::InvalidDataLength(format!(
413 "Invalid data length for {} transmutation",
414 std::any::type_name::<T>()
415 ))
416 })?;
417
418 Array::from_shape_vec(shape, transmuted).map_err(|e| NDArrayError::ArrayShapeError(e))
419}
420
421unsafe fn transmute_vec<T>(mut data: Vec<u8>) -> Option<Vec<T>> {
444 let size_of_t = mem::size_of::<T>();
445 if data.len() % size_of_t != 0 {
446 return None;
447 }
448
449 let ptr = data.as_mut_ptr() as *mut T;
450 let len = data.len() / size_of_t;
451 let capacity = data.capacity() / size_of_t;
452
453 mem::forget(data);
455
456 Some(Vec::from_raw_parts(ptr, len, capacity))
457}
458
459impl Serialize for NDArray {
462 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
463 where
464 S: Serializer,
465 {
466 let mut state = serializer.serialize_map(Some(5))?;
467
468 state.serialize_entry(Bytes::new(b"nd"), &true)?;
469
470 match self {
471 NDArray::Bool(arr) => serialize_ndarray(&mut state, "|b1", &arr.mapv(|v| v as u8)),
473 NDArray::U8(arr) => serialize_ndarray(&mut state, "|u1", arr),
474 NDArray::I8(arr) => serialize_ndarray(&mut state, "|i1", arr),
475 NDArray::U16(arr) => serialize_ndarray(&mut state, "<u2", arr),
476 NDArray::I16(arr) => serialize_ndarray(&mut state, "<i2", arr),
477 NDArray::F16(arr) => serialize_ndarray(&mut state, "<f2", arr),
478 NDArray::U32(arr) => serialize_ndarray(&mut state, "<u4", arr),
479 NDArray::I32(arr) => serialize_ndarray(&mut state, "<i4", arr),
480 NDArray::F32(arr) => serialize_ndarray(&mut state, "<f4", arr),
481 NDArray::U64(arr) => serialize_ndarray(&mut state, "<u8", arr),
482 NDArray::I64(arr) => serialize_ndarray(&mut state, "<i8", arr),
483 NDArray::F64(arr) => serialize_ndarray(&mut state, "<f8", arr),
484 NDArray::Unsupported => {
485 return Err(serde::ser::Error::custom("Unsupported numpy dtype"));
486 }
487 }?;
488
489 state.end()
490 }
491}
492
493fn serialize_ndarray<S, A, T>(
494 state: &mut S,
495 type_str: &str,
496 arr: &ArrayBase<A, IxDyn>,
497) -> Result<(), S::Error>
498where
499 S: SerializeMap,
500 A: ndarray::RawData<Elem = T>,
501{
502 state.serialize_entry(Bytes::new(b"type"), type_str)?;
503 state.serialize_entry(Bytes::new(b"kind"), Bytes::new(b""))?;
504 state.serialize_entry(Bytes::new(b"shape"), &arr.shape())?;
505
506 let data = unsafe { transmute_array_to_slice(arr) };
507 state.serialize_entry(Bytes::new(b"data"), Bytes::new(data))
508}
509
510unsafe fn transmute_array_to_slice<A: ndarray::RawData<Elem = T>, T>(
530 arr: &ArrayBase<A, IxDyn>,
531) -> &[u8] {
532 let ptr = arr.as_ptr() as *const u8;
533 let len = arr.len() * mem::size_of::<T>();
534 std::slice::from_raw_parts(ptr, len)
535}
536
537use ndarray::{ArrayView, CowArray};
541
542impl<'de: 'a, 'a> Deserialize<'de> for CowNDArray<'a> {
545 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
546 where
547 D: Deserializer<'de>,
548 {
549 struct NDArrayVisitor<'a>(std::marker::PhantomData<&'a ()>);
550
551 impl<'de: 'a, 'a> Visitor<'de> for NDArrayVisitor<'a> {
552 type Value = CowNDArray<'a>;
553
554 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
555 formatter.write_str("a numpy array in msgpack format")
556 }
557
558 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
559 where
560 A: de::MapAccess<'de>,
561 {
562 let mut nd: Option<bool> = None;
563 let mut numpy_dtype: Option<DType> = None;
564 let mut kind: Option<&'a Bytes> = None;
565 let mut shape: Option<Vec<usize>> = None;
566 let mut data: Option<&'a Bytes> = None;
567
568 while let Some(key) = map.next_key()? {
569 match key {
570 "nd" => nd = Some(map.next_value()?),
571 "type" => numpy_dtype = Some(map.next_value()?),
572 "kind" => kind = Some(map.next_value()?),
573 "shape" => shape = Some(map.next_value()?),
574 "data" => data = Some(map.next_value()?),
575 _ => {
576 return Err(de::Error::unknown_field(
577 key,
578 &["nd", "type", "kind", "shape", "data"],
579 ))
580 }
581 }
582 }
583
584 let nd = nd.ok_or_else(|| de::Error::missing_field("nd"))?;
585 let numpy_dtype = numpy_dtype.ok_or_else(|| de::Error::missing_field("type"))?;
586 let _kind = kind.ok_or_else(|| de::Error::missing_field("kind"))?;
587 let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
588 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
589
590 if !nd {
591 return Err(de::Error::custom("nd should be true for numpy arrays"));
592 }
593
594 let shape = IxDyn(&shape);
595
596 match numpy_dtype {
599 DType::String(dtype) => {
600 match dtype.as_str() {
601 "|b1" => Array::from_shape_vec(
603 shape,
604 data.into_iter().map(|v| *v != 0).collect(),
605 )
606 .map(CowArray::from)
607 .map(CowNDArray::Bool)
608 .map_err(de::Error::custom),
609 "|u1" => ArrayView::from_shape(shape, data)
610 .map(CowArray::from)
611 .map(CowNDArray::U8)
612 .map_err(de::Error::custom),
613 "|i1" => create_cowndarray_from_transmution::<i8>(data, shape)
614 .map(CowNDArray::I8)
615 .map_err(de::Error::custom),
616 "<u2" => create_cowndarray_from_transmution::<u16>(data, shape)
617 .map(CowNDArray::U16)
618 .map_err(de::Error::custom),
619 "<i2" => create_cowndarray_from_transmution::<i16>(data, shape)
620 .map(CowNDArray::I16)
621 .map_err(de::Error::custom),
622 "<f2" => create_cowndarray_from_transmution::<f16>(data, shape)
623 .map(CowNDArray::F16)
624 .map_err(de::Error::custom),
625 "<u4" => create_cowndarray_from_transmution::<u32>(data, shape)
626 .map(CowNDArray::U32)
627 .map_err(de::Error::custom),
628 "<i4" => create_cowndarray_from_transmution::<i32>(data, shape)
629 .map(CowNDArray::I32)
630 .map_err(de::Error::custom),
631 "<f4" => create_cowndarray_from_transmution::<f32>(data, shape)
632 .map(CowNDArray::F32)
633 .map_err(de::Error::custom),
634 "<u8" => create_cowndarray_from_transmution::<u64>(data, shape)
635 .map(CowNDArray::U64)
636 .map_err(de::Error::custom),
637 "<i8" => create_cowndarray_from_transmution::<i64>(data, shape)
638 .map(CowNDArray::I64)
639 .map_err(de::Error::custom),
640 "<f8" => create_cowndarray_from_transmution::<f64>(data, shape)
641 .map(CowNDArray::F64)
642 .map_err(de::Error::custom),
643 _ => Ok(CowNDArray::Unsupported),
644 }
645 }
646 DType::Array(_) => Ok(CowNDArray::Unsupported),
647 }
648 }
649 }
650
651 deserializer.deserialize_map(NDArrayVisitor(std::marker::PhantomData))
652 }
653}
654
655fn create_cowndarray_from_transmution<T: Clone>(
656 data: &[u8],
657 shape: IxDyn,
658) -> Result<CowArray<T, IxDyn>, NDArrayError> {
659 let transmuted = unsafe { transmute_slice(data) }.ok_or_else(|| {
660 NDArrayError::InvalidDataLength(format!(
661 "Invalid data length for {} transmutation",
662 std::any::type_name::<T>()
663 ))
664 })?;
665
666 match transmuted {
667 Cow::Borrowed(slice) => ArrayView::from_shape(shape, slice).map(CowArray::from),
668 Cow::Owned(vec) => Array::from_shape_vec(shape, vec).map(CowArray::from),
669 }
670 .map_err(|e| NDArrayError::ArrayShapeError(e))
671}
672
673unsafe fn transmute_slice<T: Clone>(data: &[u8]) -> Option<Cow<[T]>> {
674 let size_of_t = mem::size_of::<T>();
675 if data.len() % size_of_t != 0 {
677 return None;
678 }
679
680 let misalignment = (data.as_ptr() as usize) % mem::align_of::<T>();
682
683 if misalignment == 0 {
684 let ptr = data.as_ptr() as *const T;
687 let len = data.len() / size_of_t;
688 Some(Cow::Borrowed(std::slice::from_raw_parts(ptr, len)))
689 } else {
690 let mut aligned_vec: Vec<T> = Vec::with_capacity(data.len() / size_of_t);
693 std::ptr::copy_nonoverlapping(
694 data.as_ptr(),
695 aligned_vec.as_mut_ptr() as *mut u8,
696 data.len(),
697 );
698 aligned_vec.set_len(data.len() / size_of_t);
699 Some(Cow::Owned(aligned_vec))
700 }
701}
702
703impl<'a> Serialize for CowNDArray<'a> {
706 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
707 where
708 S: Serializer,
709 {
710 let mut state = serializer.serialize_map(Some(5))?;
711
712 state.serialize_entry(Bytes::new(b"nd"), &true)?;
713
714 match self {
715 CowNDArray::Bool(arr) => serialize_ndarray(&mut state, "|b1", &arr.mapv(|v| v as u8)),
717 CowNDArray::U8(arr) => serialize_ndarray(&mut state, "|u1", arr),
718 CowNDArray::I8(arr) => serialize_ndarray(&mut state, "|i1", arr),
719 CowNDArray::U16(arr) => serialize_ndarray(&mut state, "<u2", arr),
720 CowNDArray::I16(arr) => serialize_ndarray(&mut state, "<i2", arr),
721 CowNDArray::F16(arr) => serialize_ndarray(&mut state, "<f2", arr),
722 CowNDArray::U32(arr) => serialize_ndarray(&mut state, "<u4", arr),
723 CowNDArray::I32(arr) => serialize_ndarray(&mut state, "<i4", arr),
724 CowNDArray::F32(arr) => serialize_ndarray(&mut state, "<f4", arr),
725 CowNDArray::U64(arr) => serialize_ndarray(&mut state, "<u8", arr),
726 CowNDArray::I64(arr) => serialize_ndarray(&mut state, "<i8", arr),
727 CowNDArray::F64(arr) => serialize_ndarray(&mut state, "<f8", arr),
728 CowNDArray::Unsupported => {
729 return Err(serde::ser::Error::custom("Unsupported numpy dtype"));
730 }
731 }?;
732
733 state.end()
734 }
735}
736
737#[cfg(test)]
741mod tests {
742 use crate::core::{CowNDArray, NDArray, Scalar};
744 use half::f16;
745 use ndarray::Array;
746
747 #[test]
748 fn test_scalar_serialization() {
749 let cases = vec![
750 Scalar::Bool(true),
751 Scalar::U8(255),
752 Scalar::I8(-128),
753 Scalar::U16(65535),
754 Scalar::I16(-32768),
755 Scalar::F16(f16::from_f32(1.0)),
756 Scalar::U32(4294967295),
757 Scalar::I32(-2147483648),
758 Scalar::F32(1.0),
759 Scalar::U64(18446744073709551615),
760 Scalar::I64(-9223372036854775808),
761 Scalar::F64(1.0),
762 ];
763
764 for scalar in cases {
765 let serialized = rmp_serde::to_vec_named(&scalar).unwrap();
766 let deserialized: Scalar = rmp_serde::from_slice(&serialized).unwrap();
767 assert_eq!(deserialized, scalar);
768 }
769 }
770
771 #[test]
772 #[rustfmt::skip]
773 fn test_ndarray_serialization() {
774 let cases = vec![
775 NDArray::Bool(Array::from_vec(vec![true, false]).into_dyn().into()),
776 NDArray::U8(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
777 NDArray::I8(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
778 NDArray::U16(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
779 NDArray::I16(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
780 NDArray::F16(Array::from_vec(vec![1.0, 2.0]).into_dyn().mapv(f16::from_f32).into()),
781 NDArray::U32(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
782 NDArray::I32(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
783 NDArray::F32(Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn().into()),
784 NDArray::U64(Array::from_vec(vec![1, 2]).into_dyn().into()),
785 NDArray::I64(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
786 NDArray::F64(Array::from_vec(vec![1.0, 2.0]).into_dyn().into()),
787 ];
788
789 for ndarray in cases {
790 let serialized = rmp_serde::to_vec_named(&ndarray).unwrap();
791 let deserialized: NDArray = rmp_serde::from_slice(&serialized).unwrap();
792
793 assert_eq!(deserialized, ndarray);
794 }
795 }
796
797 #[test]
798 #[rustfmt::skip]
799 fn test_cowndarray_serialization() {
800 fn assert_float_eq<T>(a: T, b: T)
801 where
802 T: num_traits::Float + std::fmt::Debug,
803 {
804 if a.is_nan() && b.is_nan() {
805 return; }
807 if a.is_infinite() && b.is_infinite() {
808 assert_eq!(
809 a.signum(),
810 b.signum(),
811 "Infinite values have different signs"
812 );
813 return;
814 }
815 assert_eq!(a, b);
816 }
817 let cases = vec![
818 CowNDArray::Bool(Array::from_vec(vec![true, false]).into_dyn().into()),
819 CowNDArray::U8(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
820 CowNDArray::I8(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
821 CowNDArray::U16(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
822 CowNDArray::I16(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
823 CowNDArray::F16(Array::from_vec(vec![1.0, 2.0]).into_dyn().mapv(f16::from_f32).into()),
824 CowNDArray::U32(Array::from_vec(vec![1, 2, 3]).into_dyn().into()),
825 CowNDArray::I32(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
826 CowNDArray::F32(Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn().into()),
827 CowNDArray::U64(Array::from_vec(vec![1, 2]).into_dyn().into()),
828 CowNDArray::I64(Array::from_vec(vec![-1, 0, 1]).into_dyn().into()),
829 CowNDArray::F64(Array::from_vec(vec![1.0, 2.0]).into_dyn().into()),
830 ];
831
832 for ndarray in cases {
833 let serialized = rmp_serde::to_vec_named(&ndarray).unwrap();
834 let deserialized: CowNDArray = rmp_serde::from_slice(&serialized).unwrap();
835
836 match (deserialized, ndarray) {
837 (CowNDArray::Bool(a), CowNDArray::Bool(b)) => assert_eq!(a, b),
838 (CowNDArray::U8(a), CowNDArray::U8(b)) => assert_eq!(a, b),
839 (CowNDArray::U16(a), CowNDArray::U16(b)) => assert_eq!(a, b),
840 (CowNDArray::U32(a), CowNDArray::U32(b)) => assert_eq!(a, b),
841 (CowNDArray::U64(a), CowNDArray::U64(b)) => assert_eq!(a, b),
842 (CowNDArray::I8(a), CowNDArray::I8(b)) => assert_eq!(a, b),
843 (CowNDArray::I16(a), CowNDArray::I16(b)) => assert_eq!(a, b),
844 (CowNDArray::I32(a), CowNDArray::I32(b)) => assert_eq!(a, b),
845 (CowNDArray::I64(a), CowNDArray::I64(b)) => assert_eq!(a, b),
846 (CowNDArray::F16(a), CowNDArray::F16(b)) => {
847 assert_eq!(a.shape(), b.shape());
848 a.iter().zip(b.iter()).for_each(|(x, y)| {
849 assert_float_eq(x.to_f32(), y.to_f32());
850 });
851 }
852 (CowNDArray::F32(a), CowNDArray::F32(b)) => {
853 assert_eq!(a.shape(), b.shape());
854 a.iter().zip(b.iter()).for_each(|(x, y)| {
855 assert_float_eq(*x, *y);
856 });
857 }
858 (CowNDArray::F64(a), CowNDArray::F64(b)) => {
859 assert_eq!(a.shape(), b.shape());
860 a.iter().zip(b.iter()).for_each(|(x, y)| {
861 assert_float_eq(*x, *y);
862 });
863 }
864 (CowNDArray::Unsupported, CowNDArray::Unsupported) => (),
865 _ => panic!("Mismatched types"),
866 }
867 }
868 }
869}