htrpc/deserializers/
url_query.rs

1use std;
2use std::borrow::Cow;
3use std::iter::Peekable;
4use serde::de::{self, Visitor};
5use trackable::error::ErrorKindExt;
6use url;
7
8use {Error, ErrorKind, Result};
9
10#[derive(Debug, PartialEq, Eq)]
11enum Phase<'a> {
12    Key,
13    Value(Cow<'a, str>),
14}
15impl<'a> Phase<'a> {
16    pub fn take(&mut self) -> Self {
17        std::mem::replace(self, Phase::Key)
18    }
19}
20
21/// `Deserializer` implementation for URL query string.
22pub struct UrlQueryDeserializer<'de> {
23    in_map: bool,
24    phase: Phase<'de>,
25    query: Peekable<url::form_urlencoded::Parse<'de>>,
26}
27impl<'de> UrlQueryDeserializer<'de> {
28    /// Makes a new `UrlQueryDeserializer` instance.
29    pub fn new(query: url::form_urlencoded::Parse<'de>) -> Self {
30        UrlQueryDeserializer {
31            in_map: false,
32            phase: Phase::Key,
33            query: query.peekable(),
34        }
35    }
36
37    fn is_end_of_query(&mut self) -> bool {
38        self.query.peek().is_none()
39    }
40    fn next_str(&mut self) -> Result<Cow<'de, str>> {
41        match self.phase.take() {
42            Phase::Key => {
43                let (k, v) = track!(self.query.next().ok_or_else(|| ErrorKind::Invalid.error()))?;
44                self.phase = Phase::Value(v);
45                Ok(k)
46            }
47            Phase::Value(v) => Ok(v),
48        }
49    }
50}
51impl<'de, 'a> de::Deserializer<'de> for &'a mut UrlQueryDeserializer<'de> {
52    type Error = Error;
53    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
54    where
55        V: Visitor<'de>,
56    {
57        track_panic!(ErrorKind::Other, "unreachable");
58    }
59
60    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
61    where
62        V: Visitor<'de>,
63    {
64        let v = track!(self.next_str())?;
65        let v = track!(parse_cow_str(v))?;
66        track!(visitor.visit_bool(v))
67    }
68
69    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
70    where
71        V: Visitor<'de>,
72    {
73        let v = track!(self.next_str())?;
74        let v = track!(parse_cow_str(v))?;
75        track!(visitor.visit_i8(v))
76    }
77
78    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
79    where
80        V: Visitor<'de>,
81    {
82        let v = track!(self.next_str())?;
83        let v = track!(parse_cow_str(v))?;
84        track!(visitor.visit_i16(v))
85    }
86
87    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
88    where
89        V: Visitor<'de>,
90    {
91        let v = track!(self.next_str())?;
92        let v = track!(parse_cow_str(v))?;
93        track!(visitor.visit_i32(v))
94    }
95
96    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
97    where
98        V: Visitor<'de>,
99    {
100        let v = track!(self.next_str())?;
101        let v = track!(parse_cow_str(v))?;
102        track!(visitor.visit_i64(v))
103    }
104
105    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
106    where
107        V: Visitor<'de>,
108    {
109        let v = track!(self.next_str())?;
110        let v = track!(parse_cow_str(v))?;
111        track!(visitor.visit_u8(v))
112    }
113
114    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
115    where
116        V: Visitor<'de>,
117    {
118        let v = track!(self.next_str())?;
119        let v = track!(parse_cow_str(v))?;
120        track!(visitor.visit_u16(v))
121    }
122
123    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
124    where
125        V: Visitor<'de>,
126    {
127        let v = track!(self.next_str())?;
128        let v = track!(parse_cow_str(v))?;
129        track!(visitor.visit_u32(v))
130    }
131
132    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
133    where
134        V: Visitor<'de>,
135    {
136        let v = track!(self.next_str())?;
137        let v = track!(parse_cow_str(v))?;
138        track!(visitor.visit_u64(v))
139    }
140
141    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
142    where
143        V: Visitor<'de>,
144    {
145        let v = track!(self.next_str())?;
146        let v = track!(parse_cow_str(v))?;
147        track!(visitor.visit_f32(v))
148    }
149
150    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
151    where
152        V: Visitor<'de>,
153    {
154        let v = track!(self.next_str())?;
155        let v = track!(parse_cow_str(v))?;
156        track!(visitor.visit_f64(v))
157    }
158
159    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value>
160    where
161        V: Visitor<'de>,
162    {
163        track_panic!(ErrorKind::Invalid, "Unsupported");
164    }
165
166    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
167    where
168        V: Visitor<'de>,
169    {
170        let v = track!(self.next_str())?;
171        match v {
172            Cow::Borrowed(s) => track!(visitor.visit_borrowed_str(s)),
173            Cow::Owned(s) => track!(visitor.visit_string(s)),
174        }
175    }
176
177    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
178    where
179        V: Visitor<'de>,
180    {
181        track!(self.deserialize_str(visitor))
182    }
183
184    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
185    where
186        V: Visitor<'de>,
187    {
188        let v = track!(self.next_str())?;
189        match v {
190            Cow::Borrowed(s) => track!(visitor.visit_borrowed_bytes(s.as_bytes())),
191            Cow::Owned(s) => track!(visitor.visit_byte_buf(s.into_bytes())),
192        }
193    }
194
195    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
196    where
197        V: Visitor<'de>,
198    {
199        track!(self.deserialize_bytes(visitor))
200    }
201
202    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
203    where
204        V: Visitor<'de>,
205    {
206        track!(visitor.visit_some(self))
207    }
208
209    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
210    where
211        V: Visitor<'de>,
212    {
213        let v = track!(self.next_str())?;
214        track_assert!(v.is_empty(), ErrorKind::Invalid);
215        track!(visitor.visit_unit())
216    }
217
218    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
219    where
220        V: Visitor<'de>,
221    {
222        track!(self.deserialize_unit(visitor))
223    }
224
225    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
226    where
227        V: Visitor<'de>,
228    {
229        visitor.visit_newtype_struct(self)
230    }
231
232    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value>
233    where
234        V: Visitor<'de>,
235    {
236        track_panic!(ErrorKind::Invalid);
237    }
238
239    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
240    where
241        V: Visitor<'de>,
242    {
243        track_panic!(ErrorKind::Invalid);
244    }
245
246    fn deserialize_tuple_struct<V>(
247        self,
248        _name: &'static str,
249        _len: usize,
250        _visitor: V,
251    ) -> Result<V::Value>
252    where
253        V: Visitor<'de>,
254    {
255        track_panic!(ErrorKind::Invalid);
256    }
257
258    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
259    where
260        V: Visitor<'de>,
261    {
262        track_assert!(!self.in_map, ErrorKind::Invalid);
263        self.in_map = true;
264        track!(visitor.visit_map(self))
265    }
266
267    fn deserialize_struct<V>(
268        self,
269        _name: &'static str,
270        _fields: &'static [&'static str],
271        visitor: V,
272    ) -> Result<V::Value>
273    where
274        V: Visitor<'de>,
275    {
276        track!(self.deserialize_map(visitor))
277    }
278
279    fn deserialize_enum<V>(
280        self,
281        _name: &'static str,
282        _variants: &'static [&'static str],
283        _visitor: V,
284    ) -> Result<V::Value>
285    where
286        V: Visitor<'de>,
287    {
288        track_panic!(ErrorKind::Invalid);
289    }
290
291    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
292    where
293        V: Visitor<'de>,
294    {
295        track!(self.deserialize_str(visitor))
296    }
297
298    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
299    where
300        V: Visitor<'de>,
301    {
302        track!(visitor.visit_unit()) // NOTE: dummy visiting
303    }
304}
305impl<'de, 'a> de::MapAccess<'de> for &'a mut UrlQueryDeserializer<'de> {
306    type Error = Error;
307    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
308    where
309        K: de::DeserializeSeed<'de>,
310    {
311        if self.is_end_of_query() {
312            Ok(None)
313        } else {
314            let v = track!(seed.deserialize(&mut **self))?;
315            Ok(Some(v))
316        }
317    }
318    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
319    where
320        V: de::DeserializeSeed<'de>,
321    {
322        let v = track!(seed.deserialize(&mut **self))?;
323        Ok(v)
324    }
325}
326
327fn parse_cow_str<T: std::str::FromStr>(s: Cow<str>) -> Result<T>
328where
329    Error: From<T::Err>,
330{
331    let v = track!(s.parse().map_err(Error::from), "s={:?}", s)?;
332    Ok(v)
333}
334
335#[cfg(test)]
336mod test {
337    use serde::Deserialize;
338    use url::Url;
339    use super::*;
340
341    #[test]
342    fn struct_works() {
343        #[derive(Deserialize)]
344        struct Params {
345            foo: Option<usize>,
346            bar: String,
347        }
348
349        let url = Url::parse("http://localhost/?bar=baz+qux").unwrap();
350        {
351            let mut deserializer = UrlQueryDeserializer::new(url.query_pairs());
352            let params = track_try_unwrap!(Params::deserialize(&mut deserializer));
353            assert_eq!(params.foo, None);
354            assert_eq!(params.bar, "baz qux");
355        }
356
357        let url = Url::parse("http://localhost/?foo=10&bar=baz+qux").unwrap();
358        {
359            let mut deserializer = UrlQueryDeserializer::new(url.query_pairs());
360            let params = track_try_unwrap!(Params::deserialize(&mut deserializer));
361            assert_eq!(params.foo, Some(10));
362            assert_eq!(params.bar, "baz qux");
363        }
364    }
365}