Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14}
15
16pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
17where
18    T: de::Deserialize<'de>,
19{
20    let mut de = SimpleDeserializer::from_slice(input);
21    let value = T::deserialize(&mut de)?;
22
23    if de.start != input.len() {
24        return Err(SimpleError::TrailingTokens { index: de.start });
25    }
26
27    Ok(value)
28}
29
30impl<'de> SimpleDeserializer<'de> {
31    pub fn from_slice(input: &'de [String]) -> Self {
32        Self {
33            input,
34            start: 0,
35            consuming: None,
36        }
37    }
38
39    fn expect_single(&self) -> Result<&'de str, SimpleError> {
40        self.input
41            .get(self.start)
42            .map(|s| s.as_str())
43            .ok_or(SimpleError::ExpectedSingle)
44    }
45
46    fn with_sub<T, F>(
47        &mut self,
48        f: F,
49        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
50    ) -> Result<T, SimpleError>
51    where
52        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
53    {
54        let mut sub = Self {
55            input: &self.input[self.start..],
56            start: 0,
57            consuming: None,
58        };
59        sub.consuming = consuming.into();
60        let ret = f(&mut sub)?;
61        self.start += sub.start;
62        Ok(ret)
63    }
64}
65
66macro_rules! impl_number {
67    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
68        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71        {
72            let s = self.expect_single()?;
73            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
74                expected: $expect,
75                found: s.to_string(),
76            })?;
77            self.start += 1;
78            visitor.$visit(v)
79        }
80    };
81}
82
83impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
84    type Error = SimpleError;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        if self.start >= self.input.len() {
91            return visitor.visit_unit();
92        }
93
94        let s = &self.input[self.start];
95        let val = if s == "true" {
96            visitor.visit_bool(true)?
97        } else if s == "false" {
98            visitor.visit_bool(false)?
99        } else if s.is_empty() || s == "()" {
100            visitor.visit_unit()?
101        } else if let Ok(i) = s.parse::<i64>() {
102            visitor.visit_i64(i)?
103        } else if let Ok(f) = s.parse::<f64>() {
104            visitor.visit_f64(f)?
105        } else {
106            visitor.visit_str(s)?
107        };
108        self.start += 1;
109        Ok(val)
110    }
111
112    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
113    where
114        V: Visitor<'de>,
115    {
116        let s = self.expect_single()?;
117        let val = match s {
118            "true" | "" => visitor.visit_bool(true)?,
119            "false" => visitor.visit_bool(false)?,
120            _ => {
121                return Err(SimpleError::InvalidType {
122                    expected: "a boolean",
123                    found: s.to_string(),
124                });
125            }
126        };
127        self.start += 1;
128        Ok(val)
129    }
130
131    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
132    where
133        V: Visitor<'de>,
134    {
135        let s = self.expect_single()?;
136        let mut chars = s.chars();
137        let c = chars.next().ok_or(SimpleError::InvalidType {
138            expected: "a char",
139            found: s.to_string(),
140        })?;
141        if chars.next().is_some() {
142            return Err(SimpleError::InvalidType {
143                expected: "a single character",
144                found: s.to_string(),
145            });
146        }
147        self.start += 1;
148        visitor.visit_char(c)
149    }
150
151    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
152    where
153        V: Visitor<'de>,
154    {
155        let val = visitor.visit_str(self.expect_single()?)?;
156        self.start += 1;
157        Ok(val)
158    }
159
160    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
161    where
162        V: Visitor<'de>,
163    {
164        let val = visitor.visit_string(self.expect_single()?.to_string())?;
165        self.start += 1;
166        Ok(val)
167    }
168
169    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
170    where
171        V: Visitor<'de>,
172    {
173        let s = self.expect_single()?;
174        if s.is_empty() || s == "()" {
175            self.start += 1;
176            visitor.visit_unit()
177        } else {
178            Err(SimpleError::InvalidType {
179                expected: "unit",
180                found: s.to_string(),
181            })
182        }
183    }
184
185    fn deserialize_unit_struct<V>(
186        self,
187        _name: &'static str,
188        visitor: V,
189    ) -> Result<V::Value, Self::Error>
190    where
191        V: Visitor<'de>,
192    {
193        self.deserialize_unit(visitor)
194    }
195
196    fn deserialize_newtype_struct<V>(
197        self,
198        _name: &'static str,
199        visitor: V,
200    ) -> Result<V::Value, Self::Error>
201    where
202        V: Visitor<'de>,
203    {
204        visitor.visit_newtype_struct(self)
205    }
206
207    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
208    where
209        V: Visitor<'de>,
210    {
211        if self.start >= self.input.len() {
212            visitor.visit_none()
213        } else if self.input[self.start] == "null" {
214            self.start += 1;
215            visitor.visit_none()
216        } else {
217            visitor.visit_some(self)
218        }
219    }
220
221    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222    where
223        V: Visitor<'de>,
224    {
225        self.with_sub(|s| visitor.visit_seq(s), None)
226    }
227
228    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
229    where
230        V: Visitor<'de>,
231    {
232        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
233    }
234
235    fn deserialize_tuple_struct<V>(
236        self,
237        _name: &'static str,
238        len: usize,
239        visitor: V,
240    ) -> Result<V::Value, Self::Error>
241    where
242        V: Visitor<'de>,
243    {
244        self.deserialize_tuple(len, visitor)
245    }
246
247    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
248    where
249        V: Visitor<'de>,
250    {
251        self.with_sub(|s| visitor.visit_map(s), None)
252    }
253
254    fn deserialize_struct<V>(
255        self,
256        _name: &'static str,
257        fields: &'static [&'static str],
258        visitor: V,
259    ) -> Result<V::Value, Self::Error>
260    where
261        V: Visitor<'de>,
262    {
263        self.with_sub(|s| visitor.visit_map(s), Err(fields))
264    }
265
266    fn deserialize_enum<V>(
267        self,
268        _name: &'static str,
269        variants: &'static [&'static str],
270        visitor: V,
271    ) -> Result<V::Value, Self::Error>
272    where
273        V: Visitor<'de>,
274    {
275        // we don't actually use the passed in variants
276        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
277    }
278
279    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
280    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
281    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
282    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
283    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
284    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
285    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
286    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
287    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
288    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
289
290    forward_to_deserialize_any! { bytes byte_buf identifier ignored_any }
291}
292
293// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
294
295impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
296    type Error = SimpleError;
297
298    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
299    where
300        T: DeserializeSeed<'de>,
301    {
302        if let Some(Ok(len)) = self.consuming
303            && len == 0
304        {
305            return Ok(None);
306        }
307
308        if self.start >= self.input.len() {
309            return Ok(None);
310        }
311
312        let val = self.with_sub(|s| seed.deserialize(s), None)?;
313
314        if let Some(Ok(ref mut len)) = self.consuming {
315            *len -= 1;
316        }
317
318        Ok(Some(val))
319    }
320}
321
322impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
323    type Error = SimpleError;
324
325    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
326    where
327        K: DeserializeSeed<'de>,
328    {
329        if self.start >= self.input.len() {
330            return Ok(None);
331        }
332
333        let key = if let Some(Err(fields)) = self.consuming {
334            let key = &self.input[self.start];
335            if !fields.contains(&key.as_str()) {
336                return Ok(None);
337            } else {
338                self.start += 1;
339                seed.deserialize(key.clone().into_deserializer())?
340            }
341        } else {
342            self.with_sub(|s| seed.deserialize(s), None)?
343        };
344
345        Ok(Some(key))
346    }
347
348    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
349    where
350        V: DeserializeSeed<'de>,
351    {
352        let val = self.with_sub(|s| seed.deserialize(s), None)?;
353        Ok(val)
354    }
355}
356
357impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
358    type Error = SimpleError;
359    type Variant = Self;
360
361    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
362    where
363        V: DeserializeSeed<'de>,
364    {
365        let val = self.with_sub(|s| seed.deserialize(s), None)?;
366        Ok((val, self))
367    }
368}
369
370impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
371    type Error = SimpleError;
372
373    fn unit_variant(self) -> Result<(), Self::Error> {
374        Ok(())
375    }
376
377    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
378    where
379        T: DeserializeSeed<'de>,
380    {
381        self.with_sub(|s| seed.deserialize(s), None)
382    }
383
384    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
385    where
386        V: Visitor<'de>,
387    {
388        self.deserialize_tuple(len, visitor)
389    }
390
391    fn struct_variant<V>(
392        self,
393        fields: &'static [&'static str],
394        visitor: V,
395    ) -> Result<V::Value, Self::Error>
396    where
397        V: Visitor<'de>,
398    {
399        self.deserialize_struct("", fields, visitor)
400    }
401}