1#[cfg(feature = "use-serde")]
2use serde::{Serialize, Deserialize, Serializer, Deserializer,
3 ser::SerializeStruct,
4 de, de::Visitor, de::SeqAccess, de::MapAccess};
5use std::fmt;
6
7use crate::var::Var;
8
9impl Serialize for Var {
10 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11 where S: Serializer, {
12 let mut state = serializer.serialize_struct("Var", 1)?;
14 state.serialize_field("var", &*self.inner().borrow())?;
15 state.end()
16 }
17}
18
19impl<'de> Deserialize<'de> for Var {
20 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
21 where D: Deserializer<'de>, {
22
23 enum Field { Var }
24
25 impl<'de> Deserialize<'de> for Field {
26 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
27 where D: Deserializer<'de>, {
28 struct FieldVisitor;
29
30 impl<'de> Visitor<'de> for FieldVisitor {
31 type Value = Field;
32
33 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34 formatter.write_str("var")
35 }
36
37 fn visit_str<E>(self, value: &str) -> Result<Field, E>
38 where E: de::Error, {
39 match value {
40 "var" => Ok(Field::Var),
41 _ => Err(de::Error::unknown_field(value, &FIELDS)),
42 }
43 }
44 }
45
46 deserializer.deserialize_identifier(FieldVisitor)
47 }
48 }
49
50 struct VarVisitor;
51
52 impl<'de> Visitor<'de> for VarVisitor {
53 type Value = Var;
54
55 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
56 formatter.write_str("struct Var")
57 }
58
59 fn visit_map<V>(self, mut map: V) -> Result<Var, V::Error>
60 where V: MapAccess<'de>, {
61 let mut var = None;
62 while let Some(key) = map.next_key()? {
63 match key {
64 Field::Var => {
65 if var.is_some() {
66 return Err(de::Error::duplicate_field("var"));
67 }
68 var = Some(map.next_value()?);
69 },
70 }
71 }
72 let var = var.ok_or_else(|| de::Error::missing_field("id"))?;
73 Ok(Var::set_inner(var))
74 }
75
76 fn visit_seq<V>(self, mut seq: V) -> Result<Var, V::Error>
77 where V: SeqAccess<'de>, {
78 let var = seq.next_element()?
79 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
80 Ok(Var::set_inner(var))
81 }
82 }
83
84 const FIELDS: [&str; 1] = ["var"];
85 deserializer.deserialize_struct("Duration", &FIELDS, VarVisitor)
86 }
87}
88
89#[cfg(all(test, feature = "use-serde"))]
90mod tests {
91 use crate::var::Var;
92 use rand::prelude::*;
93
94 #[test]
95 fn test_serde_var_inner() {
96 let mut rng = StdRng::seed_from_u64(671);
97 let n = 10;
98 let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.);
99 let result = data.matmul(&Var::new(&vec![2., 3.], &vec![2, 1])).unwrap() + Var::new(&vec![1.], &vec![1]);
100
101 let serialized = serde_pickle::to_vec(&result, true).unwrap();
102 let deserialized: Var = serde_pickle::from_slice(&serialized).unwrap();
103 println!("{:?}", deserialized.dump_net());
104 assert_eq!(result, deserialized);
105 }
106
107}