use std::{
fmt::{Debug, Formatter},
sync::Arc,
collections::HashMap,
};
use ego_tree::{NodeId, NodeRef};
use scraper::{ElementRef, Html, Node, Selector};
use crate::xpath::{parse_xpath, NodeAccessor};
use crate::errors::{Result, PackageError};
#[derive(Debug)]
pub enum XPathResult {
Node(HtmlNode),
String(String),
}
impl XPathResult {
pub fn as_string(&self) -> Option<&str> {
match self {
XPathResult::String(s) => Some(s),
_ => None,
}
}
pub fn as_node(&self) -> Option<&HtmlNode> {
match self {
XPathResult::Node(node) => Some(node),
_ => None,
}
}
pub fn into_string(self) -> Option<String> {
match self {
XPathResult::String(s) => Some(s),
_ => None,
}
}
pub fn into_node(self) -> Option<HtmlNode> {
match self {
XPathResult::Node(node) => Some(node),
_ => None,
}
}
}
#[derive(Clone)]
pub struct HtmlDocument {
raw: String,
dom: Arc<Html>,
is_fragment: bool,
}
unsafe impl Send for HtmlDocument {}
unsafe impl Sync for HtmlDocument {}
impl HtmlDocument {
pub fn new(raw: String, dom: Html, is_fragment: bool) -> Self {
Self {
raw,
dom: Arc::new(dom),
is_fragment,
}
}
pub fn from_str(html: String) -> Self {
let is_fragment = !html.contains("html");
let dom = if is_fragment {
Html::parse_fragment(&html)
} else {
Html::parse_document(&html)
};
Self::new(html, dom, is_fragment)
}
pub fn raw(&self) -> &str {
&self.raw
}
pub fn root(&self) -> HtmlNode {
HtmlNode::new(
self.dom.clone(),
match self.is_fragment {
false => self.dom
.root_element()
.id(),
true => self.dom
.root_element()
.children()
.next()
.expect("no root element")
.id(),
}
)
}
pub fn find_all(&self, selector: &str) -> Result<Vec<HtmlNode>> {
self.root()
.find_all(selector)
}
pub fn find_all_xpath(&self, xpath: &str) -> Result<Vec<XPathResult>> {
self.root()
.find_all_xpath(xpath)
}
pub fn find(&self, selector: &str) -> Result<Option<HtmlNode>> {
self.root()
.find(selector)
}
pub fn find_xpath(&self, xpath: &str) -> Result<Option<XPathResult>> {
self.root()
.find_xpath(xpath)
}
pub fn find_nth(&self, selector: &str, n: usize) -> Result<Option<HtmlNode>> {
self.root()
.find_nth(selector, n)
}
pub fn find_nth_xpath(&self, xpath: &str, n: usize) -> Result<Option<XPathResult>> {
self.root()
.find_nth_xpath(xpath, n)
}
pub fn children(&self) -> Vec<HtmlNode> {
self.root()
.children()
}
}
impl Debug for HtmlDocument {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "<HtmlDocument is_fragment={:?}>", self.is_fragment)
}
}
#[derive(Clone)]
pub struct HtmlNode {
dom: Arc<Html>,
node: NodeId,
}
unsafe impl Send for HtmlNode {}
unsafe impl Sync for HtmlNode {}
impl HtmlNode {
pub fn new(dom: Arc<Html>, node: NodeId) -> Self {
Self { dom, node }
}
fn node(&self) -> Option<NodeRef<'_, Node>> {
self.dom
.tree
.get(self.node)
}
fn element(&self) -> Option<ElementRef<'_>> {
ElementRef::wrap(self.node()?)
}
pub fn tag_name(&self) -> &str {
self.element()
.expect("element not found")
.value()
.name()
}
pub fn find_all(&self, selector: &str) -> Result<Vec<HtmlNode>> {
Ok(
self.element()
.expect("element not found")
.select(
&Selector::parse(selector)
.map_err(|e| PackageError::SelectorParseError(e.to_string()))?
)
.map(|element| HtmlNode::new(self.dom.clone(), element.id()))
.collect()
)
}
pub fn find_all_xpath(&self, xpath: &str) -> Result<Vec<XPathResult>> {
fn resolve_accessor(node: &HtmlNode, accessor: &NodeAccessor) -> Option<XPathResult> {
match accessor {
NodeAccessor::Text { recursive } => {
Some(XPathResult::String(if *recursive {
node.inner_text()
} else {
node.text()
}))
},
NodeAccessor::Attribute(name) => {
Some(XPathResult::String(node
.get_attribute(name.as_str())?
.to_string()
))
},
_ => Some(XPathResult::Node(node.clone())),
}
}
match parse_xpath(xpath) {
Some((selector, accessor)) => {
match selector.as_str() {
"" => Ok(
resolve_accessor(self, &accessor)
.into_iter()
.collect()
),
_ => Ok(
self.find_all(&selector)?
.into_iter()
.filter_map(|node| resolve_accessor(&node, &accessor))
.collect()
),
}
},
None => Ok(Vec::new()),
}
}
pub fn find(&self, selector: &str) -> Result<Option<HtmlNode>> {
Ok(
self.find_all(selector)?
.into_iter()
.next()
)
}
pub fn find_xpath(&self, xpath: &str) -> Result<Option<XPathResult>> {
Ok(
self.find_all_xpath(xpath)?
.into_iter()
.next()
)
}
pub fn find_nth(&self, selector: &str, n: usize) -> Result<Option<HtmlNode>> {
Ok(
self.find_all(selector)?
.into_iter()
.nth(n)
)
}
pub fn find_nth_xpath(&self, xpath: &str, n: usize) -> Result<Option<XPathResult>> {
Ok(
self.find_all_xpath(xpath)?
.into_iter()
.nth(n)
)
}
pub fn attributes(&self) -> HashMap<&str, Option<&str>> {
self.element()
.expect("element not found")
.value()
.attrs()
.map(|(k, v)| (k, Some(v)))
.collect()
}
pub fn get_attribute(&self, name: &str) -> Option<&str> {
self.element()
.expect("element not found")
.value()
.attr(name)
}
pub fn text(&self) -> String {
self.element()
.expect("element not found")
.text()
.next()
.unwrap_or("")
.to_string()
}
pub fn inner_text(&self) -> String {
self.element()
.expect("element not found")
.text()
.collect::<Vec<_>>()
.join("")
}
pub fn inner_html(&self) -> String {
self.element()
.expect("element not found")
.inner_html()
}
pub fn outer_html(&self) -> String {
self.element()
.expect("element not found")
.html()
}
pub fn children(&self) -> Vec<HtmlNode> {
self.element()
.expect("element not found")
.children()
.filter_map(|child| child
.value()
.is_element()
.then(|| HtmlNode::new(self.dom.clone(), child.id()))
)
.collect()
}
}
impl Debug for HtmlNode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "<{}", self.tag_name())?;
for (key, value) in self.attributes() {
if let Some(val) = value {
write!(f, " {}=\"{}\"", key, val)?;
}
}
write!(f, ">")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_html_document_from_str() {
let html = "<html><head></head><body><h1>Hello, world!</h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
assert_eq!(doc.raw, html);
assert_eq!(doc.root().tag_name(), "html");
}
#[test]
fn test_html_document_fragment_from_str() {
let html = "<h1>Hello, world!</h1>";
let doc = HtmlDocument::from_str(html.to_string());
assert_eq!(doc.root().tag_name(), "h1");
assert_eq!(doc.root().children().len(), 0);
assert_eq!(doc.root().inner_text(), "Hello, world!".to_string());
}
#[test]
fn test_html_document_query_selector() {
let html = r#"<html><head></head><body><h1 id="test">Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find("h1#test").unwrap().unwrap();
assert_eq!(h1.inner_text(), "Hello, world!".to_string());
}
#[test]
fn test_html_node_text() {
let html = "<html><head></head><body><h1>Hello, <span>world!</span></h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find("h1").unwrap().unwrap();
assert_eq!(h1.text(), "Hello, ".to_string());
}
#[test]
fn test_html_node_inner_text() {
let html = "<html><head></head><body><h1>Hello, <span>world!</span></h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let root = doc.root();
assert_eq!(root.inner_text(), "Hello, world!".to_string());
}
#[test]
fn test_html_node_inner_html() {
let html = "<html><head></head><body><h1>Hello,<span>world!</span></h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find("h1").unwrap().unwrap();
assert_eq!(h1.inner_html(), "Hello,<span>world!</span>".to_string());
}
#[test]
fn test_html_node_outer_html() {
let html = "<html><head></head><body><h1>Hello, world!</h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find("h1").unwrap().unwrap();
assert_eq!(h1.outer_html(), "<h1>Hello, world!</h1>".to_string());
}
#[test]
fn test_html_node_get_attribute() {
let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find("h1").unwrap().unwrap();
assert_eq!(h1.get_attribute("id"), Some("title"));
}
#[test]
fn test_html_node_children() {
let html = "<html><head></head><body><h1>Hello, world!</h1><p>Paragraph</p></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let body = doc.find("body").unwrap().unwrap();
let children = body.children();
assert_eq!(children.len(), 2);
assert_eq!(children[0].inner_text(), "Hello, world!".to_string());
assert_eq!(children[1].inner_text(), "Paragraph".to_string());
}
#[test]
fn test_html_document_find_xpath() {
let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1[@id='title']")
.unwrap()
.unwrap()
.into_node()
.unwrap();
assert_eq!(h1.inner_text(), "Hello, world!".to_string());
}
#[test]
fn test_html_document_find_xpath_attribute() {
let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1/@id")
.unwrap()
.unwrap()
.into_string()
.unwrap();
assert_eq!(h1, "title".to_string());
}
#[test]
fn test_html_document_find_xpath_text() {
let html = r#"<html><head></head><body><h1 id="title">Hello, <span>world!</span></h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1/text()")
.unwrap()
.unwrap()
.into_string()
.unwrap();
assert_eq!(h1, "Hello, ".to_string());
}
#[test]
fn test_html_document_find_xpath_inner_text() {
let html = r#"<html><head></head><body><h1 id="title">Hello, <span>world!</span></h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1//text()")
.unwrap()
.unwrap()
.into_string()
.unwrap();
assert_eq!(h1, "Hello, world!".to_string());
}
#[test]
fn test_html_document_find_xpath_bad() {
let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1[@id='title']/@src").unwrap();
assert!(h1.is_none());
}
#[test]
fn test_html_document_find_nth() {
let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_nth("h1", 1).unwrap().unwrap();
assert_eq!(h1.inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_document_find_nth_selector() {
let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find("h1:nth-of-type(2)").unwrap().unwrap();
assert_eq!(h1.inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_document_find_xpath_nth() {
let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1[2]")
.unwrap()
.unwrap()
.into_node()
.unwrap();
assert_eq!(h1.inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_node_find_all() {
let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let body = doc.find("body").unwrap().unwrap();
let h1s = body.find_all("h1").unwrap();
assert_eq!(h1s.len(), 2);
assert_eq!(h1s[0].inner_text(), "Hello, world!".to_string());
assert_eq!(h1s[1].inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_node_find_all_xpath() {
let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let body = doc.find("body").unwrap().unwrap();
let h1s = body.find_all_xpath("//h1").unwrap();
assert_eq!(h1s.len(), 2);
assert_eq!(h1s[0].as_node().unwrap().inner_text(), "Hello, world!".to_string());
assert_eq!(h1s[1].as_node().unwrap().inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_node_find_nth() {
let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
let doc = HtmlDocument::from_str(html.to_string());
let body = doc.find("body").unwrap().unwrap();
let h1 = body.find_nth("h1", 1).unwrap().unwrap();
assert_eq!(h1.inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_node_find_nth_xpath() {
let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let body = doc.find("body").unwrap().unwrap();
let h1 = body.find_nth_xpath("//h1", 1).unwrap().unwrap().into_node().unwrap();
assert_eq!(h1.inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_node_relative_find_xpath() {
let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let body = doc.find("body").unwrap().unwrap();
let h1 = body.find_xpath(".//h1[2]")
.unwrap()
.unwrap()
.into_node()
.unwrap();
assert_eq!(h1.inner_text(), "Not hello world".to_string());
}
#[test]
fn test_html_node_relative_find_xpath_text_accessor() {
let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
let doc = HtmlDocument::from_str(html.to_string());
let h1 = doc.find_xpath("//h1[1]").unwrap().unwrap().into_node().unwrap();
let text = h1.find_xpath("./text()")
.unwrap()
.unwrap()
.into_string()
.unwrap();
assert_eq!(text, "Hello, world!".to_string());
}
}