use std::collections::HashMap;
use std::rc::Rc;
use ego_tree::NodeId;
use scraper::{ElementRef, Html, Selector};
use crate::locator::{self, StaticQuery};
use crate::{Error, Result};
#[derive(Clone)]
pub struct StaticElement {
doc: Rc<Html>,
id: NodeId,
}
impl StaticElement {
pub fn parse(html: &str) -> Result<Self> {
let doc = Rc::new(Html::parse_document(html));
Self::root(doc)
}
pub(crate) fn root(doc: Rc<Html>) -> Result<Self> {
let id = doc.root_element().id();
Ok(Self { doc, id })
}
fn element(&self) -> Result<ElementRef<'_>> {
let node = self
.doc
.tree
.get(self.id)
.ok_or_else(|| Error::StaleElement("静态节点不存在".into()))?;
ElementRef::wrap(node).ok_or_else(|| Error::StaleElement("静态节点不是元素".into()))
}
pub fn tag(&self) -> Result<String> {
Ok(self.element()?.value().name().to_ascii_lowercase())
}
pub fn text(&self) -> Result<String> {
Ok(self
.element()?
.text()
.collect::<String>()
.trim()
.to_string())
}
pub fn attr(&self, name: &str) -> Result<Option<String>> {
Ok(self.element()?.value().attr(name).map(|s| s.to_string()))
}
pub fn attrs(&self) -> Result<Vec<(String, String)>> {
Ok(self
.element()?
.value()
.attrs()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect())
}
pub fn html(&self) -> Result<String> {
Ok(self.element()?.html())
}
pub fn inner_html(&self) -> Result<String> {
Ok(self.element()?.inner_html())
}
pub fn ele(&self, selector: &str) -> Result<StaticElement> {
let q = locator::parse_static(selector);
let ids = collect_ids(self.element()?, &q)?;
match ids.into_iter().next() {
Some(id) => Ok(StaticElement {
doc: self.doc.clone(),
id,
}),
None => Err(Error::ElementNotFound(selector.to_string())),
}
}
pub fn eles(&self, selector: &str) -> Result<Vec<StaticElement>> {
let q = locator::parse_static(selector);
let ids = collect_ids(self.element()?, &q)?;
Ok(ids
.into_iter()
.map(|id| StaticElement {
doc: self.doc.clone(),
id,
})
.collect())
}
pub fn table(&self) -> Result<Vec<Vec<String>>> {
let el = self.element()?;
let table = find_table(el)?;
let tr = parse_sel("tr")?;
let cell = parse_sel("th, td")?;
let mut rows = Vec::new();
for row in table.select(&tr) {
let cells: Vec<String> = row
.select(&cell)
.map(|c| normalize_space(&c.text().collect::<String>()))
.collect();
if !cells.is_empty() {
rows.push(cells);
}
}
Ok(rows)
}
pub fn table_records(&self) -> Result<Vec<HashMap<String, String>>> {
let mut rows = self.table()?.into_iter();
let Some(headers) = rows.next() else {
return Ok(Vec::new());
};
let mut out = Vec::new();
for row in rows {
let mut rec = HashMap::new();
for (i, val) in row.into_iter().enumerate() {
let key = headers
.get(i)
.filter(|h| !h.is_empty())
.cloned()
.unwrap_or_else(|| format!("col{i}"));
rec.insert(key, val);
}
out.push(rec);
}
Ok(out)
}
}
fn find_table<'a>(el: ElementRef<'a>) -> Result<ElementRef<'a>> {
if el.value().name().eq_ignore_ascii_case("table") {
return Ok(el);
}
let sel = parse_sel("table")?;
el.select(&sel)
.next()
.ok_or_else(|| Error::ElementNotFound("table".into()))
}
fn parse_sel(s: &str) -> Result<Selector> {
Selector::parse(s).map_err(|e| Error::Other(format!("非法选择器 {s:?}: {e:?}")))
}
fn collect_ids(root: ElementRef<'_>, q: &StaticQuery) -> Result<Vec<NodeId>> {
match q {
StaticQuery::Css(sel) => {
let selector = Selector::parse(sel)
.map_err(|e| Error::Other(format!("非法 CSS 选择器 {sel:?}: {e:?}")))?;
Ok(root.select(&selector).map(|e| e.id()).collect())
}
StaticQuery::AttrEq { name, value } => {
let uni = universal()?;
Ok(root
.select(&uni)
.filter(|e| e.value().attr(name) == Some(value.as_str()))
.map(|e| e.id())
.collect())
}
StaticQuery::AttrPresent(name) => {
let uni = universal()?;
Ok(root
.select(&uni)
.filter(|e| e.value().attr(name).is_some())
.map(|e| e.id())
.collect())
}
StaticQuery::TextContains(t) => {
let needle = normalize_space(t);
let uni = universal()?;
Ok(root
.select(&uni)
.filter(|e| {
let direct: String = e
.children()
.filter_map(|c| c.value().as_text().map(|t| t.text.as_ref()))
.collect();
normalize_space(&direct).contains(&needle)
})
.map(|e| e.id())
.collect())
}
StaticQuery::Xpath(xp) => crate::browser::xpath::eval(*root, xp),
}
}
fn universal() -> Result<Selector> {
Selector::parse("*").map_err(|_| Error::Other("内部错误:通配选择器解析失败".into()))
}
fn normalize_space(s: &str) -> String {
s.split_whitespace().collect::<Vec<_>>().join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
const HTML: &str = r#"<html><body>
<div id="main" class="box wrap">
<a href="/a" class="link">首页</a>
<a href="/b" class="link">关于我们</a>
<span data-x="1">hello world</span>
</div>
</body></html>"#;
#[test]
fn css_and_tag() {
let root = StaticElement::parse(HTML).unwrap();
assert_eq!(
root.ele("#main").unwrap().attr("class").unwrap().as_deref(),
Some("box wrap")
);
assert_eq!(root.eles("tag:a").unwrap().len(), 2);
assert_eq!(
root.ele("css:a.link")
.unwrap()
.attr("href")
.unwrap()
.as_deref(),
Some("/a")
);
}
#[test]
fn attr_eq_and_present() {
let root = StaticElement::parse(HTML).unwrap();
assert_eq!(root.ele("@id:main").unwrap().tag().unwrap(), "div");
assert_eq!(
root.ele("@data-x:1").unwrap().text().unwrap(),
"hello world"
);
assert_eq!(root.eles("@href").unwrap().len(), 2);
}
#[test]
fn text_contains() {
let root = StaticElement::parse(HTML).unwrap();
let e = root.ele("text:关于").unwrap();
assert_eq!(e.tag().unwrap(), "a");
assert_eq!(e.attr("href").unwrap().as_deref(), Some("/b"));
}
#[test]
fn nested_and_xpath() {
let root = StaticElement::parse(HTML).unwrap();
let main = root.ele("#main").unwrap();
assert_eq!(main.eles("tag:a").unwrap().len(), 2);
assert_eq!(root.eles("xpath://a").unwrap().len(), 2);
assert_eq!(
root.ele(r#"xpath://*[@id="main"]"#).unwrap().tag().unwrap(),
"div"
);
assert!(root.ele("xpath://a/following-sibling::a").is_err());
}
const TABLE: &str = r#"<table>
<thead><tr><th>名称</th><th>价格</th></tr></thead>
<tbody>
<tr><td>苹果</td><td>3</td></tr>
<tr><td>香蕉</td><td>2</td></tr>
</tbody></table>"#;
#[test]
fn table_rows_and_records() {
let root = StaticElement::parse(TABLE).unwrap();
let rows = root.table().unwrap();
assert_eq!(rows.len(), 3); assert_eq!(rows[0], vec!["名称", "价格"]);
assert_eq!(rows[1], vec!["苹果", "3"]);
let recs = root.table_records().unwrap();
assert_eq!(recs.len(), 2);
assert_eq!(recs[0].get("名称").map(String::as_str), Some("苹果"));
assert_eq!(recs[1].get("价格").map(String::as_str), Some("2"));
}
}