auto_diff/serde/
var.rs

1#[cfg(feature = "use-serde")]
2use serde::{Serialize, Deserialize, Serializer, Deserializer,
3	    ser::SerializeStruct,
4	    de, de::Visitor, de::SeqAccess, de::MapAccess};
5use std::fmt;
6
7use crate::var::Var;
8
9impl Serialize for Var {
10    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11    where S: Serializer, {
12        // 3 is the number of fields in the struct.
13        let mut state = serializer.serialize_struct("Var", 1)?;
14        state.serialize_field("var", &*self.inner().borrow())?;
15        state.end()
16    }
17}
18
19impl<'de> Deserialize<'de> for Var {
20    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
21    where D: Deserializer<'de>, {
22
23	enum Field { Var }
24	
25        impl<'de> Deserialize<'de> for Field {
26            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
27            where D: Deserializer<'de>, {
28                struct FieldVisitor;
29
30                impl<'de> Visitor<'de> for FieldVisitor {
31                    type Value = Field;
32
33                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34                        formatter.write_str("var")
35                    }
36
37                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
38                    where E: de::Error, {
39                        match value {
40                            "var" => Ok(Field::Var),
41                            _ => Err(de::Error::unknown_field(value, &FIELDS)),
42                        }
43                    }
44                }
45
46                deserializer.deserialize_identifier(FieldVisitor)
47            }
48        }
49	
50        struct VarVisitor;
51
52        impl<'de> Visitor<'de> for VarVisitor {
53            type Value = Var;
54
55            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
56                formatter.write_str("struct Var")
57            }
58
59	    fn visit_map<V>(self, mut map: V) -> Result<Var, V::Error>
60            where V: MapAccess<'de>, {
61		let mut var = None;
62                while let Some(key) = map.next_key()? {
63                    match key {
64                        Field::Var => {
65                            if var.is_some() {
66                                return Err(de::Error::duplicate_field("var"));
67                            }
68                            var = Some(map.next_value()?);
69                        },
70                    }
71                }
72                let var = var.ok_or_else(|| de::Error::missing_field("id"))?;
73                Ok(Var::set_inner(var))
74            }
75
76            fn visit_seq<V>(self, mut seq: V) -> Result<Var, V::Error>
77            where V: SeqAccess<'de>, {
78                let var = seq.next_element()?
79                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
80                Ok(Var::set_inner(var))
81            }
82        }
83
84        const FIELDS: [&str; 1] = ["var"];
85        deserializer.deserialize_struct("Duration", &FIELDS, VarVisitor)
86    }
87}
88
89#[cfg(all(test, feature = "use-serde"))]
90mod tests {
91    use crate::var::Var;
92    use rand::prelude::*;
93
94    #[test]
95    fn test_serde_var_inner() {
96	let mut rng = StdRng::seed_from_u64(671);
97	let n = 10;
98	let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.);
99	let result = data.matmul(&Var::new(&vec![2., 3.], &vec![2, 1])).unwrap() + Var::new(&vec![1.], &vec![1]);
100
101        let serialized = serde_pickle::to_vec(&result, true).unwrap();
102        let deserialized: Var = serde_pickle::from_slice(&serialized).unwrap();
103        println!("{:?}", deserialized.dump_net());
104        assert_eq!(result, deserialized);
105    }
106
107}