use std::{
borrow::Cow,
collections::HashMap,
fmt::{self, Write},
};
#[derive(Debug, Clone, Copy)]
pub enum Arc {
EastToNorth,
EastToSouth,
NorthToEast,
NorthToWest,
SouthToEast,
SouthToWest,
WestToNorth,
WestToSouth,
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum HDir {
#[default]
LTR,
RTL,
}
impl HDir {
#[must_use]
pub fn invert(self) -> Self {
match self {
HDir::LTR => HDir::RTL,
HDir::RTL => HDir::LTR,
}
}
}
pub struct Renderer<'a> {
out: &'a mut dyn fmt::Write,
}
pub struct StartTag<'a, 'b> {
renderer: &'a mut Renderer<'b>,
}
struct EscapingWriter<'a> {
out: &'a mut dyn fmt::Write,
}
impl<'a> Renderer<'a> {
pub fn new(out: &'a mut dyn fmt::Write) -> Self {
Self { out }
}
pub fn start_element<'b>(&'b mut self, name: &str) -> Result<StartTag<'b, 'a>, fmt::Error> {
validate_tag_name(name)?;
self.out.write_char('<')?;
self.out.write_str(name)?;
Ok(StartTag { renderer: self })
}
pub fn end_element(&mut self, name: &str) -> fmt::Result {
validate_tag_name(name)?;
self.out.write_str("</")?;
self.out.write_str(name)?;
self.out.write_str(">\n")
}
pub fn write_text(&mut self, text: &str) -> fmt::Result {
let mut escaping = EscapingWriter { out: self.out };
escaping.write_str(text)
}
pub fn write_raw(&mut self, text: &str) -> fmt::Result {
self.out.write_str(text)
}
pub fn write_display(&mut self, display: impl fmt::Display) -> fmt::Result {
write!(self.out, "{display}")
}
pub fn path(&mut self, path: &PathData) -> fmt::Result {
let mut tag = self.start_element("path")?;
tag.attr("d", path)?;
tag.finish_empty()
}
pub fn path_with_class(&mut self, path: &PathData, class: &str) -> fmt::Result {
let mut tag = self.start_element("path")?;
tag.attr("d", path)?;
tag.attr("class", class)?;
tag.finish_empty()
}
pub fn text_element(
&mut self,
name: &str,
text: &str,
configure: impl FnOnce(&mut StartTag<'_, 'a>) -> fmt::Result,
) -> fmt::Result {
let mut tag = self.start_element(name)?;
configure(&mut tag)?;
tag.finish()?;
self.write_text(text)?;
self.end_element(name)
}
pub fn raw_text_element(
&mut self,
name: &str,
text: &str,
configure: impl FnOnce(&mut StartTag<'_, 'a>) -> fmt::Result,
) -> fmt::Result {
let mut tag = self.start_element(name)?;
configure(&mut tag)?;
tag.finish()?;
self.write_raw(text)?;
self.end_element(name)
}
}
impl StartTag<'_, '_> {
pub fn attr(&mut self, key: impl fmt::Display, value: impl fmt::Display) -> fmt::Result {
self.renderer.out.write_char(' ')?;
{
let mut escaping = EscapingWriter {
out: self.renderer.out,
};
write!(&mut escaping, "{key}")?;
}
self.renderer.out.write_str("=\"")?;
{
let mut escaping = EscapingWriter {
out: self.renderer.out,
};
write!(&mut escaping, "{value}")?;
}
self.renderer.out.write_char('"')
}
pub fn attr_hashmap(&mut self, attrs: &HashMap<String, String>) -> fmt::Result {
let mut attrs = attrs.iter().collect::<Vec<_>>();
attrs.sort_by_key(|(k, _)| *k);
for (key, value) in attrs {
self.attr(key, value)?;
}
Ok(())
}
pub fn finish(self) -> fmt::Result {
self.renderer.out.write_str(">\n")
}
pub fn finish_empty(self) -> fmt::Result {
self.renderer.out.write_str("/>\n")
}
}
impl fmt::Write for EscapingWriter<'_> {
fn write_str(&mut self, s: &str) -> fmt::Result {
write_escaped_minimal(self.out, s)
}
}
fn validate_tag_name(name: &str) -> fmt::Result {
let mut chars = name.chars();
let Some(first) = chars.next() else {
return Err(fmt::Error);
};
if !(first.is_ascii_alphabetic() || first == '_' || first == ':') {
return Err(fmt::Error);
}
if chars.all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | ':' | '-' | '.')) {
Ok(())
} else {
Err(fmt::Error)
}
}
pub struct PathData {
text: String,
h_dir: HDir,
}
impl PathData {
#[must_use]
pub fn new(h_dir: HDir) -> Self {
Self {
text: String::new(),
h_dir,
}
}
#[must_use]
pub fn into_path(self) -> Element {
Element::new("path").set("d", &self.text)
}
#[must_use]
pub fn move_to(mut self, x: i64, y: i64) -> Self {
write!(self.text, " M {x} {y}").unwrap();
self
}
#[must_use]
pub fn move_rel(mut self, x: i64, y: i64) -> Self {
write!(self.text, " m {x} {y}").unwrap();
self
}
#[must_use]
pub fn line_rel(mut self, x: i64, y: i64) -> Self {
write!(self.text, " l {x} {y}").unwrap();
self
}
#[must_use]
pub fn horizontal(mut self, h: i64) -> Self {
write!(self.text, " h {h}").unwrap();
match (h > 50, h < -50, self.h_dir) {
(true, _, HDir::LTR) => self
.move_rel(-(h / 2 - 3), 0)
.line_rel(-5, -5)
.move_rel(0, 10)
.line_rel(5, -5)
.move_rel(h / 2 - 3, 0),
(true, _, HDir::RTL) => self
.move_rel(-(h / 2 + 3), 0)
.line_rel(5, -5)
.move_rel(0, 10)
.line_rel(-5, -5)
.move_rel(h / 2 + 3, 0),
(_, true, HDir::LTR) => self
.move_rel(-(h / 2 - 3), 0)
.line_rel(5, -5)
.move_rel(0, 10)
.line_rel(-5, -5)
.move_rel(h / 2 - 3, 0),
(_, true, HDir::RTL) => self
.move_rel(-(h / 2 + 3), 0)
.line_rel(-5, -5)
.move_rel(0, 10)
.line_rel(5, -5)
.move_rel(h / 2 + 3, 0),
(false, false, _) => self,
}
}
#[must_use]
pub fn vertical(mut self, h: i64) -> Self {
write!(self.text, " v {h}").unwrap();
if h > 50 {
self.move_rel(0, -(h / 2 - 3))
.line_rel(-5, -5)
.move_rel(10, 0)
.line_rel(-5, 5)
.move_rel(0, h / 2 - 3)
} else if h < -50 {
self.move_rel(0, -(h / 2 - 3))
.line_rel(-5, 5)
.move_rel(10, 0)
.line_rel(-5, -5)
.move_rel(0, h / 2 - 3)
} else {
self
}
}
#[must_use]
pub fn arc(mut self, radius: i64, kind: Arc) -> Self {
let (sweep, x, y) = match kind {
Arc::EastToNorth => (1, -radius, -radius),
Arc::EastToSouth => (0, -radius, radius),
Arc::NorthToEast => (0, radius, radius),
Arc::NorthToWest => (1, -radius, radius),
Arc::SouthToEast => (1, radius, -radius),
Arc::SouthToWest => (0, -radius, -radius),
Arc::WestToNorth => (0, radius, -radius),
Arc::WestToSouth => (1, radius, radius),
};
write!(self.text, " a {radius} {radius} 0 0 {sweep} {x} {y}").unwrap();
self
}
}
impl fmt::Display for PathData {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
write!(f, "{}", self.text)
}
}
#[derive(Debug, Clone)]
pub struct Element {
name: String,
attributes: HashMap<String, String>,
text: Option<String>,
children: Vec<Element>,
siblings: Vec<Element>,
}
impl Element {
pub fn new<T>(name: &T) -> Self
where
T: ToString + ?Sized,
{
Self {
name: name.to_string(),
attributes: HashMap::default(),
text: None,
children: Vec::default(),
siblings: Vec::default(),
}
}
#[must_use]
pub fn set<K, V>(mut self, key: &K, value: &V) -> Self
where
K: ToString + ?Sized,
V: ToString + ?Sized,
{
self.attributes.insert(key.to_string(), value.to_string());
self
}
#[must_use]
pub fn set_all(
mut self,
iter: impl IntoIterator<Item = (impl ToString, impl ToString)>,
) -> Self {
self.attributes.extend(
iter.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string())),
);
self
}
#[must_use]
pub fn text(mut self, text: &str) -> Self {
self.text = Some(encode_minimal(text).into_owned());
self
}
#[must_use]
pub fn raw_text<T>(mut self, text: &T) -> Self
where
T: ToString + ?Sized,
{
self.text = Some(text.to_string());
self
}
#[allow(clippy::should_implement_trait)]
#[must_use]
pub fn add(mut self, e: Self) -> Self {
self.children.push(e);
self
}
pub fn push(&mut self, e: Self) -> &mut Self {
self.children.push(e);
self
}
#[must_use]
pub fn append(mut self, e: Self) -> Self {
self.siblings.push(e);
self
}
#[cfg(not(feature = "visual-debug"))]
#[allow(unused_variables)]
#[doc(hidden)]
#[must_use]
pub fn debug(self, name: &str, x: i64, y: i64, n: &dyn super::Node) -> Self {
self
}
#[cfg(not(feature = "visual-debug"))]
#[allow(unused_variables)]
#[doc(hidden)]
#[must_use]
pub fn debug_with_geometry(
self,
name: &str,
x: i64,
y: i64,
geo: &super::NodeGeometry,
) -> Self {
self
}
#[cfg(feature = "visual-debug")]
pub fn debug(self, name: &str, x: i64, y: i64, n: &dyn super::Node) -> Self {
self.set("railroad:type", &name)
.set("railroad:x", &x)
.set("railroad:y", &y)
.set("railroad:entry_height", &n.entry_height())
.set("railroad:height", &n.height())
.set("railroad:width", &n.width())
.add(Element::new("title").text(name))
.append(
Element::new("path")
.set(
"d",
&PathData::new(HDir::LTR)
.move_to(x, y)
.horizontal(n.width())
.vertical(5)
.move_rel(-n.width(), -5)
.vertical(n.height())
.horizontal(5)
.move_rel(-5, -n.height())
.move_rel(0, n.entry_height())
.horizontal(10),
)
.set("class", "debug"),
)
}
#[cfg(feature = "visual-debug")]
pub fn debug_with_geometry(
self,
name: &str,
x: i64,
y: i64,
geo: &super::NodeGeometry,
) -> Self {
self.set("railroad:type", &name)
.set("railroad:x", &x)
.set("railroad:y", &y)
.set("railroad:entry_height", &geo.entry_height)
.set("railroad:height", &geo.height)
.set("railroad:width", &geo.width)
.add(Element::new("title").text(name))
.append(
Element::new("path")
.set(
"d",
&PathData::new(HDir::LTR)
.move_to(x, y)
.horizontal(geo.width)
.vertical(5)
.move_rel(-geo.width, -5)
.vertical(geo.height)
.horizontal(5)
.move_rel(-5, -geo.height)
.move_rel(0, geo.entry_height)
.horizontal(10),
)
.set("class", "debug"),
)
}
}
impl ::std::fmt::Display for Element {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
write!(f, "<{}", self.name)?;
let mut attrs = self.attributes.iter().collect::<Vec<_>>();
attrs.sort_by_key(|(k, _)| *k);
for (k, v) in attrs {
write!(f, " {}=\"{}\"", encode_minimal(k), encode_minimal(v))?;
}
if self.text.is_none() && self.children.is_empty() {
f.write_str("/>\n")?;
} else {
f.write_str(">\n")?;
}
if let Some(t) = &self.text {
f.write_str(t)?;
}
for child in &self.children {
write!(f, "{child}")?;
}
if self.text.is_some() || !self.children.is_empty() {
writeln!(f, "</{}>", self.name)?;
}
for sibling in &self.siblings {
write!(f, "{sibling}")?;
}
Ok(())
}
}
fn minimal_entity(c: char) -> Option<&'static str> {
match c {
'"' => Some("""),
'&' => Some("&"),
'<' => Some("<"),
'>' => Some(">"),
'\'' => Some("'"),
_ => None,
}
}
fn write_escaped_minimal(f: &mut (impl fmt::Write + ?Sized), inp: &str) -> fmt::Result {
let mut last_idx = 0;
for (idx, c) in inp.char_indices() {
if let Some(entity) = minimal_entity(c) {
f.write_str(&inp[last_idx..idx])?;
f.write_str(entity)?;
last_idx = idx + 1;
}
}
f.write_str(&inp[last_idx..])
}
#[must_use]
pub fn encode_minimal(inp: &str) -> Cow<'_, str> {
let mut buf = String::new();
let mut last_idx = 0;
for (idx, c) in inp.char_indices() {
if let Some(entity) = minimal_entity(c) {
buf.push_str(&inp[last_idx..idx]);
buf.push_str(entity);
last_idx = idx + 1;
}
}
if buf.is_empty() {
Cow::Borrowed(inp)
} else {
buf.push_str(&inp[last_idx..]);
Cow::Owned(buf)
}
}
const ENTITIES: [Option<&'static str>; 256] = [
Some("�"),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some("	"),
Some("
"),
Some(""),
Some(""),
Some("
"),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(""),
Some(" "),
Some("!"),
Some("""),
Some("#"),
Some("$"),
Some("%"),
Some("&"),
Some("'"),
Some("("),
Some(")"),
Some("*"),
Some("+"),
Some(","),
Some("-"),
Some("."),
Some("/"),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
Some(":"),
Some(";"),
Some("<"),
Some("="),
Some(">"),
Some("?"),
Some("@"),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
Some("["),
Some("\"),
Some("]"),
Some("^"),
Some("_"),
Some("`"),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
Some("{"),
Some("|"),
Some("}"),
Some("~"),
Some(""),
Some("€"),
Some(""),
Some("‚"),
Some("ƒ"),
Some("„"),
Some("…"),
Some("†"),
Some("‡"),
Some("ˆ"),
Some("‰"),
Some("Š"),
Some("‹"),
Some("Œ"),
Some(""),
Some("Ž"),
Some(""),
Some(""),
Some("‘"),
Some("’"),
Some("“"),
Some("”"),
Some("•"),
Some("–"),
Some("—"),
Some("˜"),
Some("™"),
Some("š"),
Some("›"),
Some("œ"),
Some(""),
Some("ž"),
Some("Ÿ"),
Some(" "),
Some("¡"),
Some("¢"),
Some("£"),
Some("¤"),
Some("¥"),
Some("¦"),
Some("§"),
Some("¨"),
Some("©"),
Some("ª"),
Some("«"),
Some("¬"),
Some("­"),
Some("®"),
Some("¯"),
Some("°"),
Some("±"),
Some("²"),
Some("³"),
Some("´"),
Some("µ"),
Some("¶"),
Some("·"),
Some("¸"),
Some("¹"),
Some("º"),
Some("»"),
Some("¼"),
Some("½"),
Some("¾"),
Some("¿"),
Some("À"),
Some("Á"),
Some("Â"),
Some("Ã"),
Some("Ä"),
Some("Å"),
Some("Æ"),
Some("Ç"),
Some("È"),
Some("É"),
Some("Ê"),
Some("Ë"),
Some("Ì"),
Some("Í"),
Some("Î"),
Some("Ï"),
Some("Ð"),
Some("Ñ"),
Some("Ò"),
Some("Ó"),
Some("Ô"),
Some("Õ"),
Some("Ö"),
Some("×"),
Some("Ø"),
Some("Ù"),
Some("Ú"),
Some("Û"),
Some("Ü"),
Some("Ý"),
Some("Þ"),
Some("ß"),
Some("à"),
Some("á"),
Some("â"),
Some("ã"),
Some("ä"),
Some("å"),
Some("æ"),
Some("ç"),
Some("è"),
Some("é"),
Some("ê"),
Some("ë"),
Some("ì"),
Some("í"),
Some("î"),
Some("ï"),
Some("ð"),
Some("ñ"),
Some("ò"),
Some("ó"),
Some("ô"),
Some("õ"),
Some("ö"),
Some("÷"),
Some("ø"),
Some("ù"),
Some("ú"),
Some("û"),
Some("ü"),
Some("ý"),
Some("þ"),
Some("ÿ"),
];
#[must_use]
pub fn encode_attribute(inp: &str) -> Cow<'_, str> {
let mut buf = String::new();
let mut last_idx = 0;
for (idx, c) in inp.char_indices() {
if let Ok(b) = <char as TryInto<u8>>::try_into(c)
&& let Some(entity) = ENTITIES[b as usize]
{
let fragment = &inp[last_idx..idx];
buf.reserve(fragment.len() + entity.len());
buf.push_str(fragment);
buf.push_str(entity);
last_idx = idx + c.len_utf8();
}
}
if buf.is_empty() {
Cow::Borrowed(inp)
} else {
buf.push_str(&inp[last_idx..]);
Cow::Owned(buf)
}
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
#[test]
fn encode_minimal() {
for (inp, expected) in [
("'a", Some("'a")),
("", None),
("'", Some("'")),
("a'", Some("a'")),
("hello world!", None),
("&", Some("&")),
("<br>", Some("<br>")),
(
"\"a\" is not \"b\"",
Some(""a" is not "b""),
),
] {
let result = super::encode_minimal(inp);
assert_eq!(result, expected.unwrap_or(inp));
assert!(matches!(
(expected, result),
(None, Cow::Borrowed(_)) | (Some(_), Cow::Owned(_))
));
}
}
#[test]
fn test_encode_attribute() {
let data = [
("", None),
("foobar", None),
("0 3px", Some("0 3px")),
("<img \"\"\">", Some("<img """>")),
("hej; hå", Some("hej; hå")),
("d-none m-0", Some("d-none m-0")),
(
"\"bread\" & 奶油",
Some(""bread" & 奶油"),
),
];
for &(input, expected) in data.iter() {
let actual = super::encode_attribute(input);
assert_eq!(&actual, expected.unwrap_or(input));
assert!(matches!(
(expected, actual),
(Some(_), Cow::Owned(_)) | (None, Cow::Borrowed(_))
));
}
}
const PAYLOADS: &[&str] = &[
r#""><script>alert(1)</script>"#,
r#"' onload='alert(1)"#,
r#"<not-an-entity"#,
r#"</style><script>bad</script>"#,
r#"foo & bar"#,
r#"foo"bar"#,
];
#[test]
fn element_attribute_value_no_injection() {
for payload in PAYLOADS {
let svg = format!("{}", super::Element::new("g").set("data-x", *payload));
assert!(
!svg.contains(payload),
"raw payload appeared in SVG for value {payload:?}"
);
}
}
#[test]
fn element_attribute_key_no_injection() {
for payload in PAYLOADS {
let svg = format!("{}", super::Element::new("g").set(*payload, "value"));
assert!(
!svg.contains(payload),
"raw payload appeared in SVG for key {payload:?}"
);
}
}
#[test]
fn element_text_no_injection() {
for payload in PAYLOADS {
let svg = format!("{}", super::Element::new("text").text(payload));
assert!(
!svg.contains(payload),
"raw payload appeared in SVG text for {payload:?}"
);
}
}
#[test]
fn renderer_rejects_invalid_tag_names() {
for payload in [
"",
"path onclick=\"alert(1)\"",
"path><script>",
"9path",
"svg tag",
] {
let mut output = String::new();
let mut renderer = super::Renderer::new(&mut output);
assert!(renderer.start_element(payload).is_err());
assert!(renderer.end_element(payload).is_err());
}
}
}