use crate::error::PctErrorKind;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Mode {
Path,
Form,
}
pub fn decode<'a>(
input: &'a [u8],
mode: Mode,
out: &'a mut [u8],
) -> Result<&'a [u8], PctErrorKind> {
let needs_work = match mode {
Mode::Path => input.contains(&b'%'),
Mode::Form => input.iter().any(|&b| b == b'%' || b == b'+'),
};
if !needs_work {
return Ok(input);
}
let written = decode_into(input, mode, out)?;
Ok(&out[..written])
}
pub fn decoded_len(input: &[u8], mode: Mode) -> Result<usize, PctErrorKind> {
let mut i = 0;
let mut n = 0;
while i < input.len() {
match input[i] {
b'%' => {
if i + 2 >= input.len() {
return Err(PctErrorKind::InvalidEscape);
}
if hex_pair(input[i + 1], input[i + 2]).is_none() {
return Err(PctErrorKind::InvalidEscape);
}
i += 3;
n += 1;
}
b'+' if matches!(mode, Mode::Form) => {
i += 1;
n += 1;
}
_ => {
i += 1;
n += 1;
}
}
}
Ok(n)
}
pub fn decode_into(input: &[u8], mode: Mode, out: &mut [u8]) -> Result<usize, PctErrorKind> {
let mut i = 0;
let mut w = 0;
while i < input.len() {
let b = input[i];
let decoded = match b {
b'%' => {
if i + 2 >= input.len() {
return Err(PctErrorKind::InvalidEscape);
}
let v = hex_pair(input[i + 1], input[i + 2]).ok_or(PctErrorKind::InvalidEscape)?;
i += 3;
v
}
b'+' if matches!(mode, Mode::Form) => {
i += 1;
b' '
}
other => {
i += 1;
other
}
};
if w >= out.len() {
return Err(PctErrorKind::BufferTooSmall);
}
out[w] = decoded;
w += 1;
}
Ok(w)
}
#[inline]
const fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[inline]
const fn hex_pair(hi: u8, lo: u8) -> Option<u8> {
match (hex_digit(hi), hex_digit(lo)) {
(Some(h), Some(l)) => Some((h << 4) | l),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn path_no_escapes_is_zero_copy() {
let input = b"/users/foo";
let mut out = [0xAAu8; 16];
let got = decode(input, Mode::Path, &mut out).unwrap();
assert_eq!(got, input);
assert!(std::ptr::eq(got.as_ptr(), input.as_ptr()));
assert!(out.iter().all(|&b| b == 0xAA));
}
#[test]
fn path_decodes_space() {
let mut out = [0u8; 16];
let got = decode(b"foo%20bar", Mode::Path, &mut out).unwrap();
assert_eq!(got, b"foo bar");
}
#[test]
fn path_leaves_plus_alone() {
let mut out = [0u8; 16];
let got = decode(b"a+b%20c", Mode::Path, &mut out).unwrap();
assert_eq!(got, b"a+b c");
}
#[test]
fn form_decodes_plus_to_space() {
let mut out = [0u8; 16];
let got = decode(b"a+b", Mode::Form, &mut out).unwrap();
assert_eq!(got, b"a b");
}
#[test]
fn form_decodes_plus_and_percent() {
let mut out = [0u8; 16];
let got = decode(b"hello+world%21", Mode::Form, &mut out).unwrap();
assert_eq!(got, b"hello world!");
}
#[test]
fn form_no_plus_no_percent_is_zero_copy() {
let input = b"plain";
let mut out = [0xAAu8; 16];
let got = decode(input, Mode::Form, &mut out).unwrap();
assert!(std::ptr::eq(got.as_ptr(), input.as_ptr()));
}
#[test]
fn upper_and_lower_hex() {
let mut out = [0u8; 16];
assert_eq!(decode(b"%2f", Mode::Path, &mut out).unwrap(), b"/");
let mut out = [0u8; 16];
assert_eq!(decode(b"%2F", Mode::Path, &mut out).unwrap(), b"/");
}
#[test]
fn truncated_escape_at_end() {
let mut out = [0u8; 16];
assert_eq!(
decode(b"foo%2", Mode::Path, &mut out),
Err(PctErrorKind::InvalidEscape)
);
let mut out = [0u8; 16];
assert_eq!(
decode(b"foo%", Mode::Path, &mut out),
Err(PctErrorKind::InvalidEscape)
);
}
#[test]
fn non_hex_escape() {
let mut out = [0u8; 16];
assert_eq!(
decode(b"foo%zz", Mode::Path, &mut out),
Err(PctErrorKind::InvalidEscape)
);
}
#[test]
fn buffer_too_small() {
let mut out = [0u8; 3];
assert_eq!(
decode(b"foo%20bar", Mode::Path, &mut out),
Err(PctErrorKind::BufferTooSmall)
);
}
#[test]
fn empty_input() {
let mut out = [0u8; 4];
assert_eq!(decode(b"", Mode::Path, &mut out).unwrap(), b"");
assert_eq!(decode(b"", Mode::Form, &mut out).unwrap(), b"");
}
#[test]
fn null_byte_via_escape() {
let mut out = [0u8; 4];
assert_eq!(decode(b"%00", Mode::Path, &mut out).unwrap(), b"\0");
}
#[test]
fn decoded_len_matches_decode() {
for (input, mode) in [
(&b"foo%20bar"[..], Mode::Path),
(b"a+b", Mode::Form),
(b"plain", Mode::Path),
(b"%2F%2Fetc", Mode::Path),
(b"", Mode::Path),
] {
let n = decoded_len(input, mode).unwrap();
let mut out = [0u8; 32];
let got = decode(input, mode, &mut out).unwrap();
assert_eq!(got.len(), n, "input={input:?} mode={mode:?}");
}
}
#[test]
fn decoded_len_invalid_escape() {
assert_eq!(
decoded_len(b"foo%2", Mode::Path),
Err(PctErrorKind::InvalidEscape)
);
}
#[test]
fn decode_into_writes_even_when_clean() {
let mut out = [0u8; 8];
let n = decode_into(b"abc", Mode::Path, &mut out).unwrap();
assert_eq!(n, 3);
assert_eq!(&out[..n], b"abc");
}
}