openstep_plist/
de.rs

1use serde::{
2    de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
3    forward_to_deserialize_any,
4};
5use smol_str::SmolStr;
6
7use crate::{
8    error::{Error, Result},
9    Plist,
10};
11
12enum PathElement {
13    Key(SmolStr),
14    Index(usize),
15}
16
17pub struct Deserializer<'de> {
18    input: &'de Plist,
19    path: Vec<PathElement>,
20}
21
22impl<'de> Deserializer<'de> {
23    pub fn from_plist(input: &'de Plist) -> Self {
24        Deserializer {
25            input,
26            path: Vec::new(),
27        }
28    }
29
30    fn element(&self) -> &'de Plist {
31        let mut element = self.input;
32        for path_element in &self.path {
33            match path_element {
34                PathElement::Key(key) => {
35                    element = element.as_dict().unwrap().get(key).unwrap();
36                }
37                PathElement::Index(index) => {
38                    element = element.as_array().unwrap().get(*index).unwrap();
39                }
40            }
41        }
42        element
43    }
44}
45
46impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
47    type Error = Error;
48
49    // Look at the input data to decide what Serde data model type to
50    // deserialize as. Not all data formats are able to support this operation.
51    // Formats that support `deserialize_any` are known as self-describing.
52    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
53    where
54        V: Visitor<'de>,
55    {
56        match self.element() {
57            Plist::String(_) => self.deserialize_string(visitor),
58            Plist::Integer(_) => self.deserialize_i64(visitor),
59            Plist::Float(_) => self.deserialize_f64(visitor),
60            Plist::Dictionary(_) => self.deserialize_map(visitor),
61            Plist::Array(_) => self.deserialize_seq(visitor),
62            Plist::Data(_) => self.deserialize_byte_buf(visitor),
63        }
64    }
65
66    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
67    where
68        V: Visitor<'de>,
69    {
70        match self.element() {
71            Plist::Integer(i) => visitor.visit_bool(*i != 0),
72            _ => Err(Error::UnexpectedDataType {
73                expected: "integer",
74                found: self.element().name(),
75            }),
76        }
77    }
78
79    fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
80    where
81        V: Visitor<'de>,
82    {
83        visitor.visit_some(self)
84    }
85
86    forward_to_deserialize_any! {i8 i16 i32 u8 u16 u32 u64 f32 char str unit unit_struct}
87    forward_to_deserialize_any! {bytes}
88    forward_to_deserialize_any! {tuple tuple_struct struct identifier ignored_any}
89
90    fn deserialize_enum<V>(
91        self,
92        _name: &'static str,
93        _variants: &'static [&'static str],
94        visitor: V,
95    ) -> Result<V::Value>
96    where
97        V: Visitor<'de>,
98    {
99        match self.element() {
100            Plist::String(s) => visitor
101                .visit_enum(de::value::StringDeserializer::new(s.clone()).into_deserializer()),
102            _ => Err(Error::UnexpectedDataType {
103                expected: "string",
104                found: self.element().name(),
105            }),
106        }
107    }
108
109    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
110    where
111        V: Visitor<'de>,
112    {
113        match &self.element() {
114            Plist::Integer(i) => visitor.visit_i64(*i),
115            _ => Err(Error::UnexpectedDataType {
116                expected: "integer",
117                found: self.element().name(),
118            }),
119        }
120    }
121
122    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
123    where
124        V: Visitor<'de>,
125    {
126        match self.element() {
127            Plist::Float(f) => visitor.visit_f64(*f),
128            _ => Err(Error::UnexpectedDataType {
129                expected: "float",
130                found: self.element().name(),
131            }),
132        }
133    }
134
135    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
136    where
137        V: Visitor<'de>,
138    {
139        match &self.element() {
140            Plist::String(s) => visitor.visit_borrowed_str(s),
141            _ => Err(Error::UnexpectedDataType {
142                expected: "string",
143                found: self.element().name(),
144            }),
145        }
146    }
147
148    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
149    where
150        V: Visitor<'de>,
151    {
152        match self.element() {
153            Plist::Data(data) => {
154                // Convert the data to a byte buffer
155                visitor.visit_byte_buf(data.clone())
156            }
157            _ => Err(Error::UnexpectedDataType {
158                expected: "data",
159                found: self.element().name(),
160            }),
161        }
162    }
163
164    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
165    where
166        V: Visitor<'de>,
167    {
168        visitor.visit_newtype_struct(self)
169    }
170
171    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
172    where
173        V: Visitor<'de>,
174    {
175        match self.element() {
176            Plist::Array(_) => visitor.visit_seq(ArrayDeserializer::new(self)),
177            _ => Err(Error::UnexpectedDataType {
178                expected: "array",
179                found: self.element().name(),
180            }),
181        }
182    }
183
184    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
185    where
186        V: Visitor<'de>,
187    {
188        match self.element() {
189            Plist::Dictionary(_) => visitor.visit_map(DictDeserializer::new(self)),
190            _ => Err(Error::UnexpectedDataType {
191                expected: "dictionary",
192                found: self.element().name(),
193            }),
194        }
195    }
196}
197
198struct ArrayDeserializer<'a, 'de: 'a> {
199    de: &'a mut Deserializer<'de>,
200    index: usize,
201    len: usize,
202}
203
204impl<'a, 'de> ArrayDeserializer<'a, 'de> {
205    fn new(de: &'a mut Deserializer<'de>) -> Self {
206        let len = de.element().as_array().unwrap().len();
207        ArrayDeserializer { de, index: 0, len }
208    }
209}
210
211impl<'de> SeqAccess<'de> for ArrayDeserializer<'_, 'de> {
212    type Error = Error;
213
214    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
215    where
216        T: DeserializeSeed<'de>,
217    {
218        if self.index == self.len {
219            return Ok(None);
220        }
221        self.de.path.push(PathElement::Index(self.index));
222        let result = seed.deserialize(&mut *self.de).map(Some);
223        self.de.path.pop();
224        self.index += 1;
225        result
226    }
227}
228
229struct DictDeserializer<'a, 'de: 'a> {
230    de: &'a mut Deserializer<'de>,
231    index: usize,
232    keys: Vec<&'a SmolStr>,
233}
234
235impl<'a, 'de> DictDeserializer<'a, 'de> {
236    fn new(de: &'a mut Deserializer<'de>) -> Self {
237        let keys = de.element().as_dict().unwrap().keys().collect();
238        DictDeserializer { de, index: 0, keys }
239    }
240}
241
242impl<'de> MapAccess<'de> for DictDeserializer<'_, 'de> {
243    type Error = Error;
244
245    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
246    where
247        T: DeserializeSeed<'de>,
248    {
249        if self.index == self.keys.len() {
250            return Ok(None);
251        }
252        let key = self.keys[self.index].clone();
253        self.de.path.push(PathElement::Key(key.clone()));
254        let key_deserializer = serde::de::value::StringDeserializer::new(key.to_string());
255        seed.deserialize(key_deserializer).map(Some)
256    }
257
258    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
259    where
260        T: DeserializeSeed<'de>,
261    {
262        let result = seed.deserialize(&mut *self.de);
263        self.de.path.pop();
264        self.index += 1;
265        result
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use serde::Deserialize;
273
274    #[test]
275    fn test_basic() {
276        let plist = Plist::String("hello".to_string());
277        let mut deserializer = Deserializer::from_plist(&plist);
278        let value: String = String::deserialize(&mut deserializer).unwrap();
279        assert_eq!(value, "hello");
280    }
281
282    #[test]
283    fn simple_seq() {
284        let plist = Plist::Array(vec![
285            Plist::Integer(1),
286            Plist::Integer(2),
287            Plist::Integer(3),
288        ]);
289
290        #[derive(Deserialize, Debug, PartialEq)]
291        struct Foo(Vec<i64>);
292
293        let mut deserializer = Deserializer::from_plist(&plist);
294        let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
295        assert_eq!(value.0, vec![1, 2, 3]);
296    }
297
298    #[test]
299    fn simple_struct() {
300        #[derive(Deserialize, PartialEq, Debug)]
301        struct Foo {
302            b: i64,
303            a: i64,
304        }
305        let plist = Plist::Dictionary(
306            vec![
307                (SmolStr::new("a"), Plist::Integer(2)),
308                (SmolStr::new("b"), Plist::Integer(1)),
309            ]
310            .into_iter()
311            .collect(),
312        );
313        let mut deserializer = Deserializer::from_plist(&plist);
314        let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
315        assert_eq!(value, Foo { a: 2, b: 1 });
316    }
317
318    #[test]
319    fn nested_struct() {
320        #[derive(Deserialize, PartialEq, Debug)]
321        struct Foo {
322            a: i64,
323            b: Bar,
324            s: String,
325        }
326        #[derive(Deserialize, PartialEq, Debug)]
327        struct Bar {
328            c: i64,
329        }
330        let plist = Plist::Dictionary(
331            vec![
332                (SmolStr::new("s"), Plist::String("hello".to_string())),
333                (SmolStr::new("a"), Plist::Integer(1)),
334                (
335                    SmolStr::new("b"),
336                    Plist::Dictionary(
337                        vec![(SmolStr::new("c"), Plist::Integer(2))]
338                            .into_iter()
339                            .collect(),
340                    ),
341                ),
342            ]
343            .into_iter()
344            .collect(),
345        );
346        let mut deserializer = Deserializer::from_plist(&plist);
347        let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
348        assert_eq!(
349            value,
350            Foo {
351                a: 1,
352                b: Bar { c: 2 },
353                s: "hello".to_string()
354            }
355        );
356    }
357
358    #[test]
359    fn nested_everything() {
360        #[derive(Deserialize, PartialEq, Debug)]
361        struct Foo {
362            a: i64,
363            b: Vec<Bar>,
364            #[serde(default)]
365            s: Option<String>,
366        }
367        #[derive(Deserialize, PartialEq, Debug)]
368        struct Bar {
369            c: i64,
370            d: Vec<String>,
371        }
372        let plist_str = r#"
373        {
374            a = 1;
375            b = (
376                {
377                    c = 2;
378                    d = ("hello", "world");
379                },
380                {
381                    c = 3;
382                    d = ("foo", "bar");
383                }
384            );
385        }
386        "#;
387        let plist: Plist = Plist::parse(plist_str).unwrap();
388        let mut deserializer = Deserializer::from_plist(&plist);
389        let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
390        assert_eq!(
391            value,
392            Foo {
393                a: 1,
394                b: vec![
395                    Bar {
396                        c: 2,
397                        d: vec!["hello".to_string(), "world".to_string()]
398                    },
399                    Bar {
400                        c: 3,
401                        d: vec!["foo".to_string(), "bar".to_string()]
402                    }
403                ],
404                s: None
405            }
406        );
407    }
408}