ndtensor/impls/
impl_tensor_serde.rs1use crate::tensor::TensorBase;
6
7use core::marker::PhantomData;
8use ndarray::{Data, DataOwned, Dimension, RawData};
9use serde::de::{Deserialize, Deserializer, Visitor};
10use serde::ser::{Serialize, SerializeStruct, Serializer};
11
12const FIELDS: [&str; 1] = ["store"];
14
15pub struct TensorBaseVisitor<S, D>
16where
17 D: Dimension,
18 S: RawData,
19{
20 _phantom: PhantomData<(S, D)>,
21}
22
23impl<'a, A, S, D> Visitor<'a> for TensorBaseVisitor<S, D>
24where
25 A: Deserialize<'a>,
26 D: Dimension + Deserialize<'a>,
27 S: DataOwned<Elem = A>,
28{
29 type Value = TensorBase<S, D>;
30
31 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
32 formatter.write_str("a tensor with data")
33 }
34
35 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
36 where
37 V: serde::de::SeqAccess<'a>,
38 {
39 let store = seq
40 .next_element()?
41 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
42 Ok(TensorBase { store })
43 }
44}
45
46impl<'a, A, S, D> Deserialize<'a> for TensorBase<S, D>
47where
48 A: Deserialize<'a>,
49 D: Dimension + Deserialize<'a>,
50 S: DataOwned<Elem = A>,
51{
52 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
53 where
54 De: Deserializer<'a>,
55 {
56 deserializer.deserialize_struct(
57 "TensorBase",
58 &FIELDS,
59 TensorBaseVisitor {
60 _phantom: PhantomData,
61 },
62 )
63 }
64}
65
66impl<A, S, D> Serialize for TensorBase<S, D>
67where
68 A: Serialize,
69 D: Dimension + Serialize,
70 S: Data<Elem = A>,
71{
72 fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
73 where
74 Ser: Serializer,
75 {
76 let mut state = serializer.serialize_struct("TensorBase", 1)?;
77 state.serialize_field("data", self.store())?;
78 state.end()
79 }
80}