Skip to main content

config/de/
impl.rs

1use super::Error;
2use crate::{Configuration, Section};
3use serde::{
4    de::{
5        self,
6        value::{MapDeserializer, SeqDeserializer},
7        IntoDeserializer, Visitor,
8    },
9    Deserialize,
10};
11use std::{fmt::Display, rc::Rc, vec::IntoIter};
12
13impl de::Error for Error {
14    #[inline]
15    fn custom<T: Display>(message: T) -> Self {
16        Self::Custom(message.to_string())
17    }
18
19    #[inline]
20    fn missing_field(field: &'static str) -> Self {
21        Self::MissingValue(field)
22    }
23}
24
25macro_rules! forward_parsed_values {
26    ($($ty:ident => $method:ident,)*) => {
27        $(
28            fn $method<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
29                match self.0.value().parse::<$ty>() {
30                    Ok(val) => val.into_deserializer().$method(visitor),
31                    Err(e) => Err(de::Error::custom(format_args!("{e} while parsing value '{}' provided by {}", self.0.value(), self.0.key())))
32                }
33            }
34        )*
35    }
36}
37
38// configuration is a key/value pair mapping of String: String or String: Vec<String>; however,
39// we need a surrogate type to implement the deserialization on to underlying primitives
40struct Key<'a>(Rc<Section<'a>>);
41
42struct Val<'a>(Rc<Section<'a>>);
43
44impl<'de> IntoDeserializer<'de, Error> for Key<'de> {
45    type Deserializer = Self;
46
47    #[inline]
48    fn into_deserializer(self) -> Self::Deserializer {
49        self
50    }
51}
52
53impl<'de> de::Deserializer<'de> for Key<'de> {
54    type Error = Error;
55
56    #[inline]
57    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
58        self.0.key().to_owned().into_deserializer().deserialize_any(visitor)
59    }
60
61    #[inline]
62    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
63        visitor.visit_str(self.0.key())
64    }
65
66    #[inline]
67    fn deserialize_newtype_struct<V: Visitor<'de>>(
68        self,
69        _name: &'static str,
70        visitor: V,
71    ) -> Result<V::Value, Self::Error> {
72        visitor.visit_newtype_struct(self)
73    }
74
75    serde::forward_to_deserialize_any! {
76        char str string unit seq option
77        bytes byte_buf map unit_struct tuple_struct
78        tuple ignored_any enum
79        struct bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64
80    }
81}
82
83impl<'de> IntoDeserializer<'de, Error> for Val<'de> {
84    type Deserializer = Self;
85
86    #[inline]
87    fn into_deserializer(self) -> Self::Deserializer {
88        self
89    }
90}
91
92impl<'de> de::Deserializer<'de> for Val<'de> {
93    type Error = Error;
94
95    #[inline]
96    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
97        self.0.value().into_deserializer().deserialize_any(visitor)
98    }
99
100    // parse each numeric key exactly once, then sort by index. this is required to ensure the zero-based ordering
101    // of the sequence entries (e.g. array) are retained.
102    fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
103        let mut indexed: Vec<_> = self
104            .0
105            .sections()
106            .into_iter()
107            .filter_map(|s| s.key().parse::<usize>().ok().map(|i| (i, s)))
108            .collect();
109
110        indexed.sort_by_key(|(i, _)| *i);
111
112        let values = indexed.into_iter().map(|(_, s)| Val(Rc::new(s)));
113
114        SeqDeserializer::new(values).deserialize_seq(visitor)
115    }
116
117    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
118        let values = self.0.sections().into_iter().map(|section| {
119            let section = Rc::new(section);
120            (Key(Rc::clone(&section)), Val(section))
121        });
122
123        MapDeserializer::new(values).deserialize_map(visitor)
124    }
125
126    #[inline]
127    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
128        visitor.visit_some(self)
129    }
130
131    forward_parsed_values! {
132        bool => deserialize_bool,
133        u8 => deserialize_u8,
134        u16 => deserialize_u16,
135        u32 => deserialize_u32,
136        u64 => deserialize_u64,
137        i8 => deserialize_i8,
138        i16 => deserialize_i16,
139        i32 => deserialize_i32,
140        i64 => deserialize_i64,
141        f32 => deserialize_f32,
142        f64 => deserialize_f64,
143    }
144
145    #[inline]
146    fn deserialize_newtype_struct<V: Visitor<'de>>(
147        self,
148        _name: &'static str,
149        visitor: V,
150    ) -> Result<V::Value, Self::Error> {
151        visitor.visit_newtype_struct(self)
152    }
153
154    #[inline]
155    fn deserialize_struct<V: Visitor<'de>>(
156        self,
157        name: &'static str,
158        fields: &'static [&'static str],
159        visitor: V,
160    ) -> Result<V::Value, Self::Error> {
161        Deserializer::from_ref(self.0).deserialize_struct(name, fields, visitor)
162    }
163
164    #[inline]
165    fn deserialize_enum<V: Visitor<'de>>(
166        self,
167        _name: &'static str,
168        _variants: &'static [&'static str],
169        visitor: V,
170    ) -> Result<V::Value, Self::Error> {
171        let value = self.0.value();
172
173        if !value.is_empty() {
174            // unit variant: the value itself is the variant name (e.g. "First")
175            return visitor.visit_enum(value.into_deserializer());
176        }
177
178        // non-scalar variant: a subsection key is the variant name
179        // and its value/children are the variant data (e.g. Second: "test")
180        let sections = self.0.sections();
181
182        if let Some(section) = sections.into_iter().next() {
183            visitor.visit_enum(EnumDeserializer(section))
184        } else {
185            visitor.visit_enum(value.into_deserializer())
186        }
187    }
188
189    serde::forward_to_deserialize_any! {
190        char str string unit
191        bytes byte_buf unit_struct tuple_struct
192        identifier tuple ignored_any
193    }
194}
195
196struct ConfigValues<'a>(IntoIter<Section<'a>>);
197
198struct EnumDeserializer<'a>(Section<'a>);
199
200impl<'de> de::EnumAccess<'de> for EnumDeserializer<'de> {
201    type Error = Error;
202    type Variant = Self;
203
204    fn variant_seed<V: de::DeserializeSeed<'de>>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> {
205        let variant = self.0.key().to_owned();
206        let val = seed.deserialize(variant.into_deserializer())?;
207        Ok((val, self))
208    }
209}
210
211impl<'de> de::VariantAccess<'de> for EnumDeserializer<'de> {
212    type Error = Error;
213
214    #[inline]
215    fn unit_variant(self) -> Result<(), Self::Error> {
216        Ok(())
217    }
218
219    #[inline]
220    fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value, Self::Error> {
221        seed.deserialize(Val(Rc::new(self.0)))
222    }
223
224    #[inline]
225    fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> {
226        de::Deserializer::deserialize_seq(Val(Rc::new(self.0)), visitor)
227    }
228
229    #[inline]
230    fn struct_variant<V: Visitor<'de>>(
231        self,
232        fields: &'static [&'static str],
233        visitor: V,
234    ) -> Result<V::Value, Self::Error> {
235        de::Deserializer::deserialize_struct(Deserializer(self.0.sections().into_iter()), "", fields, visitor)
236    }
237}
238
239impl<'a> Iterator for ConfigValues<'a> {
240    type Item = (Key<'a>, Val<'a>);
241
242    fn next(&mut self) -> Option<Self::Item> {
243        self.0.next().map(|section| {
244            let section = Rc::new(section);
245            (Key(Rc::clone(&section)), Val(section))
246        })
247    }
248}
249
250struct Deserializer<'de>(IntoIter<Section<'de>>);
251
252fn fields_match(config_key: &str, field: &str) -> bool {
253    // compare characters case-insensitively, skipping underscores in the field name. this allows PascalCase config
254    // keys (ex: "MagicNumbers") to match snake_case Rust fields (ex: "magic_numbers")
255    let mut key_chars = config_key.chars();
256    let mut field_chars = field.chars().filter(|&c| c != '_');
257
258    loop {
259        match (key_chars.next(), field_chars.next()) {
260            (Some(a), Some(b)) if a.eq_ignore_ascii_case(&b) => continue,
261            (None, None) => return true,
262            _ => return false,
263        }
264    }
265}
266
267struct FieldMappingAccess<'de> {
268    sections: IntoIter<Section<'de>>,
269    fields: &'static [&'static str],
270    pending_value: Option<Rc<Section<'de>>>,
271}
272
273impl<'de> de::MapAccess<'de> for FieldMappingAccess<'de> {
274    type Error = Error;
275
276    fn next_key_seed<K: de::DeserializeSeed<'de>>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> {
277        for section in self.sections.by_ref() {
278            let config_key = section.key();
279
280            if let Some(&field) = self.fields.iter().find(|f| fields_match(config_key, f)) {
281                self.pending_value = Some(Rc::new(section));
282                return seed.deserialize(field.into_deserializer()).map(Some);
283            }
284        }
285
286        Ok(None)
287    }
288
289    fn next_value_seed<V: de::DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value, Self::Error> {
290        let section = self
291            .pending_value
292            .take()
293            .expect("next_value_seed called before next_key_seed");
294        seed.deserialize(Val(section))
295    }
296}
297
298impl<'de> From<&'de Configuration> for Deserializer<'de> {
299    #[inline]
300    fn from(config: &'de Configuration) -> Self {
301        Self(config.sections().into_iter())
302    }
303}
304
305impl<'de> Deserializer<'de> {
306    fn from_ref(section: Rc<Section<'de>>) -> Self {
307        match Rc::try_unwrap(section) {
308            Ok(section) => Self::from(section),
309            Err(section) => Self((*section).sections().into_iter()),
310        }
311    }
312}
313
314impl<'de> From<Section<'de>> for Deserializer<'de> {
315    #[inline]
316    fn from(section: Section<'de>) -> Self {
317        Self(section.sections().into_iter())
318    }
319}
320
321impl<'de> From<Vec<Section<'de>>> for Deserializer<'de> {
322    #[inline]
323    fn from(sections: Vec<Section<'de>>) -> Self {
324        Self(sections.into_iter())
325    }
326}
327
328impl<'de> de::Deserializer<'de> for Deserializer<'de> {
329    type Error = Error;
330
331    #[inline]
332    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
333        self.deserialize_map(visitor)
334    }
335
336    #[inline]
337    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
338        visitor.visit_map(MapDeserializer::new(ConfigValues(self.0)))
339    }
340
341    #[inline]
342    fn deserialize_struct<V: Visitor<'de>>(
343        self,
344        _name: &'static str,
345        fields: &'static [&'static str],
346        visitor: V,
347    ) -> Result<V::Value, Self::Error> {
348        visitor.visit_map(FieldMappingAccess {
349            sections: self.0,
350            fields,
351            pending_value: None,
352        })
353    }
354
355    serde::forward_to_deserialize_any! {
356        bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq
357        bytes byte_buf unit_struct tuple_struct
358        identifier tuple ignored_any option newtype_struct enum
359    }
360}
361
362/// Deserializes a data structure from the specified configuration sections.
363///
364/// # Arguments
365///
366/// * `configuration` - The configuration [sections](Section) to deserialize
367#[inline]
368pub fn from<'a, T: Deserialize<'a>>(configuration: impl Into<Vec<Section<'a>>>) -> Result<T, Error> {
369    T::deserialize(Deserializer::from(configuration.into()))
370}
371
372/// Deserializes the specified configuration to an existing data structure.
373///
374/// # Arguments
375///
376/// * `configuration` - The configuration [sections](Section) to bind to the data
377/// * `data` - The data to bind the configuration to
378#[inline]
379pub fn bind<'a, T: Deserialize<'a>>(sections: impl Into<Vec<Section<'a>>>, data: &mut T) -> Result<(), Error> {
380    T::deserialize_in_place(Deserializer::from(sections.into()), data)
381}