corn/
de.rs

1use std::collections::VecDeque;
2
3use serde::de::{self, DeserializeSeed, EnumAccess, IntoDeserializer, VariantAccess, Visitor};
4
5use crate::error::{Error, Result};
6use crate::parse;
7use crate::Value;
8
9#[derive(Debug)]
10pub struct Deserializer<'de> {
11    value: Option<Value<'de>>,
12}
13
14impl<'de> Deserializer<'de> {
15    pub fn from_str(input: &'de str) -> Result<Self> {
16        let parsed = parse(input)?;
17
18        Ok(Self::from_value(parsed))
19    }
20
21    fn from_value(value: Value<'de>) -> Self {
22        Self { value: Some(value) }
23    }
24}
25
26/// Attempts to deserialize the config from a string slice.
27///
28/// # Errors
29///
30/// Will return a `DeserializationError` if the config is invalid.
31pub fn from_str<T>(s: &str) -> Result<T>
32where
33    T: de::DeserializeOwned,
34{
35    let mut deserializer = Deserializer::from_str(s)?;
36    T::deserialize(&mut deserializer)
37}
38
39/// Attempts to deserialize the config from a byte slice.
40///
41/// # Errors
42///
43/// Will return a `DeserializationError` if the config is invalid.
44pub fn from_slice<T>(bytes: &[u8]) -> Result<T>
45where
46    T: de::DeserializeOwned,
47{
48    match std::str::from_utf8(bytes) {
49        Ok(s) => from_str(s),
50        Err(e) => Err(Error::DeserializationError(e.to_string())),
51    }
52}
53
54macro_rules! get_value {
55    ($self:ident) => {
56        match $self.value.take() {
57            Some(val) => Ok(val),
58            None => Err(Error::DeserializationError(String::from(
59                "Deserializer value unexpectedly `None`",
60            ))),
61        }?
62    };
63}
64
65macro_rules! err_expected {
66    ($expected:literal, $got:expr) => {
67        Err(Error::DeserializationError(format!(
68            "Expected {}, found '{:?}'",
69            $expected, $got
70        )))
71    };
72}
73
74macro_rules! match_value {
75    ($self:ident, $name:literal, $($pat:pat => $expr:expr)+) => {{
76        let value = get_value!($self);
77        match value {
78            $($pat => $expr, )+
79            _ => err_expected!($name, value)
80        }
81    }};
82}
83
84impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
85    type Error = Error;
86
87    fn deserialize_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
88    where
89        V: Visitor<'de>,
90    {
91        let value = get_value!(self);
92        match value {
93            Value::Object(_) => {
94                let map = Map::new(value);
95                visitor.visit_map(map)
96            }
97            Value::Array(_) => {
98                let seq = Seq::new(value);
99                visitor.visit_seq(seq)
100            }
101            Value::String(val) => visitor.visit_str(&val),
102            Value::Integer(val) => visitor.visit_i64(val),
103            Value::Float(val) => visitor.visit_f64(val),
104            Value::Boolean(val) => visitor.visit_bool(val),
105            Value::Null(_) => visitor.visit_unit(),
106        }
107    }
108
109    fn deserialize_bool<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
110    where
111        V: Visitor<'de>,
112    {
113        match_value!(self, "boolean", Value::Boolean(val) => visitor.visit_bool(val))
114    }
115
116    fn deserialize_i8<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
117    where
118        V: Visitor<'de>,
119    {
120        match_value!(self, "integer (i8)", Value::Integer(val) =>  visitor.visit_i8(val as i8))
121    }
122
123    fn deserialize_i16<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
124    where
125        V: Visitor<'de>,
126    {
127        match_value!(self, "integer (i16)", Value::Integer(val) =>  visitor.visit_i16(val as i16))
128    }
129
130    fn deserialize_i32<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
131    where
132        V: Visitor<'de>,
133    {
134        match_value!(self, "integer (i32)", Value::Integer(val) =>  visitor.visit_i32(val as i32))
135    }
136
137    fn deserialize_i64<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
138    where
139        V: Visitor<'de>,
140    {
141        match_value!(self, "integer (i64)", Value::Integer(val) =>  visitor.visit_i64(val))
142    }
143
144    fn deserialize_u8<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
145    where
146        V: Visitor<'de>,
147    {
148        match_value!(self, "integer (u8)", Value::Integer(val) =>  visitor.visit_u8(val as u8))
149    }
150
151    fn deserialize_u16<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
152    where
153        V: Visitor<'de>,
154    {
155        match_value!(self, "integer (u16)", Value::Integer(val) =>  visitor.visit_u16(val as u16))
156    }
157
158    fn deserialize_u32<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
159    where
160        V: Visitor<'de>,
161    {
162        match_value!(self, "integer (u32)", Value::Integer(val) =>  visitor.visit_u32(val as u32))
163    }
164
165    fn deserialize_u64<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
166    where
167        V: Visitor<'de>,
168    {
169        match_value!(self, "integer (u64)", Value::Integer(val) =>  visitor.visit_u64(val as u64))
170    }
171
172    fn deserialize_f32<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
173    where
174        V: Visitor<'de>,
175    {
176        match_value!(self, "float (f32)", Value::Float(val) =>  visitor.visit_f32(val as f32))
177    }
178
179    fn deserialize_f64<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
180    where
181        V: Visitor<'de>,
182    {
183        match_value!(self, "float (f64)", Value::Float(val) =>  visitor.visit_f64(val))
184    }
185
186    fn deserialize_char<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
187    where
188        V: Visitor<'de>,
189    {
190        let value = get_value!(self);
191        let char = match value {
192            Value::String(value) => value.chars().next(),
193            _ => return err_expected!("char", value),
194        };
195
196        match char {
197            Some(char) => visitor.visit_char(char),
198            None => err_expected!("char", "empty string"),
199        }
200    }
201
202    fn deserialize_str<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
203    where
204        V: Visitor<'de>,
205    {
206        match_value!(self, "string",
207            Value::String(val) => visitor.visit_str(&val)
208        )
209    }
210
211    fn deserialize_string<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
212    where
213        V: Visitor<'de>,
214    {
215        self.deserialize_str(visitor)
216    }
217
218    fn deserialize_bytes<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
219    where
220        V: Visitor<'de>,
221    {
222        match_value!(self, "bytes array",
223            Value::String(val) => visitor.visit_bytes(val.as_bytes())
224        )
225    }
226
227    fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
228    where
229        V: Visitor<'de>,
230    {
231        self.deserialize_bytes(visitor)
232    }
233
234    fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
235    where
236        V: Visitor<'de>,
237    {
238        let value = get_value!(self);
239        match value {
240            Value::Null(_) => visitor.visit_none(),
241            _ => visitor.visit_some(&mut Deserializer::from_value(value)),
242        }
243    }
244
245    fn deserialize_unit<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
246    where
247        V: Visitor<'de>,
248    {
249        visitor.visit_unit()
250    }
251
252    fn deserialize_unit_struct<V>(
253        self,
254        _name: &'static str,
255        visitor: V,
256    ) -> std::result::Result<V::Value, Self::Error>
257    where
258        V: Visitor<'de>,
259    {
260        self.deserialize_unit(visitor)
261    }
262
263    fn deserialize_newtype_struct<V>(
264        self,
265        _name: &'static str,
266        visitor: V,
267    ) -> std::result::Result<V::Value, Self::Error>
268    where
269        V: Visitor<'de>,
270    {
271        visitor.visit_newtype_struct(self)
272    }
273
274    fn deserialize_seq<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
275    where
276        V: Visitor<'de>,
277    {
278        let value = get_value!(self);
279        match value {
280            Value::Array(_) => visitor.visit_seq(Seq::new(value)),
281            _ => err_expected!("array", value),
282        }
283    }
284
285    fn deserialize_tuple<V>(
286        self,
287        _len: usize,
288        visitor: V,
289    ) -> std::result::Result<V::Value, Self::Error>
290    where
291        V: Visitor<'de>,
292    {
293        self.deserialize_seq(visitor)
294    }
295
296    fn deserialize_tuple_struct<V>(
297        self,
298        _name: &'static str,
299        _len: usize,
300        visitor: V,
301    ) -> std::result::Result<V::Value, Self::Error>
302    where
303        V: Visitor<'de>,
304    {
305        self.deserialize_seq(visitor)
306    }
307
308    fn deserialize_map<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
309    where
310        V: Visitor<'de>,
311    {
312        let value = get_value!(self);
313        match value {
314            Value::Object(_) => visitor.visit_map(Map::new(value)),
315            _ => err_expected!("object", value),
316        }
317    }
318
319    fn deserialize_struct<V>(
320        self,
321        _name: &'static str,
322        _fields: &'static [&'static str],
323        visitor: V,
324    ) -> std::result::Result<V::Value, Self::Error>
325    where
326        V: Visitor<'de>,
327    {
328        self.deserialize_map(visitor)
329    }
330
331    fn deserialize_enum<V>(
332        self,
333        _name: &'static str,
334        _variants: &'static [&'static str],
335        visitor: V,
336    ) -> std::result::Result<V::Value, Self::Error>
337    where
338        V: Visitor<'de>,
339    {
340        let value = get_value!(self);
341        match value {
342            Value::Object(_) => visitor.visit_enum(Enum::new(value)),
343            Value::String(val) => visitor.visit_enum(val.into_deserializer()),
344            _ => err_expected!("object or string (enum variant)", value),
345        }
346    }
347
348    fn deserialize_identifier<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
349    where
350        V: Visitor<'de>,
351    {
352        self.deserialize_str(visitor)
353    }
354
355    fn deserialize_ignored_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
356    where
357        V: Visitor<'de>,
358    {
359        self.deserialize_any(visitor)
360    }
361}
362
363struct Map<'de> {
364    values: VecDeque<Value<'de>>,
365}
366
367impl<'de> Map<'de> {
368    fn new(value: Value<'de>) -> Self {
369        match value {
370            Value::Object(values) => Self {
371                values: values
372                    .into_iter()
373                    .flat_map(|(key, value)| vec![Value::String(key), value])
374                    .collect(),
375            },
376            _ => unreachable!(),
377        }
378    }
379}
380
381impl<'de> de::MapAccess<'de> for Map<'de> {
382    type Error = Error;
383
384    fn next_key_seed<K>(&mut self, seed: K) -> std::result::Result<Option<K::Value>, Self::Error>
385    where
386        K: DeserializeSeed<'de>,
387    {
388        if let Some(value) = self.values.pop_front() {
389            seed.deserialize(&mut Deserializer::from_value(value))
390                .map(Some)
391        } else {
392            Ok(None)
393        }
394    }
395
396    fn next_value_seed<V>(&mut self, seed: V) -> std::result::Result<V::Value, Self::Error>
397    where
398        V: DeserializeSeed<'de>,
399    {
400        match self.values.pop_front() {
401            Some(value) => seed.deserialize(&mut Deserializer::from_value(value)),
402            None => Err(Error::DeserializationError(
403                "Expected value to exist".to_string(),
404            )),
405        }
406    }
407
408    fn size_hint(&self) -> Option<usize> {
409        Some(self.values.len() / 2)
410    }
411}
412
413struct Seq<'de> {
414    values: VecDeque<Value<'de>>,
415}
416
417impl<'de> Seq<'de> {
418    fn new(value: Value<'de>) -> Self {
419        match value {
420            Value::Array(values) => Self {
421                values: VecDeque::from(values),
422            },
423            _ => unreachable!(),
424        }
425    }
426}
427
428impl<'de> de::SeqAccess<'de> for Seq<'de> {
429    type Error = Error;
430
431    fn next_element_seed<T>(
432        &mut self,
433        seed: T,
434    ) -> std::result::Result<Option<T::Value>, Self::Error>
435    where
436        T: DeserializeSeed<'de>,
437    {
438        if let Some(value) = self.values.pop_front() {
439            seed.deserialize(&mut Deserializer::from_value(value))
440                .map(Some)
441        } else {
442            Ok(None)
443        }
444    }
445
446    fn size_hint(&self) -> Option<usize> {
447        Some(self.values.len())
448    }
449}
450
451struct Enum<'de> {
452    value: Value<'de>,
453}
454
455impl<'de> Enum<'de> {
456    fn new(value: Value<'de>) -> Self {
457        Self { value }
458    }
459}
460
461impl<'de> EnumAccess<'de> for Enum<'de> {
462    type Error = Error;
463    type Variant = Variant<'de>;
464
465    fn variant_seed<V>(self, seed: V) -> std::result::Result<(V::Value, Self::Variant), Self::Error>
466    where
467        V: DeserializeSeed<'de>,
468    {
469        match self.value {
470            Value::String(_) => {
471                let value = seed.deserialize(&mut Deserializer::from_value(self.value))?;
472                Ok((value, Variant::new(None)))
473            }
474            Value::Object(obj) => {
475                let first_pair = obj.into_iter().next();
476                if let Some(first_pair) = first_pair {
477                    let value = Value::String(first_pair.0);
478                    let tag = seed.deserialize(&mut Deserializer::from_value(value))?;
479                    Ok((tag, Variant::new(Some(first_pair.1))))
480                } else {
481                    Err(Error::DeserializationError(
482                        "Cannot deserialize empty object into enum".to_string(),
483                    ))
484                }
485            }
486            _ => unreachable!(),
487        }
488    }
489}
490
491struct Variant<'de> {
492    value: Option<Value<'de>>,
493}
494
495impl<'de> Variant<'de> {
496    fn new(value: Option<Value<'de>>) -> Self {
497        Self { value }
498    }
499}
500
501impl<'de> VariantAccess<'de> for Variant<'de> {
502    type Error = Error;
503
504    fn unit_variant(self) -> std::result::Result<(), Self::Error> {
505        Ok(())
506    }
507
508    fn newtype_variant_seed<T>(self, seed: T) -> std::result::Result<T::Value, Self::Error>
509    where
510        T: DeserializeSeed<'de>,
511    {
512        match self.value {
513            Some(value) => seed.deserialize(&mut Deserializer::from_value(value)),
514            None => Err(Error::DeserializationError(
515                "Expected value to exist".to_string(),
516            )),
517        }
518    }
519
520    fn tuple_variant<V>(self, _len: usize, visitor: V) -> std::result::Result<V::Value, Self::Error>
521    where
522        V: Visitor<'de>,
523    {
524        match self.value {
525            Some(value) if matches!(value, Value::Array(_)) => visitor.visit_seq(Seq::new(value)),
526            _ => unreachable!(),
527        }
528    }
529
530    fn struct_variant<V>(
531        self,
532        _fields: &'static [&'static str],
533        visitor: V,
534    ) -> std::result::Result<V::Value, Self::Error>
535    where
536        V: Visitor<'de>,
537    {
538        match self.value {
539            Some(value) if matches!(value, Value::Object(_)) => visitor.visit_map(Map::new(value)),
540            _ => unreachable!(),
541        }
542    }
543}