poem/web/path/
de.rs

1use std::fmt::{self, Display};
2
3use serde::{
4    Deserializer,
5    de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
6    forward_to_deserialize_any,
7};
8
9/// This type represents errors that can occur when deserializing.
10#[derive(Debug, Eq, PartialEq)]
11pub(crate) struct PathDeserializerError(pub(crate) String);
12
13impl de::Error for PathDeserializerError {
14    #[inline]
15    fn custom<T: Display>(msg: T) -> Self {
16        PathDeserializerError(msg.to_string())
17    }
18}
19
20impl std::error::Error for PathDeserializerError {
21    #[inline]
22    fn description(&self) -> &str {
23        "path deserializer error"
24    }
25}
26
27impl fmt::Display for PathDeserializerError {
28    #[inline]
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        match self {
31            PathDeserializerError(msg) => write!(f, "{msg}"),
32        }
33    }
34}
35
36macro_rules! unsupported_type {
37    ($trait_fn:ident, $name:literal) => {
38        fn $trait_fn<V>(self, _: V) -> Result<V::Value, Self::Error>
39        where
40            V: Visitor<'de>,
41        {
42            Err(PathDeserializerError::custom(concat!(
43                "unsupported type: ",
44                $name
45            )))
46        }
47    };
48}
49
50macro_rules! parse_single_value {
51    ($trait_fn:ident, $visit_fn:ident, $tp:literal) => {
52        fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
53        where
54            V: Visitor<'de>,
55        {
56            if self.url_params.len() != 1 {
57                return Err(PathDeserializerError::custom(
58                    format!(
59                        "wrong number of parameters: {} expected 1",
60                        self.url_params.len()
61                    )
62                    .as_str(),
63                ));
64            }
65
66            let value = self.url_params[0].1.parse().map_err(|_| {
67                PathDeserializerError::custom(format!(
68                    "can not parse `{:?}` to a `{}`",
69                    self.url_params[0].1.as_str(),
70                    $tp
71                ))
72            })?;
73            visitor.$visit_fn(value)
74        }
75    };
76}
77
78pub(crate) struct PathDeserializer<'de> {
79    url_params: &'de [(String, String)],
80}
81
82impl<'de> PathDeserializer<'de> {
83    #[inline]
84    pub(crate) fn new(url_params: &'de [(String, String)]) -> Self {
85        PathDeserializer { url_params }
86    }
87}
88
89impl<'de> Deserializer<'de> for PathDeserializer<'de> {
90    type Error = PathDeserializerError;
91
92    unsupported_type!(deserialize_any, "'any'");
93    unsupported_type!(deserialize_bytes, "bytes");
94    unsupported_type!(deserialize_option, "Option<T>");
95    unsupported_type!(deserialize_identifier, "identifier");
96    unsupported_type!(deserialize_ignored_any, "ignored_any");
97
98    parse_single_value!(deserialize_bool, visit_bool, "bool");
99    parse_single_value!(deserialize_i8, visit_i8, "i8");
100    parse_single_value!(deserialize_i16, visit_i16, "i16");
101    parse_single_value!(deserialize_i32, visit_i32, "i32");
102    parse_single_value!(deserialize_i64, visit_i64, "i64");
103    parse_single_value!(deserialize_u8, visit_u8, "u8");
104    parse_single_value!(deserialize_u16, visit_u16, "u16");
105    parse_single_value!(deserialize_u32, visit_u32, "u32");
106    parse_single_value!(deserialize_u64, visit_u64, "u64");
107    parse_single_value!(deserialize_f32, visit_f32, "f32");
108    parse_single_value!(deserialize_f64, visit_f64, "f64");
109    parse_single_value!(deserialize_string, visit_string, "String");
110    parse_single_value!(deserialize_byte_buf, visit_string, "String");
111    parse_single_value!(deserialize_char, visit_char, "char");
112
113    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
114    where
115        V: Visitor<'de>,
116    {
117        if self.url_params.len() != 1 {
118            return Err(PathDeserializerError::custom(format!(
119                "wrong number of parameters: {} expected 1",
120                self.url_params.len()
121            )));
122        }
123        visitor.visit_str(&self.url_params[0].1)
124    }
125
126    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
127    where
128        V: Visitor<'de>,
129    {
130        visitor.visit_unit()
131    }
132
133    fn deserialize_unit_struct<V>(
134        self,
135        _name: &'static str,
136        visitor: V,
137    ) -> Result<V::Value, Self::Error>
138    where
139        V: Visitor<'de>,
140    {
141        visitor.visit_unit()
142    }
143
144    fn deserialize_newtype_struct<V>(
145        self,
146        _name: &'static str,
147        visitor: V,
148    ) -> Result<V::Value, Self::Error>
149    where
150        V: Visitor<'de>,
151    {
152        visitor.visit_newtype_struct(self)
153    }
154
155    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
156    where
157        V: Visitor<'de>,
158    {
159        visitor.visit_seq(SeqDeserializer {
160            params: self.url_params,
161        })
162    }
163
164    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
165    where
166        V: Visitor<'de>,
167    {
168        if self.url_params.len() < len {
169            return Err(PathDeserializerError::custom(
170                format!(
171                    "wrong number of parameters: {} expected {}",
172                    self.url_params.len(),
173                    len
174                )
175                .as_str(),
176            ));
177        }
178        visitor.visit_seq(SeqDeserializer {
179            params: self.url_params,
180        })
181    }
182
183    fn deserialize_tuple_struct<V>(
184        self,
185        _name: &'static str,
186        len: usize,
187        visitor: V,
188    ) -> Result<V::Value, Self::Error>
189    where
190        V: Visitor<'de>,
191    {
192        if self.url_params.len() < len {
193            return Err(PathDeserializerError::custom(
194                format!(
195                    "wrong number of parameters: {} expected {}",
196                    self.url_params.len(),
197                    len
198                )
199                .as_str(),
200            ));
201        }
202        visitor.visit_seq(SeqDeserializer {
203            params: self.url_params,
204        })
205    }
206
207    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
208    where
209        V: Visitor<'de>,
210    {
211        visitor.visit_map(MapDeserializer {
212            params: self.url_params,
213            value: None,
214        })
215    }
216
217    fn deserialize_struct<V>(
218        self,
219        _name: &'static str,
220        _fields: &'static [&'static str],
221        visitor: V,
222    ) -> Result<V::Value, Self::Error>
223    where
224        V: Visitor<'de>,
225    {
226        self.deserialize_map(visitor)
227    }
228
229    fn deserialize_enum<V>(
230        self,
231        _name: &'static str,
232        _variants: &'static [&'static str],
233        visitor: V,
234    ) -> Result<V::Value, Self::Error>
235    where
236        V: Visitor<'de>,
237    {
238        if self.url_params.len() != 1 {
239            return Err(PathDeserializerError::custom(format!(
240                "wrong number of parameters: {} expected 1",
241                self.url_params.len()
242            )));
243        }
244
245        visitor.visit_enum(EnumDeserializer {
246            value: &self.url_params[0].1,
247        })
248    }
249}
250
251struct MapDeserializer<'de> {
252    params: &'de [(String, String)],
253    value: Option<&'de str>,
254}
255
256impl<'de> MapAccess<'de> for MapDeserializer<'de> {
257    type Error = PathDeserializerError;
258
259    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
260    where
261        K: DeserializeSeed<'de>,
262    {
263        match self.params.split_first() {
264            Some(((key, value), tail)) => {
265                self.value = Some(value);
266                self.params = tail;
267                seed.deserialize(KeyDeserializer { key }).map(Some)
268            }
269            None => Ok(None),
270        }
271    }
272
273    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
274    where
275        V: DeserializeSeed<'de>,
276    {
277        match self.value.take() {
278            Some(value) => seed.deserialize(ValueDeserializer { value }),
279            None => Err(serde::de::Error::custom("value is missing")),
280        }
281    }
282}
283
284struct KeyDeserializer<'de> {
285    key: &'de str,
286}
287
288macro_rules! parse_key {
289    ($trait_fn:ident) => {
290        fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
291        where
292            V: Visitor<'de>,
293        {
294            visitor.visit_str(self.key)
295        }
296    };
297}
298
299impl<'de> Deserializer<'de> for KeyDeserializer<'de> {
300    type Error = PathDeserializerError;
301
302    parse_key!(deserialize_identifier);
303    parse_key!(deserialize_str);
304    parse_key!(deserialize_string);
305
306    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
307    where
308        V: Visitor<'de>,
309    {
310        Err(PathDeserializerError::custom("Unexpected"))
311    }
312
313    forward_to_deserialize_any! {
314        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char bytes
315        byte_buf option unit unit_struct seq tuple
316        tuple_struct map newtype_struct struct enum ignored_any
317    }
318}
319
320macro_rules! parse_value {
321    ($trait_fn:ident, $visit_fn:ident, $ty:literal) => {
322        fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
323        where
324            V: Visitor<'de>,
325        {
326            let v = self.value.parse().map_err(|_| {
327                PathDeserializerError::custom(format!(
328                    "can not parse `{:?}` to a `{}`",
329                    self.value, $ty
330                ))
331            })?;
332            visitor.$visit_fn(v)
333        }
334    };
335}
336
337struct ValueDeserializer<'de> {
338    value: &'de str,
339}
340
341impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
342    type Error = PathDeserializerError;
343
344    unsupported_type!(deserialize_any, "any");
345    unsupported_type!(deserialize_seq, "seq");
346    unsupported_type!(deserialize_map, "map");
347    unsupported_type!(deserialize_identifier, "identifier");
348
349    parse_value!(deserialize_bool, visit_bool, "bool");
350    parse_value!(deserialize_i8, visit_i8, "i8");
351    parse_value!(deserialize_i16, visit_i16, "i16");
352    parse_value!(deserialize_i32, visit_i32, "i16");
353    parse_value!(deserialize_i64, visit_i64, "i64");
354    parse_value!(deserialize_u8, visit_u8, "u8");
355    parse_value!(deserialize_u16, visit_u16, "u16");
356    parse_value!(deserialize_u32, visit_u32, "u32");
357    parse_value!(deserialize_u64, visit_u64, "u64");
358    parse_value!(deserialize_f32, visit_f32, "f32");
359    parse_value!(deserialize_f64, visit_f64, "f64");
360    parse_value!(deserialize_string, visit_string, "String");
361    parse_value!(deserialize_byte_buf, visit_string, "String");
362    parse_value!(deserialize_char, visit_char, "char");
363
364    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
365    where
366        V: Visitor<'de>,
367    {
368        visitor.visit_borrowed_str(self.value)
369    }
370
371    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
372    where
373        V: Visitor<'de>,
374    {
375        visitor.visit_borrowed_bytes(self.value.as_bytes())
376    }
377
378    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
379    where
380        V: Visitor<'de>,
381    {
382        visitor.visit_some(self)
383    }
384
385    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
386    where
387        V: Visitor<'de>,
388    {
389        visitor.visit_unit()
390    }
391
392    fn deserialize_unit_struct<V>(
393        self,
394        _name: &'static str,
395        visitor: V,
396    ) -> Result<V::Value, Self::Error>
397    where
398        V: Visitor<'de>,
399    {
400        visitor.visit_unit()
401    }
402
403    fn deserialize_newtype_struct<V>(
404        self,
405        _name: &'static str,
406        visitor: V,
407    ) -> Result<V::Value, Self::Error>
408    where
409        V: Visitor<'de>,
410    {
411        visitor.visit_newtype_struct(self)
412    }
413
414    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
415    where
416        V: Visitor<'de>,
417    {
418        Err(PathDeserializerError::custom("unsupported type: tuple"))
419    }
420
421    fn deserialize_tuple_struct<V>(
422        self,
423        _name: &'static str,
424        _len: usize,
425        _visitor: V,
426    ) -> Result<V::Value, Self::Error>
427    where
428        V: Visitor<'de>,
429    {
430        Err(PathDeserializerError::custom(
431            "unsupported type: tuple struct",
432        ))
433    }
434
435    fn deserialize_struct<V>(
436        self,
437        _name: &'static str,
438        _fields: &'static [&'static str],
439        _visitor: V,
440    ) -> Result<V::Value, Self::Error>
441    where
442        V: Visitor<'de>,
443    {
444        Err(PathDeserializerError::custom("unsupported type: struct"))
445    }
446
447    fn deserialize_enum<V>(
448        self,
449        _name: &'static str,
450        _variants: &'static [&'static str],
451        visitor: V,
452    ) -> Result<V::Value, Self::Error>
453    where
454        V: Visitor<'de>,
455    {
456        visitor.visit_enum(EnumDeserializer { value: self.value })
457    }
458
459    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
460    where
461        V: Visitor<'de>,
462    {
463        visitor.visit_unit()
464    }
465}
466
467struct EnumDeserializer<'de> {
468    value: &'de str,
469}
470
471impl<'de> EnumAccess<'de> for EnumDeserializer<'de> {
472    type Error = PathDeserializerError;
473    type Variant = UnitVariant;
474
475    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
476    where
477        V: de::DeserializeSeed<'de>,
478    {
479        Ok((
480            seed.deserialize(KeyDeserializer { key: self.value })?,
481            UnitVariant,
482        ))
483    }
484}
485
486struct UnitVariant;
487
488impl<'de> VariantAccess<'de> for UnitVariant {
489    type Error = PathDeserializerError;
490
491    fn unit_variant(self) -> Result<(), Self::Error> {
492        Ok(())
493    }
494
495    fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Self::Error>
496    where
497        T: DeserializeSeed<'de>,
498    {
499        Err(PathDeserializerError::custom("not supported"))
500    }
501
502    fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
503    where
504        V: Visitor<'de>,
505    {
506        Err(PathDeserializerError::custom("not supported"))
507    }
508
509    fn struct_variant<V>(
510        self,
511        _fields: &'static [&'static str],
512        _visitor: V,
513    ) -> Result<V::Value, Self::Error>
514    where
515        V: Visitor<'de>,
516    {
517        Err(PathDeserializerError::custom("not supported"))
518    }
519}
520
521struct SeqDeserializer<'de> {
522    params: &'de [(String, String)],
523}
524
525impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
526    type Error = PathDeserializerError;
527
528    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
529    where
530        T: DeserializeSeed<'de>,
531    {
532        match self.params.split_first() {
533            Some(((_, value), tail)) => {
534                self.params = tail;
535                Ok(Some(seed.deserialize(ValueDeserializer { value })?))
536            }
537            None => Ok(None),
538        }
539    }
540}
541
542#[cfg(test)]
543#[allow(clippy::float_cmp)]
544mod tests {
545    use std::collections::HashMap;
546
547    use serde::Deserialize;
548
549    use super::*;
550    use crate::route::PathParams;
551
552    #[derive(Debug, Deserialize, Eq, PartialEq)]
553    enum MyEnum {
554        A,
555        B,
556        #[serde(rename = "c")]
557        C,
558    }
559
560    #[derive(Debug, Deserialize, Eq, PartialEq)]
561    struct Struct {
562        c: String,
563        b: bool,
564        a: i32,
565    }
566
567    fn create_url_params<I, K, V>(values: I) -> PathParams
568    where
569        I: IntoIterator<Item = (K, V)>,
570        K: Into<String>,
571        V: Into<String>,
572    {
573        values
574            .into_iter()
575            .map(|(k, v)| (k.into(), v.into()))
576            .collect()
577    }
578
579    macro_rules! check_single_value {
580        ($ty:ty, $value_str:literal, $value:expr) => {
581            #[allow(clippy::bool_assert_comparison)]
582            {
583                let url_params = create_url_params(vec![("value", $value_str)]);
584                let deserializer = PathDeserializer::new(&url_params);
585                assert_eq!(<$ty>::deserialize(deserializer).unwrap(), $value);
586            }
587        };
588    }
589
590    #[test]
591    fn test_parse_single_value() {
592        check_single_value!(bool, "true", true);
593        check_single_value!(bool, "false", false);
594        check_single_value!(i8, "-123", -123);
595        check_single_value!(i16, "-123", -123);
596        check_single_value!(i32, "-123", -123);
597        check_single_value!(i64, "-123", -123);
598        check_single_value!(u8, "123", 123);
599        check_single_value!(u16, "123", 123);
600        check_single_value!(u32, "123", 123);
601        check_single_value!(u64, "123", 123);
602        check_single_value!(f32, "123", 123.0);
603        check_single_value!(f64, "123", 123.0);
604        check_single_value!(String, "abc", "abc");
605        check_single_value!(char, "a", 'a');
606
607        let url_params = create_url_params(vec![("a", "B")]);
608        assert_eq!(
609            MyEnum::deserialize(PathDeserializer::new(&url_params)).unwrap(),
610            MyEnum::B
611        );
612
613        let url_params = create_url_params(vec![("a", "1"), ("b", "2")]);
614        assert_eq!(
615            i32::deserialize(PathDeserializer::new(&url_params)).unwrap_err(),
616            PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_string())
617        );
618    }
619
620    #[test]
621    fn test_parse_seq() {
622        let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]);
623        assert_eq!(
624            <(i32, bool, String)>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
625            (1, true, "abc".to_string())
626        );
627
628        #[derive(Debug, Deserialize, Eq, PartialEq)]
629        struct TupleStruct(i32, bool, String);
630        assert_eq!(
631            TupleStruct::deserialize(PathDeserializer::new(&url_params)).unwrap(),
632            TupleStruct(1, true, "abc".to_string())
633        );
634
635        let url_params = create_url_params(vec![("a", "1"), ("b", "2"), ("c", "3")]);
636        assert_eq!(
637            <Vec<i32>>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
638            vec![1, 2, 3]
639        );
640
641        let url_params = create_url_params(vec![("a", "c"), ("a", "B")]);
642        assert_eq!(
643            <Vec<MyEnum>>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
644            vec![MyEnum::C, MyEnum::B]
645        );
646    }
647
648    #[test]
649    fn test_parse_struct() {
650        let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]);
651        assert_eq!(
652            Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(),
653            Struct {
654                c: "abc".to_string(),
655                b: true,
656                a: 1,
657            }
658        );
659    }
660
661    #[test]
662    fn test_parse_map() {
663        let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]);
664        assert_eq!(
665            <HashMap<String, String>>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
666            [("a", "1"), ("b", "true"), ("c", "abc")]
667                .iter()
668                .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
669                .collect()
670        );
671    }
672}