Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14
15    // configurable
16    pub option_hatch: Option<&'de str>,
17}
18
19pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
20where
21    T: de::Deserialize<'de> + std::fmt::Debug,
22{
23    let mut de = SimpleDeserializer::from_slice(input);
24    let value = T::deserialize(&mut de)?;
25
26    if de.start != input.len() {
27        return Err(SimpleError::TrailingTokens { index: de.start });
28    }
29
30    Ok(value)
31}
32
33impl<'de> SimpleDeserializer<'de> {
34    pub fn from_slice(input: &'de [String]) -> Self {
35        Self {
36            input,
37            start: 0,
38            consuming: None,
39            option_hatch: None,
40        }
41    }
42
43    fn expect_single(&self) -> Result<&'de str, SimpleError> {
44        self.input
45            .get(self.start)
46            .map(|s| s.as_str())
47            .ok_or(SimpleError::ExpectedSingle)
48    }
49
50    fn with_sub<T, F>(
51        &mut self,
52        f: F,
53        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
54    ) -> Result<T, SimpleError>
55    where
56        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
57    {
58        let mut sub = Self {
59            input: &self.input[self.start..],
60            start: 0,
61            consuming: None,
62            option_hatch: self.option_hatch,
63        };
64        sub.consuming = consuming.into();
65        let ret = f(&mut sub)?;
66        self.start += sub.start;
67        Ok(ret)
68    }
69}
70
71macro_rules! impl_number {
72    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
73        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
74        where
75            V: Visitor<'de>,
76        {
77            let s = self.expect_single()?;
78            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
79                expected: $expect,
80                found: s.to_string(),
81            })?;
82            self.start += 1;
83            visitor.$visit(v)
84        }
85    };
86}
87
88impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
89    type Error = SimpleError;
90
91    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
92    where
93        V: Visitor<'de>,
94    {
95        let remaining = self.input.len() - self.start;
96
97        let no_sequences = match self.consuming {
98            Some(Err(_fields)) => true,
99            _ => false,
100        };
101
102        if remaining > 1 && !no_sequences {
103            return self.deserialize_seq(visitor);
104        }
105
106        if remaining == 0 {
107            return self.deserialize_seq(visitor);
108        }
109
110        let s = &self.input[self.start];
111        let val = if s == "true" {
112            visitor.visit_bool(true)?
113        } else if s == "false" {
114            visitor.visit_bool(false)?
115        } else if s.is_empty() || s == "()" {
116            visitor.visit_unit()?
117        } else if let Ok(i) = s.parse::<i64>() {
118            visitor.visit_i64(i)?
119        } else if let Ok(f) = s.parse::<f64>() {
120            visitor.visit_f64(f)?
121        } else {
122            visitor.visit_str(s)?
123        };
124
125        self.start += 1;
126        Ok(val)
127    }
128
129    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
130    where
131        V: Visitor<'de>,
132    {
133        let s = self.expect_single()?;
134        //  else {
135        //     return visitor.visit_bool(true); // note that, like Option, this runs the risk of infinite loop
136        // };
137        let val = match s {
138            "true" | "" => visitor.visit_bool(true)?,
139            "false" => visitor.visit_bool(false)?,
140            _ => {
141                return Err(SimpleError::InvalidType {
142                    expected: "a boolean",
143                    found: s.to_string(),
144                });
145            }
146        };
147        self.start += 1;
148        Ok(val)
149    }
150
151    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
152    where
153        V: Visitor<'de>,
154    {
155        let s = self.expect_single()?;
156        let mut chars = s.chars();
157        let c = chars.next().ok_or(SimpleError::InvalidType {
158            expected: "a char",
159            found: s.to_string(),
160        })?;
161        if chars.next().is_some() {
162            return Err(SimpleError::InvalidType {
163                expected: "a single character",
164                found: s.to_string(),
165            });
166        }
167        self.start += 1;
168        visitor.visit_char(c)
169    }
170
171    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
172    where
173        V: Visitor<'de>,
174    {
175        let val = visitor.visit_str(self.expect_single()?)?;
176        self.start += 1;
177        Ok(val)
178    }
179
180    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
181    where
182        V: Visitor<'de>,
183    {
184        let val = visitor.visit_string(self.expect_single()?.to_string())?;
185        self.start += 1;
186        Ok(val)
187    }
188
189    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
190    where
191        V: Visitor<'de>,
192    {
193        self.deserialize_str(visitor)
194    }
195
196    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
197    where
198        V: Visitor<'de>,
199    {
200        let s = self.expect_single()?;
201        if s.is_empty() || s == "()" {
202            self.start += 1;
203            visitor.visit_unit()
204        } else {
205            Err(SimpleError::InvalidType {
206                expected: "unit",
207                found: s.to_string(),
208            })
209        }
210    }
211
212    fn deserialize_unit_struct<V>(
213        self,
214        _name: &'static str,
215        visitor: V,
216    ) -> Result<V::Value, Self::Error>
217    where
218        V: Visitor<'de>,
219    {
220        self.deserialize_unit(visitor)
221    }
222
223    fn deserialize_newtype_struct<V>(
224        self,
225        _name: &'static str,
226        visitor: V,
227    ) -> Result<V::Value, Self::Error>
228    where
229        V: Visitor<'de>,
230    {
231        visitor.visit_newtype_struct(self)
232    }
233
234    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
235    where
236        V: Visitor<'de>,
237    {
238        if self.start >= self.input.len() {
239            visitor.visit_none()
240        } else if self.option_hatch == Some(&self.input[self.start]) {
241            self.start += 1;
242            visitor.visit_none()
243        } else {
244            visitor.visit_some(self)
245        }
246    }
247
248    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
249    where
250        V: Visitor<'de>,
251    {
252        self.with_sub(|s| visitor.visit_seq(s), None)
253    }
254
255    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
256    where
257        V: Visitor<'de>,
258    {
259        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
260    }
261
262    fn deserialize_tuple_struct<V>(
263        self,
264        _name: &'static str,
265        len: usize,
266        visitor: V,
267    ) -> Result<V::Value, Self::Error>
268    where
269        V: Visitor<'de>,
270    {
271        self.deserialize_tuple(len, visitor)
272    }
273
274    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
275    where
276        V: Visitor<'de>,
277    {
278        self.with_sub(|s| visitor.visit_map(s), None)
279    }
280
281    fn deserialize_struct<V>(
282        self,
283        _name: &'static str,
284        fields: &'static [&'static str],
285        visitor: V,
286    ) -> Result<V::Value, Self::Error>
287    where
288        V: Visitor<'de>,
289    {
290        self.with_sub(|s| visitor.visit_map(s), Err(fields))
291    }
292
293    fn deserialize_enum<V>(
294        self,
295        _name: &'static str,
296        variants: &'static [&'static str],
297        visitor: V,
298    ) -> Result<V::Value, Self::Error>
299    where
300        V: Visitor<'de>,
301    {
302        // we don't actually use the passed in variants
303        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
304    }
305
306    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
307    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
308    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
309    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
310    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
311    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
312    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
313    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
314    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
315    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
316
317    forward_to_deserialize_any! { bytes byte_buf ignored_any }
318}
319
320// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
321
322impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
323    type Error = SimpleError;
324
325    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
326    where
327        T: DeserializeSeed<'de>,
328    {
329        if let Some(Ok(len)) = self.consuming
330            && len == 0
331        {
332            return Ok(None);
333        }
334
335        if self.start >= self.input.len() {
336            return Ok(None);
337        }
338
339        // prevent deserialize_any from deserializing sequences
340        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
341
342        if let Some(Ok(ref mut len)) = self.consuming {
343            *len -= 1;
344        }
345
346        Ok(Some(val))
347    }
348}
349
350impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
351    type Error = SimpleError;
352
353    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
354    where
355        K: DeserializeSeed<'de>,
356    {
357        if self.start >= self.input.len() {
358            return Ok(None);
359        }
360
361        let key = if let Some(Err(fields)) = self.consuming {
362            let key = &self.input[self.start];
363            if !fields.contains(&key.as_str()) {
364                return Ok(None);
365            } else {
366                self.start += 1;
367                seed.deserialize(key.clone().into_deserializer())?
368            }
369        } else {
370            self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?
371        };
372
373        Ok(Some(key))
374    }
375
376    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
377    where
378        V: DeserializeSeed<'de>,
379    {
380        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
381        Ok(val)
382    }
383}
384
385impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
386    type Error = SimpleError;
387    type Variant = Self;
388
389    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
390    where
391        V: DeserializeSeed<'de>,
392    {
393        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
394        Ok((val, self))
395    }
396}
397
398impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
399    type Error = SimpleError;
400
401    fn unit_variant(self) -> Result<(), Self::Error> {
402        Ok(())
403    }
404
405    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
406    where
407        T: DeserializeSeed<'de>,
408    {
409        self.with_sub(|s| seed.deserialize(s), None)
410    }
411
412    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
413    where
414        V: Visitor<'de>,
415    {
416        self.deserialize_tuple(len, visitor)
417    }
418
419    fn struct_variant<V>(
420        self,
421        fields: &'static [&'static str],
422        visitor: V,
423    ) -> Result<V::Value, Self::Error>
424    where
425        V: Visitor<'de>,
426    {
427        self.deserialize_struct("", fields, visitor)
428    }
429}