concision_core/params/impls/
impl_params_serde.rs

1/*
2    appellation: impl_params_serde <module>
3    authors: @FL03
4*/
5use crate::params::ParamsBase;
6use ndarray::{Data, DataOwned, Dimension, RawData};
7use serde::de::{Deserialize, Deserializer, Error, Visitor};
8use serde::ser::{Serialize, SerializeStruct, Serializer};
9
10use core::marker::PhantomData;
11
12const FIELDS: [&str; 2] = ["bias", "weights"];
13
14struct ParamsBaseVisitor<S, D>
15where
16    D: Dimension,
17    S: RawData,
18{
19    marker: PhantomData<(S, D)>,
20}
21
22impl<'a, A, S, D> Visitor<'a> for ParamsBaseVisitor<S, D>
23where
24    D: Dimension + Deserialize<'a>,
25    S: DataOwned<Elem = A>,
26    A: Deserialize<'a>,
27    <D as ndarray::Dimension>::Smaller: Deserialize<'a>,
28{
29    type Value = ParamsBase<S, D>;
30
31    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
32        formatter.write_str("a ParamsBase object")
33    }
34
35    fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
36    where
37        V: serde::de::SeqAccess<'a>,
38    {
39        let bias = seq
40            .next_element()?
41            .ok_or_else(|| Error::invalid_length(1, &self))?;
42        let weights = seq
43            .next_element()?
44            .ok_or_else(|| Error::invalid_length(2, &self))?;
45
46        Ok(ParamsBase { bias, weights })
47    }
48}
49
50impl<'a, A, S, D> Deserialize<'a> for ParamsBase<S, D>
51where
52    D: Dimension + Deserialize<'a>,
53    S: DataOwned<Elem = A>,
54    A: Deserialize<'a>,
55    <D as ndarray::Dimension>::Smaller: Deserialize<'a>,
56{
57    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
58    where
59        De: Deserializer<'a>,
60    {
61        deserializer.deserialize_struct(
62            "ParamsBase",
63            &FIELDS,
64            ParamsBaseVisitor {
65                marker: PhantomData,
66            },
67        )
68    }
69}
70
71impl<A, S, D> Serialize for ParamsBase<S, D>
72where
73    A: Serialize,
74    D: Dimension + Serialize,
75    S: Data<Elem = A>,
76    <D as ndarray::Dimension>::Smaller: Serialize,
77{
78    fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
79    where
80        Ser: Serializer,
81    {
82        let mut state = serializer.serialize_struct("ParamsBase", 2)?;
83        state.serialize_field("bias", self.bias())?;
84        state.serialize_field("weights", self.weights())?;
85        state.end()
86    }
87}