use html5ever::{
ParseOpts, parse_document,
serialize::{SerializeOpts, TraversalScope, serialize},
tendril::TendrilSink,
};
use markup5ever_rcdom::{Handle, NodeData, RcDom, SerializableHandle};
const VOID_ELEMENTS: &[&str] = &[
"area", "base", "br", "col", "embed", "hr", "img", "input", "link", "meta", "param", "source",
"track", "wbr",
];
pub enum ScriptSource {
Inline(String),
External(String),
}
pub struct Document {
dom: RcDom,
}
pub fn parse(html: &str) -> Document {
let dom = parse_document(RcDom::default(), ParseOpts::default())
.from_utf8()
.read_from(&mut html.as_bytes())
.unwrap();
Document { dom }
}
impl Document {
pub fn collect_meta(&self) -> std::collections::HashMap<String, String> {
let mut map = std::collections::HashMap::new();
collect_meta_tags(&self.dom.document, &mut map);
map
}
pub fn extract_scripts(&self) -> Vec<ScriptSource> {
let mut out = Vec::new();
collect_scripts(&self.dom.document, &mut out);
out
}
pub fn serialize_with_body_and_injection(&self, body_html: &str, extra: &str) -> String {
let mut bytes = Vec::new();
serialize(
&mut bytes,
&SerializableHandle::from(self.dom.document.clone()),
SerializeOpts {
traversal_scope: TraversalScope::ChildrenOnly(None),
..Default::default()
},
)
.expect("serialization failed");
let mut html = String::from_utf8(bytes).expect("html5ever always outputs utf-8");
if !body_html.is_empty() {
let replaced = first_element_id(body_html)
.and_then(|id| find_element_range_by_id(&html, &id))
.map(|range| html.replace_range(range, body_html))
.is_some();
if !replaced && let Some((start, end)) = body_content_range(&html) {
html.replace_range(start..end, body_html);
}
}
if !extra.is_empty() {
if let Some(pos) = html.rfind("</body>") {
html.insert_str(pos, extra);
} else {
html.push_str(extra);
}
}
html
}
}
fn body_content_range(html: &str) -> Option<(usize, usize)> {
let body_pos = html.find("<body")?;
let tag_close = html[body_pos..].find('>')? + body_pos + 1;
let body_end = html.rfind("</body>")?;
if body_end >= tag_close {
Some((tag_close, body_end))
} else {
None
}
}
fn first_element_id(html: &str) -> Option<String> {
let s = html.trim_start();
let tag_end = s.find('>')?;
let tag = &s[1..tag_end];
let marker = "id=\"";
let pos = tag.find(marker)? + marker.len();
let end = tag[pos..].find('"')? + pos;
let id = &tag[pos..end];
if id.is_empty() {
None
} else {
Some(id.to_owned())
}
}
fn find_element_range_by_id(html: &str, id: &str) -> Option<std::ops::Range<usize>> {
let needle = format!("id=\"{id}\"");
let attr_pos = html.find(&needle)?;
let tag_start = html[..attr_pos].rfind('<')?;
let after_lt = &html[tag_start + 1..];
let name_len = after_lt.find(|c: char| c.is_ascii_whitespace() || c == '>' || c == '/')?;
let tag_name = after_lt[..name_len].to_ascii_lowercase();
let open_end = html[tag_start..].find('>')? + tag_start + 1;
if VOID_ELEMENTS.contains(&tag_name.as_str()) || html[tag_start..open_end].ends_with("/>") {
return Some(tag_start..open_end);
}
let open_pat = format!("<{tag_name}"); let close_pat = format!("</{tag_name}>"); let mut depth: usize = 1;
let mut pos = open_end;
while depth > 0 {
let rest = &html[pos..];
let next_open = rest.find(&open_pat).map(|p| p + pos);
let next_close = rest.find(&close_pat).map(|p| p + pos);
match (next_open, next_close) {
(Some(o), Some(c)) if o < c => {
let after = html.as_bytes().get(o + open_pat.len()).copied();
if matches!(after, Some(b' ' | b'\t' | b'\n' | b'>' | b'/')) {
depth += 1;
}
pos = o + open_pat.len();
}
(_, Some(c)) => {
depth -= 1;
let close_end = c + close_pat.len();
if depth == 0 {
return Some(tag_start..close_end);
}
pos = close_end;
}
_ => return None,
}
}
None
}
fn collect_scripts(handle: &Handle, out: &mut Vec<ScriptSource>) {
if let NodeData::Element {
ref name,
ref attrs,
..
} = handle.data
&& &name.local == "script"
{
let attrs = attrs.borrow();
let type_val = attrs
.iter()
.find(|a| &a.name.local == "type")
.map(|a| a.value.to_string());
if let Some(ref t) = type_val {
let t = t.trim().to_ascii_lowercase();
let executable = match t.as_str() {
""
| "text/javascript"
| "application/javascript"
| "module"
| "text/rocketscript" => true,
t => t.ends_with("-text/javascript") || t.ends_with("-application/javascript"),
};
if !executable {
let src = attrs
.iter()
.find(|a| &a.name.local == "src")
.map(|a| a.value.trim().to_string());
if let Some(src) = src.filter(|s| !s.is_empty()) {
let t_escaped = t.replace('\'', "\\'");
let s_escaped = src.replace('\\', "\\\\").replace('\'', "\\'");
out.push(ScriptSource::Inline(format!(
"_r_nonstandard_scripts.push({{type:'{t_escaped}',src:'{s_escaped}',\
getAttribute:function(n){{return n==='src'?this.src:n==='type'?this.type:null;}},\
innerHTML:''}});"
)));
}
return;
}
}
let src = attrs
.iter()
.find(|a| &a.name.local == "src")
.map(|a| a.value.trim().to_string());
if let Some(src) = src {
if !src.is_empty() {
out.push(ScriptSource::External(src));
}
} else {
let mut content = String::new();
for child in handle.children.borrow().iter() {
if let NodeData::Text { ref contents } = child.data {
content.push_str(&contents.borrow());
}
}
if !content.trim().is_empty() {
out.push(ScriptSource::Inline(content));
}
}
return;
}
for child in handle.children.borrow().iter() {
collect_scripts(child, out);
}
}
fn collect_meta_tags(handle: &Handle, map: &mut std::collections::HashMap<String, String>) {
if let NodeData::Element {
ref name,
ref attrs,
..
} = handle.data
&& &name.local == "meta"
{
let attrs = attrs.borrow();
let name_val = attrs
.iter()
.find(|a| &a.name.local == "name")
.map(|a| a.value.to_string());
let content_val = attrs
.iter()
.find(|a| &a.name.local == "content")
.map(|a| a.value.to_string());
if let (Some(n), Some(c)) = (name_val, content_val)
&& !n.is_empty()
{
map.insert(n, c);
}
}
for child in handle.children.borrow().iter() {
collect_meta_tags(child, map);
}
}