auto_diff/serde/
var_inner.rs1#[cfg(feature = "use-serde")]
2use serde::{Serialize, Deserialize, Serializer, Deserializer,
3 ser::SerializeStruct,
4 de, de::Visitor, de::SeqAccess, de::MapAccess};
5use std::fmt;
6
7use crate::var_inner::VarInner;
8
9impl Serialize for VarInner {
10 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11 where S: Serializer, {
12 let mut state = serializer.serialize_struct("VarInner", 3)?;
14 state.serialize_field("id", &self.get_id())?;
15 state.serialize_field("need_grad", &self.get_need_grad())?;
16 state.serialize_field("net", &*self.get_net().borrow())?;
17 state.end()
18 }
19}
20
21impl<'de> Deserialize<'de> for VarInner {
22 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
23 where D: Deserializer<'de>, {
24
25 enum Field { Id, NeedGrad, Net }
26
27 impl<'de> Deserialize<'de> for Field {
28 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
29 where D: Deserializer<'de>, {
30 struct FieldVisitor;
31
32 impl<'de> Visitor<'de> for FieldVisitor {
33 type Value = Field;
34
35 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
36 formatter.write_str("id, need_grad, or net")
37 }
38
39 fn visit_str<E>(self, value: &str) -> Result<Field, E>
40 where E: de::Error, {
41 match value {
42 "id" => Ok(Field::Id),
43 "need_grad" => Ok(Field::NeedGrad),
44 "net" => Ok(Field::Net),
45 _ => Err(de::Error::unknown_field(value, &FIELDS)),
46 }
47 }
48 }
49
50 deserializer.deserialize_identifier(FieldVisitor)
51 }
52 }
53
54 struct VarInnerVisitor;
55
56 impl<'de> Visitor<'de> for VarInnerVisitor {
57 type Value = VarInner;
58
59 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
60 formatter.write_str("struct VarInner")
61 }
62
63 fn visit_map<V>(self, mut map: V) -> Result<VarInner, V::Error>
64 where V: MapAccess<'de>, {
65 let mut id = None;
66 let mut need_grad = None;
67 let mut net = None;
68 while let Some(key) = map.next_key()? {
69 match key {
70 Field::Id => {
71 if id.is_some() {
72 return Err(de::Error::duplicate_field("id"));
73 }
74 id = Some(map.next_value()?);
75 },
76 Field::NeedGrad => {
77 if need_grad.is_some() {
78 return Err(de::Error::duplicate_field("need_grad"));
79 }
80 need_grad = Some(map.next_value()?);
81 },
82 Field::Net => {
83 if net.is_some() {
84 return Err(de::Error::duplicate_field("net"));
85 }
86 net = Some(map.next_value()?);
87 }
88 }
89 }
90 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
91 let need_grad = need_grad.ok_or_else(|| de::Error::missing_field("need_grad"))?;
92 let net = net.ok_or_else(|| de::Error::missing_field("net"))?;
93 Ok(VarInner::set_inner(id, need_grad, net))
94 }
95
96 fn visit_seq<V>(self, mut seq: V) -> Result<VarInner, V::Error>
97 where V: SeqAccess<'de>, {
98 let id = seq.next_element()?
99 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
100 let need_grad = seq.next_element()?
101 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
102 let net = seq.next_element()?
103 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
104 Ok(VarInner::set_inner(id, need_grad, net))
105 }
106 }
107
108 const FIELDS: [&str; 3] = ["id", "need_grad", "net"];
109 deserializer.deserialize_struct("Duration", &FIELDS, VarInnerVisitor)
110 }
111}
112
113#[cfg(all(test, feature = "use-serde"))]
114mod tests {
115 use crate::var_inner::VarInner;
116 use crate::var::Var;
117 use rand::prelude::*;
118
119 #[test]
120 fn test_serde_var_inner() {
121 let mut rng = StdRng::seed_from_u64(671);
122 let n = 10;
123 let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.);
124 let result = data.matmul(&Var::new(&vec![2., 3.], &vec![2, 1])).unwrap() + Var::new(&vec![1.], &vec![1]);
125
126 let serialized = serde_pickle::to_vec(&*result.inner().borrow(), true).unwrap();
127 let deserialized: VarInner = serde_pickle::from_slice(&serialized).unwrap();
128 println!("{:?}", deserialized.dump_net());
129 }
131
132}