use core::fmt::{self, Write as _};
use core::marker::PhantomData;
use crate::format::eq_str_display;
use crate::parser::char::{is_ascii_unreserved, is_unreserved, is_utf8_byte_continue};
use crate::parser::str::{find_split_hole, take_first_char};
use crate::parser::trusted::take_xdigits2;
use crate::spec::Spec;
pub(crate) fn is_pct_case_normalized<S: Spec>(s: &str) -> bool {
eq_str_display(s, &PctCaseNormalized::<S>::new(s))
}
fn into_char_trusted(bytes: &[u8]) -> Result<char, ()> {
const CONTINUE_BYTE_MASK: u8 = 0b_0011_1111;
const MIN: [u32; 3] = [0x80, 0x800, 0x1_0000];
let len = bytes.len();
let c: u32 = match len {
2 => (u32::from(bytes[0] & 0b_0001_1111) << 6) | u32::from(bytes[1] & CONTINUE_BYTE_MASK),
3 => {
(u32::from(bytes[0] & 0b_0000_1111) << 12)
| (u32::from(bytes[1] & CONTINUE_BYTE_MASK) << 6)
| u32::from(bytes[2] & CONTINUE_BYTE_MASK)
}
4 => {
(u32::from(bytes[0] & 0b_0000_0111) << 18)
| (u32::from(bytes[1] & CONTINUE_BYTE_MASK) << 12)
| (u32::from(bytes[2] & CONTINUE_BYTE_MASK) << 6)
| u32::from(bytes[3] & CONTINUE_BYTE_MASK)
}
len => {
unreachable!("expected 2, 3, or 4 bytes for a character, but got {len} as the length")
}
};
if c < MIN[len - 2] {
return Err(());
}
char::from_u32(c).ok_or(())
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct PctCaseNormalized<'a, S> {
path: &'a str,
_spec: PhantomData<fn() -> S>,
}
impl<'a, S: Spec> PctCaseNormalized<'a, S> {
#[inline]
#[must_use]
pub(crate) fn new(source: &'a str) -> Self {
Self {
path: source,
_spec: PhantomData,
}
}
}
impl<S: Spec> fmt::Display for PctCaseNormalized<'_, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut rest = self.path;
'outer_loop: while !rest.is_empty() {
let (prefix, after_percent) = match find_split_hole(rest, b'%') {
Some(v) => v,
None => return f.write_str(rest),
};
f.write_str(prefix)?;
let (first_decoded, after_first_triplet) = take_xdigits2(after_percent);
rest = after_first_triplet;
let expected_char_len = match first_decoded {
0x00..=0x7F => {
debug_assert!(first_decoded.is_ascii());
if is_ascii_unreserved(first_decoded) {
f.write_char(char::from(first_decoded))?;
} else {
write!(f, "%{:02X}", first_decoded)?;
}
continue 'outer_loop;
}
0xC2..=0xDF => 2,
0xE0..=0xEF => 3,
0xF0..=0xF4 => 4,
0x80..=0xC1 | 0xF5..=0xFF => {
write!(f, "%{:02X}", first_decoded)?;
continue 'outer_loop;
}
};
let c_buf = &mut [first_decoded, 0, 0, 0][..expected_char_len];
for (i, buf_dest) in c_buf[1..].iter_mut().enumerate() {
match take_first_char(rest) {
Some(('%', after_percent)) => {
let (byte, after_triplet) = take_xdigits2(after_percent);
if !is_utf8_byte_continue(byte) {
c_buf[..=i]
.iter()
.try_for_each(|b| write!(f, "%{:02X}", b))?;
continue 'outer_loop;
}
*buf_dest = byte;
rest = after_triplet;
}
Some((c, after_percent)) => {
c_buf[..=i]
.iter()
.try_for_each(|b| write!(f, "%{:02X}", b))?;
f.write_char(c)?;
rest = after_percent;
continue 'outer_loop;
}
None => {
c_buf[..=i]
.iter()
.try_for_each(|b| write!(f, "%{:02X}", b))?;
break 'outer_loop;
}
}
}
match into_char_trusted(&c_buf[..expected_char_len]) {
Ok(decoded_c) => {
if is_unreserved::<S>(decoded_c) {
f.write_char(decoded_c)?;
} else {
c_buf[0..expected_char_len]
.iter()
.try_for_each(|b| write!(f, "%{:02X}", b))?;
}
}
Err(_) => {
debug_assert!(
c_buf[1..expected_char_len]
.iter()
.copied()
.all(is_utf8_byte_continue),
"all non-first bytes have been confirmed to be UTF-8 continue bytes"
);
rest = &after_first_triplet[((expected_char_len - 1) * 3)..];
c_buf[0..expected_char_len]
.iter()
.try_for_each(|b| write!(f, "%{:02X}", b))?;
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct NormalizedAsciiOnlyHost<'a> {
host_port: &'a str,
}
impl<'a> NormalizedAsciiOnlyHost<'a> {
#[inline]
#[must_use]
pub(crate) fn new(host_port: &'a str) -> Self {
Self { host_port }
}
}
impl fmt::Display for NormalizedAsciiOnlyHost<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut rest = self.host_port;
while !rest.is_empty() {
let (prefix, after_percent) = match find_split_hole(rest, b'%') {
Some(v) => v,
None => {
return rest
.chars()
.try_for_each(|c| f.write_char(c.to_ascii_lowercase()));
}
};
prefix
.chars()
.try_for_each(|c| f.write_char(c.to_ascii_lowercase()))?;
let (first_decoded, after_triplet) = take_xdigits2(after_percent);
rest = after_triplet;
assert!(
first_decoded.is_ascii(),
"this function requires ASCII-only host as an argument"
);
if is_ascii_unreserved(first_decoded) {
f.write_char(char::from(first_decoded.to_ascii_lowercase()))?;
} else {
write!(f, "%{:02X}", first_decoded)?;
}
}
Ok(())
}
}
#[cfg(test)]
#[cfg(feature = "alloc")]
mod tests {
use super::*;
#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::string::ToString;
use crate::spec::{IriSpec, UriSpec};
#[test]
fn invalid_utf8() {
assert_eq!(
PctCaseNormalized::<UriSpec>::new("%80%cc%cc%cc").to_string(),
"%80%CC%CC%CC"
);
assert_eq!(
PctCaseNormalized::<IriSpec>::new("%80%cc%cc%cc").to_string(),
"%80%CC%CC%CC"
);
}
#[test]
fn iri_unreserved() {
assert_eq!(
PctCaseNormalized::<UriSpec>::new("%ce%b1").to_string(),
"%CE%B1"
);
assert_eq!(
PctCaseNormalized::<IriSpec>::new("%ce%b1").to_string(),
"\u{03B1}"
);
}
#[test]
fn iri_middle_decode() {
assert_eq!(
PctCaseNormalized::<UriSpec>::new("%ce%ce%b1%b1").to_string(),
"%CE%CE%B1%B1"
);
assert_eq!(
PctCaseNormalized::<IriSpec>::new("%ce%ce%b1%b1").to_string(),
"%CE\u{03B1}%B1"
);
}
#[test]
fn ascii_reserved() {
assert_eq!(PctCaseNormalized::<UriSpec>::new("%3f").to_string(), "%3F");
assert_eq!(PctCaseNormalized::<IriSpec>::new("%3f").to_string(), "%3F");
}
#[test]
fn ascii_forbidden() {
assert_eq!(
PctCaseNormalized::<UriSpec>::new("%3c%3e").to_string(),
"%3C%3E"
);
assert_eq!(
PctCaseNormalized::<IriSpec>::new("%3c%3e").to_string(),
"%3C%3E"
);
}
#[test]
fn ascii_unreserved() {
assert_eq!(PctCaseNormalized::<UriSpec>::new("%7ea").to_string(), "~a");
assert_eq!(PctCaseNormalized::<IriSpec>::new("%7ea").to_string(), "~a");
}
}