use super::{Pass, PassResult};
use crate::ast::{Document, NodeId, NodeKind};
const GROUP_ONLY_ATTRS: &[&str] = &["clip-path", "mask", "filter"];
pub struct CollapseGroups;
impl Pass for CollapseGroups {
fn name(&self) -> &'static str {
"collapseGroups"
}
fn run(&self, doc: &mut Document) -> PassResult {
let mut changed = false;
let mut ids = doc.traverse();
ids.reverse();
for id in ids {
if doc.node(id).removed {
continue;
}
if let NodeKind::Element(ref elem) = doc.node(id).kind {
if elem.name != "g" {
continue;
}
let children: Vec<NodeId> = doc.children(id).collect();
if elem.attributes.is_empty()
&& elem.prefix.is_none()
&& let Some(parent_id) = doc.node(id).parent
{
hoist_children(doc, id, parent_id);
doc.node_mut(id).removed = true;
changed = true;
continue;
}
if children.len() == 1 {
let child_id = children[0];
if let NodeKind::Element(ref child_elem) = doc.node(child_id).kind {
let g_has_group_only = elem
.attributes
.iter()
.any(|a| GROUP_ONLY_ATTRS.contains(&a.name.as_str()));
if !g_has_group_only && can_merge_attrs_with_transform(elem, child_elem) {
merge_group_into_child(doc, id, child_id);
changed = true;
continue;
}
}
}
}
}
if changed {
PassResult::Changed
} else {
PassResult::Unchanged
}
}
}
fn hoist_children(doc: &mut Document, group_id: NodeId, parent_id: NodeId) {
let group_children: Vec<NodeId> = doc.node(group_id).children.clone();
let parent = doc.node_mut(parent_id);
let pos = parent.children.iter().position(|&c| c == group_id);
if let Some(pos) = pos {
parent
.children
.splice(pos..=pos, group_children.iter().copied());
for &child in &group_children {
doc.node_mut(child).parent = Some(parent_id);
}
}
}
fn can_merge_attrs_with_transform(
group: &crate::ast::Element,
child: &crate::ast::Element,
) -> bool {
for g_attr in &group.attributes {
if g_attr.name == "transform" && g_attr.prefix.is_none() {
continue;
}
if child
.attributes
.iter()
.any(|a| a.name == g_attr.name && a.prefix == g_attr.prefix)
{
return false;
}
}
true
}
fn merge_group_into_child(doc: &mut Document, group_id: NodeId, child_id: NodeId) {
let group_attrs = if let NodeKind::Element(ref elem) = doc.node(group_id).kind {
elem.attributes.clone()
} else {
return;
};
if let NodeKind::Element(ref mut child_elem) = doc.node_mut(child_id).kind {
for attr in group_attrs {
if attr.name == "transform" && attr.prefix.is_none() {
if let Some(child_tf) = child_elem
.attributes
.iter_mut()
.find(|a| a.name == "transform" && a.prefix.is_none())
{
child_tf.value = format!("{} {}", attr.value, child_tf.value);
} else {
child_elem.attributes.push(attr);
}
} else {
child_elem.attributes.push(attr);
}
}
}
if let Some(parent_id) = doc.node(group_id).parent {
hoist_children(doc, group_id, parent_id);
doc.node_mut(group_id).removed = true;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse;
use crate::serializer::serialize;
#[test]
fn collapses_attr_less_group() {
let input =
r#"<svg xmlns="http://www.w3.org/2000/svg"><g><rect/><circle r="5"/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Changed);
let output = serialize(&doc);
assert!(!output.contains("<g>"), "group should be removed: {output}");
assert!(output.contains("<rect/>"));
assert!(output.contains("<circle"));
}
#[test]
fn merges_single_child_attrs() {
let input =
r#"<svg xmlns="http://www.w3.org/2000/svg"><g fill="red"><rect width="10"/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Changed);
let output = serialize(&doc);
assert!(!output.contains("<g"), "group should be removed: {output}");
assert!(output.contains("fill=\"red\""));
assert!(output.contains("width=\"10\""));
}
#[test]
fn keeps_group_with_conflicting_attrs() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g fill="red"><rect fill="blue"/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Unchanged);
}
#[test]
fn collapses_group_with_transform_to_child() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g transform="translate(10,10)"><rect/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Changed);
let output = serialize(&doc);
assert!(
!output.contains("<g"),
"group should be collapsed: {output}"
);
assert!(
output.contains("transform=\"translate(10,10)\""),
"transform should be on rect: {output}"
);
}
#[test]
fn composes_transforms_during_collapse() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g transform="translate(10,10)"><rect transform="scale(2)"/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Changed);
let output = serialize(&doc);
assert!(
!output.contains("<g"),
"group should be collapsed: {output}"
);
assert!(
output.contains("translate(10,10) scale(2)"),
"transforms should be composed: {output}"
);
}
#[test]
fn collapses_nested_groups() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g><g><rect/></g></g></svg>"#;
let mut doc = parse(input).unwrap();
CollapseGroups.run(&mut doc);
CollapseGroups.run(&mut doc);
let output = serialize(&doc);
assert!(
!output.contains("<g"),
"all groups should be removed: {output}"
);
assert!(output.contains("<rect/>"));
}
#[test]
fn keeps_group_with_id() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g id="layer1"><rect/></g></svg>"#;
let mut doc = parse(input).unwrap();
let result = CollapseGroups.run(&mut doc);
assert_eq!(result, PassResult::Changed);
}
#[test]
fn clip_path_blocks_merge() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g clip-path="url(#clip1)"><rect/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Unchanged);
let output = serialize(&doc);
assert!(
output.contains("<g"),
"group with clip-path must be preserved: {output}"
);
}
#[test]
fn mask_blocks_merge() {
let input =
r#"<svg xmlns="http://www.w3.org/2000/svg"><g mask="url(#mask1)"><rect/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Unchanged);
let output = serialize(&doc);
assert!(
output.contains("<g"),
"group with mask must be preserved: {output}"
);
}
#[test]
fn filter_blocks_merge() {
let input =
r#"<svg xmlns="http://www.w3.org/2000/svg"><g filter="url(#blur)"><rect/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Unchanged);
let output = serialize(&doc);
assert!(
output.contains("<g"),
"group with filter must be preserved: {output}"
);
}
#[test]
fn opacity_single_child_merges() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g opacity="0.5"><rect/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Changed);
let output = serialize(&doc);
assert!(
!output.contains("<g"),
"single-child opacity group should collapse: {output}"
);
assert!(
output.contains("opacity=\"0.5\""),
"opacity should be on rect: {output}"
);
}
#[test]
fn opacity_multi_child_preserved() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g opacity="0.5"><rect fill="red"/><rect fill="blue"/></g></svg>"#;
let mut doc = parse(input).unwrap();
assert_eq!(CollapseGroups.run(&mut doc), PassResult::Unchanged);
let output = serialize(&doc);
assert!(
output.contains("<g"),
"multi-child opacity group must be preserved: {output}"
);
}
#[test]
fn inherited_attrs_cascade_correctly() {
let input = r#"<svg xmlns="http://www.w3.org/2000/svg"><g fill="red"><g stroke="blue"><rect/></g><circle r="5"/></g></svg>"#;
let mut doc = parse(input).unwrap();
CollapseGroups.run(&mut doc);
let output = serialize(&doc);
assert!(
output.contains("stroke=\"blue\""),
"stroke should be on rect: {output}"
);
assert!(
output.contains("fill=\"red\""),
"fill should stay on outer group: {output}"
);
assert!(
output.contains("<g"),
"outer group must be preserved: {output}"
);
}
}