Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14}
15
16pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
17where
18    T: de::Deserialize<'de>,
19{
20    let mut de = SimpleDeserializer::from_slice(input);
21    let value = T::deserialize(&mut de)?;
22
23    if de.start != input.len() {
24        return Err(SimpleError::TrailingTokens { index: de.start });
25    }
26
27    Ok(value)
28}
29
30impl<'de> SimpleDeserializer<'de> {
31    pub fn from_slice(input: &'de [String]) -> Self {
32        Self {
33            input,
34            start: 0,
35            consuming: None,
36        }
37    }
38
39    fn expect_single(&self) -> Result<&'de str, SimpleError> {
40        self.input
41            .get(self.start)
42            .map(|s| s.as_str())
43            .ok_or(SimpleError::ExpectedSingle)
44    }
45
46    fn with_sub<T, F>(
47        &mut self,
48        f: F,
49        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
50    ) -> Result<T, SimpleError>
51    where
52        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
53    {
54        let mut sub = Self {
55            input: &self.input[self.start..],
56            start: 0,
57            consuming: None,
58        };
59        sub.consuming = consuming.into();
60        let ret = f(&mut sub)?;
61        self.start += sub.start;
62        Ok(ret)
63    }
64}
65
66macro_rules! impl_number {
67    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
68        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71        {
72            let s = self.expect_single()?;
73            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
74                expected: $expect,
75                found: s.to_string(),
76            })?;
77            self.start += 1;
78            visitor.$visit(v)
79        }
80    };
81}
82
83impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
84    type Error = SimpleError;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        let remaining = self.input.len() - self.start;
91
92        let is_key_context = match self.consuming {
93            Some(Err(_fields)) => true,
94            _ => false,
95        };
96
97        if remaining > 1 && !is_key_context {
98            return self.deserialize_seq(visitor);
99        }
100
101        let s = &self.input[self.start];
102        let val = if s == "true" {
103            visitor.visit_bool(true)?
104        } else if s == "false" {
105            visitor.visit_bool(false)?
106        } else if s.is_empty() || s == "()" {
107            visitor.visit_unit()?
108        } else if let Ok(i) = s.parse::<i64>() {
109            visitor.visit_i64(i)?
110        } else if let Ok(f) = s.parse::<f64>() {
111            visitor.visit_f64(f)?
112        } else {
113            visitor.visit_str(s)?
114        };
115
116        self.start += 1;
117        Ok(val)
118    }
119
120    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
121    where
122        V: Visitor<'de>,
123    {
124        let Ok(s) = self.expect_single() else {
125            return visitor.visit_bool(true); // note that, like Option, this runs the risk of infinite loop
126        };
127        let val = match s {
128            "true" | "" => visitor.visit_bool(true)?,
129            "false" => visitor.visit_bool(false)?,
130            _ => {
131                return Err(SimpleError::InvalidType {
132                    expected: "a boolean",
133                    found: s.to_string(),
134                });
135            }
136        };
137        self.start += 1;
138        Ok(val)
139    }
140
141    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
142    where
143        V: Visitor<'de>,
144    {
145        let s = self.expect_single()?;
146        let mut chars = s.chars();
147        let c = chars.next().ok_or(SimpleError::InvalidType {
148            expected: "a char",
149            found: s.to_string(),
150        })?;
151        if chars.next().is_some() {
152            return Err(SimpleError::InvalidType {
153                expected: "a single character",
154                found: s.to_string(),
155            });
156        }
157        self.start += 1;
158        visitor.visit_char(c)
159    }
160
161    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
162    where
163        V: Visitor<'de>,
164    {
165        let val = visitor.visit_str(self.expect_single()?)?;
166        self.start += 1;
167        Ok(val)
168    }
169
170    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
171    where
172        V: Visitor<'de>,
173    {
174        let val = visitor.visit_string(self.expect_single()?.to_string())?;
175        self.start += 1;
176        Ok(val)
177    }
178
179    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
180    where
181        V: Visitor<'de>,
182    {
183        self.deserialize_str(visitor)
184    }
185
186    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
187    where
188        V: Visitor<'de>,
189    {
190        let s = self.expect_single()?;
191        if s.is_empty() || s == "()" {
192            self.start += 1;
193            visitor.visit_unit()
194        } else {
195            Err(SimpleError::InvalidType {
196                expected: "unit",
197                found: s.to_string(),
198            })
199        }
200    }
201
202    fn deserialize_unit_struct<V>(
203        self,
204        _name: &'static str,
205        visitor: V,
206    ) -> Result<V::Value, Self::Error>
207    where
208        V: Visitor<'de>,
209    {
210        self.deserialize_unit(visitor)
211    }
212
213    fn deserialize_newtype_struct<V>(
214        self,
215        _name: &'static str,
216        visitor: V,
217    ) -> Result<V::Value, Self::Error>
218    where
219        V: Visitor<'de>,
220    {
221        visitor.visit_newtype_struct(self)
222    }
223
224    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
225    where
226        V: Visitor<'de>,
227    {
228        if self.start >= self.input.len() {
229            visitor.visit_none()
230        } else if self.input[self.start] == "null" {
231            self.start += 1;
232            visitor.visit_none()
233        } else {
234            visitor.visit_some(self)
235        }
236    }
237
238    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
239    where
240        V: Visitor<'de>,
241    {
242        self.with_sub(|s| visitor.visit_seq(s), None)
243    }
244
245    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
246    where
247        V: Visitor<'de>,
248    {
249        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
250    }
251
252    fn deserialize_tuple_struct<V>(
253        self,
254        _name: &'static str,
255        len: usize,
256        visitor: V,
257    ) -> Result<V::Value, Self::Error>
258    where
259        V: Visitor<'de>,
260    {
261        self.deserialize_tuple(len, visitor)
262    }
263
264    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265    where
266        V: Visitor<'de>,
267    {
268        self.with_sub(|s| visitor.visit_map(s), None)
269    }
270
271    fn deserialize_struct<V>(
272        self,
273        _name: &'static str,
274        fields: &'static [&'static str],
275        visitor: V,
276    ) -> Result<V::Value, Self::Error>
277    where
278        V: Visitor<'de>,
279    {
280        self.with_sub(|s| visitor.visit_map(s), Err(fields))
281    }
282
283    fn deserialize_enum<V>(
284        self,
285        _name: &'static str,
286        variants: &'static [&'static str],
287        visitor: V,
288    ) -> Result<V::Value, Self::Error>
289    where
290        V: Visitor<'de>,
291    {
292        // we don't actually use the passed in variants
293        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
294    }
295
296    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
297    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
298    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
299    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
300    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
301    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
302    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
303    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
304    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
305    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
306
307    forward_to_deserialize_any! { bytes byte_buf ignored_any }
308}
309
310// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
311
312impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
313    type Error = SimpleError;
314
315    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
316    where
317        T: DeserializeSeed<'de>,
318    {
319        if let Some(Ok(len)) = self.consuming
320            && len == 0
321        {
322            return Ok(None);
323        }
324
325        if self.start >= self.input.len() {
326            return Ok(None);
327        }
328
329        let val = self.with_sub(|s| seed.deserialize(s), None)?;
330
331        if let Some(Ok(ref mut len)) = self.consuming {
332            *len -= 1;
333        }
334
335        Ok(Some(val))
336    }
337}
338
339impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
340    type Error = SimpleError;
341
342    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
343    where
344        K: DeserializeSeed<'de>,
345    {
346        if self.start >= self.input.len() {
347            return Ok(None);
348        }
349
350        let key = if let Some(Err(fields)) = self.consuming {
351            let key = &self.input[self.start];
352            if !fields.contains(&key.as_str()) {
353                return Ok(None);
354            } else {
355                self.start += 1;
356                seed.deserialize(key.clone().into_deserializer())?
357            }
358        } else {
359            self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?
360        };
361
362        Ok(Some(key))
363    }
364
365    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
366    where
367        V: DeserializeSeed<'de>,
368    {
369        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
370        Ok(val)
371    }
372}
373
374impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
375    type Error = SimpleError;
376    type Variant = Self;
377
378    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
379    where
380        V: DeserializeSeed<'de>,
381    {
382        let val = self.with_sub(|s| seed.deserialize(s), Err(&[][..]))?;
383        Ok((val, self))
384    }
385}
386
387impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
388    type Error = SimpleError;
389
390    fn unit_variant(self) -> Result<(), Self::Error> {
391        Ok(())
392    }
393
394    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
395    where
396        T: DeserializeSeed<'de>,
397    {
398        self.with_sub(|s| seed.deserialize(s), None)
399    }
400
401    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
402    where
403        V: Visitor<'de>,
404    {
405        self.deserialize_tuple(len, visitor)
406    }
407
408    fn struct_variant<V>(
409        self,
410        fields: &'static [&'static str],
411        visitor: V,
412    ) -> Result<V::Value, Self::Error>
413    where
414        V: Visitor<'de>,
415    {
416        self.deserialize_struct("", fields, visitor)
417    }
418}