ndtensor/impls/
impl_tensor_serde.rs

1/*
2    appellation: impl_tensor_serde <module>
3    authors: @FL03
4*/
5use crate::tensor::TensorBase;
6
7use core::marker::PhantomData;
8use ndarray::{Data, DataOwned, Dimension, RawData};
9use serde::de::{Deserialize, Deserializer, Visitor};
10use serde::ser::{Serialize, SerializeStruct, Serializer};
11
12/// a constant defining the fields of the `TensorBase` struct for serialization
13const FIELDS: [&str; 1] = ["store"];
14
15pub struct TensorBaseVisitor<S, D>
16where
17    D: Dimension,
18    S: RawData,
19{
20    _phantom: PhantomData<(S, D)>,
21}
22
23impl<'a, A, S, D> Visitor<'a> for TensorBaseVisitor<S, D>
24where
25    A: Deserialize<'a>,
26    D: Dimension + Deserialize<'a>,
27    S: DataOwned<Elem = A>,
28{
29    type Value = TensorBase<S, D>;
30
31    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
32        formatter.write_str("a tensor with data")
33    }
34
35    fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
36    where
37        V: serde::de::SeqAccess<'a>,
38    {
39        let store = seq
40            .next_element()?
41            .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
42        Ok(TensorBase { store })
43    }
44}
45
46impl<'a, A, S, D> Deserialize<'a> for TensorBase<S, D>
47where
48    A: Deserialize<'a>,
49    D: Dimension + Deserialize<'a>,
50    S: DataOwned<Elem = A>,
51{
52    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
53    where
54        De: Deserializer<'a>,
55    {
56        deserializer.deserialize_struct(
57            "TensorBase",
58            &FIELDS,
59            TensorBaseVisitor {
60                _phantom: PhantomData,
61            },
62        )
63    }
64}
65
66impl<A, S, D> Serialize for TensorBase<S, D>
67where
68    A: Serialize,
69    D: Dimension + Serialize,
70    S: Data<Elem = A>,
71{
72    fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
73    where
74        Ser: Serializer,
75    {
76        let mut state = serializer.serialize_struct("TensorBase", 1)?;
77        state.serialize_field("data", self.store())?;
78        state.end()
79    }
80}