1use std;
2use std::borrow::Cow;
3use std::iter::Peekable;
4use serde::de::{self, Visitor};
5use trackable::error::ErrorKindExt;
6use url;
7
8use {Error, ErrorKind, Result};
9
10#[derive(Debug, PartialEq, Eq)]
11enum Phase<'a> {
12 Key,
13 Value(Cow<'a, str>),
14}
15impl<'a> Phase<'a> {
16 pub fn take(&mut self) -> Self {
17 std::mem::replace(self, Phase::Key)
18 }
19}
20
21pub struct UrlQueryDeserializer<'de> {
23 in_map: bool,
24 phase: Phase<'de>,
25 query: Peekable<url::form_urlencoded::Parse<'de>>,
26}
27impl<'de> UrlQueryDeserializer<'de> {
28 pub fn new(query: url::form_urlencoded::Parse<'de>) -> Self {
30 UrlQueryDeserializer {
31 in_map: false,
32 phase: Phase::Key,
33 query: query.peekable(),
34 }
35 }
36
37 fn is_end_of_query(&mut self) -> bool {
38 self.query.peek().is_none()
39 }
40 fn next_str(&mut self) -> Result<Cow<'de, str>> {
41 match self.phase.take() {
42 Phase::Key => {
43 let (k, v) = track!(self.query.next().ok_or_else(|| ErrorKind::Invalid.error()))?;
44 self.phase = Phase::Value(v);
45 Ok(k)
46 }
47 Phase::Value(v) => Ok(v),
48 }
49 }
50}
51impl<'de, 'a> de::Deserializer<'de> for &'a mut UrlQueryDeserializer<'de> {
52 type Error = Error;
53 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
54 where
55 V: Visitor<'de>,
56 {
57 track_panic!(ErrorKind::Other, "unreachable");
58 }
59
60 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
61 where
62 V: Visitor<'de>,
63 {
64 let v = track!(self.next_str())?;
65 let v = track!(parse_cow_str(v))?;
66 track!(visitor.visit_bool(v))
67 }
68
69 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
70 where
71 V: Visitor<'de>,
72 {
73 let v = track!(self.next_str())?;
74 let v = track!(parse_cow_str(v))?;
75 track!(visitor.visit_i8(v))
76 }
77
78 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
79 where
80 V: Visitor<'de>,
81 {
82 let v = track!(self.next_str())?;
83 let v = track!(parse_cow_str(v))?;
84 track!(visitor.visit_i16(v))
85 }
86
87 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
88 where
89 V: Visitor<'de>,
90 {
91 let v = track!(self.next_str())?;
92 let v = track!(parse_cow_str(v))?;
93 track!(visitor.visit_i32(v))
94 }
95
96 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
97 where
98 V: Visitor<'de>,
99 {
100 let v = track!(self.next_str())?;
101 let v = track!(parse_cow_str(v))?;
102 track!(visitor.visit_i64(v))
103 }
104
105 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
106 where
107 V: Visitor<'de>,
108 {
109 let v = track!(self.next_str())?;
110 let v = track!(parse_cow_str(v))?;
111 track!(visitor.visit_u8(v))
112 }
113
114 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
115 where
116 V: Visitor<'de>,
117 {
118 let v = track!(self.next_str())?;
119 let v = track!(parse_cow_str(v))?;
120 track!(visitor.visit_u16(v))
121 }
122
123 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
124 where
125 V: Visitor<'de>,
126 {
127 let v = track!(self.next_str())?;
128 let v = track!(parse_cow_str(v))?;
129 track!(visitor.visit_u32(v))
130 }
131
132 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
133 where
134 V: Visitor<'de>,
135 {
136 let v = track!(self.next_str())?;
137 let v = track!(parse_cow_str(v))?;
138 track!(visitor.visit_u64(v))
139 }
140
141 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
142 where
143 V: Visitor<'de>,
144 {
145 let v = track!(self.next_str())?;
146 let v = track!(parse_cow_str(v))?;
147 track!(visitor.visit_f32(v))
148 }
149
150 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
151 where
152 V: Visitor<'de>,
153 {
154 let v = track!(self.next_str())?;
155 let v = track!(parse_cow_str(v))?;
156 track!(visitor.visit_f64(v))
157 }
158
159 fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value>
160 where
161 V: Visitor<'de>,
162 {
163 track_panic!(ErrorKind::Invalid, "Unsupported");
164 }
165
166 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
167 where
168 V: Visitor<'de>,
169 {
170 let v = track!(self.next_str())?;
171 match v {
172 Cow::Borrowed(s) => track!(visitor.visit_borrowed_str(s)),
173 Cow::Owned(s) => track!(visitor.visit_string(s)),
174 }
175 }
176
177 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
178 where
179 V: Visitor<'de>,
180 {
181 track!(self.deserialize_str(visitor))
182 }
183
184 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
185 where
186 V: Visitor<'de>,
187 {
188 let v = track!(self.next_str())?;
189 match v {
190 Cow::Borrowed(s) => track!(visitor.visit_borrowed_bytes(s.as_bytes())),
191 Cow::Owned(s) => track!(visitor.visit_byte_buf(s.into_bytes())),
192 }
193 }
194
195 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
196 where
197 V: Visitor<'de>,
198 {
199 track!(self.deserialize_bytes(visitor))
200 }
201
202 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
203 where
204 V: Visitor<'de>,
205 {
206 track!(visitor.visit_some(self))
207 }
208
209 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
210 where
211 V: Visitor<'de>,
212 {
213 let v = track!(self.next_str())?;
214 track_assert!(v.is_empty(), ErrorKind::Invalid);
215 track!(visitor.visit_unit())
216 }
217
218 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
219 where
220 V: Visitor<'de>,
221 {
222 track!(self.deserialize_unit(visitor))
223 }
224
225 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
226 where
227 V: Visitor<'de>,
228 {
229 visitor.visit_newtype_struct(self)
230 }
231
232 fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value>
233 where
234 V: Visitor<'de>,
235 {
236 track_panic!(ErrorKind::Invalid);
237 }
238
239 fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
240 where
241 V: Visitor<'de>,
242 {
243 track_panic!(ErrorKind::Invalid);
244 }
245
246 fn deserialize_tuple_struct<V>(
247 self,
248 _name: &'static str,
249 _len: usize,
250 _visitor: V,
251 ) -> Result<V::Value>
252 where
253 V: Visitor<'de>,
254 {
255 track_panic!(ErrorKind::Invalid);
256 }
257
258 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
259 where
260 V: Visitor<'de>,
261 {
262 track_assert!(!self.in_map, ErrorKind::Invalid);
263 self.in_map = true;
264 track!(visitor.visit_map(self))
265 }
266
267 fn deserialize_struct<V>(
268 self,
269 _name: &'static str,
270 _fields: &'static [&'static str],
271 visitor: V,
272 ) -> Result<V::Value>
273 where
274 V: Visitor<'de>,
275 {
276 track!(self.deserialize_map(visitor))
277 }
278
279 fn deserialize_enum<V>(
280 self,
281 _name: &'static str,
282 _variants: &'static [&'static str],
283 _visitor: V,
284 ) -> Result<V::Value>
285 where
286 V: Visitor<'de>,
287 {
288 track_panic!(ErrorKind::Invalid);
289 }
290
291 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
292 where
293 V: Visitor<'de>,
294 {
295 track!(self.deserialize_str(visitor))
296 }
297
298 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
299 where
300 V: Visitor<'de>,
301 {
302 track!(visitor.visit_unit()) }
304}
305impl<'de, 'a> de::MapAccess<'de> for &'a mut UrlQueryDeserializer<'de> {
306 type Error = Error;
307 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
308 where
309 K: de::DeserializeSeed<'de>,
310 {
311 if self.is_end_of_query() {
312 Ok(None)
313 } else {
314 let v = track!(seed.deserialize(&mut **self))?;
315 Ok(Some(v))
316 }
317 }
318 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
319 where
320 V: de::DeserializeSeed<'de>,
321 {
322 let v = track!(seed.deserialize(&mut **self))?;
323 Ok(v)
324 }
325}
326
327fn parse_cow_str<T: std::str::FromStr>(s: Cow<str>) -> Result<T>
328where
329 Error: From<T::Err>,
330{
331 let v = track!(s.parse().map_err(Error::from), "s={:?}", s)?;
332 Ok(v)
333}
334
335#[cfg(test)]
336mod test {
337 use serde::Deserialize;
338 use url::Url;
339 use super::*;
340
341 #[test]
342 fn struct_works() {
343 #[derive(Deserialize)]
344 struct Params {
345 foo: Option<usize>,
346 bar: String,
347 }
348
349 let url = Url::parse("http://localhost/?bar=baz+qux").unwrap();
350 {
351 let mut deserializer = UrlQueryDeserializer::new(url.query_pairs());
352 let params = track_try_unwrap!(Params::deserialize(&mut deserializer));
353 assert_eq!(params.foo, None);
354 assert_eq!(params.bar, "baz qux");
355 }
356
357 let url = Url::parse("http://localhost/?foo=10&bar=baz+qux").unwrap();
358 {
359 let mut deserializer = UrlQueryDeserializer::new(url.query_pairs());
360 let params = track_try_unwrap!(Params::deserialize(&mut deserializer));
361 assert_eq!(params.foo, Some(10));
362 assert_eq!(params.bar, "baz qux");
363 }
364 }
365}