use std::borrow::Cow;
use percent_encoding::{AsciiSet, CONTROLS, percent_decode_str, utf8_percent_encode};
use crate::error::NixUriError;
pub(crate) const QUERY_VALUE: &AsciiSet = &CONTROLS
.add(b' ')
.add(b'!')
.add(b'"')
.add(b'#')
.add(b'$')
.add(b'%')
.add(b'&')
.add(b'\'')
.add(b'(')
.add(b')')
.add(b'*')
.add(b'+')
.add(b',')
.add(b';')
.add(b'<')
.add(b'=')
.add(b'>')
.add(b'[')
.add(b'\\')
.add(b']')
.add(b'^')
.add(b'`')
.add(b'{')
.add(b'|')
.add(b'}');
pub(crate) const FRAGMENT: &AsciiSet = &QUERY_VALUE.add(b':').add(b'@').add(b'/').add(b'?');
pub(crate) fn encode_query(s: &str) -> Cow<'_, str> {
utf8_percent_encode(s, QUERY_VALUE).into()
}
pub(crate) fn encode_fragment(s: &str) -> Cow<'_, str> {
utf8_percent_encode(s, FRAGMENT).into()
}
pub(crate) fn encode_path_segment(s: &str) -> Cow<'_, str> {
if !s.contains('/') {
return Cow::Borrowed(s);
}
s.replace('/', "%2F").into()
}
pub(crate) fn decode_percent(s: &str) -> Result<Cow<'_, str>, NixUriError> {
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' {
if i + 3 > bytes.len()
|| !bytes[i + 1].is_ascii_hexdigit()
|| !bytes[i + 2].is_ascii_hexdigit()
{
return Err(NixUriError::InvalidUrl(format!(
"invalid percent-encoding in '{s}'"
)));
}
i += 3;
} else {
i += 1;
}
}
percent_decode_str(s)
.decode_utf8()
.map_err(|_| NixUriError::InvalidUrl(format!("invalid utf-8 percent-encoding in '{s}'")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn space_is_encoded_in_query_value() {
assert_eq!(encode_query("foo bar"), "foo%20bar");
}
#[test]
fn slash_is_kept_in_query_value() {
assert_eq!(encode_query("foo/bar"), "foo/bar");
}
#[test]
fn slash_is_encoded_in_fragment() {
assert_eq!(encode_fragment("foo/bar"), "foo%2Fbar");
}
#[test]
fn non_ascii_is_encoded() {
assert_eq!(encode_query("fÖö"), "f%C3%96%C3%B6");
assert_eq!(encode_fragment("fÖö"), "f%C3%96%C3%B6");
}
#[test]
fn decode_round_trip() {
let encoded = encode_query("foo bar/baz");
assert_eq!(decode_percent(&encoded).unwrap(), "foo bar/baz");
}
#[test]
fn decode_rejects_truncated() {
assert!(decode_percent("foo%2").is_err());
assert!(decode_percent("foo%").is_err());
}
#[test]
fn decode_rejects_non_hex() {
assert!(decode_percent("foo%XY").is_err());
assert!(decode_percent("foo%2Z").is_err());
}
#[test]
fn decode_passes_valid_full() {
assert_eq!(decode_percent("foo%20bar").unwrap(), "foo bar");
}
}