Skip to main content

matchmaker_partial/
simple_de.rs

1use crate::errors::SimpleError;
2use serde::de::{
3    self, DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
4    VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7
8#[derive(Debug)]
9pub struct SimpleDeserializer<'de> {
10    input: &'de [String],
11    start: usize,
12    // Ok(len) for tuple, Err(field_names) for struct
13    consuming: Option<Result<usize, &'static [&'static str]>>,
14}
15
16pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
17where
18    T: de::Deserialize<'de>,
19{
20    let mut de = SimpleDeserializer::from_slice(input);
21    let value = T::deserialize(&mut de)?;
22
23    if de.start != input.len() {
24        return Err(SimpleError::TrailingTokens { index: de.start });
25    }
26
27    Ok(value)
28}
29
30impl<'de> SimpleDeserializer<'de> {
31    pub fn from_slice(input: &'de [String]) -> Self {
32        Self {
33            input,
34            start: 0,
35            consuming: None,
36        }
37    }
38
39    fn expect_single(&self) -> Result<&'de str, SimpleError> {
40        self.input
41            .get(self.start)
42            .map(|s| s.as_str())
43            .ok_or(SimpleError::ExpectedSingle)
44    }
45
46    fn with_sub<T, F>(
47        &mut self,
48        f: F,
49        consuming: impl Into<Option<Result<usize, &'static [&'static str]>>>,
50    ) -> Result<T, SimpleError>
51    where
52        F: FnOnce(&mut Self) -> Result<T, SimpleError>,
53    {
54        let mut sub = Self {
55            input: &self.input[self.start..],
56            start: 0,
57            consuming: None,
58        };
59        sub.consuming = consuming.into();
60        let ret = f(&mut sub)?;
61        self.start += sub.start;
62        Ok(ret)
63    }
64}
65
66macro_rules! impl_number {
67    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
68        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71        {
72            let s = self.expect_single()?;
73            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
74                expected: $expect,
75                found: s.to_string(),
76            })?;
77            self.start += 1;
78            visitor.$visit(v)
79        }
80    };
81}
82
83impl<'de> Deserializer<'de> for &mut SimpleDeserializer<'de> {
84    type Error = SimpleError;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        if self.start >= self.input.len() {
91            return visitor.visit_unit();
92        }
93
94        let s = &self.input[self.start];
95        let val = if s == "true" {
96            visitor.visit_bool(true)?
97        } else if s == "false" {
98            visitor.visit_bool(false)?
99        } else if s.is_empty() || s == "()" {
100            visitor.visit_unit()?
101        } else if let Ok(i) = s.parse::<i64>() {
102            visitor.visit_i64(i)?
103        } else if let Ok(f) = s.parse::<f64>() {
104            visitor.visit_f64(f)?
105        } else {
106            visitor.visit_str(s)?
107        };
108        self.start += 1;
109        Ok(val)
110    }
111
112    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
113    where
114        V: Visitor<'de>,
115    {
116        let Ok(s) = self.expect_single() else {
117            return visitor.visit_bool(true); // note that, like Option, this runs the risk of infinite loop
118        };
119        let val = match s {
120            "true" | "" => visitor.visit_bool(true)?,
121            "false" => visitor.visit_bool(false)?,
122            _ => {
123                return Err(SimpleError::InvalidType {
124                    expected: "a boolean",
125                    found: s.to_string(),
126                });
127            }
128        };
129        self.start += 1;
130        Ok(val)
131    }
132
133    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
134    where
135        V: Visitor<'de>,
136    {
137        let s = self.expect_single()?;
138        let mut chars = s.chars();
139        let c = chars.next().ok_or(SimpleError::InvalidType {
140            expected: "a char",
141            found: s.to_string(),
142        })?;
143        if chars.next().is_some() {
144            return Err(SimpleError::InvalidType {
145                expected: "a single character",
146                found: s.to_string(),
147            });
148        }
149        self.start += 1;
150        visitor.visit_char(c)
151    }
152
153    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
154    where
155        V: Visitor<'de>,
156    {
157        let val = visitor.visit_str(self.expect_single()?)?;
158        self.start += 1;
159        Ok(val)
160    }
161
162    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
163    where
164        V: Visitor<'de>,
165    {
166        let val = visitor.visit_string(self.expect_single()?.to_string())?;
167        self.start += 1;
168        Ok(val)
169    }
170
171    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
172    where
173        V: Visitor<'de>,
174    {
175        let s = self.expect_single()?;
176        if s.is_empty() || s == "()" {
177            self.start += 1;
178            visitor.visit_unit()
179        } else {
180            Err(SimpleError::InvalidType {
181                expected: "unit",
182                found: s.to_string(),
183            })
184        }
185    }
186
187    fn deserialize_unit_struct<V>(
188        self,
189        _name: &'static str,
190        visitor: V,
191    ) -> Result<V::Value, Self::Error>
192    where
193        V: Visitor<'de>,
194    {
195        self.deserialize_unit(visitor)
196    }
197
198    fn deserialize_newtype_struct<V>(
199        self,
200        _name: &'static str,
201        visitor: V,
202    ) -> Result<V::Value, Self::Error>
203    where
204        V: Visitor<'de>,
205    {
206        visitor.visit_newtype_struct(self)
207    }
208
209    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: Visitor<'de>,
212    {
213        if self.start >= self.input.len() {
214            visitor.visit_none()
215        } else if self.input[self.start] == "null" {
216            self.start += 1;
217            visitor.visit_none()
218        } else {
219            visitor.visit_some(self)
220        }
221    }
222
223    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
224    where
225        V: Visitor<'de>,
226    {
227        self.with_sub(|s| visitor.visit_seq(s), None)
228    }
229
230    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: Visitor<'de>,
233    {
234        self.with_sub(|s| visitor.visit_seq(s), Ok(len))
235    }
236
237    fn deserialize_tuple_struct<V>(
238        self,
239        _name: &'static str,
240        len: usize,
241        visitor: V,
242    ) -> Result<V::Value, Self::Error>
243    where
244        V: Visitor<'de>,
245    {
246        self.deserialize_tuple(len, visitor)
247    }
248
249    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
250    where
251        V: Visitor<'de>,
252    {
253        self.with_sub(|s| visitor.visit_map(s), None)
254    }
255
256    fn deserialize_struct<V>(
257        self,
258        _name: &'static str,
259        fields: &'static [&'static str],
260        visitor: V,
261    ) -> Result<V::Value, Self::Error>
262    where
263        V: Visitor<'de>,
264    {
265        self.with_sub(|s| visitor.visit_map(s), Err(fields))
266    }
267
268    fn deserialize_enum<V>(
269        self,
270        _name: &'static str,
271        variants: &'static [&'static str],
272        visitor: V,
273    ) -> Result<V::Value, Self::Error>
274    where
275        V: Visitor<'de>,
276    {
277        // we don't actually use the passed in variants
278        self.with_sub(|s| visitor.visit_enum(s), Err(variants))
279    }
280
281    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
282    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
283    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
284    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
285    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
286    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
287    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
288    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
289    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
290    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
291
292    forward_to_deserialize_any! { bytes byte_buf identifier ignored_any }
293}
294
295// === Implement SeqAccess, MapAccess, EnumAccess, VariantAccess ===
296
297impl<'de> SeqAccess<'de> for &mut SimpleDeserializer<'de> {
298    type Error = SimpleError;
299
300    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
301    where
302        T: DeserializeSeed<'de>,
303    {
304        if let Some(Ok(len)) = self.consuming
305            && len == 0
306        {
307            return Ok(None);
308        }
309
310        if self.start >= self.input.len() {
311            return Ok(None);
312        }
313
314        let val = self.with_sub(|s| seed.deserialize(s), None)?;
315
316        if let Some(Ok(ref mut len)) = self.consuming {
317            *len -= 1;
318        }
319
320        Ok(Some(val))
321    }
322}
323
324impl<'de> MapAccess<'de> for &mut SimpleDeserializer<'de> {
325    type Error = SimpleError;
326
327    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
328    where
329        K: DeserializeSeed<'de>,
330    {
331        if self.start >= self.input.len() {
332            return Ok(None);
333        }
334
335        let key = if let Some(Err(fields)) = self.consuming {
336            let key = &self.input[self.start];
337            if !fields.contains(&key.as_str()) {
338                return Ok(None);
339            } else {
340                self.start += 1;
341                seed.deserialize(key.clone().into_deserializer())?
342            }
343        } else {
344            self.with_sub(|s| seed.deserialize(s), None)?
345        };
346
347        Ok(Some(key))
348    }
349
350    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
351    where
352        V: DeserializeSeed<'de>,
353    {
354        let val = self.with_sub(|s| seed.deserialize(s), None)?;
355        Ok(val)
356    }
357}
358
359impl<'de> EnumAccess<'de> for &mut SimpleDeserializer<'de> {
360    type Error = SimpleError;
361    type Variant = Self;
362
363    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
364    where
365        V: DeserializeSeed<'de>,
366    {
367        let val = self.with_sub(|s| seed.deserialize(s), None)?;
368        Ok((val, self))
369    }
370}
371
372impl<'de> VariantAccess<'de> for &mut SimpleDeserializer<'de> {
373    type Error = SimpleError;
374
375    fn unit_variant(self) -> Result<(), Self::Error> {
376        Ok(())
377    }
378
379    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
380    where
381        T: DeserializeSeed<'de>,
382    {
383        self.with_sub(|s| seed.deserialize(s), None)
384    }
385
386    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
387    where
388        V: Visitor<'de>,
389    {
390        self.deserialize_tuple(len, visitor)
391    }
392
393    fn struct_variant<V>(
394        self,
395        fields: &'static [&'static str],
396        visitor: V,
397    ) -> Result<V::Value, Self::Error>
398    where
399        V: Visitor<'de>,
400    {
401        self.deserialize_struct("", fields, visitor)
402    }
403}