concision_neural/model/params/
impl_model_params_serde.rs

1/*
2    appellation: impl_model_params_serde <module>
3    authors: @FL03
4*/
5use crate::model::ModelParamsBase;
6
7use crate::RawHidden;
8use core::marker::PhantomData;
9use ndarray::{Data, DataOwned, Dimension, RawData};
10use serde::de::{Deserialize, Deserializer, Error, Visitor};
11use serde::ser::{Serialize, SerializeStruct, Serializer};
12
13/// a constant defining the various fields of the [`ModelParamsBase`] type that are used for
14/// serialization and deserialization.
15const FIELDS: [&str; 3] = ["input", "hidden", "output"];
16
17struct ModelParamsBaseVisitor<S, D, H>
18where
19    D: Dimension,
20    S: RawData,
21    H: RawHidden<S, D>,
22{
23    marker: PhantomData<(S, D, H)>,
24}
25
26impl<'a, A, S, D, H> Visitor<'a> for ModelParamsBaseVisitor<S, D, H>
27where
28    A: Deserialize<'a>,
29    D: Dimension + Deserialize<'a>,
30    S: DataOwned<Elem = A>,
31    H: RawHidden<S, D> + Deserialize<'a>,
32    <D as Dimension>::Smaller: Deserialize<'a>,
33{
34    type Value = ModelParamsBase<S, D, H>;
35
36    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
37        formatter.write_str("The visitor is expecting to receive a `ModelParamsBase` object.")
38    }
39
40    fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
41    where
42        V: serde::de::SeqAccess<'a>,
43    {
44        let input = seq
45            .next_element()?
46            .ok_or_else(|| Error::invalid_length(1, &self))?;
47        let hidden = seq
48            .next_element()?
49            .ok_or_else(|| Error::invalid_length(2, &self))?;
50        let output = seq
51            .next_element()?
52            .ok_or_else(|| Error::invalid_length(3, &self))?;
53
54        Ok(ModelParamsBase {
55            input,
56            hidden,
57            output,
58        })
59    }
60}
61
62impl<'a, A, S, D, H> Deserialize<'a> for ModelParamsBase<S, D, H>
63where
64    A: Deserialize<'a>,
65    D: Dimension + Deserialize<'a>,
66    S: DataOwned<Elem = A>,
67    H: RawHidden<S, D> + Deserialize<'a>,
68    <D as Dimension>::Smaller: Deserialize<'a>,
69{
70    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
71    where
72        De: Deserializer<'a>,
73    {
74        deserializer.deserialize_struct(
75            "ModelParamsBase",
76            &FIELDS,
77            ModelParamsBaseVisitor {
78                marker: PhantomData,
79            },
80        )
81    }
82}
83
84impl<A, S, D, H> Serialize for ModelParamsBase<S, D, H>
85where
86    A: Serialize,
87    D: Dimension + Serialize,
88    S: Data<Elem = A>,
89    H: RawHidden<S, D> + Serialize,
90    <D as Dimension>::Smaller: Serialize,
91{
92    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
93    where
94        Se: Serializer,
95    {
96        let mut state = serializer.serialize_struct("ModelParamsBase", 3)?;
97        state.serialize_field("input", &self.input)?;
98        state.serialize_field("hidden", &self.hidden)?;
99        state.serialize_field("output", &self.output)?;
100        state.end()
101    }
102}