use std::borrow::Cow;
use std::convert::{Infallible, TryFrom};
use std::error::Error;
use std::fmt::{self, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::str;
use crate::utility::{
get_percent_encoded_value, normalize_string, percent_encoded_equality, percent_encoded_hash,
UNRESERVED_CHAR_MAP,
};
#[rustfmt::skip]
const QUERY_CHAR_MAP: [u8; 256] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, b'!', 0, 0, b'$', b'%', b'&',b'\'', b'(', b')', b'*', b'+', b',', b'-', b'.', b'/',
b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', 0, b'=', 0, b'?',
b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O',
b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', 0, 0, 0, 0, b'_',
0, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o',
b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', 0, 0, 0, b'~', 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Query<'query> {
normalized: bool,
query: Cow<'query, str>,
}
impl Query<'_> {
pub fn as_borrowed(&self) -> Query {
use self::Cow::*;
let query = match &self.query {
Borrowed(borrowed) => *borrowed,
Owned(owned) => owned.as_str(),
};
Query {
normalized: self.normalized,
query: Cow::Borrowed(query),
}
}
pub fn as_str(&self) -> &str {
&self.query
}
pub fn into_owned(self) -> Query<'static> {
Query {
normalized: self.normalized,
query: Cow::from(self.query.into_owned()),
}
}
pub fn is_normalized(&self) -> bool {
self.normalized
}
pub fn normalize(&mut self) {
if !self.normalized {
unsafe { normalize_string(&mut self.query.to_mut(), true) };
self.normalized = true;
}
}
}
impl AsRef<[u8]> for Query<'_> {
fn as_ref(&self) -> &[u8] {
self.query.as_bytes()
}
}
impl AsRef<str> for Query<'_> {
fn as_ref(&self) -> &str {
&self.query
}
}
impl Deref for Query<'_> {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.query
}
}
impl Display for Query<'_> {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
formatter.write_str(&self.query)
}
}
impl Eq for Query<'_> {}
impl<'query> From<Query<'query>> for String {
fn from(value: Query<'query>) -> Self {
value.to_string()
}
}
impl Hash for Query<'_> {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
percent_encoded_hash(self.query.as_bytes(), state, true);
}
}
impl PartialEq for Query<'_> {
fn eq(&self, other: &Query) -> bool {
percent_encoded_equality(self.query.as_bytes(), other.query.as_bytes(), true)
}
}
impl PartialEq<[u8]> for Query<'_> {
fn eq(&self, other: &[u8]) -> bool {
percent_encoded_equality(self.query.as_bytes(), other, true)
}
}
impl<'query> PartialEq<Query<'query>> for [u8] {
fn eq(&self, other: &Query<'query>) -> bool {
percent_encoded_equality(self, other.query.as_bytes(), true)
}
}
impl<'a> PartialEq<&'a [u8]> for Query<'_> {
fn eq(&self, other: &&'a [u8]) -> bool {
percent_encoded_equality(self.query.as_bytes(), other, true)
}
}
impl<'a, 'query> PartialEq<Query<'query>> for &'a [u8] {
fn eq(&self, other: &Query<'query>) -> bool {
percent_encoded_equality(self, other.query.as_bytes(), true)
}
}
impl PartialEq<str> for Query<'_> {
fn eq(&self, other: &str) -> bool {
percent_encoded_equality(self.query.as_bytes(), other.as_bytes(), true)
}
}
impl<'query> PartialEq<Query<'query>> for str {
fn eq(&self, other: &Query<'query>) -> bool {
percent_encoded_equality(self.as_bytes(), other.query.as_bytes(), true)
}
}
impl<'a> PartialEq<&'a str> for Query<'_> {
fn eq(&self, other: &&'a str) -> bool {
percent_encoded_equality(self.query.as_bytes(), other.as_bytes(), true)
}
}
impl<'a, 'query> PartialEq<Query<'query>> for &'a str {
fn eq(&self, other: &Query<'query>) -> bool {
percent_encoded_equality(self.as_bytes(), other.query.as_bytes(), true)
}
}
impl<'query> TryFrom<&'query [u8]> for Query<'query> {
type Error = QueryError;
fn try_from(value: &'query [u8]) -> Result<Self, Self::Error> {
let (query, rest) = parse_query(value)?;
if rest.is_empty() {
Ok(query)
} else {
Err(QueryError::InvalidCharacter)
}
}
}
impl<'query> TryFrom<&'query str> for Query<'query> {
type Error = QueryError;
fn try_from(value: &'query str) -> Result<Self, Self::Error> {
Query::try_from(value.as_bytes())
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum QueryError {
InvalidCharacter,
InvalidPercentEncoding,
}
impl Display for QueryError {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
use self::QueryError::*;
match self {
InvalidCharacter => write!(formatter, "invalid query character"),
InvalidPercentEncoding => write!(formatter, "invalid query percent encoding"),
}
}
}
impl Error for QueryError {}
impl From<Infallible> for QueryError {
fn from(_: Infallible) -> Self {
QueryError::InvalidCharacter
}
}
pub(crate) fn parse_query(value: &[u8]) -> Result<(Query, &[u8]), QueryError> {
let mut bytes = value.iter();
let mut end_index = 0;
let mut normalized = true;
while let Some(&byte) = bytes.next() {
match QUERY_CHAR_MAP[byte as usize] {
0 if byte == b'#' => break,
0 => return Err(QueryError::InvalidCharacter),
b'%' => match get_percent_encoded_value(bytes.next().cloned(), bytes.next().cloned()) {
Ok((hex_value, uppercase)) => {
if !uppercase || UNRESERVED_CHAR_MAP[hex_value as usize] != 0 {
normalized = false;
}
end_index += 3;
}
Err(_) => return Err(QueryError::InvalidPercentEncoding),
},
_ => end_index += 1,
}
}
let (value, rest) = value.split_at(end_index);
let query = Query {
normalized,
query: Cow::from(unsafe { str::from_utf8_unchecked(value) }),
};
Ok((query, rest))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_query_normalize() {
fn test_case(value: &str, expected: &str) {
let mut query = Query::try_from(value).unwrap();
query.normalize();
assert_eq!(query, expected);
}
test_case("", "");
test_case("%ff", "%FF");
test_case("%41", "A");
}
#[test]
fn test_query_parse() {
use self::QueryError::*;
assert_eq!(Query::try_from("").unwrap(), "");
assert_eq!(Query::try_from("query").unwrap(), "query");
assert_eq!(Query::try_from("qUeRy").unwrap(), "qUeRy");
assert_eq!(Query::try_from("%ff%ff%ff%41").unwrap(), "%ff%ff%ff%41");
assert_eq!(Query::try_from(" "), Err(InvalidCharacter));
assert_eq!(Query::try_from("#"), Err(InvalidCharacter));
assert_eq!(Query::try_from("%"), Err(InvalidPercentEncoding));
assert_eq!(Query::try_from("%f"), Err(InvalidPercentEncoding));
assert_eq!(Query::try_from("%zz"), Err(InvalidPercentEncoding));
}
}