Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14}
15
16pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
17where
18    T: de::Deserialize<'de> + std::fmt::Debug,
19{
20    let mut de = SimpleDeserializer::from_slice(input);
21    let value = T::deserialize(&mut de)?;
22
23    if de.start != input.len() {
24        return Err(SimpleError::TrailingTokens { index: de.start });
25    }
26
27    Ok(value)
28}
29
30impl<'de> SimpleDeserializer<'de> {
31    pub fn from_slice(input: &'de [String]) -> Self {
32        Self {
33            input,
34            start: 0,
35            consuming: None,
36        }
37    }
38
39    fn expect_single(&self) -> Result<&'de str, SimpleError> {
40        self.input
41            .get(self.start)
42            .map(|s| s.as_str())
43            .ok_or(SimpleError::ExpectedSingle)
44    }
45
46    fn with_sub<T, F>(
47        &mut self,
48        f: F,
49        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
50    ) -> Result<T, SimpleError>
51    where
52        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
53    {
54        let mut sub = Self {
55            input: &self.input[self.start..],
56            start: 0,
57            consuming: None,
58        };
59        sub.consuming = consuming.into();
60        let ret = f(&mut sub)?;
61        self.start += sub.start;
62        Ok(ret)
63    }
64}
65
66macro_rules! impl_number {
67    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
68        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71        {
72            let s = self.expect_single()?;
73            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
74                expected: $expect,
75                found: s.to_string(),
76            })?;
77            self.start += 1;
78            visitor.$visit(v)
79        }
80    };
81}
82
83impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
84    type Error = SimpleError;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        let remaining = self.input.len() - self.start;
91
92        let no_sequences = match self.consuming {
93            Some(Err(_fields)) => true,
94            _ => false,
95        };
96
97        if remaining > 1 && !no_sequences {
98            return self.deserialize_seq(visitor);
99        }
100
101        if remaining == 0 {
102            return self.deserialize_seq(visitor);
103        }
104
105        let s = &self.input[self.start];
106        let val = if s == "true" {
107            visitor.visit_bool(true)?
108        } else if s == "false" {
109            visitor.visit_bool(false)?
110        } else if s.is_empty() || s == "()" {
111            visitor.visit_unit()?
112        } else if let Ok(i) = s.parse::<i64>() {
113            visitor.visit_i64(i)?
114        } else if let Ok(f) = s.parse::<f64>() {
115            visitor.visit_f64(f)?
116        } else {
117            visitor.visit_str(s)?
118        };
119
120        self.start += 1;
121        Ok(val)
122    }
123
124    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
125    where
126        V: Visitor<'de>,
127    {
128        let s = self.expect_single()?;
129        //  else {
130        //     return visitor.visit_bool(true); // note that, like Option, this runs the risk of infinite loop
131        // };
132        let val = match s {
133            "true" | "" => visitor.visit_bool(true)?,
134            "false" => visitor.visit_bool(false)?,
135            _ => {
136                return Err(SimpleError::InvalidType {
137                    expected: "a boolean",
138                    found: s.to_string(),
139                });
140            }
141        };
142        self.start += 1;
143        Ok(val)
144    }
145
146    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
147    where
148        V: Visitor<'de>,
149    {
150        let s = self.expect_single()?;
151        let mut chars = s.chars();
152        let c = chars.next().ok_or(SimpleError::InvalidType {
153            expected: "a char",
154            found: s.to_string(),
155        })?;
156        if chars.next().is_some() {
157            return Err(SimpleError::InvalidType {
158                expected: "a single character",
159                found: s.to_string(),
160            });
161        }
162        self.start += 1;
163        visitor.visit_char(c)
164    }
165
166    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
167    where
168        V: Visitor<'de>,
169    {
170        let val = visitor.visit_str(self.expect_single()?)?;
171        self.start += 1;
172        Ok(val)
173    }
174
175    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
176    where
177        V: Visitor<'de>,
178    {
179        let val = visitor.visit_string(self.expect_single()?.to_string())?;
180        self.start += 1;
181        Ok(val)
182    }
183
184    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
185    where
186        V: Visitor<'de>,
187    {
188        self.deserialize_str(visitor)
189    }
190
191    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
192    where
193        V: Visitor<'de>,
194    {
195        let s = self.expect_single()?;
196        if s.is_empty() || s == "()" {
197            self.start += 1;
198            visitor.visit_unit()
199        } else {
200            Err(SimpleError::InvalidType {
201                expected: "unit",
202                found: s.to_string(),
203            })
204        }
205    }
206
207    fn deserialize_unit_struct<V>(
208        self,
209        _name: &'static str,
210        visitor: V,
211    ) -> Result<V::Value, Self::Error>
212    where
213        V: Visitor<'de>,
214    {
215        self.deserialize_unit(visitor)
216    }
217
218    fn deserialize_newtype_struct<V>(
219        self,
220        _name: &'static str,
221        visitor: V,
222    ) -> Result<V::Value, Self::Error>
223    where
224        V: Visitor<'de>,
225    {
226        visitor.visit_newtype_struct(self)
227    }
228
229    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
230    where
231        V: Visitor<'de>,
232    {
233        if self.start >= self.input.len() {
234            visitor.visit_none()
235        } else if self.input[self.start] == "null" {
236            self.start += 1;
237            visitor.visit_none()
238        } else {
239            visitor.visit_some(self)
240        }
241    }
242
243    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
244    where
245        V: Visitor<'de>,
246    {
247        self.with_sub(|s| visitor.visit_seq(s), None)
248    }
249
250    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
251    where
252        V: Visitor<'de>,
253    {
254        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
255    }
256
257    fn deserialize_tuple_struct<V>(
258        self,
259        _name: &'static str,
260        len: usize,
261        visitor: V,
262    ) -> Result<V::Value, Self::Error>
263    where
264        V: Visitor<'de>,
265    {
266        self.deserialize_tuple(len, visitor)
267    }
268
269    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
270    where
271        V: Visitor<'de>,
272    {
273        self.with_sub(|s| visitor.visit_map(s), None)
274    }
275
276    fn deserialize_struct<V>(
277        self,
278        _name: &'static str,
279        fields: &'static [&'static str],
280        visitor: V,
281    ) -> Result<V::Value, Self::Error>
282    where
283        V: Visitor<'de>,
284    {
285        self.with_sub(|s| visitor.visit_map(s), Err(fields))
286    }
287
288    fn deserialize_enum<V>(
289        self,
290        _name: &'static str,
291        variants: &'static [&'static str],
292        visitor: V,
293    ) -> Result<V::Value, Self::Error>
294    where
295        V: Visitor<'de>,
296    {
297        // we don't actually use the passed in variants
298        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
299    }
300
301    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
302    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
303    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
304    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
305    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
306    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
307    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
308    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
309    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
310    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
311
312    forward_to_deserialize_any! { bytes byte_buf ignored_any }
313}
314
315// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
316
317impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
318    type Error = SimpleError;
319
320    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
321    where
322        T: DeserializeSeed<'de>,
323    {
324        if let Some(Ok(len)) = self.consuming
325            && len == 0
326        {
327            return Ok(None);
328        }
329
330        if self.start >= self.input.len() {
331            return Ok(None);
332        }
333
334        // prevent deserialize_any from deserializing sequences
335        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
336
337        if let Some(Ok(ref mut len)) = self.consuming {
338            *len -= 1;
339        }
340
341        Ok(Some(val))
342    }
343}
344
345impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
346    type Error = SimpleError;
347
348    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
349    where
350        K: DeserializeSeed<'de>,
351    {
352        if self.start >= self.input.len() {
353            return Ok(None);
354        }
355
356        let key = if let Some(Err(fields)) = self.consuming {
357            let key = &self.input[self.start];
358            if !fields.contains(&key.as_str()) {
359                return Ok(None);
360            } else {
361                self.start += 1;
362                seed.deserialize(key.clone().into_deserializer())?
363            }
364        } else {
365            self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?
366        };
367
368        Ok(Some(key))
369    }
370
371    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
372    where
373        V: DeserializeSeed<'de>,
374    {
375        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
376        Ok(val)
377    }
378}
379
380impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
381    type Error = SimpleError;
382    type Variant = Self;
383
384    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
385    where
386        V: DeserializeSeed<'de>,
387    {
388        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
389        Ok((val, self))
390    }
391}
392
393impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
394    type Error = SimpleError;
395
396    fn unit_variant(self) -> Result<(), Self::Error> {
397        Ok(())
398    }
399
400    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
401    where
402        T: DeserializeSeed<'de>,
403    {
404        self.with_sub(|s| seed.deserialize(s), None)
405    }
406
407    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
408    where
409        V: Visitor<'de>,
410    {
411        self.deserialize_tuple(len, visitor)
412    }
413
414    fn struct_variant<V>(
415        self,
416        fields: &'static [&'static str],
417        visitor: V,
418    ) -> Result<V::Value, Self::Error>
419    where
420        V: Visitor<'de>,
421    {
422        self.deserialize_struct("", fields, visitor)
423    }
424}