use super::{Pass, PassResult};
use crate::ast::{Document, NodeId, NodeKind};
use std::collections::{HashMap, HashSet};
pub struct CleanupIds;
impl Pass for CleanupIds {
fn name(&self) -> &'static str {
"cleanupIds"
}
fn run(&self, doc: &mut Document) -> PassResult {
let ids = doc.traverse();
let mut declarations: HashMap<String, NodeId> = HashMap::new();
for &id in &ids {
if let NodeKind::Element(ref elem) = doc.node(id).kind
&& let Some(id_val) = elem.attr("id")
{
declarations.insert(id_val.to_string(), id);
}
}
if declarations.is_empty() {
return PassResult::Unchanged;
}
let declared_ids: HashSet<String> = declarations.keys().cloned().collect();
let mut referenced: HashSet<String> = HashSet::new();
for &id in &ids {
let node = doc.node(id);
match &node.kind {
NodeKind::Element(elem) => {
for attr in &elem.attributes {
for r in extract_url_refs(&attr.value) {
if declared_ids.contains(&r) {
referenced.insert(r);
}
}
if attr.name == "href"
&& let Some(r) = extract_href_ref(&attr.value)
&& declared_ids.contains(r)
{
referenced.insert(r.to_string());
}
if attr.name == "begin" || attr.name == "end" {
for r in extract_smil_refs(&attr.value, &declared_ids) {
referenced.insert(r);
}
}
}
}
NodeKind::Text(text) | NodeKind::CData(text) => {
if let Some(parent_id) = node.parent
&& let NodeKind::Element(ref parent_elem) = doc.node(parent_id).kind
&& parent_elem.name == "style"
{
for r in extract_css_id_refs(text, &declared_ids) {
referenced.insert(r);
}
}
}
_ => {}
}
}
let mut referenced_ids: Vec<(&String, &NodeId)> = declarations
.iter()
.filter(|(name, _)| referenced.contains(name.as_str()))
.collect();
referenced_ids.sort_by_key(|(_, node_id)| node_id.0);
let mut generator = IdGenerator::new();
let mut rename_map: HashMap<String, String> = HashMap::new();
for (old_name, _) in &referenced_ids {
let new_name = generator.next_id();
if new_name != **old_name {
rename_map.insert((*old_name).clone(), new_name);
}
}
let unreferenced_count = declarations.len() - referenced.len();
if unreferenced_count == 0 && rename_map.is_empty() {
return PassResult::Unchanged;
}
let mut changed = false;
for &id in &ids {
let node = doc.node_mut(id);
match &mut node.kind {
NodeKind::Element(elem) => {
if let Some(id_attr_pos) = elem
.attributes
.iter()
.position(|a| a.name == "id" && a.prefix.is_none())
{
let id_val = elem.attributes[id_attr_pos].value.clone();
if !referenced.contains(&id_val) {
elem.attributes.remove(id_attr_pos);
changed = true;
} else if let Some(new_name) = rename_map.get(&id_val) {
elem.attributes[id_attr_pos].value = new_name.clone();
changed = true;
}
}
for attr in &mut elem.attributes {
if attr.value.contains("url(#") {
let new_val = replace_url_refs(&attr.value, &rename_map);
if new_val != attr.value {
attr.value = new_val;
changed = true;
}
}
if attr.name == "href" && attr.value.starts_with('#') {
let old_ref = &attr.value[1..];
if let Some(new_name) = rename_map.get(old_ref) {
attr.value = format!("#{new_name}");
changed = true;
}
}
if attr.name == "begin" || attr.name == "end" {
let new_val = replace_smil_refs(&attr.value, &rename_map);
if new_val != attr.value {
attr.value = new_val;
changed = true;
}
}
}
}
NodeKind::Text(text) | NodeKind::CData(text) => {
if let Some(parent_id) = node.parent {
let _ = parent_id; let new_text = replace_css_ids(text, &rename_map, &declared_ids);
if new_text != *text {
*text = new_text;
changed = true;
}
}
}
_ => {}
}
}
if changed {
PassResult::Changed
} else {
PassResult::Unchanged
}
}
}
fn extract_url_refs(value: &str) -> Vec<String> {
let mut refs = Vec::new();
let mut search = value;
while let Some(start) = search.find("url(#") {
let rest = &search[start + 5..];
if let Some(end) = rest.find(')') {
let id = rest[..end]
.trim()
.trim_matches(|c: char| c == '\'' || c == '"');
if !id.is_empty() {
refs.push(id.to_string());
}
}
search = &search[start + 5..];
}
refs
}
fn extract_href_ref(value: &str) -> Option<&str> {
let id = value.strip_prefix('#')?;
if id.is_empty() { None } else { Some(id) }
}
fn extract_smil_refs(value: &str, known_ids: &HashSet<String>) -> Vec<String> {
let mut refs = Vec::new();
for part in value.split(';') {
let trimmed = part.trim();
if let Some(dot_pos) = trimmed.find('.') {
let candidate = &trimmed[..dot_pos];
if known_ids.contains(candidate) {
refs.push(candidate.to_string());
}
}
}
refs
}
fn extract_css_id_refs(text: &str, known_ids: &HashSet<String>) -> Vec<String> {
let mut refs = Vec::new();
let bytes = text.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'#' {
let start = i + 1;
let mut end = start;
while end < bytes.len()
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_' || bytes[end] == b'-')
{
end += 1;
}
if end > start {
let candidate = &text[start..end];
if known_ids.contains(candidate) {
refs.push(candidate.to_string());
}
}
i = end;
} else {
i += 1;
}
}
refs
}
fn replace_url_refs(value: &str, rename_map: &HashMap<String, String>) -> String {
let mut result = String::with_capacity(value.len());
let mut search = value;
while let Some(start) = search.find("url(#") {
result.push_str(&search[..start + 5]);
let rest = &search[start + 5..];
if let Some(end) = rest.find(')') {
let raw_id = &rest[..end];
let id = raw_id.trim().trim_matches(|c: char| c == '\'' || c == '"');
if let Some(new_name) = rename_map.get(id) {
result.push_str(new_name);
} else {
result.push_str(raw_id);
}
search = &rest[end..];
} else {
result.push_str(rest);
return result;
}
}
result.push_str(search);
result
}
fn replace_smil_refs(value: &str, rename_map: &HashMap<String, String>) -> String {
let parts: Vec<&str> = value.split(';').collect();
let mut new_parts = Vec::new();
let mut any_changed = false;
for part in parts {
let trimmed = part.trim();
if let Some(dot_pos) = trimmed.find('.') {
let candidate = &trimmed[..dot_pos];
if let Some(new_name) = rename_map.get(candidate) {
new_parts.push(format!("{new_name}{}", &trimmed[dot_pos..]));
any_changed = true;
continue;
}
}
new_parts.push(trimmed.to_string());
}
if any_changed {
new_parts.join(";")
} else {
value.to_string()
}
}
fn replace_css_ids(
text: &str,
rename_map: &HashMap<String, String>,
declared_ids: &HashSet<String>,
) -> String {
let bytes = text.as_bytes();
let mut result = String::with_capacity(text.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'#' {
let start = i + 1;
let mut end = start;
while end < bytes.len()
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_' || bytes[end] == b'-')
{
end += 1;
}
if end > start {
let candidate = &text[start..end];
if let Some(new_name) = rename_map.get(candidate) {
result.push('#');
result.push_str(new_name);
i = end;
continue;
} else if declared_ids.contains(candidate) {
result.push_str(&text[i..end]);
i = end;
continue;
}
}
result.push('#');
i = start;
} else {
result.push(text[i..].chars().next().unwrap());
i += text[i..].chars().next().unwrap().len_utf8();
}
}
result
}
struct IdGenerator {
counter: usize,
}
impl IdGenerator {
fn new() -> Self {
Self { counter: 0 }
}
fn next_id(&mut self) -> String {
let name = encode(self.counter);
self.counter += 1;
name
}
}
fn encode(mut n: usize) -> String {
let mut result = Vec::new();
loop {
result.push(b'a' + (n % 26) as u8);
n /= 26;
if n == 0 {
break;
}
n -= 1;
}
result.reverse();
String::from_utf8(result).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse;
use crate::serializer::serialize;
fn run_pass(input: &str) -> (PassResult, String) {
let mut doc = parse(input).unwrap();
let result = CleanupIds.run(&mut doc);
(result, serialize(&doc))
}
#[test]
fn encode_sequence() {
assert_eq!(encode(0), "a");
assert_eq!(encode(1), "b");
assert_eq!(encode(25), "z");
assert_eq!(encode(26), "aa");
assert_eq!(encode(27), "ab");
assert_eq!(encode(51), "az");
assert_eq!(encode(52), "ba");
assert_eq!(encode(701), "zz");
assert_eq!(encode(702), "aaa");
}
#[test]
fn removes_unreferenced_id() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><rect id=\"unused\" width=\"10\" height=\"10\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(!output.contains("id="));
assert!(!output.contains("unused"));
}
#[test]
fn no_change_without_ids() {
let input =
"<svg xmlns=\"http://www.w3.org/2000/svg\"><rect width=\"10\" height=\"10\"/></svg>";
let (result, _) = run_pass(input);
assert_eq!(result, PassResult::Unchanged);
}
#[test]
fn shortens_url_referenced_id() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><defs><clipPath id=\"clip0_7441_19649\"><rect width=\"10\" height=\"10\"/></clipPath></defs><rect clip-path=\"url(#clip0_7441_19649)\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("id=\"a\""));
assert!(output.contains("url(#a)"));
assert!(!output.contains("clip0_7441_19649"));
}
#[test]
fn preserves_and_shortens_multiple_ids() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><defs><clipPath id=\"clip1\"><rect width=\"10\" height=\"10\"/></clipPath><clipPath id=\"clip2\"><rect width=\"20\" height=\"20\"/></clipPath><clipPath id=\"unused\"><rect width=\"5\" height=\"5\"/></clipPath></defs><rect clip-path=\"url(#clip1)\"/><rect clip-path=\"url(#clip2)\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("id=\"a\""));
assert!(output.contains("id=\"b\""));
assert!(output.contains("url(#a)"));
assert!(output.contains("url(#b)"));
assert!(!output.contains("unused"));
}
#[test]
fn shortens_href_reference() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs><symbol id=\"icon\" viewBox=\"0 0 24 24\"><path d=\"M0 0\"/></symbol></defs><use xlink:href=\"#icon\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("id=\"a\""));
assert!(output.contains("href=\"#a\""));
assert!(!output.contains("icon"));
}
#[test]
fn shortens_plain_href() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><defs><symbol id=\"longname\" viewBox=\"0 0 24 24\"><path d=\"M0 0\"/></symbol></defs><use href=\"#longname\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("id=\"a\""));
assert!(output.contains("href=\"#a\""));
}
#[test]
fn shortens_smil_timing_reference() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><rect id=\"target\" width=\"10\" height=\"10\"/><animate begin=\"target.click\" attributeName=\"opacity\" to=\"0\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("id=\"a\""));
assert!(output.contains("begin=\"a.click\""));
}
#[test]
fn shortens_css_id_reference() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><style>#myid{opacity:.5}</style><rect id=\"myid\" width=\"10\" height=\"10\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("id=\"a\""));
assert!(output.contains("#a{"));
}
#[test]
fn already_short_id_no_change() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><defs><clipPath id=\"a\"><rect width=\"10\" height=\"10\"/></clipPath></defs><rect clip-path=\"url(#a)\"/></svg>";
let (result, _) = run_pass(input);
assert_eq!(result, PassResult::Unchanged);
}
#[test]
fn multiple_url_refs_in_single_attr() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><defs><linearGradient id=\"grad1\"><stop offset=\"0\" stop-color=\"red\"/></linearGradient><linearGradient id=\"grad2\"><stop offset=\"0\" stop-color=\"blue\"/></linearGradient></defs><rect fill=\"url(#grad1)\" stroke=\"url(#grad2)\"/></svg>";
let (result, output) = run_pass(input);
assert_eq!(result, PassResult::Changed);
assert!(output.contains("url(#a)"));
assert!(output.contains("url(#b)"));
}
#[test]
fn full_optimizer_convergence() {
let input = "<svg xmlns=\"http://www.w3.org/2000/svg\"><defs><clipPath id=\"clip0_long_name\"><rect width=\"10\" height=\"10\"/></clipPath></defs><rect clip-path=\"url(#clip0_long_name)\"/></svg>";
let result1 = crate::optimize(input).unwrap();
let result2 = crate::optimize(&result1.data).unwrap();
assert_eq!(result1.data, result2.data, "should converge");
}
}