1use std::char;
10use std::error::Error;
11use std::fmt;
12
13#[derive(Debug, PartialEq)]
14pub enum DecodeErrorKind {
15 UnescapedSlash,
18 InvalidValue,
20 HexNumberToShort,
22 InvalidHexDigit,
24}
25
26#[derive(Debug)]
27pub struct DecodeError {
28 pub kind: DecodeErrorKind,
29 pub index: usize,
30 pub(crate) mat: String,
31}
32
33pub(crate) enum PushGeneric<'a> {
34 Value { start: usize, val: u32 },
36 String(&'a str),
38}
39
40pub(crate) fn decode_generic<F>(mut push_val: F, s: &str) -> Result<(), DecodeError>
42where
43 F: FnMut(PushGeneric) -> Result<(), DecodeError>,
44{
45 let mut string = s;
46 let mut offset = 0;
47
48 while let Some(byte_index) = string.find('\\') {
49 if byte_index > 0 {
50 push_val(PushGeneric::String(&string[..byte_index]))?;
51 }
52 let start_idx = offset + byte_index;
54 let rest = string.len() - byte_index;
55 if rest < 2 {
56 Err(DecodeError {
57 index: start_idx,
58 kind: DecodeErrorKind::UnescapedSlash,
59 mat: string[byte_index..].to_string(),
60 })?
61 }
62
63 macro_rules! pg_value {
65 ( $v:expr ) => {{
66 PushGeneric::Value {
67 start: start_idx,
68 val: $v as u32,
69 }
70 }};
71 }
72 let consumed_bytes = match &string.as_bytes()[byte_index + 1] {
73 b't' => {
74 push_val(pg_value!(b'\t'))?;
75 2
76 }
77 b'n' => {
78 push_val(pg_value!(b'\n'))?;
79 2
80 }
81 b'r' => {
82 push_val(pg_value!(b'\r'))?;
83 2
84 }
85 b'\\' => {
86 push_val(pg_value!(b'\\'))?;
87 2
88 }
89 b'x' => {
90 if rest < 4 {
91 Err(DecodeError {
92 index: start_idx,
93 kind: DecodeErrorKind::HexNumberToShort,
94 mat: string[byte_index..].to_string(),
95 })?
96 }
97
98 match u32::from_str_radix(&string[(byte_index + 2)..(byte_index + 4)], 16) {
99 Ok(x) => push_val(pg_value!(x)),
100 Err(_) => Err(DecodeError {
101 index: start_idx,
102 kind: DecodeErrorKind::InvalidHexDigit,
103 mat: s.to_string(),
104 }),
105 }?;
106 4
107 }
108 b'u' => {
109 if rest < 8 {
110 Err(DecodeError {
111 index: start_idx,
112 kind: DecodeErrorKind::HexNumberToShort,
113 mat: string[byte_index..].to_string(),
114 })?
115 }
116
117 let c32 = match u32::from_str_radix(&string[(byte_index + 2)..(byte_index + 8)], 16)
118 {
119 Ok(x) => Ok(x),
120 Err(_) => Err(DecodeError {
121 index: start_idx,
122 kind: DecodeErrorKind::InvalidHexDigit,
123 mat: s.to_string(),
124 }),
125 }?;
126
127 match char::from_u32(c32) {
128 Some(c) => push_val(PushGeneric::String(&c.to_string())),
131 None => push_val(pg_value!(c32)),
134 }?;
135 8
136 }
137 _ => Err(DecodeError {
138 index: start_idx,
139 kind: DecodeErrorKind::UnescapedSlash,
140 mat: string[byte_index..].to_string(),
141 })?,
142 };
143
144 string = &string[(byte_index + consumed_bytes)..];
145 offset += byte_index + consumed_bytes;
146 }
147 push_val(PushGeneric::String(string))?;
148 Ok(())
149}
150
151impl Error for DecodeError {
152 fn description(&self) -> &str {
153 match self.kind {
154 DecodeErrorKind::UnescapedSlash => r#"Found unmatched '\'. Use "\\" to escape slashes"#,
155 DecodeErrorKind::InvalidValue => r#"Escaped value is out of range of the decoder"#,
156 DecodeErrorKind::HexNumberToShort => r#"Not enough characters after "\x" or "\u""#,
157 DecodeErrorKind::InvalidHexDigit => r#"Invalid hex digit after "\x" or "\u""#,
158 }
159 }
160}
161
162impl fmt::Display for DecodeError {
163 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
164 write!(
165 f,
166 "{} when decoding {:?} [index={}]",
167 self.index, self, self.mat
168 )
169 }
170}
171
172#[cfg(test)]
173mod error_tests {
174 use crate::{decode::PushGeneric, DecodeError, DecodeErrorKind};
175
176 use super::decode_generic;
177
178 fn do_error_test(string: &str, err_index: usize, err_kind: DecodeErrorKind) {
179 let mut out: Vec<u8> = Vec::new();
180 let f = |val: PushGeneric| -> Result<(), DecodeError> {
181 match val {
182 PushGeneric::Value { val, start: _ } => {
183 out.push(val as u8);
184 Ok(())
185 }
186 PushGeneric::String(s) => {
187 out.extend_from_slice(s.as_bytes());
188 Ok(())
189 }
190 }
191 };
192
193 let result = decode_generic(f, string);
194
195 assert!(result.is_err());
196 let err = result.err().unwrap();
197 assert_eq!(err_index, err.index);
198 assert_eq!(err_kind, err.kind);
199 }
200
201 #[test]
202 fn test_error_unescaped_backslash() {
203 do_error_test(r"foo\bar", 3, DecodeErrorKind::UnescapedSlash)
204 }
205
206 #[test]
207 fn test_error_unescaped_backslash_2() {
208 do_error_test(r"foo\n\bar", 5, DecodeErrorKind::UnescapedSlash)
209 }
210
211 #[test]
212 fn test_error_unescaped_backslash_end() {
213 do_error_test(r"foo\", 3, DecodeErrorKind::UnescapedSlash)
214 }
215
216 #[test]
217 fn test_error_unescaped_backslash_end_2() {
218 do_error_test(r"foo\nbar\", 8, DecodeErrorKind::UnescapedSlash);
219 }
220
221 #[test]
222 fn test_error_escape_no_digits() {
223 do_error_test(r"foo\nbar\x", 8, DecodeErrorKind::HexNumberToShort);
224 }
225
226 #[test]
227 fn test_error_short_x_escape() {
228 do_error_test(r"foo\nbar\x1", 8, DecodeErrorKind::HexNumberToShort);
229 }
230
231 #[test]
232 fn test_error_short_u_escape() {
233 do_error_test(r"foo\nbar\u12345", 8, DecodeErrorKind::HexNumberToShort);
234 }
235
236 #[test]
237 fn test_error_invalid_hex_char() {
238 do_error_test(r"foo\nbar\xax", 8, DecodeErrorKind::InvalidHexDigit);
239 }
240}