1use std::fmt;
35use std::string::FromUtf8Error;
36
37#[derive(Debug)]
39pub enum DecodeError {
40 TruncatedSequence { pos: usize },
42 InvalidHexDigit { ch: char, pos: usize },
44 InvalidUtf8(FromUtf8Error),
46}
47
48impl fmt::Display for DecodeError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 DecodeError::TruncatedSequence { pos } => {
52 write!(f, "truncated percent-sequence at position {pos}")
53 }
54 DecodeError::InvalidHexDigit { ch, pos } => {
55 write!(f, "invalid hex digit {ch:?} at position {pos}")
56 }
57 DecodeError::InvalidUtf8(err) => {
58 write!(f, "decoded bytes are not valid UTF-8: {err}")
59 }
60 }
61 }
62}
63
64impl std::error::Error for DecodeError {
65 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66 match self {
67 DecodeError::InvalidUtf8(err) => Some(err),
68 _ => None,
69 }
70 }
71}
72
73impl From<FromUtf8Error> for DecodeError {
74 fn from(err: FromUtf8Error) -> Self {
75 DecodeError::InvalidUtf8(err)
76 }
77}
78
79pub fn encode(input: &str) -> String {
92 encode_bytes(input.as_bytes())
93}
94
95pub fn encode_bytes(input: &[u8]) -> String {
107 let mut output = String::with_capacity(input.len());
108 for &byte in input {
109 if is_unreserved(byte) {
110 output.push(byte as char);
111 } else {
112 output.push('%');
113 output.push(hex_digit(byte >> 4));
114 output.push(hex_digit(byte & 0x0F));
115 }
116 }
117 output
118}
119
120pub fn decode(input: &str) -> Result<String, DecodeError> {
143 let bytes = input.as_bytes();
144 let mut output = Vec::with_capacity(bytes.len());
145 let mut i = 0;
146 while i < bytes.len() {
147 match bytes[i] {
148 b'+' => {
149 output.push(b' ');
150 i += 1;
151 }
152 b'%' => {
153 if i + 2 >= bytes.len() {
154 return Err(DecodeError::TruncatedSequence { pos: i });
155 }
156 let high = from_hex(bytes[i + 1]).ok_or(DecodeError::InvalidHexDigit {
157 ch: bytes[i + 1] as char,
158 pos: i + 1,
159 })?;
160 let low = from_hex(bytes[i + 2]).ok_or(DecodeError::InvalidHexDigit {
161 ch: bytes[i + 2] as char,
162 pos: i + 2,
163 })?;
164 output.push((high << 4) | low);
165 i += 3;
166 }
167 byte => {
168 output.push(byte);
169 i += 1;
170 }
171 }
172 }
173 Ok(String::from_utf8(output)?)
174}
175
176pub fn decode_lossy(input: &str) -> String {
190 let bytes = input.as_bytes();
191 let mut output = String::with_capacity(bytes.len());
192 let mut i = 0;
193 while i < bytes.len() {
194 match bytes[i] {
195 b'+' => {
196 output.push(' ');
197 i += 1;
198 }
199 b'%' if i + 2 < bytes.len() => {
200 if let (Some(high), Some(low)) = (from_hex(bytes[i + 1]), from_hex(bytes[i + 2])) {
201 let decoded_byte = (high << 4) | low;
202 output.push(decoded_byte as char);
203 i += 3;
204 } else {
205 output.push('%');
206 i += 1;
207 }
208 }
209 byte => {
210 output.push(byte as char);
211 i += 1;
212 }
213 }
214 }
215 output
216}
217
218fn is_unreserved(byte: u8) -> bool {
219 byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~')
220}
221
222fn hex_digit(nibble: u8) -> char {
223 match nibble {
224 0..=9 => (b'0' + nibble) as char,
225 10..=15 => (b'A' + nibble - 10) as char,
226 _ => unreachable!(),
227 }
228}
229
230fn from_hex(byte: u8) -> Option<u8> {
231 match byte {
232 b'0'..=b'9' => Some(byte - b'0'),
233 b'a'..=b'f' => Some(byte - b'a' + 10),
234 b'A'..=b'F' => Some(byte - b'A' + 10),
235 _ => None,
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn encode_unreserved_unchanged() {
245 assert_eq!(encode("abcABC123-_.~"), "abcABC123-_.~");
246 }
247
248 #[test]
249 fn encode_space() {
250 assert_eq!(encode("hello world"), "hello%20world");
251 }
252
253 #[test]
254 fn encode_special_chars() {
255 assert_eq!(encode("a+b=c&d"), "a%2Bb%3Dc%26d");
256 }
257
258 #[test]
259 fn encode_slash_and_query() {
260 assert_eq!(encode("/path?q=1"), "%2Fpath%3Fq%3D1");
261 }
262
263 #[test]
264 fn encode_bytes_high() {
265 assert_eq!(encode_bytes(b"\x00\x7F\xFF"), "%00%7F%FF");
266 }
267
268 #[test]
269 fn decode_basic() {
270 assert_eq!(decode("hello%20world").unwrap(), "hello world");
271 }
272
273 #[test]
274 fn decode_plus_as_space() {
275 assert_eq!(decode("hello+world").unwrap(), "hello world");
276 }
277
278 #[test]
279 fn decode_mixed_case_hex() {
280 assert_eq!(decode("a%2bb").unwrap(), "a+b");
281 assert_eq!(decode("a%2Bb").unwrap(), "a+b");
282 }
283
284 #[test]
285 fn decode_error_bad_hex() {
286 assert!(matches!(
287 decode("bad%GG").unwrap_err(),
288 DecodeError::InvalidHexDigit { ch: 'G', pos: 4 }
289 ));
290 }
291
292 #[test]
293 fn decode_error_truncated() {
294 assert!(matches!(
295 decode("bad%2").unwrap_err(),
296 DecodeError::TruncatedSequence { pos: 3 }
297 ));
298 }
299
300 #[test]
301 fn decode_lossy_passthrough() {
302 assert_eq!(decode_lossy("bad%GGvalue"), "bad%GGvalue");
303 }
304
305 #[test]
306 fn roundtrip() {
307 let original = "hello/world?q=rust & encoding!";
308 assert_eq!(decode(&encode(original)).unwrap(), original);
309 }
310}