use std::borrow::Borrow;
use std::hash;
use std::{
cmp::{Ord, Ordering, PartialOrd},
fmt::Display,
};
use std::{convert::TryFrom, fmt, io, str::FromStr};
#[derive(Debug, Clone, thiserror::Error)]
#[error("invalid percent-encoded string")]
pub struct InvalidPctString<T>(pub T);
impl<T> InvalidPctString<T> {
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> InvalidPctString<U> {
InvalidPctString(f(self.0))
}
}
impl<'a, T: ?Sized + ToOwned> InvalidPctString<&'a T> {
pub fn into_owned(self) -> InvalidPctString<T::Owned> {
self.map(T::to_owned)
}
}
#[inline(always)]
fn to_digit(b: u8) -> Result<u8, ByteError> {
match b {
0x30..=0x39 => Ok(b - 0x30),
0x41..=0x46 => Ok(b - 0x37),
0x61..=0x66 => Ok(b - 0x57),
_ => Err(ByteError::InvalidByte(b)),
}
}
pub struct Bytes<'a>(std::slice::Iter<'a, u8>);
#[derive(Debug, Clone)]
enum ByteError {
InvalidByte(u8),
IncompleteEncoding,
}
impl From<ByteError> for io::Error {
fn from(e: ByteError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, e.to_string())
}
}
impl Display for ByteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ByteError::InvalidByte(b) => write!(f, "Invalid UTF-8 byte: {:#x}", b),
ByteError::IncompleteEncoding => f.write_str("Incomplete percent-encoding segment"),
}
}
}
impl std::error::Error for ByteError {}
impl<'a> Iterator for Bytes<'a> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
if let Some(next) = self.0.next().copied() {
match next {
b'%' => {
let a = self.0.next().copied().unwrap();
let a = to_digit(a).unwrap();
let b = self.0.next().copied().unwrap();
let b = to_digit(b).unwrap();
let byte = a << 4 | b;
Some(byte)
}
_ => Some(next),
}
} else {
None
}
}
}
impl<'a> std::iter::FusedIterator for Bytes<'a> {}
struct UntrustedBytes<B>(B);
impl<B> UntrustedBytes<B> {
fn new(bytes: B) -> Self {
Self(bytes)
}
}
impl<B: Iterator<Item = u8>> UntrustedBytes<B> {
fn try_next(&mut self, next: u8) -> io::Result<u8> {
match next {
b'%' => {
let a = self.0.next().ok_or(ByteError::IncompleteEncoding)?;
let a = to_digit(a)?;
let b = self.0.next().ok_or(ByteError::IncompleteEncoding)?;
let b = to_digit(b)?;
let byte = a << 4 | b;
Ok(byte)
}
_ => Ok(next),
}
}
}
impl<B: Iterator<Item = u8>> Iterator for UntrustedBytes<B> {
type Item = io::Result<u8>;
fn next(&mut self) -> Option<io::Result<u8>> {
self.0.next().map(|b| self.try_next(b))
}
}
impl<B: Iterator<Item = u8>> std::iter::FusedIterator for UntrustedBytes<B> {}
pub struct Chars<'a> {
inner: utf8_decode::Decoder<Bytes<'a>>,
}
impl<'a> Chars<'a> {
fn new(bytes: Bytes<'a>) -> Self {
Self {
inner: utf8_decode::Decoder::new(bytes),
}
}
}
impl<'a> Iterator for Chars<'a> {
type Item = char;
fn next(&mut self) -> Option<char> {
self.inner.next().map(|x| x.unwrap())
}
}
impl<'a> std::iter::FusedIterator for Chars<'a> {}
pub struct PctStr([u8]);
impl PctStr {
pub fn new<S: AsRef<[u8]> + ?Sized>(input: &S) -> Result<&PctStr, InvalidPctString<&S>> {
let input_bytes = input.as_ref();
if Self::validate(input_bytes.iter().copied()) {
Ok(unsafe { Self::new_unchecked(input_bytes) })
} else {
Err(InvalidPctString(input))
}
}
pub unsafe fn new_unchecked<S: AsRef<[u8]> + ?Sized>(input: &S) -> &PctStr {
std::mem::transmute(input.as_ref())
}
pub fn validate(input: impl Iterator<Item = u8>) -> bool {
let chars = UntrustedBytes::new(input);
utf8_decode::UnsafeDecoder::new(chars).all(|r| r.is_ok())
}
#[inline]
pub fn len(&self) -> usize {
self.chars().count()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
#[inline]
pub fn as_str(&self) -> &str {
unsafe {
core::str::from_utf8_unchecked(&self.0)
}
}
#[inline]
pub fn chars(&self) -> Chars {
Chars::new(self.bytes())
}
#[inline]
pub fn bytes(&self) -> Bytes {
Bytes(self.0.iter())
}
pub fn decode(&self) -> String {
let mut decoded = String::with_capacity(self.len());
for c in self.chars() {
decoded.push(c)
}
decoded
}
}
impl PartialEq for PctStr {
#[inline]
fn eq(&self, other: &PctStr) -> bool {
let mut a = self.chars();
let mut b = other.chars();
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) if a != b => return false,
(Some(_), None) => return false,
(None, Some(_)) => return false,
(None, None) => break,
_ => (),
}
}
true
}
}
impl Eq for PctStr {}
impl PartialEq<str> for PctStr {
#[inline]
fn eq(&self, other: &str) -> bool {
let mut a = self.chars();
let mut b = other.chars();
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) if a != b => return false,
(Some(_), None) => return false,
(None, Some(_)) => return false,
(None, None) => break,
_ => (),
}
}
true
}
}
impl PartialEq<PctString> for PctStr {
#[inline]
fn eq(&self, other: &PctString) -> bool {
let mut a = self.chars();
let mut b = other.chars();
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) if a != b => return false,
(Some(_), None) => return false,
(None, Some(_)) => return false,
(None, None) => break,
_ => (),
}
}
true
}
}
impl PartialOrd for PctStr {
fn partial_cmp(&self, other: &PctStr) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PctStr {
fn cmp(&self, other: &PctStr) -> Ordering {
let mut self_chars = self.chars();
let mut other_chars = other.chars();
loop {
match (self_chars.next(), other_chars.next()) {
(None, None) => return Ordering::Equal,
(None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater,
(Some(a), Some(b)) => match a.cmp(&b) {
Ordering::Less => return Ordering::Less,
Ordering::Greater => return Ordering::Greater,
Ordering::Equal => (),
},
}
}
}
}
impl PartialOrd<PctString> for PctStr {
fn partial_cmp(&self, other: &PctString) -> Option<Ordering> {
self.partial_cmp(other.as_pct_str())
}
}
impl hash::Hash for PctStr {
#[inline]
fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
for c in self.chars() {
c.hash(hasher)
}
}
}
impl fmt::Display for PctStr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self.as_str(), f)
}
}
impl fmt::Debug for PctStr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.as_str(), f)
}
}
impl ToOwned for PctStr {
type Owned = PctString;
fn to_owned(&self) -> Self::Owned {
unsafe { PctString::new_unchecked(self.0.to_owned()) }
}
}
impl Borrow<str> for PctStr {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl AsRef<str> for PctStr {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Borrow<[u8]> for PctStr {
fn borrow(&self) -> &[u8] {
self.as_bytes()
}
}
impl AsRef<[u8]> for PctStr {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
pub trait Encoder {
fn encode(&self, c: char) -> bool;
}
impl<F: Fn(char) -> bool> Encoder for F {
fn encode(&self, c: char) -> bool {
self(c)
}
}
pub struct PctString(Vec<u8>);
impl PctString {
pub fn new<B: Into<Vec<u8>>>(bytes: B) -> Result<Self, InvalidPctString<Vec<u8>>> {
let bytes = bytes.into();
if PctStr::validate(bytes.iter().copied()) {
Ok(Self(bytes))
} else {
Err(InvalidPctString(bytes))
}
}
pub fn from_string(string: String) -> Result<Self, InvalidPctString<String>> {
Self::new(string).map_err(|e| {
e.map(|bytes| unsafe {
String::from_utf8_unchecked(bytes)
})
})
}
pub unsafe fn new_unchecked<B: Into<Vec<u8>>>(bytes: B) -> Self {
Self(bytes.into())
}
pub fn encode<E: Encoder>(src: impl Iterator<Item = char>, encoder: E) -> PctString {
use std::fmt::Write;
let mut buf = String::with_capacity(4);
let mut encoded = String::new();
for c in src {
if encoder.encode(c) || c == '%' {
buf.clear();
buf.push(c);
for byte in buf.bytes() {
write!(encoded, "%{:02X}", byte).unwrap();
}
} else {
encoded.push(c);
}
}
PctString(encoded.into_bytes())
}
#[inline]
pub fn as_pct_str(&self) -> &PctStr {
unsafe {
PctStr::new_unchecked(&self.0)
}
}
#[inline]
pub fn into_string(self) -> String {
unsafe {
String::from_utf8_unchecked(self.0)
}
}
#[inline]
pub fn into_bytes(self) -> Vec<u8> {
self.0
}
}
impl std::ops::Deref for PctString {
type Target = PctStr;
#[inline]
fn deref(&self) -> &PctStr {
self.as_pct_str()
}
}
impl Borrow<PctStr> for PctString {
fn borrow(&self) -> &PctStr {
self.as_pct_str()
}
}
impl AsRef<PctStr> for PctString {
fn as_ref(&self) -> &PctStr {
self.as_pct_str()
}
}
impl Borrow<str> for PctString {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl AsRef<str> for PctString {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Borrow<[u8]> for PctString {
fn borrow(&self) -> &[u8] {
self.as_bytes()
}
}
impl AsRef<[u8]> for PctString {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl PartialEq for PctString {
#[inline]
fn eq(&self, other: &PctString) -> bool {
let mut a = self.chars();
let mut b = other.chars();
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) if a != b => return false,
(Some(_), None) => return false,
(None, Some(_)) => return false,
(None, None) => break,
_ => (),
}
}
true
}
}
impl Eq for PctString {}
impl PartialEq<PctStr> for PctString {
#[inline]
fn eq(&self, other: &PctStr) -> bool {
let mut a = self.chars();
let mut b = other.chars();
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) if a != b => return false,
(Some(_), None) => return false,
(None, Some(_)) => return false,
(None, None) => break,
_ => (),
}
}
true
}
}
impl PartialEq<&str> for PctString {
#[inline]
fn eq(&self, other: &&str) -> bool {
let mut a = self.chars();
let mut b = other.chars();
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) if a != b => return false,
(Some(_), None) => return false,
(None, Some(_)) => return false,
(None, None) => break,
_ => (),
}
}
true
}
}
impl PartialEq<str> for PctString {
#[inline]
fn eq(&self, other: &str) -> bool {
self.eq(&other)
}
}
impl PartialOrd for PctString {
fn partial_cmp(&self, other: &PctString) -> Option<Ordering> {
self.as_pct_str().partial_cmp(other.as_pct_str())
}
}
impl PartialOrd<PctStr> for PctString {
fn partial_cmp(&self, other: &PctStr) -> Option<Ordering> {
self.as_pct_str().partial_cmp(other)
}
}
impl hash::Hash for PctString {
#[inline]
fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
for c in self.chars() {
c.hash(hasher)
}
}
}
impl fmt::Display for PctString {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self.as_str(), f)
}
}
impl fmt::Debug for PctString {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.as_str(), f)
}
}
impl FromStr for PctString {
type Err = InvalidPctString<String>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_string(s.to_string())
}
}
impl TryFrom<String> for PctString {
type Error = InvalidPctString<String>;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::from_string(value)
}
}
impl<'a> TryFrom<&'a str> for PctString {
type Error = InvalidPctString<String>;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::from_string(value.to_owned())
}
}
impl<'a> TryFrom<&'a str> for &'a PctStr {
type Error = InvalidPctString<&'a str>;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
PctStr::new(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct URIReserved;
impl Encoder for URIReserved {
fn encode(&self, c: char) -> bool {
if !c.is_ascii_graphic() {
return true;
}
matches!(
c,
'!' | '#'
| '$' | '%' | '&'
| '\'' | '(' | ')'
| '*' | '+' | ','
| '/' | ':' | ';'
| '=' | '?' | '@'
| '[' | ']'
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IriReserved {
Segment,
SegmentNoColons,
Fragment,
Query,
}
impl Encoder for IriReserved {
fn encode(&self, c: char) -> bool {
if c.is_ascii_alphanumeric() {
return false;
}
match c {
'@' => return false,
'-' | '.' | '_' | '~' => return false,
'!' | '$' | '&' | '\'' | '(' | ')' | '*' | '+' | ',' | ';' | '=' => return false,
'/' | '?' => return *self != IriReserved::Query && *self != IriReserved::Fragment,
':' => return *self == IriReserved::SegmentNoColons,
_ => { }
}
match c as u32 {
0xA0..=0xD7FF
| 0xF900..=0xFDCF
| 0xFDF0..=0xFFEF
| 0x10000..=0x1FFFD
| 0x20000..=0x2FFFD
| 0x30000..=0x3FFFD
| 0x40000..=0x4FFFD
| 0x50000..=0x5FFFD
| 0x60000..=0x6FFFD
| 0x70000..=0x7FFFD
| 0x80000..=0x8FFFD
| 0x90000..=0x9FFFD
| 0xA0000..=0xAFFFD
| 0xB0000..=0xBFFFD
| 0xC0000..=0xCFFFD
| 0xD0000..=0xDFFFD
| 0xE1000..=0xEFFFD => false,
0xE000..=0xF8FF | 0xF0000..=0xFFFFD | 0x100000..=0x10FFFD => {
*self != IriReserved::Query
}
_ => true,
}
}
}
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use super::*;
#[test]
fn iri_encode_cyrillic() {
let encoder = IriReserved::Segment;
let pct_string = PctString::encode("традиционное польское блюдо".chars(), encoder);
assert_eq!(&pct_string, &"традиционное польское блюдо");
assert_eq!(&pct_string.as_str(), &"традиционное%20польское%20блюдо");
}
#[test]
fn iri_encode_segment() {
let encoder = IriReserved::Segment;
let pct_string = PctString::encode(
"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}".chars(),
encoder,
);
assert_eq!(
&pct_string,
&"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}"
);
assert_eq!(
&pct_string.as_str(),
&"%3Ftest=традиционное%20польское%20блюдо&cjk=真正&private=%F4%8F%BF%BD"
);
}
#[test]
fn iri_encode_segment_nocolon() {
let encoder = IriReserved::SegmentNoColons;
let pct_string = PctString::encode(
"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}".chars(),
encoder,
);
assert_eq!(
&pct_string,
&"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}"
);
assert_eq!(
&pct_string.as_str(),
&"%3Ftest=традиционное%20польское%20блюдо&cjk=真正&private=%F4%8F%BF%BD"
);
}
#[test]
fn iri_encode_fragment() {
let encoder = IriReserved::Fragment;
let pct_string = PctString::encode(
"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}".chars(),
encoder,
);
assert_eq!(
&pct_string,
&"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}"
);
assert_eq!(
&pct_string.as_str(),
&"?test=традиционное%20польское%20блюдо&cjk=真正&private=%F4%8F%BF%BD"
);
}
#[test]
fn iri_encode_query() {
let encoder = IriReserved::Query;
let pct_string = PctString::encode(
"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}".chars(),
encoder,
);
assert_eq!(
&pct_string,
&"?test=традиционное польское блюдо&cjk=真正&private=\u{10FFFD}"
);
assert_eq!(
&pct_string.as_str(),
&"?test=традиционное%20польское%20блюдо&cjk=真正&private=\u{10FFFD}"
);
}
#[test]
fn uri_encode_cyrillic() {
let encoder = URIReserved;
let pct_string = PctString::encode("традиционное польское блюдо\0".chars(), encoder);
assert_eq!(&pct_string, &"традиционное польское блюдо\0");
assert_eq!(&pct_string.as_str(), &"%D1%82%D1%80%D0%B0%D0%B4%D0%B8%D1%86%D0%B8%D0%BE%D0%BD%D0%BD%D0%BE%D0%B5%20%D0%BF%D0%BE%D0%BB%D1%8C%D1%81%D0%BA%D0%BE%D0%B5%20%D0%B1%D0%BB%D1%8E%D0%B4%D0%BE%00");
}
#[test]
fn pct_encoding_invalid() {
let s = "%FF%FE%20%4F";
assert!(PctStr::new(s).is_err());
let s = "%36%A";
assert!(PctStr::new(s).is_err());
let s = "%%32";
assert!(PctStr::new(s).is_err());
let s = "%%32";
assert!(PctStr::new(s).is_err());
}
#[test]
fn pct_encoding_valid() {
let s = "%00%5C%F4%8F%BF%BD%69";
assert!(PctStr::new(s).is_ok());
let s = "No percent.";
assert!(PctStr::new(s).is_ok());
let s = "%e2%82%acwat";
assert!(PctStr::new(s).is_ok());
}
#[test]
fn try_from() {
let s = "%00%5C%F4%8F%BF%BD%69";
let _pcs = PctString::try_from(s).unwrap();
let _pcs: &PctStr = s.try_into().unwrap();
}
#[test]
fn encode_percent_always() {
struct NoopEncoder;
impl Encoder for NoopEncoder {
fn encode(&self, _: char) -> bool {
false
}
}
let s = "%";
let c = PctString::encode(s.chars(), NoopEncoder);
assert_eq!(c.as_str(), "%25");
}
}