Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14}
15
16pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
17where
18    T: de::Deserialize<'de> + std::fmt::Debug,
19{
20    let mut de = SimpleDeserializer::from_slice(input);
21    let value = T::deserialize(&mut de)?;
22
23    if de.start != input.len() {
24        return Err(SimpleError::TrailingTokens { index: de.start });
25    }
26
27    Ok(value)
28}
29
30impl<'de> SimpleDeserializer<'de> {
31    pub fn from_slice(input: &'de [String]) -> Self {
32        Self {
33            input,
34            start: 0,
35            consuming: None,
36        }
37    }
38
39    fn expect_single(&self) -> Result<&'de str, SimpleError> {
40        self.input
41            .get(self.start)
42            .map(|s| s.as_str())
43            .ok_or(SimpleError::ExpectedSingle)
44    }
45
46    fn with_sub<T, F>(
47        &mut self,
48        f: F,
49        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
50    ) -> Result<T, SimpleError>
51    where
52        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
53    {
54        let mut sub = Self {
55            input: &self.input[self.start..],
56            start: 0,
57            consuming: None,
58        };
59        sub.consuming = consuming.into();
60        let ret = f(&mut sub)?;
61        self.start += sub.start;
62        Ok(ret)
63    }
64}
65
66macro_rules! impl_number {
67    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
68        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71        {
72            let s = self.expect_single()?;
73            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
74                expected: $expect,
75                found: s.to_string(),
76            })?;
77            self.start += 1;
78            visitor.$visit(v)
79        }
80    };
81}
82
83impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
84    type Error = SimpleError;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        dbg!(&self);
91        let remaining = self.input.len() - self.start;
92
93        let no_sequences = match self.consuming {
94            Some(Err(_fields)) => true,
95            _ => false,
96        };
97
98        if remaining > 1 && !no_sequences {
99            return self.deserialize_seq(visitor);
100        }
101
102        if remaining == 0 {
103            return self.deserialize_seq(visitor);
104        }
105
106        let s = &self.input[self.start];
107        let val = if s == "true" {
108            visitor.visit_bool(true)?
109        } else if s == "false" {
110            visitor.visit_bool(false)?
111        } else if s.is_empty() || s == "()" {
112            visitor.visit_unit()?
113        } else if let Ok(i) = s.parse::<i64>() {
114            visitor.visit_i64(i)?
115        } else if let Ok(f) = s.parse::<f64>() {
116            visitor.visit_f64(f)?
117        } else {
118            visitor.visit_str(s)?
119        };
120
121        self.start += 1;
122        Ok(val)
123    }
124
125    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126    where
127        V: Visitor<'de>,
128    {
129        let s = self.expect_single()?;
130        //  else {
131        //     return visitor.visit_bool(true); // note that, like Option, this runs the risk of infinite loop
132        // };
133        let val = match s {
134            "true" | "" => visitor.visit_bool(true)?,
135            "false" => visitor.visit_bool(false)?,
136            _ => {
137                return Err(SimpleError::InvalidType {
138                    expected: "a boolean",
139                    found: s.to_string(),
140                });
141            }
142        };
143        self.start += 1;
144        Ok(val)
145    }
146
147    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
148    where
149        V: Visitor<'de>,
150    {
151        let s = self.expect_single()?;
152        let mut chars = s.chars();
153        let c = chars.next().ok_or(SimpleError::InvalidType {
154            expected: "a char",
155            found: s.to_string(),
156        })?;
157        if chars.next().is_some() {
158            return Err(SimpleError::InvalidType {
159                expected: "a single character",
160                found: s.to_string(),
161            });
162        }
163        self.start += 1;
164        visitor.visit_char(c)
165    }
166
167    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
168    where
169        V: Visitor<'de>,
170    {
171        let val = visitor.visit_str(self.expect_single()?)?;
172        self.start += 1;
173        Ok(val)
174    }
175
176    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
177    where
178        V: Visitor<'de>,
179    {
180        let val = visitor.visit_string(self.expect_single()?.to_string())?;
181        self.start += 1;
182        Ok(val)
183    }
184
185    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
186    where
187        V: Visitor<'de>,
188    {
189        self.deserialize_str(visitor)
190    }
191
192    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
193    where
194        V: Visitor<'de>,
195    {
196        let s = self.expect_single()?;
197        if s.is_empty() || s == "()" {
198            self.start += 1;
199            visitor.visit_unit()
200        } else {
201            Err(SimpleError::InvalidType {
202                expected: "unit",
203                found: s.to_string(),
204            })
205        }
206    }
207
208    fn deserialize_unit_struct<V>(
209        self,
210        _name: &'static str,
211        visitor: V,
212    ) -> Result<V::Value, Self::Error>
213    where
214        V: Visitor<'de>,
215    {
216        self.deserialize_unit(visitor)
217    }
218
219    fn deserialize_newtype_struct<V>(
220        self,
221        _name: &'static str,
222        visitor: V,
223    ) -> Result<V::Value, Self::Error>
224    where
225        V: Visitor<'de>,
226    {
227        visitor.visit_newtype_struct(self)
228    }
229
230    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: Visitor<'de>,
233    {
234        if self.start >= self.input.len() {
235            visitor.visit_none()
236        } else if self.input[self.start] == "null" {
237            self.start += 1;
238            visitor.visit_none()
239        } else {
240            visitor.visit_some(self)
241        }
242    }
243
244    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
245    where
246        V: Visitor<'de>,
247    {
248        self.with_sub(|s| visitor.visit_seq(s), None)
249    }
250
251    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
252    where
253        V: Visitor<'de>,
254    {
255        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
256    }
257
258    fn deserialize_tuple_struct<V>(
259        self,
260        _name: &'static str,
261        len: usize,
262        visitor: V,
263    ) -> Result<V::Value, Self::Error>
264    where
265        V: Visitor<'de>,
266    {
267        self.deserialize_tuple(len, visitor)
268    }
269
270    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
271    where
272        V: Visitor<'de>,
273    {
274        self.with_sub(|s| visitor.visit_map(s), None)
275    }
276
277    fn deserialize_struct<V>(
278        self,
279        _name: &'static str,
280        fields: &'static [&'static str],
281        visitor: V,
282    ) -> Result<V::Value, Self::Error>
283    where
284        V: Visitor<'de>,
285    {
286        self.with_sub(|s| visitor.visit_map(s), Err(fields))
287    }
288
289    fn deserialize_enum<V>(
290        self,
291        _name: &'static str,
292        variants: &'static [&'static str],
293        visitor: V,
294    ) -> Result<V::Value, Self::Error>
295    where
296        V: Visitor<'de>,
297    {
298        // we don't actually use the passed in variants
299        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
300    }
301
302    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
303    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
304    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
305    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
306    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
307    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
308    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
309    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
310    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
311    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
312
313    forward_to_deserialize_any! { bytes byte_buf ignored_any }
314}
315
316// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
317
318impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
319    type Error = SimpleError;
320
321    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
322    where
323        T: DeserializeSeed<'de>,
324    {
325        if let Some(Ok(len)) = self.consuming
326            && len == 0
327        {
328            return Ok(None);
329        }
330
331        if self.start >= self.input.len() {
332            return Ok(None);
333        }
334
335        // prevent deserialize_any from deserializing sequences
336        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
337
338        if let Some(Ok(ref mut len)) = self.consuming {
339            *len -= 1;
340        }
341
342        Ok(Some(val))
343    }
344}
345
346impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
347    type Error = SimpleError;
348
349    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
350    where
351        K: DeserializeSeed<'de>,
352    {
353        if self.start >= self.input.len() {
354            return Ok(None);
355        }
356
357        let key = if let Some(Err(fields)) = self.consuming {
358            let key = &self.input[self.start];
359            if !fields.contains(&key.as_str()) {
360                return Ok(None);
361            } else {
362                self.start += 1;
363                seed.deserialize(key.clone().into_deserializer())?
364            }
365        } else {
366            self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?
367        };
368
369        Ok(Some(key))
370    }
371
372    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
373    where
374        V: DeserializeSeed<'de>,
375    {
376        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
377        Ok(val)
378    }
379}
380
381impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
382    type Error = SimpleError;
383    type Variant = Self;
384
385    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
386    where
387        V: DeserializeSeed<'de>,
388    {
389        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
390        Ok((val, self))
391    }
392}
393
394impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
395    type Error = SimpleError;
396
397    fn unit_variant(self) -> Result<(), Self::Error> {
398        Ok(())
399    }
400
401    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
402    where
403        T: DeserializeSeed<'de>,
404    {
405        self.with_sub(|s| seed.deserialize(s), None)
406    }
407
408    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
409    where
410        V: Visitor<'de>,
411    {
412        self.deserialize_tuple(len, visitor)
413    }
414
415    fn struct_variant<V>(
416        self,
417        fields: &'static [&'static str],
418        visitor: V,
419    ) -> Result<V::Value, Self::Error>
420    where
421        V: Visitor<'de>,
422    {
423        self.deserialize_struct("", fields, visitor)
424    }
425}