auto_diff/serde/
op.rs

1#[cfg(feature = "use-serde")]
2use serde::{Serialize, Deserialize, Serializer, Deserializer,
3	    ser::SerializeStruct,
4	    de, de::Visitor, de::SeqAccess, de::MapAccess};
5use crate::op::{Op, OpTrait};
6use std::fmt;
7use std::ops::Deref;
8
9
10
11impl Serialize for Box<dyn OpTrait> {
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where S: Serializer, {
14        // 3 is the number of fields in the struct.
15        //let mut state = serializer.serialize_struct("OpTrait", 1)?;
16        //state.serialize_field("op_name", &self.get_name())?;
17        //state.end()
18        crate::op::serialize_box::<S>(&self, serializer)
19    }
20}
21
22impl Serialize for Op {
23    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24    where S: Serializer, {
25        // 3 is the number of fields in the struct.
26        let mut state = serializer.serialize_struct("Op", 2)?;
27        state.serialize_field("op_name", &self.get_name())?;
28	state.serialize_field("op_obj", &self.inner().borrow().deref())?;
29        state.end()
30    }
31}
32
33impl<'de> Deserialize<'de> for Op {
34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35    where D: Deserializer<'de>, {
36
37	enum Field { OpName, OpObj }
38	
39        impl<'de> Deserialize<'de> for Field {
40            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
41            where D: Deserializer<'de>, {
42                struct FieldVisitor;
43
44                impl<'de> Visitor<'de> for FieldVisitor {
45                    type Value = Field;
46
47                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
48                        formatter.write_str("op_name or op_obj")
49                    }
50
51                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
52                    where E: de::Error, {
53                        match value {
54                            "op_name" => Ok(Field::OpName),
55			    "op_obj" => Ok(Field::OpObj),
56                            _ => Err(de::Error::unknown_field(value, &FIELDS)),
57                        }
58                    }
59                }
60
61                deserializer.deserialize_identifier(FieldVisitor)
62            }
63        }
64	
65        struct OpVisitor;
66
67        impl<'de> Visitor<'de> for OpVisitor {
68            type Value = Op;
69
70            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
71                formatter.write_str("struct Op")
72            }
73
74	    fn visit_map<V>(self, mut map: V) -> Result<Op, V::Error>
75            where V: MapAccess<'de>, {
76		let mut op_name = None;
77                while let Some(key) = map.next_key()? {
78                    match key {
79                        Field::OpName => {
80                            if op_name.is_some() {
81                                return Err(de::Error::duplicate_field("op_name"));
82                            }
83                            op_name = Some(map.next_value()?);
84                        },
85			Field::OpObj => {
86                            //if op_obj.is_some() {
87                            //    return Err(de::Error::duplicate_field("op_obj"));
88                            //}
89                            //op_obj = Some(map.next_value()?);
90			    let op_name: String = op_name.ok_or_else(|| de::Error::missing_field("op_name"))?;
91
92                            return crate::op::deserialize_map(op_name, map);
93                        }
94                    }
95                }
96		Err(de::Error::missing_field("op_obj"))
97            }
98
99            fn visit_seq<V>(self, mut seq: V) -> Result<Op, V::Error>
100            where V: SeqAccess<'de>, {
101                let op_name: String = seq.next_element()?
102                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
103		return crate::op::deserialize_seq(op_name, seq);
104            }
105        }
106
107        const FIELDS: [&str; 2] = ["op_name", "op_obj"];
108        deserializer.deserialize_struct("Op", &FIELDS, OpVisitor)
109    }
110}
111
112
113#[cfg(all(test, feature = "use-serde"))]
114mod tests {
115    use crate::op::linear::Linear;
116    use super::*;
117    use std::rc::Rc;
118    use std::cell::RefCell;
119    
120    #[test]
121    fn test_serde_op() {
122	let m1 = Linear::new(None, None, true);
123	let m1 = Op::new(Rc::new(RefCell::new(Box::new(m1))));
124	
125        let serialized = serde_pickle::to_vec(&m1, true).unwrap();
126        let deserialized: Op = serde_pickle::from_slice(&serialized).unwrap();
127        //println!("{:?}", deserialized);
128        //assert_eq!(m1, deserialized);
129    }
130}