default_from_serde/
lib.rs

1//! This crate provides a derive macro named `SerdeDefault` which derives
2//! `Default` from `serde::Deserialize`.
3//!
4//! # Usage
5//!
6//!
7//! ```
8//! use default_from_serde::SerdeDefault;
9//! # use serde_derive::Deserialize;
10//!
11//! #[derive(SerdeDefault, Deserialize)]
12//! pub struct ComplexTypewithDefault {
13//!     #[serde(default)]
14//!     pub a: i32,
15//!     #[serde(default = "default_b")]
16//!     pub b: String,
17//!     #[serde(default)]
18//!     pub c: Vec<i32>,
19//! }
20//!
21//! fn default_b() -> String {
22//!     "default".to_string()
23//! }
24//!
25//! fn main() {
26//!     let x = ComplexTypewithDefault::default();
27//!
28//!     assert_eq!(x.b, "default");
29//! }
30//! ````
31
32#![cfg_attr(not(feature = "std"), no_std)]
33#![allow(clippy::box_collection)]
34
35use core::fmt;
36#[cfg(feature = "std")]
37use std::error;
38use std::fmt::Display;
39#[cfg(feature = "std")]
40use std::string::String;
41
42pub use derive_default_from_serde::SerdeDefault;
43use serde::de::{
44    DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess, Visitor,
45};
46
47use crate::number::Number;
48
49mod number;
50
51// We only use our own error type; no need for From conversions provided by the
52// standard library's try! macro. This reduces lines of LLVM IR by 4%.
53macro_rules! tri {
54    ($e:expr $(,)?) => {
55        match $e {
56            core::result::Result::Ok(val) => val,
57            core::result::Result::Err(err) => return core::result::Result::Err(err),
58        }
59    };
60}
61
62#[derive(Debug, Clone, Copy, Default)]
63pub struct DefaultDeserializer;
64
65pub type Result<T, E = Error> = core::result::Result<T, E>;
66
67#[derive(Debug, Clone)]
68pub struct Error(Box<String>);
69
70impl Display for Error {
71    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
72        formatter.write_str(&self.0)
73    }
74}
75
76impl serde::de::Error for Error {
77    fn custom<T>(msg: T) -> Self
78    where
79        T: Display,
80    {
81        Error(Box::new(msg.to_string()))
82    }
83}
84
85impl serde::de::StdError for Error {
86    #[cfg(feature = "std")]
87    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
88        None
89    }
90}
91
92macro_rules! deserialize_number {
93    ($method:ident) => {
94        fn $method<V>(self, visitor: V) -> Result<V::Value, Error>
95        where
96            V: Visitor<'de>,
97        {
98            Number.deserialize_any(visitor)
99        }
100    };
101}
102
103fn visit_array<'de, V>(visitor: V) -> Result<V::Value, Error>
104where
105    V: Visitor<'de>,
106{
107    let mut deserializer = SeqDeserializer;
108    let seq = tri!(visitor.visit_seq(&mut deserializer));
109
110    Ok(seq)
111}
112
113fn visit_object<'de, V>(visitor: V) -> Result<V::Value, Error>
114where
115    V: Visitor<'de>,
116{
117    let mut deserializer = MapDeserializer;
118    let map = tri!(visitor.visit_map(&mut deserializer));
119
120    Ok(map)
121}
122
123impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
124    type Error = Error;
125
126    deserialize_number!(deserialize_i8);
127
128    deserialize_number!(deserialize_i16);
129
130    deserialize_number!(deserialize_i32);
131
132    deserialize_number!(deserialize_i64);
133
134    deserialize_number!(deserialize_i128);
135
136    deserialize_number!(deserialize_u8);
137
138    deserialize_number!(deserialize_u16);
139
140    deserialize_number!(deserialize_u32);
141
142    deserialize_number!(deserialize_u64);
143
144    deserialize_number!(deserialize_u128);
145
146    deserialize_number!(deserialize_f32);
147
148    deserialize_number!(deserialize_f64);
149
150    #[inline]
151    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
152    where
153        V: Visitor<'de>,
154    {
155        visit_array(visitor)
156    }
157
158    #[inline]
159    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Error>
160    where
161        V: Visitor<'de>,
162    {
163        visitor.visit_none()
164    }
165
166    #[inline]
167    fn deserialize_enum<V>(
168        self,
169        _name: &str,
170        _variants: &'static [&'static str],
171        visitor: V,
172    ) -> Result<V::Value, Error>
173    where
174        V: Visitor<'de>,
175    {
176        visitor.visit_enum(EnumDeserializer)
177    }
178
179    #[inline]
180    fn deserialize_newtype_struct<V>(
181        self,
182        name: &'static str,
183        visitor: V,
184    ) -> Result<V::Value, Error>
185    where
186        V: Visitor<'de>,
187    {
188        let _ = name;
189        visitor.visit_newtype_struct(self)
190    }
191
192    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Error>
193    where
194        V: Visitor<'de>,
195    {
196        visitor.visit_unit()
197    }
198
199    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Error>
200    where
201        V: Visitor<'de>,
202    {
203        self.deserialize_string(visitor)
204    }
205
206    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Error>
207    where
208        V: Visitor<'de>,
209    {
210        self.deserialize_string(visitor)
211    }
212
213    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Error>
214    where
215        V: Visitor<'de>,
216    {
217        visitor.visit_unit()
218    }
219
220    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Error>
221    where
222        V: Visitor<'de>,
223    {
224        self.deserialize_byte_buf(visitor)
225    }
226
227    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Error>
228    where
229        V: Visitor<'de>,
230    {
231        visit_array(visitor)
232    }
233
234    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Error>
235    where
236        V: Visitor<'de>,
237    {
238        visitor.visit_unit()
239    }
240
241    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Error>
242    where
243        V: Visitor<'de>,
244    {
245        self.deserialize_unit(visitor)
246    }
247
248    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Error>
249    where
250        V: Visitor<'de>,
251    {
252        visit_array(visitor)
253    }
254
255    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
256    where
257        V: Visitor<'de>,
258    {
259        self.deserialize_seq(visitor)
260    }
261
262    fn deserialize_tuple_struct<V>(
263        self,
264        _name: &'static str,
265        _len: usize,
266        visitor: V,
267    ) -> Result<V::Value, Error>
268    where
269        V: Visitor<'de>,
270    {
271        self.deserialize_seq(visitor)
272    }
273
274    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Error>
275    where
276        V: Visitor<'de>,
277    {
278        visit_object(visitor)
279    }
280
281    fn deserialize_struct<V>(
282        self,
283        _name: &'static str,
284        fields: &'static [&'static str],
285        visitor: V,
286    ) -> Result<V::Value, Error>
287    where
288        V: Visitor<'de>,
289    {
290        if fields.is_empty() {
291            visit_object(visitor)
292        } else if fields.iter().any(|f| f.starts_with('0')) {
293            visit_array(visitor)
294        } else {
295            visit_object(visitor)
296        }
297    }
298
299    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Error>
300    where
301        V: Visitor<'de>,
302    {
303        self.deserialize_string(visitor)
304    }
305
306    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
307    where
308        V: Visitor<'de>,
309    {
310        visitor.visit_unit()
311    }
312}
313
314struct EnumDeserializer;
315
316impl<'de> EnumAccess<'de> for EnumDeserializer {
317    type Error = Error;
318    type Variant = VariantDeserializer;
319
320    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, VariantDeserializer), Error>
321    where
322        V: DeserializeSeed<'de>,
323    {
324        let variant = DefaultDeserializer;
325        let visitor = VariantDeserializer;
326        seed.deserialize(variant).map(|v| (v, visitor))
327    }
328}
329
330impl<'de> IntoDeserializer<'de, Error> for DefaultDeserializer {
331    type Deserializer = Self;
332
333    fn into_deserializer(self) -> Self::Deserializer {
334        self
335    }
336}
337
338struct VariantDeserializer;
339
340impl<'de> VariantAccess<'de> for VariantDeserializer {
341    type Error = Error;
342
343    fn unit_variant(self) -> Result<(), Error> {
344        Ok(())
345    }
346
347    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
348    where
349        T: DeserializeSeed<'de>,
350    {
351        seed.deserialize(DefaultDeserializer)
352    }
353
354    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
355    where
356        V: Visitor<'de>,
357    {
358        visitor.visit_unit()
359    }
360
361    fn struct_variant<V>(
362        self,
363        _fields: &'static [&'static str],
364        visitor: V,
365    ) -> Result<V::Value, Error>
366    where
367        V: Visitor<'de>,
368    {
369        visit_object(visitor)
370    }
371}
372
373struct SeqDeserializer;
374
375impl<'de> SeqAccess<'de> for SeqDeserializer {
376    type Error = Error;
377
378    fn next_element_seed<T>(&mut self, _: T) -> Result<Option<T::Value>, Error>
379    where
380        T: DeserializeSeed<'de>,
381    {
382        Ok(None)
383    }
384
385    fn size_hint(&self) -> Option<usize> {
386        Some(0)
387    }
388}
389
390struct MapDeserializer;
391
392impl<'de> MapAccess<'de> for MapDeserializer {
393    type Error = Error;
394
395    fn next_key_seed<T>(&mut self, _: T) -> Result<Option<T::Value>, Error>
396    where
397        T: DeserializeSeed<'de>,
398    {
399        Ok(None)
400    }
401
402    fn next_value_seed<T>(&mut self, _: T) -> Result<T::Value, Error>
403    where
404        T: DeserializeSeed<'de>,
405    {
406        Err(serde::de::Error::custom("value is missing"))
407    }
408
409    fn size_hint(&self) -> Option<usize> {
410        Some(0)
411    }
412}