concision_core/params/impls/
impl_params_serde.rs1use crate::params::ParamsBase;
6use ndarray::{Data, DataOwned, Dimension, RawData};
7use serde::de::{Deserialize, Deserializer, Error, Visitor};
8use serde::ser::{Serialize, SerializeStruct, Serializer};
9
10use core::marker::PhantomData;
11
12const FIELDS: [&str; 2] = ["bias", "weights"];
13
14struct ParamsBaseVisitor<S, D>
15where
16 D: Dimension,
17 S: RawData,
18{
19 marker: PhantomData<(S, D)>,
20}
21
22impl<'a, A, S, D> Visitor<'a> for ParamsBaseVisitor<S, D>
23where
24 D: Dimension + Deserialize<'a>,
25 S: DataOwned<Elem = A>,
26 A: Deserialize<'a>,
27 <D as ndarray::Dimension>::Smaller: Deserialize<'a>,
28{
29 type Value = ParamsBase<S, D>;
30
31 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
32 formatter.write_str("a ParamsBase object")
33 }
34
35 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
36 where
37 V: serde::de::SeqAccess<'a>,
38 {
39 let bias = seq
40 .next_element()?
41 .ok_or_else(|| Error::invalid_length(1, &self))?;
42 let weights = seq
43 .next_element()?
44 .ok_or_else(|| Error::invalid_length(2, &self))?;
45
46 Ok(ParamsBase { bias, weights })
47 }
48}
49
50impl<'a, A, S, D> Deserialize<'a> for ParamsBase<S, D>
51where
52 D: Dimension + Deserialize<'a>,
53 S: DataOwned<Elem = A>,
54 A: Deserialize<'a>,
55 <D as ndarray::Dimension>::Smaller: Deserialize<'a>,
56{
57 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
58 where
59 De: Deserializer<'a>,
60 {
61 deserializer.deserialize_struct(
62 "ParamsBase",
63 &FIELDS,
64 ParamsBaseVisitor {
65 marker: PhantomData,
66 },
67 )
68 }
69}
70
71impl<A, S, D> Serialize for ParamsBase<S, D>
72where
73 A: Serialize,
74 D: Dimension + Serialize,
75 S: Data<Elem = A>,
76 <D as ndarray::Dimension>::Smaller: Serialize,
77{
78 fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
79 where
80 Ser: Serializer,
81 {
82 let mut state = serializer.serialize_struct("ParamsBase", 2)?;
83 state.serialize_field("bias", self.bias())?;
84 state.serialize_field("weights", self.weights())?;
85 state.end()
86 }
87}