1use serde::{
2 de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
3 forward_to_deserialize_any,
4};
5use smol_str::SmolStr;
6
7use crate::{
8 error::{Error, Result},
9 Plist,
10};
11
12enum PathElement {
13 Key(SmolStr),
14 Index(usize),
15}
16
17pub struct Deserializer<'de> {
18 input: &'de Plist,
19 path: Vec<PathElement>,
20}
21
22impl<'de> Deserializer<'de> {
23 pub fn from_plist(input: &'de Plist) -> Self {
24 Deserializer {
25 input,
26 path: Vec::new(),
27 }
28 }
29
30 fn element(&self) -> &'de Plist {
31 let mut element = self.input;
32 for path_element in &self.path {
33 match path_element {
34 PathElement::Key(key) => {
35 element = element.as_dict().unwrap().get(key).unwrap();
36 }
37 PathElement::Index(index) => {
38 element = element.as_array().unwrap().get(*index).unwrap();
39 }
40 }
41 }
42 element
43 }
44}
45
46impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
47 type Error = Error;
48
49 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
53 where
54 V: Visitor<'de>,
55 {
56 match self.element() {
57 Plist::String(_) => self.deserialize_string(visitor),
58 Plist::Integer(_) => self.deserialize_i64(visitor),
59 Plist::Float(_) => self.deserialize_f64(visitor),
60 Plist::Dictionary(_) => self.deserialize_map(visitor),
61 Plist::Array(_) => self.deserialize_seq(visitor),
62 Plist::Data(_) => self.deserialize_byte_buf(visitor),
63 }
64 }
65
66 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
67 where
68 V: Visitor<'de>,
69 {
70 match self.element() {
71 Plist::Integer(i) => visitor.visit_bool(*i != 0),
72 _ => Err(Error::UnexpectedDataType {
73 expected: "integer",
74 found: self.element().name(),
75 }),
76 }
77 }
78
79 fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
80 where
81 V: Visitor<'de>,
82 {
83 visitor.visit_some(self)
84 }
85
86 forward_to_deserialize_any! {i8 i16 i32 u8 u16 u32 u64 f32 char str unit unit_struct}
87 forward_to_deserialize_any! {bytes}
88 forward_to_deserialize_any! {tuple tuple_struct struct identifier ignored_any}
89
90 fn deserialize_enum<V>(
91 self,
92 _name: &'static str,
93 _variants: &'static [&'static str],
94 visitor: V,
95 ) -> Result<V::Value>
96 where
97 V: Visitor<'de>,
98 {
99 match self.element() {
100 Plist::String(s) => visitor
101 .visit_enum(de::value::StringDeserializer::new(s.clone()).into_deserializer()),
102 _ => Err(Error::UnexpectedDataType {
103 expected: "string",
104 found: self.element().name(),
105 }),
106 }
107 }
108
109 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
110 where
111 V: Visitor<'de>,
112 {
113 match &self.element() {
114 Plist::Integer(i) => visitor.visit_i64(*i),
115 _ => Err(Error::UnexpectedDataType {
116 expected: "integer",
117 found: self.element().name(),
118 }),
119 }
120 }
121
122 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
123 where
124 V: Visitor<'de>,
125 {
126 match self.element() {
127 Plist::Float(f) => visitor.visit_f64(*f),
128 _ => Err(Error::UnexpectedDataType {
129 expected: "float",
130 found: self.element().name(),
131 }),
132 }
133 }
134
135 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
136 where
137 V: Visitor<'de>,
138 {
139 match &self.element() {
140 Plist::String(s) => visitor.visit_borrowed_str(s),
141 _ => Err(Error::UnexpectedDataType {
142 expected: "string",
143 found: self.element().name(),
144 }),
145 }
146 }
147
148 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
149 where
150 V: Visitor<'de>,
151 {
152 match self.element() {
153 Plist::Data(data) => {
154 visitor.visit_byte_buf(data.clone())
156 }
157 _ => Err(Error::UnexpectedDataType {
158 expected: "data",
159 found: self.element().name(),
160 }),
161 }
162 }
163
164 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
165 where
166 V: Visitor<'de>,
167 {
168 visitor.visit_newtype_struct(self)
169 }
170
171 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
172 where
173 V: Visitor<'de>,
174 {
175 match self.element() {
176 Plist::Array(_) => visitor.visit_seq(ArrayDeserializer::new(self)),
177 _ => Err(Error::UnexpectedDataType {
178 expected: "array",
179 found: self.element().name(),
180 }),
181 }
182 }
183
184 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
185 where
186 V: Visitor<'de>,
187 {
188 match self.element() {
189 Plist::Dictionary(_) => visitor.visit_map(DictDeserializer::new(self)),
190 _ => Err(Error::UnexpectedDataType {
191 expected: "dictionary",
192 found: self.element().name(),
193 }),
194 }
195 }
196}
197
198struct ArrayDeserializer<'a, 'de: 'a> {
199 de: &'a mut Deserializer<'de>,
200 index: usize,
201 len: usize,
202}
203
204impl<'a, 'de> ArrayDeserializer<'a, 'de> {
205 fn new(de: &'a mut Deserializer<'de>) -> Self {
206 let len = de.element().as_array().unwrap().len();
207 ArrayDeserializer { de, index: 0, len }
208 }
209}
210
211impl<'de> SeqAccess<'de> for ArrayDeserializer<'_, 'de> {
212 type Error = Error;
213
214 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
215 where
216 T: DeserializeSeed<'de>,
217 {
218 if self.index == self.len {
219 return Ok(None);
220 }
221 self.de.path.push(PathElement::Index(self.index));
222 let result = seed.deserialize(&mut *self.de).map(Some);
223 self.de.path.pop();
224 self.index += 1;
225 result
226 }
227}
228
229struct DictDeserializer<'a, 'de: 'a> {
230 de: &'a mut Deserializer<'de>,
231 index: usize,
232 keys: Vec<&'a SmolStr>,
233}
234
235impl<'a, 'de> DictDeserializer<'a, 'de> {
236 fn new(de: &'a mut Deserializer<'de>) -> Self {
237 let keys = de.element().as_dict().unwrap().keys().collect();
238 DictDeserializer { de, index: 0, keys }
239 }
240}
241
242impl<'de> MapAccess<'de> for DictDeserializer<'_, 'de> {
243 type Error = Error;
244
245 fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
246 where
247 T: DeserializeSeed<'de>,
248 {
249 if self.index == self.keys.len() {
250 return Ok(None);
251 }
252 let key = self.keys[self.index].clone();
253 self.de.path.push(PathElement::Key(key.clone()));
254 let key_deserializer = serde::de::value::StringDeserializer::new(key.to_string());
255 seed.deserialize(key_deserializer).map(Some)
256 }
257
258 fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
259 where
260 T: DeserializeSeed<'de>,
261 {
262 let result = seed.deserialize(&mut *self.de);
263 self.de.path.pop();
264 self.index += 1;
265 result
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use serde::Deserialize;
273
274 #[test]
275 fn test_basic() {
276 let plist = Plist::String("hello".to_string());
277 let mut deserializer = Deserializer::from_plist(&plist);
278 let value: String = String::deserialize(&mut deserializer).unwrap();
279 assert_eq!(value, "hello");
280 }
281
282 #[test]
283 fn simple_seq() {
284 let plist = Plist::Array(vec![
285 Plist::Integer(1),
286 Plist::Integer(2),
287 Plist::Integer(3),
288 ]);
289
290 #[derive(Deserialize, Debug, PartialEq)]
291 struct Foo(Vec<i64>);
292
293 let mut deserializer = Deserializer::from_plist(&plist);
294 let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
295 assert_eq!(value.0, vec![1, 2, 3]);
296 }
297
298 #[test]
299 fn simple_struct() {
300 #[derive(Deserialize, PartialEq, Debug)]
301 struct Foo {
302 b: i64,
303 a: i64,
304 }
305 let plist = Plist::Dictionary(
306 vec![
307 (SmolStr::new("a"), Plist::Integer(2)),
308 (SmolStr::new("b"), Plist::Integer(1)),
309 ]
310 .into_iter()
311 .collect(),
312 );
313 let mut deserializer = Deserializer::from_plist(&plist);
314 let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
315 assert_eq!(value, Foo { a: 2, b: 1 });
316 }
317
318 #[test]
319 fn nested_struct() {
320 #[derive(Deserialize, PartialEq, Debug)]
321 struct Foo {
322 a: i64,
323 b: Bar,
324 s: String,
325 }
326 #[derive(Deserialize, PartialEq, Debug)]
327 struct Bar {
328 c: i64,
329 }
330 let plist = Plist::Dictionary(
331 vec![
332 (SmolStr::new("s"), Plist::String("hello".to_string())),
333 (SmolStr::new("a"), Plist::Integer(1)),
334 (
335 SmolStr::new("b"),
336 Plist::Dictionary(
337 vec![(SmolStr::new("c"), Plist::Integer(2))]
338 .into_iter()
339 .collect(),
340 ),
341 ),
342 ]
343 .into_iter()
344 .collect(),
345 );
346 let mut deserializer = Deserializer::from_plist(&plist);
347 let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
348 assert_eq!(
349 value,
350 Foo {
351 a: 1,
352 b: Bar { c: 2 },
353 s: "hello".to_string()
354 }
355 );
356 }
357
358 #[test]
359 fn nested_everything() {
360 #[derive(Deserialize, PartialEq, Debug)]
361 struct Foo {
362 a: i64,
363 b: Vec<Bar>,
364 #[serde(default)]
365 s: Option<String>,
366 }
367 #[derive(Deserialize, PartialEq, Debug)]
368 struct Bar {
369 c: i64,
370 d: Vec<String>,
371 }
372 let plist_str = r#"
373 {
374 a = 1;
375 b = (
376 {
377 c = 2;
378 d = ("hello", "world");
379 },
380 {
381 c = 3;
382 d = ("foo", "bar");
383 }
384 );
385 }
386 "#;
387 let plist: Plist = Plist::parse(plist_str).unwrap();
388 let mut deserializer = Deserializer::from_plist(&plist);
389 let value: Foo = Foo::deserialize(&mut deserializer).unwrap();
390 assert_eq!(
391 value,
392 Foo {
393 a: 1,
394 b: vec![
395 Bar {
396 c: 2,
397 d: vec!["hello".to_string(), "world".to_string()]
398 },
399 Bar {
400 c: 3,
401 d: vec!["foo".to_string(), "bar".to_string()]
402 }
403 ],
404 s: None
405 }
406 );
407 }
408}