#![cfg_attr(not(feature = "alloc"), allow(clippy::redundant_pub_crate))]
#[cfg(feature = "alloc")]
use crate::Error;
#[cfg(feature = "alloc")]
use crate::tables::{HEX_VAL, PLAIN_ENC_NSS, PLAIN_ENC_RQF};
use crate::{
Result,
TriCow,
tables::{BYTE_CLASS, HEX, PLAIN_PARSE},
};
#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::{string::String, vec::Vec};
#[derive(Copy, Clone)]
enum PctEncoded {
Nss,
RComponent,
QComponent,
FComponent,
}
#[inline]
fn scan_plain_run(bytes: &[u8], mut i: usize) -> usize {
while i + 8 <= bytes.len() {
let mut c = [0u8; 8];
c.copy_from_slice(&bytes[i..i + 8]);
let mut mask: u32 = 0;
for k in 0..8 {
if BYTE_CLASS[c[k] as usize] & PLAIN_PARSE == 0 {
mask |= 1 << k;
}
}
if mask == 0 {
i += 8;
} else {
return i + mask.trailing_zeros() as usize;
}
}
while i < bytes.len() && BYTE_CLASS[bytes[i] as usize] & PLAIN_PARSE != 0 {
i += 1;
}
i
}
fn scan_needs_hex_upper(bytes: &[u8], start: usize) -> bool {
let mut i = start;
while i + 2 < bytes.len() {
if bytes[i] == b'%' && BYTE_CLASS[bytes[i + 1] as usize] & HEX != 0 && BYTE_CLASS[bytes[i + 2] as usize] & HEX != 0 {
if bytes[i + 1].is_ascii_lowercase() || bytes[i + 2].is_ascii_lowercase() {
return true;
}
i += 3;
} else {
i += 1;
}
}
false
}
fn parse(s: &mut TriCow, start: usize, kind: PctEncoded) -> Result<usize> {
let needs_upper = scan_needs_hex_upper(s.as_bytes(), start);
#[cfg(feature = "alloc")]
if needs_upper && matches!(s, TriCow::Borrowed(_)) {
s.ensure_owned()?;
}
let mut bytes = s.as_bytes();
let mut i = start;
while i < bytes.len() {
i = scan_plain_run(bytes, i);
if i >= bytes.len() {
break;
}
let ch = bytes[i];
match ch {
b'?' => match kind {
PctEncoded::FComponent => {}
PctEncoded::QComponent if i != start => {}
PctEncoded::RComponent if i != start && bytes.get(i + 1) != Some(&b'=') => {}
_ => return Ok(i),
},
b'/' => match kind {
PctEncoded::FComponent => {}
_ if i != start => {}
_ => return Ok(i),
},
b'%' => {
if i + 2 < bytes.len() && BYTE_CLASS[bytes[i + 1] as usize] & HEX != 0 && BYTE_CLASS[bytes[i + 2] as usize] & HEX != 0 {
if needs_upper {
s.make_uppercase(i + 1..i + 3)?;
bytes = s.as_bytes();
}
i += 3;
continue;
}
return Ok(i);
}
_ => return Ok(i),
}
i += 1;
}
Ok(s.len())
}
pub(crate) fn parse_nss(s: &mut TriCow, start: usize) -> Result<usize> {
parse(s, start, PctEncoded::Nss)
}
pub(crate) fn parse_r_component(s: &mut TriCow, start: usize) -> Result<usize> {
parse(s, start, PctEncoded::RComponent)
}
pub(crate) fn parse_q_component(s: &mut TriCow, start: usize) -> Result<usize> {
parse(s, start, PctEncoded::QComponent)
}
pub(crate) fn parse_f_component(s: &mut TriCow, start: usize) -> Result<usize> {
parse(s, start, PctEncoded::FComponent)
}
fn validate(s: &str, start: usize, kind: PctEncoded) -> (usize, bool) {
let bytes = s.as_bytes();
let mut i = start;
let mut needs_norm = false;
while i < bytes.len() {
i = scan_plain_run(bytes, i);
if i >= bytes.len() {
break;
}
let ch = bytes[i];
match ch {
b'?' => match kind {
PctEncoded::FComponent => {}
PctEncoded::QComponent if i != start => {}
PctEncoded::RComponent if i != start && bytes.get(i + 1) != Some(&b'=') => {}
_ => return (i, needs_norm),
},
b'/' => match kind {
PctEncoded::FComponent => {}
_ if i != start => {}
_ => return (i, needs_norm),
},
b'%' => {
if i + 2 < bytes.len() && BYTE_CLASS[bytes[i + 1] as usize] & HEX != 0 && BYTE_CLASS[bytes[i + 2] as usize] & HEX != 0 {
if bytes[i + 1].is_ascii_lowercase() || bytes[i + 2].is_ascii_lowercase() {
needs_norm = true;
}
i += 3;
continue;
}
return (i, needs_norm);
}
_ => return (i, needs_norm),
}
i += 1;
}
(bytes.len(), needs_norm)
}
#[inline]
pub(crate) fn validate_nss(s: &str) -> (usize, bool) {
validate(s, 0, PctEncoded::Nss)
}
#[inline]
pub(crate) fn validate_r_component(s: &str) -> (usize, bool) {
validate(s, 0, PctEncoded::RComponent)
}
#[inline]
pub(crate) fn validate_q_component(s: &str) -> (usize, bool) {
validate(s, 0, PctEncoded::QComponent)
}
#[inline]
pub(crate) fn validate_f_component(s: &str) -> (usize, bool) {
validate(s, 0, PctEncoded::FComponent)
}
pub(crate) fn normalize_range(s: &mut TriCow, range: core::ops::Range<usize>) -> Result<()> {
let mut i = range.start;
while i + 2 < range.end {
if s.as_bytes()[i] == b'%' {
s.make_uppercase(i + 1..i + 3)?;
i += 3;
} else {
i += 1;
}
}
Ok(())
}
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub struct DecodeIter<'a> {
bytes: &'a [u8],
i: usize,
kind: PctEncoded,
err: Error,
done: bool,
}
#[cfg(feature = "alloc")]
impl<'a> DecodeIter<'a> {
const fn new(s: &'a str, kind: PctEncoded, err: Error) -> Self {
Self {
bytes: s.as_bytes(),
i: 0,
kind,
err,
done: false,
}
}
}
#[cfg(feature = "alloc")]
impl<'a> Iterator for DecodeIter<'a> {
type Item = Result<u8>;
fn next(&mut self) -> Option<Result<u8>> {
if self.done || self.i >= self.bytes.len() {
return None;
}
let i = self.i;
let ch = self.bytes[i];
let fail = |this: &mut Self| {
this.done = true;
Some(Err(this.err))
};
let cls = BYTE_CLASS[ch as usize];
if cls & PLAIN_PARSE != 0 {
self.i = i + 1;
return Some(Ok(ch));
}
match ch {
b'?' => match self.kind {
PctEncoded::FComponent => {}
PctEncoded::QComponent if i != 0 => {}
PctEncoded::RComponent if i != 0 && self.bytes.get(i + 1) != Some(&b'=') => {}
_ => return fail(self),
},
b'/' => match self.kind {
PctEncoded::FComponent => {}
_ if i != 0 => {}
_ => return fail(self),
},
b'%' => {
if i + 2 >= self.bytes.len() {
return fail(self);
}
let hi = HEX_VAL[self.bytes[i + 1] as usize];
let lo = HEX_VAL[self.bytes[i + 2] as usize];
if hi == 0xFF || lo == 0xFF {
return fail(self);
}
self.i = i + 3;
return Some(Ok((hi << 4) | lo));
}
_ => return fail(self),
}
self.i = i + 1;
Some(Ok(ch))
}
}
#[cfg(feature = "alloc")]
fn decode(s: &str, kind: PctEncoded) -> Option<String> {
let mut ret = Vec::with_capacity(s.len());
for byte in DecodeIter::new(s, kind, Error::InvalidNss) {
ret.push(byte.ok()?);
}
String::from_utf8(ret).ok()
}
#[cfg(feature = "alloc")]
pub fn decode_nss(s: &str) -> Result<String> {
decode(s, PctEncoded::Nss).ok_or(Error::InvalidNss)
}
#[cfg(feature = "alloc")]
pub fn decode_r_component(s: &str) -> Result<String> {
decode(s, PctEncoded::RComponent).ok_or(Error::InvalidRComponent)
}
#[cfg(feature = "alloc")]
pub fn decode_q_component(s: &str) -> Result<String> {
decode(s, PctEncoded::QComponent).ok_or(Error::InvalidQComponent)
}
#[cfg(feature = "alloc")]
pub fn decode_f_component(s: &str) -> Result<String> {
decode(s, PctEncoded::FComponent).ok_or(Error::InvalidFComponent)
}
#[cfg(feature = "alloc")]
pub const fn decode_nss_iter(s: &str) -> DecodeIter<'_> {
DecodeIter::new(s, PctEncoded::Nss, Error::InvalidNss)
}
#[cfg(feature = "alloc")]
pub const fn decode_r_component_iter(s: &str) -> DecodeIter<'_> {
DecodeIter::new(s, PctEncoded::RComponent, Error::InvalidRComponent)
}
#[cfg(feature = "alloc")]
pub const fn decode_q_component_iter(s: &str) -> DecodeIter<'_> {
DecodeIter::new(s, PctEncoded::QComponent, Error::InvalidQComponent)
}
#[cfg(feature = "alloc")]
pub const fn decode_f_component_iter(s: &str) -> DecodeIter<'_> {
DecodeIter::new(s, PctEncoded::FComponent, Error::InvalidFComponent)
}
#[cfg(feature = "alloc")]
const fn to_hex(n: u8) -> [u8; 2] {
let a = (n & 0xF0) >> 4;
let b = n & 0xF;
let a = if a < 10 { b'0' + a } else { b'A' + (a - 10) };
let b = if b < 10 { b'0' + b } else { b'A' + (b - 10) };
[a, b]
}
#[cfg(feature = "alloc")]
#[inline]
fn scan_enc_plain_run(bytes: &[u8], mut i: usize, plain_mask: u8) -> usize {
while i + 8 <= bytes.len() {
let mut c = [0u8; 8];
c.copy_from_slice(&bytes[i..i + 8]);
let mut mask: u32 = 0;
for k in 0..8 {
if BYTE_CLASS[c[k] as usize] & plain_mask == 0 {
mask |= 1 << k;
}
}
if mask == 0 {
i += 8;
} else {
return i + mask.trailing_zeros() as usize;
}
}
while i < bytes.len() && BYTE_CLASS[bytes[i] as usize] & plain_mask != 0 {
i += 1;
}
i
}
#[cfg(feature = "alloc")]
fn encode(s: &str, kind: PctEncoded) -> String {
let bytes = s.as_bytes();
let mut ret = String::with_capacity(bytes.len());
let plain_mask = match kind {
PctEncoded::Nss => PLAIN_ENC_NSS,
_ => PLAIN_ENC_RQF,
};
let mut i = 0;
while i < bytes.len() {
let run_end = scan_enc_plain_run(bytes, i, plain_mask);
if run_end > i {
ret.push_str(unsafe { core::str::from_utf8_unchecked(&bytes[i..run_end]) });
i = run_end;
if i >= bytes.len() {
break;
}
}
let b = bytes[i];
if b < 0x80 {
let cls = BYTE_CLASS[b as usize];
let allowed = cls & plain_mask != 0
|| match b {
b'?' => match kind {
PctEncoded::FComponent => true,
PctEncoded::QComponent => i != 0,
PctEncoded::RComponent => i != 0 && bytes.get(i + 1) != Some(&b'='),
PctEncoded::Nss => false,
},
b'/' => match kind {
PctEncoded::FComponent => true,
PctEncoded::RComponent | PctEncoded::QComponent => i != 0,
PctEncoded::Nss => false,
},
_ => false,
};
if allowed {
ret.push(b as char);
} else {
let hex = to_hex(b);
let triplet = [b'%', hex[0], hex[1]];
ret.push_str(unsafe { core::str::from_utf8_unchecked(&triplet) });
}
i += 1;
} else {
let start = i;
i += 1;
while i < bytes.len() && (bytes[i] & 0xC0) == 0x80 {
i += 1;
}
let mut buf = [0u8; 12];
let seq = &bytes[start..i];
for (j, &byte) in seq.iter().enumerate() {
let hex = to_hex(byte);
buf[j * 3] = b'%';
buf[j * 3 + 1] = hex[0];
buf[j * 3 + 2] = hex[1];
}
let len = seq.len() * 3;
ret.push_str(unsafe { core::str::from_utf8_unchecked(&buf[..len]) });
}
}
ret
}
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub struct EncodeIter<'a> {
bytes: &'a [u8],
i: usize,
kind: PctEncoded,
pending: [u8; 3],
pending_len: u8,
pending_pos: u8,
}
#[cfg(feature = "alloc")]
impl<'a> EncodeIter<'a> {
const fn new(s: &'a str, kind: PctEncoded) -> Self {
Self {
bytes: s.as_bytes(),
i: 0,
kind,
pending: [0; 3],
pending_len: 0,
pending_pos: 0,
}
}
}
#[cfg(feature = "alloc")]
impl<'a> Iterator for EncodeIter<'a> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
if self.pending_pos < self.pending_len {
let b = self.pending[self.pending_pos as usize];
self.pending_pos += 1;
return Some(b);
}
if self.i >= self.bytes.len() {
return None;
}
let b = self.bytes[self.i];
let plain_mask = match self.kind {
PctEncoded::Nss => PLAIN_ENC_NSS,
_ => PLAIN_ENC_RQF,
};
if b < 0x80 {
let cls = BYTE_CLASS[b as usize];
let allowed = cls & plain_mask != 0
|| match b {
b'?' => match self.kind {
PctEncoded::FComponent => true,
PctEncoded::QComponent => self.i != 0,
PctEncoded::RComponent => self.i != 0 && self.bytes.get(self.i + 1) != Some(&b'='),
PctEncoded::Nss => false,
},
b'/' => match self.kind {
PctEncoded::FComponent => true,
PctEncoded::RComponent | PctEncoded::QComponent => self.i != 0,
PctEncoded::Nss => false,
},
_ => false,
};
self.i += 1;
if allowed {
return Some(b);
}
} else {
self.i += 1;
}
let hex = to_hex(b);
self.pending = [b'%', hex[0], hex[1]];
self.pending_len = 3;
self.pending_pos = 1;
Some(b'%')
}
}
#[cfg(feature = "alloc")]
#[must_use]
pub const fn encode_nss_iter(s: &str) -> EncodeIter<'_> {
EncodeIter::new(s, PctEncoded::Nss)
}
#[cfg(feature = "alloc")]
#[must_use]
pub const fn encode_r_component_iter(s: &str) -> EncodeIter<'_> {
EncodeIter::new(s, PctEncoded::RComponent)
}
#[cfg(feature = "alloc")]
#[must_use]
pub const fn encode_q_component_iter(s: &str) -> EncodeIter<'_> {
EncodeIter::new(s, PctEncoded::QComponent)
}
#[cfg(feature = "alloc")]
#[must_use]
pub const fn encode_f_component_iter(s: &str) -> EncodeIter<'_> {
EncodeIter::new(s, PctEncoded::FComponent)
}
#[cfg(feature = "alloc")]
pub fn encode_nss(s: &str) -> Result<String> {
if s.is_empty() {
return Err(Error::InvalidNss);
}
Ok(encode(s, PctEncoded::Nss))
}
#[cfg(feature = "alloc")]
pub fn encode_r_component(s: &str) -> Result<String> {
if s.is_empty() {
return Err(Error::InvalidRComponent);
}
Ok(encode(s, PctEncoded::RComponent))
}
#[cfg(feature = "alloc")]
pub fn encode_q_component(s: &str) -> Result<String> {
if s.is_empty() {
return Err(Error::InvalidQComponent);
}
Ok(encode(s, PctEncoded::QComponent))
}
#[cfg(feature = "alloc")]
pub fn encode_f_component(s: &str) -> Result<String> {
Ok(encode(s, PctEncoded::FComponent))
}
#[cfg(all(test, feature = "alloc"))]
#[allow(clippy::unwrap_used, clippy::panic, clippy::expect_used)]
mod swar_tests {
#[cfg(not(feature = "std"))]
use alloc::vec;
use super::{BYTE_CLASS, PLAIN_PARSE, scan_plain_run};
fn scan_plain_scalar(bytes: &[u8], mut i: usize) -> usize {
while i < bytes.len() && BYTE_CLASS[bytes[i] as usize] & PLAIN_PARSE != 0 {
i += 1;
}
i
}
#[test]
fn swar_matches_scalar_all_prefixes() {
let mut buf = [0u8; 1024];
let mut x: u32 = 0x1234_5678;
for b in &mut buf {
x = x.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
*b = (x >> 16) as u8;
}
for len in 0..=buf.len() {
for start in 0..=len {
let a = scan_plain_run(&buf[..len], start);
let b = scan_plain_scalar(&buf[..len], start);
assert_eq!(a, b, "mismatch at len={len} start={start}");
}
}
}
#[test]
fn swar_boundary_cases() {
let all_plain = vec![b'A'; 33];
assert_eq!(scan_plain_run(&all_plain, 0), 33);
for pos in 0..20 {
let mut v = vec![b'A'; 20];
v[pos] = b'#';
assert_eq!(scan_plain_run(&v, 0), pos);
}
}
}