concision_neural/model/params/
impl_model_params_serde.rs1use crate::model::ModelParamsBase;
6
7use crate::RawHidden;
8use core::marker::PhantomData;
9use ndarray::{Data, DataOwned, Dimension, RawData};
10use serde::de::{Deserialize, Deserializer, Error, Visitor};
11use serde::ser::{Serialize, SerializeStruct, Serializer};
12
13const FIELDS: [&str; 3] = ["input", "hidden", "output"];
16
17struct ModelParamsBaseVisitor<S, D, H>
18where
19 D: Dimension,
20 S: RawData,
21 H: RawHidden<S, D>,
22{
23 marker: PhantomData<(S, D, H)>,
24}
25
26impl<'a, A, S, D, H> Visitor<'a> for ModelParamsBaseVisitor<S, D, H>
27where
28 A: Deserialize<'a>,
29 D: Dimension + Deserialize<'a>,
30 S: DataOwned<Elem = A>,
31 H: RawHidden<S, D> + Deserialize<'a>,
32 <D as Dimension>::Smaller: Deserialize<'a>,
33{
34 type Value = ModelParamsBase<S, D, H>;
35
36 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
37 formatter.write_str("The visitor is expecting to receive a `ModelParamsBase` object.")
38 }
39
40 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
41 where
42 V: serde::de::SeqAccess<'a>,
43 {
44 let input = seq
45 .next_element()?
46 .ok_or_else(|| Error::invalid_length(1, &self))?;
47 let hidden = seq
48 .next_element()?
49 .ok_or_else(|| Error::invalid_length(2, &self))?;
50 let output = seq
51 .next_element()?
52 .ok_or_else(|| Error::invalid_length(3, &self))?;
53
54 Ok(ModelParamsBase {
55 input,
56 hidden,
57 output,
58 })
59 }
60}
61
62impl<'a, A, S, D, H> Deserialize<'a> for ModelParamsBase<S, D, H>
63where
64 A: Deserialize<'a>,
65 D: Dimension + Deserialize<'a>,
66 S: DataOwned<Elem = A>,
67 H: RawHidden<S, D> + Deserialize<'a>,
68 <D as Dimension>::Smaller: Deserialize<'a>,
69{
70 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
71 where
72 De: Deserializer<'a>,
73 {
74 deserializer.deserialize_struct(
75 "ModelParamsBase",
76 &FIELDS,
77 ModelParamsBaseVisitor {
78 marker: PhantomData,
79 },
80 )
81 }
82}
83
84impl<A, S, D, H> Serialize for ModelParamsBase<S, D, H>
85where
86 A: Serialize,
87 D: Dimension + Serialize,
88 S: Data<Elem = A>,
89 H: RawHidden<S, D> + Serialize,
90 <D as Dimension>::Smaller: Serialize,
91{
92 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
93 where
94 Se: Serializer,
95 {
96 let mut state = serializer.serialize_struct("ModelParamsBase", 3)?;
97 state.serialize_field("input", &self.input)?;
98 state.serialize_field("hidden", &self.hidden)?;
99 state.serialize_field("output", &self.output)?;
100 state.end()
101 }
102}