use url::Url;
use std::collections::BTreeMap;
pub struct UrlNormalizer {
remove_fragments: bool,
sort_query_params: bool,
remove_tracking_params: bool,
tracking_params: Vec<String>,
force_https: bool,
remove_www: bool,
remove_trailing_slash: bool,
remove_default_port: bool,
}
impl Default for UrlNormalizer {
fn default() -> Self {
Self {
remove_fragments: true,
sort_query_params: true,
remove_tracking_params: true,
tracking_params: vec![
"utm_source".to_string(),
"utm_medium".to_string(),
"utm_campaign".to_string(),
"utm_term".to_string(),
"utm_content".to_string(),
"fbclid".to_string(),
"gclid".to_string(),
"ref".to_string(),
"_ga".to_string(),
],
force_https: false,
remove_www: false,
remove_trailing_slash: false,
remove_default_port: true,
}
}
}
impl UrlNormalizer {
pub fn new() -> Self {
Self::default()
}
pub fn normalize(&self, url: &Url) -> Url {
let mut url = url.clone();
if self.remove_fragments {
url.set_fragment(None);
}
if self.remove_default_port {
if let Some(port) = url.port() {
let default_port = match url.scheme() {
"http" => 80,
"https" => 443,
_ => 0,
};
if port == default_port {
let _ = url.set_port(None);
}
}
}
if self.force_https && url.scheme() == "http" {
let _ = url.set_scheme("https");
}
if self.remove_www {
if let Some(host) = url.host_str() {
if host.starts_with("www.") {
let new_host = host[4..].to_string();
let _ = url.set_host(Some(&new_host));
}
}
}
if self.remove_trailing_slash {
let path = url.path().to_string();
if path.len() > 1 && path.ends_with('/') {
url.set_path(&path[..path.len() - 1]);
}
}
if self.sort_query_params || self.remove_tracking_params {
let query_pairs: Vec<(String, String)> = url
.query_pairs()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
if !query_pairs.is_empty() {
let mut filtered: BTreeMap<String, String> = BTreeMap::new();
for (key, value) in query_pairs {
if self.remove_tracking_params
&& self.tracking_params.iter().any(|t| t.eq_ignore_ascii_case(&key))
{
continue;
}
filtered.insert(key, value);
}
if filtered.is_empty() {
url.set_query(None);
} else {
let query: String = filtered
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
url.set_query(Some(&query));
}
}
}
url
}
pub fn resolve(&self, base: &Url, relative: &str) -> Option<Url> {
base.join(relative).ok().map(|u| self.normalize(&u))
}
pub fn are_equal(&self, a: &Url, b: &Url) -> bool {
self.normalize(a) == self.normalize(b)
}
}
pub fn extract_domain(url: &Url) -> Option<String> {
url.host_str().map(String::from)
}
pub fn extract_base_domain(url: &Url) -> Option<String> {
url.host_str().map(|h| {
let parts: Vec<&str> = h.split('.').collect();
if parts.len() > 2 {
parts[parts.len() - 2..].join(".")
} else {
h.to_string()
}
})
}
pub fn is_same_domain(a: &Url, b: &Url) -> bool {
a.host_str() == b.host_str()
}
pub fn is_same_base_domain(a: &Url, b: &Url) -> bool {
extract_base_domain(a) == extract_base_domain(b)
}