Skip to main content

matchmaker_partial/
simple_de.rs

1use serde::{
2    de::{self, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor},
3    forward_to_deserialize_any,
4};
5
6use crate::SimpleError;
7
8pub struct SimpleDeserializer<'de> {
9    input: &'de [String],
10}
11
12impl<'de> SimpleDeserializer<'de> {
13    pub fn from_slice(input: &'de [String]) -> Self {
14        Self { input }
15    }
16
17    fn expect_single(&self) -> Result<&'de str, SimpleError> {
18        if self.input.is_empty() {
19            return Err(SimpleError::ExpectedSingle);
20        }
21        Ok(&self.input[self.input.len() - 1])
22    }
23}
24
25macro_rules! impl_number {
26    ($name:ident, $ty:ty, $visit:ident, $expect:literal) => {
27        fn $name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
28        where
29            V: Visitor<'de>,
30        {
31            let s = self.expect_single()?;
32            let v: $ty = s.parse().map_err(|_| SimpleError::InvalidType {
33                expected: $expect,
34                found: s.to_string(),
35            })?;
36            visitor.$visit(v)
37        }
38    };
39}
40
41impl<'de> Deserializer<'de> for SimpleDeserializer<'de> {
42    type Error = SimpleError;
43
44    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
45    where
46        V: Visitor<'de>,
47    {
48        match self.input.len() {
49            0 => visitor.visit_unit(),
50            1 => {
51                let s = &self.input[self.input.len() - 1];
52
53                if s == "true" {
54                    return visitor.visit_bool(true);
55                }
56                if s == "false" {
57                    return visitor.visit_bool(false);
58                }
59                if s.is_empty() || s == "()" {
60                    return visitor.visit_unit();
61                }
62                if let Ok(i) = s.parse::<i64>() {
63                    return visitor.visit_i64(i);
64                }
65                if let Ok(f) = s.parse::<f64>() {
66                    return visitor.visit_f64(f);
67                }
68
69                visitor.visit_str(s)
70            }
71            _ => self.deserialize_seq(visitor),
72        }
73    }
74
75    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
76    where
77        V: Visitor<'de>,
78    {
79        let s = self.expect_single()?;
80        match s {
81            "true" => visitor.visit_bool(true),
82            "false" => visitor.visit_bool(false),
83            _ => Err(SimpleError::InvalidType {
84                expected: "a boolean",
85                found: s.to_string(),
86            }),
87        }
88    }
89
90    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
91    where
92        V: Visitor<'de>,
93    {
94        let s = self.expect_single()?;
95        let mut chars = s.chars();
96        let c = chars.next().ok_or_else(|| SimpleError::InvalidType {
97            expected: "a char",
98            found: s.to_string(),
99        })?;
100        if chars.next().is_some() {
101            return Err(SimpleError::InvalidType {
102                expected: "a single character",
103                found: s.to_string(),
104            });
105        }
106        visitor.visit_char(c)
107    }
108
109    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
110    where
111        V: Visitor<'de>,
112    {
113        visitor.visit_str(self.expect_single()?)
114    }
115
116    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
117    where
118        V: Visitor<'de>,
119    {
120        visitor.visit_string(self.expect_single()?.to_string())
121    }
122
123    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
124    where
125        V: Visitor<'de>,
126    {
127        let s = self.expect_single()?;
128        if s.is_empty() || s == "()" {
129            visitor.visit_unit()
130        } else {
131            Err(SimpleError::InvalidType {
132                expected: "unit",
133                found: s.to_string(),
134            })
135        }
136    }
137
138    fn deserialize_unit_struct<V>(
139        self,
140        _name: &'static str,
141        visitor: V,
142    ) -> Result<V::Value, Self::Error>
143    where
144        V: Visitor<'de>,
145    {
146        self.deserialize_unit(visitor)
147    }
148
149    fn deserialize_newtype_struct<V>(
150        self,
151        _name: &'static str,
152        visitor: V,
153    ) -> Result<V::Value, Self::Error>
154    where
155        V: Visitor<'de>,
156    {
157        visitor.visit_newtype_struct(self)
158    }
159
160    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
161    where
162        V: Visitor<'de>,
163    {
164        if self.input.is_empty() || (self.input.len() == 1 && self.input[0] == "null") {
165            visitor.visit_none()
166        } else {
167            visitor.visit_some(self)
168        }
169    }
170
171    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
172    where
173        V: Visitor<'de>,
174    {
175        visitor.visit_seq(SimpleSeqAccess {
176            iter: self.input.iter(),
177        })
178    }
179
180    impl_number!(deserialize_i8, i8, visit_i8, "an i8");
181    impl_number!(deserialize_i16, i16, visit_i16, "an i16");
182    impl_number!(deserialize_i32, i32, visit_i32, "an i32");
183    impl_number!(deserialize_i64, i64, visit_i64, "an i64");
184
185    impl_number!(deserialize_u8, u8, visit_u8, "a u8");
186    impl_number!(deserialize_u16, u16, visit_u16, "a u16");
187    impl_number!(deserialize_u32, u32, visit_u32, "a u32");
188    impl_number!(deserialize_u64, u64, visit_u64, "a u64");
189
190    impl_number!(deserialize_f32, f32, visit_f32, "an f32");
191    impl_number!(deserialize_f64, f64, visit_f64, "an f64");
192
193    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
194    where
195        V: Visitor<'de>,
196    {
197        if self.input.len() != len {
198            return Err(SimpleError::InvalidType {
199                expected: "tuple of specified length",
200                found: format!("{} elements", self.input.len()),
201            });
202        }
203        self.deserialize_seq(visitor)
204    }
205
206    fn deserialize_tuple_struct<V>(
207        self,
208        _name: &'static str,
209        len: usize,
210        visitor: V,
211    ) -> Result<V::Value, Self::Error>
212    where
213        V: Visitor<'de>,
214    {
215        self.deserialize_tuple(len, visitor)
216    }
217
218    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
219    where
220        V: Visitor<'de>,
221    {
222        visitor.visit_map(SimpleMapAccess {
223            iter: self.input.iter(),
224        })
225    }
226
227    fn deserialize_struct<V>(
228        self,
229        _name: &'static str,
230        _fields: &'static [&'static str],
231        visitor: V,
232    ) -> Result<V::Value, Self::Error>
233    where
234        V: Visitor<'de>,
235    {
236        self.deserialize_map(visitor)
237    }
238
239    fn deserialize_enum<V>(
240        self,
241        _name: &'static str,
242        _variants: &'static [&'static str],
243        visitor: V,
244    ) -> Result<V::Value, Self::Error>
245    where
246        V: Visitor<'de>,
247    {
248        if self.input.is_empty() {
249            return Err(SimpleError::InvalidType {
250                expected: "enum variant",
251                found: "empty input".to_string(),
252            });
253        }
254
255        let variant = &self.input[0..1];
256        let rest = &self.input[1..];
257        visitor.visit_enum(SimpleEnumAccess { variant, rest })
258    }
259
260    forward_to_deserialize_any! {
261        bytes byte_buf identifier ignored_any
262    }
263}
264
265struct SimpleSeqAccess<'de> {
266    iter: std::slice::Iter<'de, String>,
267}
268
269impl<'de> SeqAccess<'de> for SimpleSeqAccess<'de> {
270    type Error = SimpleError;
271
272    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
273    where
274        T: DeserializeSeed<'de>,
275    {
276        match self.iter.next() {
277            Some(value) => {
278                let de = SimpleDeserializer {
279                    input: std::slice::from_ref(value),
280                };
281                seed.deserialize(de).map(Some)
282            }
283            None => Ok(None),
284        }
285    }
286}
287
288struct SimpleMapAccess<'de> {
289    iter: std::slice::Iter<'de, String>,
290}
291
292impl<'de> MapAccess<'de> for SimpleMapAccess<'de> {
293    type Error = SimpleError;
294
295    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
296    where
297        K: DeserializeSeed<'de>,
298    {
299        match self.iter.next() {
300            Some(key) => {
301                let de = SimpleDeserializer {
302                    input: std::slice::from_ref(key),
303                };
304                seed.deserialize(de).map(Some)
305            }
306            None => Ok(None),
307        }
308    }
309
310    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
311    where
312        V: DeserializeSeed<'de>,
313    {
314        match self.iter.next() {
315            Some(value) => {
316                let de = SimpleDeserializer {
317                    input: std::slice::from_ref(value),
318                };
319                seed.deserialize(de)
320            }
321            None => Err(SimpleError::ExpectedSingle),
322        }
323    }
324}
325
326struct SimpleEnumAccess<'de> {
327    variant: &'de [String],
328    rest: &'de [String],
329}
330
331impl<'de> de::EnumAccess<'de> for SimpleEnumAccess<'de> {
332    type Error = SimpleError;
333    type Variant = SimpleDeserializer<'de>;
334
335    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
336    where
337        V: DeserializeSeed<'de>,
338    {
339        let val = seed.deserialize(SimpleDeserializer::from_slice(&self.variant))?;
340        Ok((val, SimpleDeserializer::from_slice(self.rest)))
341    }
342}
343
344impl<'de> de::VariantAccess<'de> for SimpleDeserializer<'de> {
345    type Error = SimpleError;
346
347    fn unit_variant(self) -> Result<(), Self::Error> {
348        if !self.input.is_empty() {
349            return Err(SimpleError::InvalidType {
350                expected: "unit variant",
351                found: format!("{} elements", self.input.len()),
352            });
353        }
354        Ok(())
355    }
356
357    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
358    where
359        T: DeserializeSeed<'de>,
360    {
361        seed.deserialize(self)
362    }
363
364    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
365    where
366        V: Visitor<'de>,
367    {
368        self.deserialize_tuple(len, visitor)
369    }
370
371    fn struct_variant<V>(
372        self,
373        fields: &'static [&'static str],
374        visitor: V,
375    ) -> Result<V::Value, Self::Error>
376    where
377        V: Visitor<'de>,
378    {
379        self.deserialize_struct("", fields, visitor)
380    }
381}
382
383pub fn deserialize<'de, T>(input: &'de [String]) -> Result<T, SimpleError>
384where
385    T: de::Deserialize<'de>,
386{
387    let de = SimpleDeserializer::from_slice(input);
388    T::deserialize(de)
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use serde::Deserialize;
395    use serde::de::DeserializeOwned;
396    use std::collections::HashMap;
397
398    fn de<T: DeserializeOwned>(input: &[&str]) -> T {
399        let data: Vec<String> = input.iter().map(|s| s.to_string()).collect();
400        let de = SimpleDeserializer::from_slice(&data);
401        T::deserialize(de).unwrap()
402    }
403
404    fn de_err<T: DeserializeOwned>(input: &[&str]) {
405        let data: Vec<String> = input.iter().map(|s| s.to_string()).collect();
406        let de = SimpleDeserializer::from_slice(&data);
407        assert!(T::deserialize(de).is_err());
408    }
409
410    #[test]
411    fn primitives_last() {
412        assert_eq!(de::<i32>(&["1", "2"]), 2);
413        assert_eq!(de::<bool>(&["false", "true"]), true);
414        assert_eq!(de::<String>(&["first", "second"]), "second");
415    }
416
417    #[test]
418    fn bool_ok() {
419        assert_eq!(de::<bool>(&["true"]), true);
420        assert_eq!(de::<bool>(&["false"]), false);
421    }
422
423    #[test]
424    fn bool_err() {
425        de_err::<bool>(&["not_bool"]);
426    }
427
428    #[test]
429    fn integers() {
430        assert_eq!(de::<i32>(&["42"]), 42);
431        assert_eq!(de::<i8>(&["-5"]), -5);
432        assert_eq!(de::<u16>(&["10"]), 10);
433    }
434
435    #[test]
436    fn floats() {
437        assert_eq!(de::<f32>(&["1.5"]), 1.5);
438        assert_eq!(de::<f64>(&["2.25"]), 2.25);
439    }
440
441    #[test]
442    fn char_ok() {
443        assert_eq!(de::<char>(&["a"]), 'a');
444    }
445
446    #[test]
447    fn char_err() {
448        de_err::<char>(&["ab"]);
449        de_err::<char>(&[""]);
450    }
451
452    #[test]
453    fn string_ok() {
454        assert_eq!(de::<String>(&["hello"]), "hello");
455    }
456
457    #[test]
458    fn unit_ok() {
459        assert_eq!(de::<()>(&[""]), ());
460        assert_eq!(de::<()>(&["()"]), ());
461    }
462
463    #[test]
464    fn unit_err() {
465        de_err::<()>(&["not_unit"]);
466    }
467
468    #[test]
469    fn option_none() {
470        assert_eq!(de::<Option<i32>>(&[]), None);
471        assert_eq!(de::<Option<i32>>(&["null"]), None);
472    }
473
474    #[test]
475    fn option_some() {
476        assert_eq!(de::<Option<i32>>(&["5"]), Some(5));
477    }
478
479    #[test]
480    fn vec_of_ints() {
481        let v: Vec<i32> = de(&["1", "2", "3"]);
482        assert_eq!(v, vec![1, 2, 3]);
483    }
484
485    #[test]
486    fn vec_of_strings() {
487        let v: Vec<String> = de(&["a", "b", "c"]);
488        assert_eq!(v, vec!["a", "b", "c"]);
489    }
490
491    #[test]
492    fn vec_of_options() {
493        let v: Vec<Option<i32>> = de(&["1", "null", "3"]);
494        assert_eq!(v, vec![Some(1), None, Some(3)]);
495    }
496
497    #[derive(Debug, Deserialize, PartialEq)]
498    struct Newtype(i32);
499
500    #[test]
501    fn newtype_struct() {
502        let n: Newtype = de(&["99"]);
503        assert_eq!(n, Newtype(99));
504    }
505
506    #[test]
507    fn error_on_multiple_scalars() {
508        de_err::<i32>(&[]);
509    }
510
511    #[test]
512    fn struct_map() {
513        #[derive(Debug, Deserialize, PartialEq)]
514        struct S {
515            a: i32,
516            b: String,
517        }
518
519        let s: S = de(&["a", "10", "b", "hello"]);
520        assert_eq!(
521            s,
522            S {
523                a: 10,
524                b: "hello".to_string()
525            }
526        );
527    }
528
529    #[test]
530    fn hashmap_ok() {
531        let m: HashMap<String, i32> = de(&["x", "1", "y", "2"]);
532        let mut expected = HashMap::new();
533        expected.insert("x".to_string(), 1);
534        expected.insert("y".to_string(), 2);
535        assert_eq!(m, expected);
536    }
537
538    #[test]
539    fn deserialize_struct_tuple_enum() {
540        #[derive(Debug, Deserialize, PartialEq)]
541        struct MyStruct {
542            a: i32,
543            b: String,
544        }
545
546        #[derive(Debug, Deserialize, PartialEq)]
547        struct MyTupleStruct(i32, String);
548
549        #[derive(Debug, Deserialize, PartialEq)]
550        enum MyEnum {
551            Unit,
552            Newtype(i32),
553            Tuple(i32, i32),
554            Struct { x: i32, y: i32 },
555        }
556
557        // Struct
558        let s: MyStruct = de(&["a", "42", "b", "hello"]);
559        assert_eq!(
560            s,
561            MyStruct {
562                a: 42,
563                b: "hello".to_string()
564            }
565        );
566
567        // Tuple struct
568        let t: MyTupleStruct = de(&["7", "world"]);
569        assert_eq!(t, MyTupleStruct(7, "world".to_string()));
570
571        // Enum unit
572        let e: MyEnum = de(&["Unit"]);
573        assert_eq!(e, MyEnum::Unit);
574
575        // Enum newtype
576        let e: MyEnum = de(&["Newtype", "123"]);
577        assert_eq!(e, MyEnum::Newtype(123));
578
579        // Enum tuple
580        let e: MyEnum = de(&["Tuple", "1", "2"]);
581        assert_eq!(e, MyEnum::Tuple(1, 2));
582
583        // Enum struct
584        let e: MyEnum = de(&["Struct", "y", "20", "x", "10"]);
585        assert_eq!(e, MyEnum::Struct { x: 10, y: 20 });
586    }
587}