stfu8/
decode.rs

1/* Copyright (c) 2018 Garrett Berg, vitiral@gmail.com
2 *
3 * Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 * http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 * http://opensource.org/licenses/MIT>, at your option. This file may not be
6 * copied, modified, or distributed except according to those terms.
7 */
8
9use std::char;
10use std::error::Error;
11use std::fmt;
12
13#[derive(Debug, PartialEq)]
14pub enum DecodeErrorKind {
15    /// A single unescaped backslash was found. Either the following character doesn't
16    /// start a valid escape sequence or it is at the end of the string.
17    UnescapedSlash,
18    /// The value from a '\x' or '\u' hexadecimal escape sequence is out of range for the decode.
19    InvalidValue,
20    /// There are not enough characters after a '\x' or '\u' to build a escape sequence.
21    HexNumberToShort,
22    /// The required characters after a '\x' or '\u' are not all valid hex digits.
23    InvalidHexDigit,
24}
25
26#[derive(Debug)]
27pub struct DecodeError {
28    pub kind: DecodeErrorKind,
29    pub index: usize,
30    pub(crate) mat: String,
31}
32
33pub(crate) enum PushGeneric<'a> {
34    /// Push a value that may be invalid.
35    Value { start: usize, val: u32 },
36    /// Push an always-valid string.
37    String(&'a str),
38}
39
40/// Decode generically
41pub(crate) fn decode_generic<F>(mut push_val: F, s: &str) -> Result<(), DecodeError>
42where
43    F: FnMut(PushGeneric) -> Result<(), DecodeError>,
44{
45    let mut string = s;
46    let mut offset = 0;
47
48    while let Some(byte_index) = string.find('\\') {
49        if byte_index > 0 {
50            push_val(PushGeneric::String(&string[..byte_index]))?;
51        }
52        // byte index of the backslash in the original string
53        let start_idx = offset + byte_index;
54        let rest = string.len() - byte_index;
55        if rest < 2 {
56            Err(DecodeError {
57                index: start_idx,
58                kind: DecodeErrorKind::UnescapedSlash,
59                mat: string[byte_index..].to_string(),
60            })?
61        }
62
63        // macro to create a PushGeneric::Value
64        macro_rules! pg_value {
65            ( $v:expr ) => {{
66                PushGeneric::Value {
67                    start: start_idx,
68                    val: $v as u32,
69                }
70            }};
71        }
72        let consumed_bytes = match &string.as_bytes()[byte_index + 1] {
73            b't' => {
74                push_val(pg_value!(b'\t'))?;
75                2
76            }
77            b'n' => {
78                push_val(pg_value!(b'\n'))?;
79                2
80            }
81            b'r' => {
82                push_val(pg_value!(b'\r'))?;
83                2
84            }
85            b'\\' => {
86                push_val(pg_value!(b'\\'))?;
87                2
88            }
89            b'x' => {
90                if rest < 4 {
91                    Err(DecodeError {
92                        index: start_idx,
93                        kind: DecodeErrorKind::HexNumberToShort,
94                        mat: string[byte_index..].to_string(),
95                    })?
96                }
97
98                match u32::from_str_radix(&string[(byte_index + 2)..(byte_index + 4)], 16) {
99                    Ok(x) => push_val(pg_value!(x)),
100                    Err(_) => Err(DecodeError {
101                        index: start_idx,
102                        kind: DecodeErrorKind::InvalidHexDigit,
103                        mat: s.to_string(),
104                    }),
105                }?;
106                4
107            }
108            b'u' => {
109                if rest < 8 {
110                    Err(DecodeError {
111                        index: start_idx,
112                        kind: DecodeErrorKind::HexNumberToShort,
113                        mat: string[byte_index..].to_string(),
114                    })?
115                }
116
117                let c32 = match u32::from_str_radix(&string[(byte_index + 2)..(byte_index + 8)], 16)
118                {
119                    Ok(x) => Ok(x),
120                    Err(_) => Err(DecodeError {
121                        index: start_idx,
122                        kind: DecodeErrorKind::InvalidHexDigit,
123                        mat: s.to_string(),
124                    }),
125                }?;
126
127                match char::from_u32(c32) {
128                    // It is a valid UTF code point. Always
129                    // decode it as such.
130                    Some(c) => push_val(PushGeneric::String(&c.to_string())),
131                    // It is not a valid code point. Still try
132                    // to record it's value "as is".
133                    None => push_val(pg_value!(c32)),
134                }?;
135                8
136            }
137            _ => Err(DecodeError {
138                index: start_idx,
139                kind: DecodeErrorKind::UnescapedSlash,
140                mat: string[byte_index..].to_string(),
141            })?,
142        };
143
144        string = &string[(byte_index + consumed_bytes)..];
145        offset += byte_index + consumed_bytes;
146    }
147    push_val(PushGeneric::String(string))?;
148    Ok(())
149}
150
151impl Error for DecodeError {
152    fn description(&self) -> &str {
153        match self.kind {
154            DecodeErrorKind::UnescapedSlash => r#"Found unmatched '\'. Use "\\" to escape slashes"#,
155            DecodeErrorKind::InvalidValue => r#"Escaped value is out of range of the decoder"#,
156            DecodeErrorKind::HexNumberToShort => r#"Not enough characters after "\x" or "\u""#,
157            DecodeErrorKind::InvalidHexDigit => r#"Invalid hex digit after "\x" or "\u""#,
158        }
159    }
160}
161
162impl fmt::Display for DecodeError {
163    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
164        write!(
165            f,
166            "{} when decoding {:?} [index={}]",
167            self.index, self, self.mat
168        )
169    }
170}
171
172#[cfg(test)]
173mod error_tests {
174    use crate::{decode::PushGeneric, DecodeError, DecodeErrorKind};
175
176    use super::decode_generic;
177
178    fn do_error_test(string: &str, err_index: usize, err_kind: DecodeErrorKind) {
179        let mut out: Vec<u8> = Vec::new();
180        let f = |val: PushGeneric| -> Result<(), DecodeError> {
181            match val {
182                PushGeneric::Value { val, start: _ } => {
183                    out.push(val as u8);
184                    Ok(())
185                }
186                PushGeneric::String(s) => {
187                    out.extend_from_slice(s.as_bytes());
188                    Ok(())
189                }
190            }
191        };
192
193        let result = decode_generic(f, string);
194
195        assert!(result.is_err());
196        let err = result.err().unwrap();
197        assert_eq!(err_index, err.index);
198        assert_eq!(err_kind, err.kind);
199    }
200
201    #[test]
202    fn test_error_unescaped_backslash() {
203        do_error_test(r"foo\bar", 3, DecodeErrorKind::UnescapedSlash)
204    }
205
206    #[test]
207    fn test_error_unescaped_backslash_2() {
208        do_error_test(r"foo\n\bar", 5, DecodeErrorKind::UnescapedSlash)
209    }
210
211    #[test]
212    fn test_error_unescaped_backslash_end() {
213        do_error_test(r"foo\", 3, DecodeErrorKind::UnescapedSlash)
214    }
215
216    #[test]
217    fn test_error_unescaped_backslash_end_2() {
218        do_error_test(r"foo\nbar\", 8, DecodeErrorKind::UnescapedSlash);
219    }
220
221    #[test]
222    fn test_error_escape_no_digits() {
223        do_error_test(r"foo\nbar\x", 8, DecodeErrorKind::HexNumberToShort);
224    }
225
226    #[test]
227    fn test_error_short_x_escape() {
228        do_error_test(r"foo\nbar\x1", 8, DecodeErrorKind::HexNumberToShort);
229    }
230
231    #[test]
232    fn test_error_short_u_escape() {
233        do_error_test(r"foo\nbar\u12345", 8, DecodeErrorKind::HexNumberToShort);
234    }
235
236    #[test]
237    fn test_error_invalid_hex_char() {
238        do_error_test(r"foo\nbar\xax", 8, DecodeErrorKind::InvalidHexDigit);
239    }
240}