Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14}
15
16pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
17where
18    T: de::Deserialize<'de> + std::fmt::Debug,
19{
20    let mut de = SimpleDeserializer::from_slice(input);
21    let value = T::deserialize(&mut de)?;
22
23    if de.start != input.len() {
24        return Err(SimpleError::TrailingTokens { index: de.start });
25    }
26
27    Ok(value)
28}
29
30impl<'de> SimpleDeserializer<'de> {
31    pub fn from_slice(input: &'de [String]) -> Self {
32        Self {
33            input,
34            start: 0,
35            consuming: None,
36        }
37    }
38
39    fn expect_single(&self) -> Result<&'de str, SimpleError> {
40        self.input
41            .get(self.start)
42            .map(|s| s.as_str())
43            .ok_or(SimpleError::ExpectedSingle)
44    }
45
46    fn with_sub<T, F>(
47        &mut self,
48        f: F,
49        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
50    ) -> Result<T, SimpleError>
51    where
52        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
53    {
54        let mut sub = Self {
55            input: &self.input[self.start..],
56            start: 0,
57            consuming: None,
58        };
59        sub.consuming = consuming.into();
60        let ret = f(&mut sub)?;
61        self.start += sub.start;
62        Ok(ret)
63    }
64}
65
66macro_rules! impl_number {
67    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
68        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71        {
72            let s = self.expect_single()?;
73            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
74                expected: $expect,
75                found: s.to_string(),
76            })?;
77            self.start += 1;
78            visitor.$visit(v)
79        }
80    };
81}
82
83impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
84    type Error = SimpleError;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        let remaining = self.input.len() - self.start;
91
92        let no_sequences = match self.consuming {
93            Some(Err(_fields)) => true,
94            _ => false,
95        };
96
97        if remaining > 1 && !no_sequences {
98            return self.deserialize_seq(visitor);
99        }
100
101        if remaining == 0 {
102            return self.deserialize_seq(visitor);
103        }
104
105        let s = &self.input[self.start];
106        let val = if s == "true" {
107            visitor.visit_bool(true)?
108        } else if s == "false" {
109            visitor.visit_bool(false)?
110        } else if s.is_empty() || s == "()" {
111            visitor.visit_unit()?
112        } else if let Ok(i) = s.parse::<i64>() {
113            visitor.visit_i64(i)?
114        } else if let Ok(f) = s.parse::<f64>() {
115            visitor.visit_f64(f)?
116        } else {
117            visitor.visit_str(s)?
118        };
119
120        self.start += 1;
121        Ok(val)
122    }
123
124    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
125    where
126        V: Visitor<'de>,
127    {
128        let Ok(s) = self.expect_single() else {
129            return visitor.visit_bool(true); // note that, like Option, this runs the risk of infinite loop
130        };
131        let val = match s {
132            "true" | "" => visitor.visit_bool(true)?,
133            "false" => visitor.visit_bool(false)?,
134            _ => {
135                return Err(SimpleError::InvalidType {
136                    expected: "a boolean",
137                    found: s.to_string(),
138                });
139            }
140        };
141        self.start += 1;
142        Ok(val)
143    }
144
145    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146    where
147        V: Visitor<'de>,
148    {
149        let s = self.expect_single()?;
150        let mut chars = s.chars();
151        let c = chars.next().ok_or(SimpleError::InvalidType {
152            expected: "a char",
153            found: s.to_string(),
154        })?;
155        if chars.next().is_some() {
156            return Err(SimpleError::InvalidType {
157                expected: "a single character",
158                found: s.to_string(),
159            });
160        }
161        self.start += 1;
162        visitor.visit_char(c)
163    }
164
165    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166    where
167        V: Visitor<'de>,
168    {
169        let val = visitor.visit_str(self.expect_single()?)?;
170        self.start += 1;
171        Ok(val)
172    }
173
174    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
175    where
176        V: Visitor<'de>,
177    {
178        let val = visitor.visit_string(self.expect_single()?.to_string())?;
179        self.start += 1;
180        Ok(val)
181    }
182
183    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
184    where
185        V: Visitor<'de>,
186    {
187        self.deserialize_str(visitor)
188    }
189
190    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
191    where
192        V: Visitor<'de>,
193    {
194        let s = self.expect_single()?;
195        if s.is_empty() || s == "()" {
196            self.start += 1;
197            visitor.visit_unit()
198        } else {
199            Err(SimpleError::InvalidType {
200                expected: "unit",
201                found: s.to_string(),
202            })
203        }
204    }
205
206    fn deserialize_unit_struct<V>(
207        self,
208        _name: &'static str,
209        visitor: V,
210    ) -> Result<V::Value, Self::Error>
211    where
212        V: Visitor<'de>,
213    {
214        self.deserialize_unit(visitor)
215    }
216
217    fn deserialize_newtype_struct<V>(
218        self,
219        _name: &'static str,
220        visitor: V,
221    ) -> Result<V::Value, Self::Error>
222    where
223        V: Visitor<'de>,
224    {
225        visitor.visit_newtype_struct(self)
226    }
227
228    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
229    where
230        V: Visitor<'de>,
231    {
232        if self.start >= self.input.len() {
233            visitor.visit_none()
234        } else if self.input[self.start] == "null" {
235            self.start += 1;
236            visitor.visit_none()
237        } else {
238            visitor.visit_some(self)
239        }
240    }
241
242    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
243    where
244        V: Visitor<'de>,
245    {
246        self.with_sub(|s| visitor.visit_seq(s), None)
247    }
248
249    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
250    where
251        V: Visitor<'de>,
252    {
253        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
254    }
255
256    fn deserialize_tuple_struct<V>(
257        self,
258        _name: &'static str,
259        len: usize,
260        visitor: V,
261    ) -> Result<V::Value, Self::Error>
262    where
263        V: Visitor<'de>,
264    {
265        self.deserialize_tuple(len, visitor)
266    }
267
268    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
269    where
270        V: Visitor<'de>,
271    {
272        self.with_sub(|s| visitor.visit_map(s), None)
273    }
274
275    fn deserialize_struct<V>(
276        self,
277        _name: &'static str,
278        fields: &'static [&'static str],
279        visitor: V,
280    ) -> Result<V::Value, Self::Error>
281    where
282        V: Visitor<'de>,
283    {
284        self.with_sub(|s| visitor.visit_map(s), Err(fields))
285    }
286
287    fn deserialize_enum<V>(
288        self,
289        _name: &'static str,
290        variants: &'static [&'static str],
291        visitor: V,
292    ) -> Result<V::Value, Self::Error>
293    where
294        V: Visitor<'de>,
295    {
296        // we don't actually use the passed in variants
297        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
298    }
299
300    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
301    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
302    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
303    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
304    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
305    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
306    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
307    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
308    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
309    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
310
311    forward_to_deserialize_any! { bytes byte_buf ignored_any }
312}
313
314// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
315
316impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
317    type Error = SimpleError;
318
319    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
320    where
321        T: DeserializeSeed<'de>,
322    {
323        if let Some(Ok(len)) = self.consuming
324            && len == 0
325        {
326            return Ok(None);
327        }
328
329        if self.start >= self.input.len() {
330            return Ok(None);
331        }
332
333        // prevent deserialize_any from deserializing sequences
334        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
335
336        if let Some(Ok(ref mut len)) = self.consuming {
337            *len -= 1;
338        }
339
340        Ok(Some(val))
341    }
342}
343
344impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
345    type Error = SimpleError;
346
347    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
348    where
349        K: DeserializeSeed<'de>,
350    {
351        if self.start >= self.input.len() {
352            return Ok(None);
353        }
354
355        let key = if let Some(Err(fields)) = self.consuming {
356            let key = &self.input[self.start];
357            if !fields.contains(&key.as_str()) {
358                return Ok(None);
359            } else {
360                self.start += 1;
361                seed.deserialize(key.clone().into_deserializer())?
362            }
363        } else {
364            self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?
365        };
366
367        Ok(Some(key))
368    }
369
370    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
371    where
372        V: DeserializeSeed<'de>,
373    {
374        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
375        Ok(val)
376    }
377}
378
379impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
380    type Error = SimpleError;
381    type Variant = Self;
382
383    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
384    where
385        V: DeserializeSeed<'de>,
386    {
387        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
388        Ok((val, self))
389    }
390}
391
392impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
393    type Error = SimpleError;
394
395    fn unit_variant(self) -> Result<(), Self::Error> {
396        Ok(())
397    }
398
399    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
400    where
401        T: DeserializeSeed<'de>,
402    {
403        self.with_sub(|s| seed.deserialize(s), None)
404    }
405
406    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
407    where
408        V: Visitor<'de>,
409    {
410        self.deserialize_tuple(len, visitor)
411    }
412
413    fn struct_variant<V>(
414        self,
415        fields: &'static [&'static str],
416        visitor: V,
417    ) -> Result<V::Value, Self::Error>
418    where
419        V: Visitor<'de>,
420    {
421        self.deserialize_struct("", fields, visitor)
422    }
423}