use std::borrow::{Borrow, Cow};
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display};
use std::net::IpAddr;
use std::num::{NonZeroU8, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU128, NonZeroUsize};
use pct_str::{PctString, UriReserved};
#[cfg(feature = "uuid")]
use uuid::Uuid;
#[derive(Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct UriSafe<T>(T);
impl<T: Display> Display for UriSafe<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl<T: Debug> Debug for UriSafe<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
macro_rules! impl_uri_safe_from {
($($t:ty),*) => {
$(
impl From<$t> for UriSafe<$t> {
fn from(value: $t) -> Self {
Self(value)
}
}
)*
};
}
impl_uri_safe_from!(
usize,
u8,
u16,
u32,
u64,
u128,
NonZeroU8,
NonZeroU16,
NonZeroU32,
NonZeroU64,
NonZeroU128,
NonZeroUsize,
IpAddr
);
#[cfg(feature = "uuid")]
impl_uri_safe_from!(Uuid);
pub type UriSafeString = UriSafe<Cow<'static, str>>;
#[derive(Debug)]
pub struct UriSafeError {
pub invalid_char: char,
pub position: usize,
}
impl fmt::Display for UriSafeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"invalid character '{}' at position {} for URI safe string",
self.invalid_char, self.position
)
}
}
impl Error for UriSafeError {}
impl UriSafeString {
pub fn encode(s: impl AsRef<str>) -> Self {
let s = s.as_ref();
let encoded = PctString::encode(s.chars(), UriReserved::Any);
if encoded.as_str().len() == s.len() {
Self(Cow::Owned(s.to_owned()))
} else {
Self(Cow::Owned(encoded.into_string()))
}
}
#[must_use]
pub fn encode_owned(s: String) -> Self {
let encoded = PctString::encode(s.chars(), UriReserved::Any);
if encoded.as_str().len() == s.len() {
Self(Cow::Owned(s))
} else {
Self(Cow::Owned(encoded.into_string()))
}
}
pub fn try_new(raw: impl Into<String>) -> Result<Self, UriSafeError> {
let raw = raw.into();
let mut characters = raw.chars().enumerate();
while let Some((i, c)) = characters.next() {
if c == '%' {
for _ in 0..2 {
if !characters.next().is_some_and(|(_, c)| c.is_ascii_hexdigit()) {
return Err(UriSafeError {
invalid_char: '%',
position: i,
});
}
}
continue;
}
if c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '~' | '.') {
continue;
}
return Err(UriSafeError {
invalid_char: c,
position: i,
});
}
Ok(Self(Cow::Owned(raw)))
}
#[must_use]
pub fn as_str(&self) -> &str {
self.0.borrow()
}
#[cfg_attr(test, mutants::skip)] #[inline]
#[must_use]
pub const fn from_static(s: &'static str) -> Self {
let bytes = s.as_bytes();
let mut i = 0;
let mut url_encoded_char: Option<u8> = None;
while i < bytes.len() {
let b = bytes[i];
i += 1;
if let Some(pct_num) = url_encoded_char {
assert!(b.is_ascii_hexdigit(), "string contains invalid URL encoding character");
if pct_num == 1 {
url_encoded_char = None;
continue;
}
url_encoded_char = Some(pct_num + 1);
}
if b == b'%' {
url_encoded_char = Some(0);
continue;
}
assert!(
b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'~' | b'.'),
"any reserved characters need to be URL encoded"
);
}
assert!(url_encoded_char.is_none(), "string contains unfinished URL encoded character");
Self(Cow::Borrowed(s))
}
}
impl From<String> for UriSafeString {
fn from(s: String) -> Self {
Self::encode_owned(s)
}
}
impl<'a> From<&'a str> for UriSafeString {
fn from(s: &'a str) -> Self {
Self::encode(s)
}
}
impl AsRef<str> for UriSafeString {
fn as_ref(&self) -> &str {
self.as_str()
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for UriSafeString {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Self::try_new(s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
const RESERVED_CHARACTERS: &str = "{}/:?#[]@!$&'()*+,;=";
macro_rules! test_static_reserved_fail {
($(($index:ident, $char:expr)),* $(,)?) => {
$(
#[test]
#[should_panic(expected = "any reserved characters need to be URL encoded")]
fn $index() {
let _ = UriSafeString::from_static(concat!("hello", $char, "world"));
}
)*
};
}
#[test]
fn test_uri_safe_string_creation() {
let safe = UriSafeString::encode("hello_world");
assert_eq!(safe.as_ref(), "hello_world");
for reserved in RESERVED_CHARACTERS.chars() {
let encoded_str = UriSafeString::encode(format!("hello_{reserved}_world"));
assert_eq!(encoded_str.to_string(), format!("hello_%{:02X}_world", reserved as u8));
}
}
#[test]
fn debug_delegates_to_inner() {
let safe = UriSafeString::encode("hello");
assert_eq!(format!("{safe:?}"), format!("{:?}", "hello"));
let safe_num = UriSafe::from(42u32);
assert_eq!(format!("{safe_num:?}"), "42");
}
#[test]
fn test_uri_safe_string_from_static() {
const SAFE: UriSafeString = UriSafeString::from_static("hello_world");
assert_eq!(SAFE.as_str(), "hello_world");
}
#[test]
fn test_from_string_valid() {
let result = UriSafeString::from("valid_string_123".to_string());
assert_eq!(result.as_str(), "valid_string_123");
}
#[test]
fn encode_owned_no_encoding_reuses_string() {
let safe = UriSafeString::encode_owned("hello_world".to_string());
assert_eq!(safe.as_str(), "hello_world");
}
#[test]
fn encode_owned_encodes_reserved() {
let safe = UriSafeString::encode_owned("hello{world}".to_string());
assert_eq!(safe.as_str(), "hello%7Bworld%7D");
}
#[test]
fn test_raw_string_valid() {
let result = UriSafeString::try_new("valid_string_123".to_string());
assert_eq!(result.unwrap().as_str(), "valid_string_123");
}
#[test]
fn try_new_accepts_valid_percent_encoded_sequence() {
let result = UriSafeString::try_new("hello%3Dworld");
assert!(result.is_ok(), "valid percent-encoded sequence should be accepted");
assert_eq!(result.unwrap().as_str(), "hello%3Dworld");
}
#[test]
fn test_try_new_invalid_percent_encoding() {
let result = UriSafeString::try_new("hello%3world".to_string());
assert!(result.is_err(), "string with invalid percent encoding should be rejected");
let err = result.unwrap_err();
assert_eq!(err.invalid_char, '%', "error should indicate the '%' character as invalid");
assert_eq!(err.position, 5, "error should indicate the position of the invalid '%' character");
}
#[test]
fn uri_safe_error_display_contains_char_and_position() {
let err = UriSafeError {
invalid_char: '{',
position: 5,
};
let msg = err.to_string();
assert!(msg.contains('{'), "error message should contain the invalid character");
assert!(msg.contains('5'), "error message should contain the position");
}
#[test]
fn test_from_string_reserved() {
let result = UriSafeString::from("reserved{string}".to_string());
assert_eq!(result.as_str(), "reserved%7Bstring%7D");
}
#[test]
fn test_raw_string_reserved() {
let result = UriSafeString::try_new("invalid{string}".to_string());
assert!(result.is_err());
result.unwrap_err();
}
#[test]
fn test_from_str_valid() {
let result = UriSafeString::from("valid_str_456");
assert_eq!(result.as_str(), "valid_str_456");
}
#[test]
fn test_from_str_reserved() {
let result = UriSafeString::from("reserved{string}");
assert_eq!(result.as_str(), "reserved%7Bstring%7D");
}
mod from_static_reserved_characters {
use super::*;
test_static_reserved_fail! {
(curly_left, "{"),
(curly_right, "}"),
(slash, "/"),
(colon, ":"),
(question_mark, "?"),
(hash, "#"),
(square_left, "["),
(square_right, "]"),
(at, "@"),
(exclamation_mark, "!"),
(dollar, "$"),
(ampersand, "&"),
(apostrophe, "'"),
(parentheses_left, "("),
(parentheses_right, ")"),
(asterisk, "*"),
(plus, "+"),
(comma, ","),
(semicolon, ";"),
(equal, "=")
}
}
#[test]
fn from_static_urlencoded() {
let result = UriSafeString::from_static("hello%3Dworld");
assert_eq!(result.as_str(), "hello%3Dworld");
}
#[test]
#[should_panic(expected = "string contains unfinished URL encoded character")]
fn from_static_urlencoded_short() {
let _ = UriSafeString::from_static("hello%3");
}
#[test]
#[should_panic(expected = "string contains invalid URL encoding character")]
fn from_static_urlencoded_bad_char() {
let _ = UriSafeString::from_static("hello%3-world");
}
#[cfg(feature = "serde")]
mod serde_tests {
use super::*;
#[test]
fn uri_safe_string_roundtrip() {
let original = UriSafeString::encode("hello world");
let json = serde_json::to_string(&original).unwrap();
assert_eq!(json, r#""hello%20world""#);
let deserialized: UriSafeString = serde_json::from_str(&json).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn uri_safe_string_deserialize_rejects_reserved() {
serde_json::from_str::<UriSafeString>(r#""hello{world}""#).unwrap_err();
}
}
}