1use serde::de::{self, Visitor};
2use std;
3use std::iter::Peekable;
4use std::str::Split;
5use trackable::error::ErrorKindExt;
6use url::Url;
7
8use types::EntryPoint;
9use {Error, ErrorKind, Result};
10
11#[derive(Debug)]
13pub struct UrlPathDeserializer<'de> {
14    in_seq: bool,
15    segments: Peekable<Split<'de, char>>,
16    entry_point: EntryPoint,
17    index: usize,
18    free_vars: usize,
19}
20impl<'de> UrlPathDeserializer<'de> {
21    pub fn new(entry_point: EntryPoint, url: &'de Url) -> Result<Self> {
23        let segments = track!(url
24            .path_segments()
25            .ok_or_else(|| ErrorKind::Invalid.error(),))?;
26        Ok(UrlPathDeserializer {
27            in_seq: false,
28            segments: segments.peekable(),
29            entry_point,
30            index: 0,
31            free_vars: entry_point.var_count(),
32        })
33    }
34    fn is_end_of_segment(&mut self) -> bool {
35        self.segments.peek().is_none()
36    }
37    fn finish(&mut self) -> Result<()> {
38        for segment in &self.entry_point.segments()[self.index..] {
39            let expected = track!(segment
40                .as_option()
41                .ok_or_else(|| ErrorKind::Invalid.error(),))?;
42            let actual = track!(self
43                .segments
44                .next()
45                .ok_or_else(|| ErrorKind::Invalid.error(),))?;
46            track_assert_eq!(actual, expected, ErrorKind::Invalid);
47        }
48        Ok(())
49    }
50    fn next_value(&mut self) -> Result<&'de str> {
51        track_assert!(
52            self.index < self.entry_point.segments().len(),
53            ErrorKind::Invalid
54        );
55        track_assert!(!self.is_end_of_segment(), ErrorKind::Invalid);
56        let i = self.index;
57        self.index += 1;
58        if let Some(expected) = self.entry_point.segments()[i].as_option() {
59            let s = self.segments.next().unwrap();
60            track_assert_eq!(s, expected, ErrorKind::Invalid);
61            self.next_value()
62        } else {
63            self.free_vars -= 1;
64            let s = self.segments.next().unwrap();
65            Ok(s)
66        }
67    }
68}
69impl<'de, 'a> de::Deserializer<'de> for &'a mut UrlPathDeserializer<'de> {
70    type Error = Error;
71    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
72    where
73        V: Visitor<'de>,
74    {
75        track_panic!(ErrorKind::Other, "unreachable");
76    }
77
78    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
79    where
80        V: Visitor<'de>,
81    {
82        let v = track!(self.next_value())?;
83        let v = track!(parse_escaped_str(v))?;
84        track!(visitor.visit_bool(v))
85    }
86
87    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
88    where
89        V: Visitor<'de>,
90    {
91        let v = track!(self.next_value())?;
92        let v = track!(parse_escaped_str(v))?;
93        track!(visitor.visit_i8(v))
94    }
95
96    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
97    where
98        V: Visitor<'de>,
99    {
100        let v = track!(self.next_value())?;
101        let v = track!(parse_escaped_str(v))?;
102        track!(visitor.visit_i16(v))
103    }
104
105    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
106    where
107        V: Visitor<'de>,
108    {
109        let v = track!(self.next_value())?;
110        let v = track!(parse_escaped_str(v))?;
111        track!(visitor.visit_i32(v))
112    }
113
114    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
115    where
116        V: Visitor<'de>,
117    {
118        let v = track!(self.next_value())?;
119        let v = track!(parse_escaped_str(v))?;
120        track!(visitor.visit_i64(v))
121    }
122
123    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
124    where
125        V: Visitor<'de>,
126    {
127        let v = track!(self.next_value())?;
128        let v = track!(parse_escaped_str(v))?;
129        track!(visitor.visit_u8(v))
130    }
131
132    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
133    where
134        V: Visitor<'de>,
135    {
136        let v = track!(self.next_value())?;
137        let v = track!(parse_escaped_str(v))?;
138        track!(visitor.visit_u16(v))
139    }
140
141    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
142    where
143        V: Visitor<'de>,
144    {
145        let v = track!(self.next_value())?;
146        let v = track!(parse_escaped_str(v))?;
147        track!(visitor.visit_u32(v))
148    }
149
150    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
151    where
152        V: Visitor<'de>,
153    {
154        let v = track!(self.next_value())?;
155        let v = track!(parse_escaped_str(v))?;
156        track!(visitor.visit_u64(v))
157    }
158
159    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
160    where
161        V: Visitor<'de>,
162    {
163        let v = track!(self.next_value())?;
164        let v = track!(parse_escaped_str(v))?;
165        track!(visitor.visit_f32(v))
166    }
167
168    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
169    where
170        V: Visitor<'de>,
171    {
172        let v = track!(self.next_value())?;
173        let v = track!(parse_escaped_str(v))?;
174        track!(visitor.visit_f64(v))
175    }
176
177    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value>
178    where
179        V: Visitor<'de>,
180    {
181        track_panic!(ErrorKind::Invalid, "Unsupported");
182    }
183
184    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
185    where
186        V: Visitor<'de>,
187    {
188        let v = track!(self.next_value())?;
189        let s = track!(percent_decode(v))?;
190        track!(visitor.visit_string(s))
191    }
192
193    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
194    where
195        V: Visitor<'de>,
196    {
197        track!(self.deserialize_str(visitor))
198    }
199
200    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
201    where
202        V: Visitor<'de>,
203    {
204        track_panic!(ErrorKind::Invalid);
205    }
206
207    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
208    where
209        V: Visitor<'de>,
210    {
211        track!(self.deserialize_bytes(visitor))
212    }
213
214    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value>
215    where
216        V: Visitor<'de>,
217    {
218        track_panic!(ErrorKind::Invalid);
219    }
220
221    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
222    where
223        V: Visitor<'de>,
224    {
225        track_panic!(ErrorKind::Invalid);
226    }
227
228    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
229    where
230        V: Visitor<'de>,
231    {
232        track!(self.deserialize_unit(visitor))
233    }
234
235    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
236    where
237        V: Visitor<'de>,
238    {
239        visitor.visit_newtype_struct(self)
240    }
241
242    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
243    where
244        V: Visitor<'de>,
245    {
246        track_assert!(!self.in_seq, ErrorKind::Invalid);
247        self.in_seq = true;
248        track!(visitor.visit_seq(self))
249    }
250
251    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
252    where
253        V: Visitor<'de>,
254    {
255        track!(self.deserialize_seq(visitor))
256    }
257
258    fn deserialize_tuple_struct<V>(
259        self,
260        _name: &'static str,
261        _len: usize,
262        visitor: V,
263    ) -> Result<V::Value>
264    where
265        V: Visitor<'de>,
266    {
267        track!(self.deserialize_seq(visitor))
268    }
269
270    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
271    where
272        V: Visitor<'de>,
273    {
274        track_panic!(ErrorKind::Invalid);
275    }
276
277    fn deserialize_struct<V>(
278        self,
279        _name: &'static str,
280        _fields: &'static [&'static str],
281        visitor: V,
282    ) -> Result<V::Value>
283    where
284        V: Visitor<'de>,
285    {
286        track!(self.deserialize_map(visitor))
287    }
288
289    fn deserialize_enum<V>(
290        self,
291        _name: &'static str,
292        _variants: &'static [&'static str],
293        _visitor: V,
294    ) -> Result<V::Value>
295    where
296        V: Visitor<'de>,
297    {
298        track_panic!(ErrorKind::Invalid);
299    }
300
301    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
302    where
303        V: Visitor<'de>,
304    {
305        track!(self.deserialize_str(visitor))
306    }
307
308    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
309    where
310        V: Visitor<'de>,
311    {
312        track!(visitor.visit_unit()) }
314}
315impl<'de, 'a> de::SeqAccess<'de> for &'a mut UrlPathDeserializer<'de> {
316    type Error = Error;
317    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
318    where
319        T: de::DeserializeSeed<'de>,
320    {
321        if self.free_vars == 0 {
322            track!(self.finish())?;
323            Ok(None)
324        } else {
325            let v = track!(seed.deserialize(&mut **self))?;
326            Ok(Some(v))
327        }
328    }
329}
330
331fn parse_escaped_str<T: std::str::FromStr>(s: &str) -> Result<T>
332where
333    Error: From<T::Err>,
334{
335    let s = percent_decode(s)?;
336    let v = track!(s.parse().map_err(Error::from), "s={:?}", s)?;
337    Ok(v)
338}
339
340fn percent_decode(s: &str) -> Result<String> {
341    let mut chars = s.chars();
342    let mut decoded = String::new();
343    while let Some(c) = chars.next() {
344        if c == '%' {
345            let h0 = track!(chars
346                .next()
347                .ok_or_else(|| ErrorKind::Invalid.cause("Invalid escaped char")))?;
348            let h1 = track!(chars
349                .next()
350                .ok_or_else(|| ErrorKind::Invalid.cause("Invalid escaped char")))?;
351            let code = track!(u32::from_str_radix(&format!("{h0}{h1}"), 16)
352                .map_err(|e| track!(ErrorKind::Invalid.cause(e))))?;
353            let c = track!(char::from_u32(code)
354                .ok_or_else(|| ErrorKind::Invalid.cause("Invalid escaped char")))?;
355            decoded.push(c);
356            continue;
357        }
358        decoded.push(c);
359    }
360    Ok(decoded)
361}
362
363#[cfg(test)]
364mod test {
365    use serde::Deserialize;
366    use url::Url;
367
368    use super::*;
369
370    #[test]
371    fn it_works() {
372        let entry_point = htrpc_entry_point!["foo", _, "baz", _];
373
374        #[derive(Deserialize)]
375        struct Args(String, usize);
376
377        let url = Url::parse("http://localhost/foo/hello%20world/baz/3").unwrap();
378        let mut deserializer = track_try_unwrap!(UrlPathDeserializer::new(entry_point, &url));
379        let Args(v0, v1) = track_try_unwrap!(Args::deserialize(&mut deserializer));
380        assert_eq!(v0, "hello world");
381        assert_eq!(v1, 3);
382    }
383}