1use std::fmt;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum EscapeError {
29 TrailingBackslash { pos: usize },
31 UnknownEscape { pos: usize, ch: char },
33 ShortHex { pos: usize },
35 BadHex { pos: usize },
37}
38
39impl fmt::Display for EscapeError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 EscapeError::TrailingBackslash { pos } => {
43 write!(f, "trailing backslash at byte {pos}")
44 }
45 EscapeError::UnknownEscape { pos, ch } => {
46 write!(f, "unknown escape '\\{ch}' at byte {pos}")
47 }
48 EscapeError::ShortHex { pos } => {
49 write!(
50 f,
51 "truncated \\xNN escape at byte {pos} (need two hex digits)"
52 )
53 }
54 EscapeError::BadHex { pos } => {
55 write!(f, "invalid hex digit at byte {pos}")
56 }
57 }
58 }
59}
60
61impl std::error::Error for EscapeError {}
62
63pub fn decode_bytes_escape(s: &str) -> Result<Vec<u8>, EscapeError> {
74 let bytes = s.as_bytes();
75 let mut out = Vec::with_capacity(bytes.len());
76 let mut i = 0;
77 while i < bytes.len() {
78 let b = bytes[i];
79 if b != b'\\' {
80 out.push(b);
81 i += 1;
82 continue;
83 }
84 if i + 1 >= bytes.len() {
86 return Err(EscapeError::TrailingBackslash { pos: i });
87 }
88 let next = bytes[i + 1];
89 match next {
90 b'r' => {
91 out.push(0x0d);
92 i += 2;
93 }
94 b'n' => {
95 out.push(0x0a);
96 i += 2;
97 }
98 b't' => {
99 out.push(0x09);
100 i += 2;
101 }
102 b'0' => {
103 out.push(0x00);
104 i += 2;
105 }
106 b'\\' => {
107 out.push(b'\\');
108 i += 2;
109 }
110 b'x' => {
111 if i + 3 >= bytes.len() {
113 return Err(EscapeError::ShortHex { pos: i });
114 }
115 let h1 = bytes[i + 2];
116 let h2 = bytes[i + 3];
117 let v1 = hex_digit(h1).ok_or(EscapeError::BadHex { pos: i + 2 })?;
118 let v2 = hex_digit(h2).ok_or(EscapeError::BadHex { pos: i + 3 })?;
119 out.push((v1 << 4) | v2);
120 i += 4;
121 }
122 _ => {
123 let ch = if next < 0x80 { next as char } else { '?' };
125 return Err(EscapeError::UnknownEscape { pos: i + 1, ch });
126 }
127 }
128 }
129 Ok(out)
130}
131
132fn hex_digit(b: u8) -> Option<u8> {
133 match b {
134 b'0'..=b'9' => Some(b - b'0'),
135 b'a'..=b'f' => Some(b - b'a' + 10),
136 b'A'..=b'F' => Some(b - b'A' + 10),
137 _ => None,
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn decodes_carriage_return() {
147 assert_eq!(
148 decode_bytes_escape("reboot\\r").unwrap(),
149 vec![b'r', b'e', b'b', b'o', b'o', b't', 0x0d]
150 );
151 }
152
153 #[test]
154 fn decodes_repeated_hex_escapes() {
155 assert_eq!(
156 decode_bytes_escape("\\x03\\x03\\x03").unwrap(),
157 vec![0x03, 0x03, 0x03]
158 );
159 }
160
161 #[test]
162 fn decodes_simple_escape_set() {
163 assert_eq!(
164 decode_bytes_escape("\\n\\t\\0\\\\").unwrap(),
165 vec![0x0a, 0x09, 0x00, b'\\']
166 );
167 }
168
169 #[test]
170 fn bad_hex_reports_position() {
171 let err = decode_bytes_escape("abc\\xZ1").unwrap_err();
173 assert_eq!(err, EscapeError::BadHex { pos: 5 });
174 }
175
176 #[test]
177 fn trailing_backslash_reports_position() {
178 let err = decode_bytes_escape("abc\\").unwrap_err();
179 assert_eq!(err, EscapeError::TrailingBackslash { pos: 3 });
180 }
181
182 #[test]
183 fn short_hex_reports_position() {
184 let err = decode_bytes_escape("\\x4").unwrap_err();
186 assert_eq!(err, EscapeError::ShortHex { pos: 0 });
187 }
188
189 #[test]
190 fn unknown_escape_reports_position_and_char() {
191 let err = decode_bytes_escape("\\q").unwrap_err();
192 assert_eq!(err, EscapeError::UnknownEscape { pos: 1, ch: 'q' });
193 }
194
195 #[test]
196 fn utf8_passes_through_byte_for_byte() {
197 assert_eq!(decode_bytes_escape("é").unwrap(), vec![0xc3, 0xa9]);
199 }
200
201 #[test]
202 fn non_ascii_after_backslash_maps_to_question_mark() {
203 let err = decode_bytes_escape("\\é").unwrap_err();
205 assert_eq!(err, EscapeError::UnknownEscape { pos: 1, ch: '?' });
206 }
207
208 #[test]
209 fn escape_error_displays_and_implements_error_trait() {
210 let e = EscapeError::TrailingBackslash { pos: 3 };
211 assert_eq!(format!("{e}"), "trailing backslash at byte 3");
212 let _: &dyn std::error::Error = &e;
213 let _cloned = e.clone();
215 }
216
217 #[test]
218 fn upper_and_lower_hex_digits_accepted() {
219 assert_eq!(decode_bytes_escape("\\xAB").unwrap(), vec![0xab]);
220 assert_eq!(decode_bytes_escape("\\xab").unwrap(), vec![0xab]);
221 assert_eq!(decode_bytes_escape("\\xFf").unwrap(), vec![0xff]);
222 }
223}