use std::net::Ipv4Addr;
const MAX_DECODE_ITERATIONS: usize = 16;
const MAX_URL_LENGTH: usize = 8192;
const MAX_HOST_LENGTH: usize = 255;
const HEX_CHARS: &[u8; 16] = b"0123456789ABCDEF";
fn is_control_char(c: char) -> bool {
c == '\t' || c == '\r' || c == '\n'
}
fn strip_control_chars(s: &str) -> String {
s.chars().filter(|&c| !is_control_char(c)).collect()
}
pub fn canonicalize(url: &str) -> Option<String> {
if url.is_empty() || url.len() > MAX_URL_LENGTH {
return None;
}
let raw = strip_control_chars(url);
let raw = raw.trim();
if raw.is_empty() {
return None;
}
let url_with_scheme = if raw.contains("://") {
raw.to_string()
} else {
format!("http://{}", raw)
};
let url_with_scheme = escape_spaces_in_authority(&url_with_scheme);
let (scheme, raw_host, path, query) = parse_url(&url_with_scheme)?;
if raw_host.is_empty() || raw_host.len() > MAX_HOST_LENGTH {
return None;
}
let host_decoded = decode_until_stable(&raw_host);
let host_ip_normalized = normalize_ip_address(&host_decoded);
let host_ascii = if host_ip_normalized.bytes().any(|b| !b.is_ascii()) {
idna::domain_to_ascii(&host_ip_normalized).unwrap_or_else(|_| host_ip_normalized.clone())
} else {
host_ip_normalized
};
let normalized_host = webrisk_uri_escape(&host_ascii);
let normalized_host = squeeze_char(&normalized_host, '.');
let normalized_host = normalized_host
.trim_start_matches('.')
.trim_end_matches('.')
.to_lowercase();
let path = if path.is_empty() {
"/".to_string()
} else {
path
};
let path_ends_with_slash = path.ends_with('/');
let dots_normalized = normalize_dots_in_paths(&path);
let mut normalized_path = normalize_component_encoding(&dots_normalized);
if path_ends_with_slash && !normalized_path.ends_with('/') {
normalized_path.push('/');
}
let normalized_query = match query {
Some(q) => format!("?{}", q),
None => String::new(),
};
Some(format!(
"{}://{}{}{}",
scheme, normalized_host, normalized_path, normalized_query
))
}
fn escape_spaces_in_authority(url: &str) -> String {
if let Some(scheme_end) = url.find("://") {
let authority_start = scheme_end + 3;
let after_scheme = &url[authority_start..];
let authority_end = after_scheme.find('/').unwrap_or(after_scheme.len());
let authority = &after_scheme[..authority_end];
let rest = &after_scheme[authority_end..];
let escaped = authority.replace(' ', "%20");
format!("{}{}{}", &url[..authority_start], escaped, rest)
} else {
url.to_string()
}
}
fn parse_url(url: &str) -> Option<(String, String, String, Option<String>)> {
let scheme_end = url.find("://")?;
let scheme = &url[..scheme_end];
if scheme.is_empty() || !scheme.as_bytes()[0].is_ascii_alphabetic() {
return None;
}
if !scheme
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'+' || b == b'-' || b == b'.')
{
return None;
}
let after_scheme = &url[scheme_end + 3..];
let slash_idx = after_scheme.find('/');
let authority = match slash_idx {
Some(idx) => &after_scheme[..idx],
None => after_scheme,
};
let path_and_rest = match slash_idx {
Some(idx) => &after_scheme[idx..],
None => "/",
};
let path_and_query = path_and_rest.split('#').next().unwrap_or("/");
let (path, query) = if let Some(q_idx) = path_and_query.find('?') {
(
&path_and_query[..q_idx],
Some(path_and_query[q_idx + 1..].to_string()),
)
} else {
(path_and_query, None)
};
let path = if path.is_empty() { "/" } else { path };
let host_and_port = match authority.rfind('@') {
Some(idx) => &authority[idx + 1..],
None => authority,
};
let host = if host_and_port.starts_with('[') {
match host_and_port.find(']') {
Some(idx) => &host_and_port[..=idx],
None => host_and_port,
}
} else {
match host_and_port.rfind(':') {
Some(idx) => &host_and_port[..idx],
None => host_and_port,
}
};
Some((
scheme.to_string(),
host.to_string(),
path.to_string(),
query,
))
}
fn normalize_ip_address(host: &str) -> String {
let parts: Vec<&str> = host.split('.').collect();
if parts.len() == 1 {
if let Some(n) = parse_ip_part(parts[0]) {
if n <= 0xFFFF_FFFF {
return Ipv4Addr::from(n as u32).to_string();
}
}
return host.to_string();
}
if parts.len() < 2 || parts.len() > 4 {
return host.to_string();
}
let nums: Option<Vec<u64>> = parts.iter().map(|p| parse_ip_part(p)).collect();
let nums = match nums {
Some(n) => n,
None => return host.to_string(),
};
let ip_int: u64 = match nums.len() {
2 => {
if nums[0] > 0xFF || nums[1] > 0xFF_FFFF {
return host.to_string();
}
(nums[0] << 24) | (nums[1] & 0xFF_FFFF)
}
3 => {
if nums[0] > 0xFF || nums[1] > 0xFF || nums[2] > 0xFFFF {
return host.to_string();
}
(nums[0] << 24) | ((nums[1] & 0xFF) << 16) | (nums[2] & 0xFFFF)
}
4 => {
if nums.iter().any(|&n| n > 0xFF) {
return host.to_string();
}
(nums[0] << 24) | ((nums[1] & 0xFF) << 16) | ((nums[2] & 0xFF) << 8) | (nums[3] & 0xFF)
}
_ => return host.to_string(),
};
if ip_int <= 0xFFFF_FFFF {
Ipv4Addr::from(ip_int as u32).to_string()
} else {
host.to_string()
}
}
fn parse_ip_part(s: &str) -> Option<u64> {
if s.is_empty() {
return None;
}
if s.starts_with("0x") || s.starts_with("0X") {
if s.len() <= 2 {
return None;
}
u64::from_str_radix(&s[2..], 16).ok()
} else if s.len() > 1
&& s.starts_with('0')
&& s.bytes().skip(1).all(|b| (b'0'..=b'7').contains(&b))
{
u64::from_str_radix(s, 8).ok()
} else {
s.parse::<u64>().ok()
}
}
fn custom_decode_uri_component(input: &str) -> String {
let bytes = input.as_bytes();
let mut result: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%'
&& i + 2 < bytes.len()
&& bytes[i + 1].is_ascii_hexdigit()
&& bytes[i + 2].is_ascii_hexdigit()
{
let mut encoded_bytes = Vec::new();
while i < bytes.len()
&& bytes[i] == b'%'
&& i + 2 < bytes.len()
&& bytes[i + 1].is_ascii_hexdigit()
&& bytes[i + 2].is_ascii_hexdigit()
{
encoded_bytes.push(hex_pair_to_byte(bytes[i + 1], bytes[i + 2]));
i += 3;
}
let decoded = String::from_utf8_lossy(&encoded_bytes);
result.extend_from_slice(decoded.as_bytes());
} else {
result.push(bytes[i]);
i += 1;
}
}
String::from_utf8_lossy(&result).into_owned()
}
fn decode_until_stable(input: &str) -> String {
let mut value = input.to_string();
for _ in 0..MAX_DECODE_ITERATIONS {
let decoded = strip_control_chars(&custom_decode_uri_component(&value));
if decoded == value {
break;
}
value = decoded;
}
value
}
fn normalize_component_encoding(input: &str) -> String {
webrisk_uri_escape(&decode_until_stable(input))
}
fn normalize_dots_in_paths(path: &str) -> String {
let segments: Vec<&str> = path.split('/').collect();
let mut new_segments: Vec<&str> = Vec::new();
for seg in &segments {
if seg.is_empty() || *seg == "." {
continue;
}
if *seg == ".." {
new_segments.pop();
} else {
new_segments.push(seg);
}
}
format!("/{}", new_segments.join("/"))
}
fn webrisk_uri_escape(s: &str) -> String {
let bytes = s.as_bytes();
let mut out = String::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%'
&& i + 2 < bytes.len()
&& bytes[i + 1].is_ascii_hexdigit()
&& bytes[i + 2].is_ascii_hexdigit()
{
out.push('%');
out.push((bytes[i + 1] as char).to_ascii_uppercase());
out.push((bytes[i + 2] as char).to_ascii_uppercase());
i += 3;
} else if should_escape(bytes[i]) {
push_percent_encoded(&mut out, bytes[i]);
i += 1;
} else {
out.push(bytes[i] as char);
i += 1;
}
}
out
}
fn push_percent_encoded(out: &mut String, byte: u8) {
out.push('%');
out.push(HEX_CHARS[(byte >> 4) as usize] as char);
out.push(HEX_CHARS[(byte & 0x0F) as usize] as char);
}
fn should_escape(b: u8) -> bool {
b <= 32 || b > 127 || b == b'%' || b == b'#'
}
fn hex_pair_to_byte(h1: u8, h2: u8) -> u8 {
(hex_val(h1) << 4) | hex_val(h2)
}
fn hex_val(b: u8) -> u8 {
match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => b - b'a' + 10,
b'A'..=b'F' => b - b'A' + 10,
_ => 0,
}
}
fn squeeze_char(s: &str, c: char) -> String {
let mut result = String::with_capacity(s.len());
let mut prev_was_c = false;
for ch in s.chars() {
if ch == c {
if !prev_was_c {
result.push(ch);
}
prev_was_c = true;
} else {
result.push(ch);
prev_was_c = false;
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_ip_single_int() {
assert_eq!(normalize_ip_address("3279880203"), "195.127.0.11");
}
#[test]
fn test_normalize_ip_passthrough() {
assert_eq!(normalize_ip_address("example.com"), "example.com");
}
#[test]
fn test_normalize_ip_rejects_oversize_octet() {
assert_eq!(normalize_ip_address("1.2.3.256"), "1.2.3.256");
}
#[test]
fn test_normalize_dots_in_paths() {
assert_eq!(normalize_dots_in_paths("/a/b/../c"), "/a/c");
assert_eq!(normalize_dots_in_paths("/a/./b"), "/a/b");
assert_eq!(normalize_dots_in_paths("/blah/.."), "/");
assert_eq!(normalize_dots_in_paths("//twoslashes"), "/twoslashes");
}
#[test]
fn test_custom_decode_basic() {
assert_eq!(custom_decode_uri_component("%41%42%43"), "ABC");
}
#[test]
fn test_webrisk_uri_escape_preserves_valid() {
assert_eq!(webrisk_uri_escape("abc%20def"), "abc%20def");
}
#[test]
fn test_webrisk_uri_escape_uppercases_hex() {
assert_eq!(webrisk_uri_escape("abc%2fdef"), "abc%2Fdef");
}
#[test]
fn test_squeeze_char() {
assert_eq!(squeeze_char("a...b...c", '.'), "a.b.c");
}
#[test]
fn test_strip_control_chars() {
assert_eq!(strip_control_chars("foo\tbar\rbaz\n2"), "foobarbaz2");
}
}