Skip to main content

libconfig_rs/serde/
deserialize.rs

1use super::error::Error;
2use crate::Value;
3use serde::{
4    Deserialize,
5    de::{self, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor},
6};
7use std::{collections::VecDeque, marker::PhantomData, str::FromStr};
8
9// Example
10// https://serde.rs/data-format.html
11// https://github.com/serde-rs/example-format
12
13#[derive(Clone, Debug)]
14enum Token {
15    Bool(bool),
16    Int(i64),
17    Float(f64),
18    String(String),
19    SeqCount(usize),
20    MapCount(usize),
21}
22
23impl Token {
24    fn into_bool(self) -> Result<bool, Token> {
25        match self {
26            Token::Bool(v) => Ok(v),
27            _ => Err(self),
28        }
29    }
30
31    fn into_int(self) -> Result<i64, Token> {
32        match self {
33            Token::Int(v) => Ok(v),
34            _ => Err(self),
35        }
36    }
37
38    fn into_float(self) -> Result<f64, Token> {
39        match self {
40            Token::Float(v) => Ok(v),
41            _ => Err(self),
42        }
43    }
44
45    fn into_string(self) -> Result<String, Token> {
46        match self {
47            Token::String(v) => Ok(v),
48            _ => Err(self),
49        }
50    }
51
52    fn into_count(self) -> Result<usize, Token> {
53        match self {
54            Token::SeqCount(v) | Token::MapCount(v) => Ok(v),
55            _ => Err(self),
56        }
57    }
58}
59
60fn flatten(res: &mut VecDeque<Token>, value: Value) {
61    match value {
62        Value::Bool(b) => {
63            res.push_back(Token::Bool(b));
64        }
65        Value::Int(i) => {
66            res.push_back(Token::Int(i));
67        }
68        Value::Float(f) => {
69            res.push_back(Token::Float(f));
70        }
71        Value::String(s) => {
72            res.push_back(Token::String(s));
73        }
74        Value::Array(a, _) => {
75            res.push_back(Token::SeqCount(a.len()));
76            for v in a {
77                flatten(res, v)
78            }
79        }
80        Value::Object(o) => {
81            res.push_back(Token::MapCount(o.len()));
82            for (k, v) in o {
83                res.push_back(Token::String(k));
84                flatten(res, v)
85            }
86        }
87    }
88}
89
90pub struct Deserializer<'de> {
91    tokens: VecDeque<Token>,
92    phantom: PhantomData<&'de str>,
93}
94
95pub fn from_str<'a, T>(s: &'a str) -> Result<T, Error>
96where
97    T: Deserialize<'a>,
98{
99    let value = crate::Value::from_str(s).map_err(|e| Error::Message(format!("{e:?}")))?;
100
101    let mut tokens = VecDeque::new();
102
103    flatten(&mut tokens, value);
104
105    let mut deserializer = Deserializer::<'a> {
106        tokens,
107        phantom: PhantomData,
108    };
109
110    T::deserialize(&mut deserializer)
111}
112
113impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
114    type Error = Error;
115
116    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
117    where
118        V: Visitor<'de>,
119    {
120        let token = self
121            .tokens
122            .front()
123            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
124
125        match token {
126            Token::Bool(_) => self.deserialize_bool(visitor),
127            Token::Int(_) => self.deserialize_i64(visitor),
128            Token::Float(_) => self.deserialize_f64(visitor),
129            Token::String(_) => self.deserialize_string(visitor),
130            Token::SeqCount(_) => self.deserialize_seq(visitor),
131            Token::MapCount(_) => self.deserialize_map(visitor),
132        }
133    }
134
135    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
136    where
137        V: Visitor<'de>,
138    {
139        let token = self
140            .tokens
141            .pop_front()
142            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
143
144        visitor.visit_bool(
145            token
146                .into_bool()
147                .map_err(|t| Error::Message(format!("{t:?} is not a bool")))?,
148        )
149    }
150
151    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
152    where
153        V: Visitor<'de>,
154    {
155        let token = self
156            .tokens
157            .pop_front()
158            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
159
160        visitor.visit_i8(
161            token
162                .into_int()
163                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as i8,
164        )
165    }
166
167    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
168    where
169        V: Visitor<'de>,
170    {
171        let token = self
172            .tokens
173            .pop_front()
174            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
175
176        visitor.visit_i16(
177            token
178                .into_int()
179                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as i16,
180        )
181    }
182
183    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
184    where
185        V: Visitor<'de>,
186    {
187        let token = self
188            .tokens
189            .pop_front()
190            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
191
192        visitor.visit_i32(
193            token
194                .into_int()
195                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as i32,
196        )
197    }
198
199    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
200    where
201        V: Visitor<'de>,
202    {
203        let token = self
204            .tokens
205            .pop_front()
206            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
207
208        visitor.visit_i64(
209            token
210                .into_int()
211                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))?,
212        )
213    }
214
215    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
216    where
217        V: Visitor<'de>,
218    {
219        let token = self
220            .tokens
221            .pop_front()
222            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
223
224        visitor.visit_u8(
225            token
226                .into_int()
227                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as u8,
228        )
229    }
230
231    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
232    where
233        V: Visitor<'de>,
234    {
235        let token = self
236            .tokens
237            .pop_front()
238            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
239
240        visitor.visit_u16(
241            token
242                .into_int()
243                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as u16,
244        )
245    }
246
247    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
248    where
249        V: Visitor<'de>,
250    {
251        let token = self
252            .tokens
253            .pop_front()
254            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
255
256        visitor.visit_u32(
257            token
258                .into_int()
259                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as u32,
260        )
261    }
262
263    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
264    where
265        V: Visitor<'de>,
266    {
267        let token = self
268            .tokens
269            .pop_front()
270            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
271
272        visitor.visit_u64(
273            token
274                .into_int()
275                .map_err(|t| Error::Message(format!("{t:?} is not a integer")))? as u64,
276        )
277    }
278
279    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
280    where
281        V: Visitor<'de>,
282    {
283        let token = self
284            .tokens
285            .pop_front()
286            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
287
288        visitor.visit_f32(
289            token
290                .into_float()
291                .map_err(|t| Error::Message(format!("{t:?} is not a float")))? as f32,
292        )
293    }
294
295    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
296    where
297        V: Visitor<'de>,
298    {
299        let token = self
300            .tokens
301            .pop_front()
302            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
303
304        visitor.visit_f64(
305            token
306                .into_float()
307                .map_err(|t| Error::Message(format!("{t:?} is not a float")))?,
308        )
309    }
310
311    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
312    where
313        V: Visitor<'de>,
314    {
315        let token = self
316            .tokens
317            .pop_front()
318            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
319
320        visitor.visit_char(
321            token
322                .into_string()
323                .map_err(|t| Error::Message(format!("{t:?} is not a char")))?
324                .chars()
325                .next()
326                .ok_or_else(|| Error::Message("String is empty".into()))?,
327        )
328    }
329
330    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
331    where
332        V: Visitor<'de>,
333    {
334        let token = self
335            .tokens
336            .pop_front()
337            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
338
339        visitor.visit_str(
340            token
341                .into_string()
342                .map_err(|t| Error::Message(format!("{t:?} is not a str")))?
343                .as_str(),
344        )
345    }
346
347    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
348    where
349        V: Visitor<'de>,
350    {
351        let token = self
352            .tokens
353            .pop_front()
354            .ok_or_else(|| Error::Message("Reached end of input!".into()))?;
355
356        visitor.visit_string(
357            token
358                .into_string()
359                .map_err(|t| Error::Message(format!("{t:?} is not a str")))?,
360        )
361    }
362
363    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
364    where
365        V: Visitor<'de>,
366    {
367        unimplemented!("")
368    }
369
370    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
371    where
372        V: Visitor<'de>,
373    {
374        unimplemented!("")
375    }
376
377    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
378    where
379        V: Visitor<'de>,
380    {
381        let len = self
382            .tokens
383            .pop_front()
384            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
385            .into_count()
386            .map_err(|t| Error::Message(format!("{t:?} is not a count")))?;
387
388        if len == 0 {
389            visitor.visit_none()
390        } else {
391            visitor.visit_some(self)
392        }
393    }
394
395    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
396    where
397        V: Visitor<'de>,
398    {
399        let len = self
400            .tokens
401            .pop_front()
402            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
403            .into_count()
404            .map_err(|t| Error::Message(format!("{t:?} is not a count")))?;
405
406        if len == 0 {
407            visitor.visit_unit()
408        } else {
409            Err(Error::Message("Expected empty list".into()))
410        }
411    }
412
413    fn deserialize_unit_struct<V>(
414        self,
415        _name: &'static str,
416        visitor: V,
417    ) -> Result<V::Value, Self::Error>
418    where
419        V: Visitor<'de>,
420    {
421        self.deserialize_unit(visitor)
422    }
423
424    fn deserialize_newtype_struct<V>(
425        self,
426        _name: &'static str,
427        visitor: V,
428    ) -> Result<V::Value, Self::Error>
429    where
430        V: Visitor<'de>,
431    {
432        let len = self
433            .tokens
434            .pop_front()
435            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
436            .into_count()
437            .map_err(|t| Error::Message(format!("{t:?} is not a count")))?;
438
439        if len != 1 {
440            return Err(Error::Message(format!(
441                "Expected 1 field in struct got {len}"
442            )));
443        }
444
445        visitor.visit_newtype_struct(self)
446    }
447
448    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
449    where
450        V: Visitor<'de>,
451    {
452        let count = self
453            .tokens
454            .pop_front()
455            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
456            .into_count()
457            .map_err(|t| Error::Message(format!("Expected field count, got {t:?}")))?;
458
459        visitor.visit_seq(SeqAccessor {
460            de: self,
461            remaining: count,
462        })
463    }
464
465    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
466    where
467        V: Visitor<'de>,
468    {
469        self.deserialize_seq(visitor)
470    }
471
472    fn deserialize_tuple_struct<V>(
473        self,
474        _name: &'static str,
475        _len: usize,
476        visitor: V,
477    ) -> Result<V::Value, Self::Error>
478    where
479        V: Visitor<'de>,
480    {
481        self.deserialize_seq(visitor)
482    }
483
484    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
485    where
486        V: Visitor<'de>,
487    {
488        let count = self
489            .tokens
490            .pop_front()
491            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
492            .into_count()
493            .map_err(|t| Error::Message(format!("Expected field count, got {t:?}")))?;
494
495        visitor.visit_map(MapAccessor {
496            de: self,
497            remaining: count,
498        })
499    }
500
501    fn deserialize_struct<V>(
502        self,
503        _name: &'static str,
504        _fields: &'static [&'static str],
505        visitor: V,
506    ) -> Result<V::Value, Self::Error>
507    where
508        V: Visitor<'de>,
509    {
510        let count = self
511            .tokens
512            .pop_front()
513            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
514            .into_count()
515            .map_err(|t| Error::Message(format!("Expected field count, got {t:?}")))?;
516
517        visitor.visit_map(StructAccessor {
518            de: self,
519            remaining: count,
520        })
521    }
522
523    fn deserialize_enum<V>(
524        self,
525        _name: &'static str,
526        _variants: &'static [&'static str],
527        visitor: V,
528    ) -> Result<V::Value, Self::Error>
529    where
530        V: Visitor<'de>,
531    {
532        if let Some(Token::SeqCount(_) | Token::MapCount(_)) = self.tokens.front() {
533            self.tokens.pop_front();
534        };
535
536        visitor.visit_enum(Enum::new(self))
537    }
538
539    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
540    where
541        V: Visitor<'de>,
542    {
543        let token = self
544            .tokens
545            .pop_front()
546            .ok_or_else(|| Error::Message("Reached end of input!".into()))?
547            .into_string()
548            .map_err(|t| Error::Message(format!("{t:?} is not an identifier")))?;
549
550        visitor.visit_str(token.as_str())
551    }
552
553    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
554    where
555        V: Visitor<'de>,
556    {
557        self.deserialize_any(visitor)
558    }
559}
560
561struct SeqAccessor<'a, 'de: 'a> {
562    de: &'a mut Deserializer<'de>,
563    remaining: usize,
564}
565
566impl<'de, 'a> SeqAccess<'de> for SeqAccessor<'a, 'de> {
567    type Error = Error;
568
569    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
570    where
571        T: DeserializeSeed<'de>,
572    {
573        if self.remaining > 0 {
574            self.remaining -= 1;
575            seed.deserialize(&mut *self.de).map(Some)
576        } else {
577            Ok(None)
578        }
579    }
580}
581
582struct StructAccessor<'a, 'de: 'a> {
583    de: &'a mut Deserializer<'de>,
584    remaining: usize,
585}
586
587impl<'de, 'a> MapAccess<'de> for StructAccessor<'a, 'de> {
588    type Error = Error;
589
590    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
591    where
592        K: DeserializeSeed<'de>,
593    {
594        if self.remaining > 0 {
595            self.remaining -= 1;
596            seed.deserialize(&mut *self.de).map(Some)
597        } else {
598            Ok(None)
599        }
600    }
601
602    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
603    where
604        V: DeserializeSeed<'de>,
605    {
606        seed.deserialize(&mut *self.de)
607    }
608}
609
610struct MapAccessor<'a, 'de: 'a> {
611    de: &'a mut Deserializer<'de>,
612    remaining: usize,
613}
614
615impl<'de, 'a> MapAccess<'de> for MapAccessor<'a, 'de> {
616    type Error = Error;
617
618    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
619    where
620        K: DeserializeSeed<'de>,
621    {
622        if self.remaining > 0 {
623            self.remaining -= 1;
624            // Consume inner tuple SeqCount when map entries are encoded as tuples
625            if let Some(Token::SeqCount(_)) = self.de.tokens.front() {
626                self.de.tokens.pop_front();
627            }
628            seed.deserialize(&mut *self.de).map(Some)
629        } else {
630            Ok(None)
631        }
632    }
633
634    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
635    where
636        V: DeserializeSeed<'de>,
637    {
638        seed.deserialize(&mut *self.de)
639    }
640}
641
642struct Enum<'a, 'de: 'a> {
643    de: &'a mut Deserializer<'de>,
644}
645
646impl<'a, 'de> Enum<'a, 'de> {
647    fn new(de: &'a mut Deserializer<'de>) -> Self {
648        Enum { de }
649    }
650}
651
652impl<'de, 'a> EnumAccess<'de> for Enum<'a, 'de> {
653    type Error = Error;
654    type Variant = Self;
655
656    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
657    where
658        V: DeserializeSeed<'de>,
659    {
660        Ok((seed.deserialize(&mut *self.de)?, self))
661    }
662}
663
664impl<'de, 'a> VariantAccess<'de> for Enum<'a, 'de> {
665    type Error = Error;
666
667    fn unit_variant(self) -> Result<(), Self::Error> {
668        Ok(())
669    }
670
671    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
672    where
673        T: DeserializeSeed<'de>,
674    {
675        self.de.tokens.pop_front();
676        seed.deserialize(self.de)
677    }
678
679    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
680    where
681        V: Visitor<'de>,
682    {
683        de::Deserializer::deserialize_seq(self.de, visitor)
684    }
685
686    fn struct_variant<V>(
687        self,
688        _fields: &'static [&'static str],
689        visitor: V,
690    ) -> Result<V::Value, Self::Error>
691    where
692        V: Visitor<'de>,
693    {
694        de::Deserializer::deserialize_struct(self.de, "", &[], visitor)
695    }
696}