use std::collections::HashMap;
use pyo3::{
prelude::*,
exceptions::{PyValueError, PyRuntimeError},
types::PyType,
};
use crate::document::{HtmlDocument, HtmlNode, XPathResult};
use crate::markdown::{html_to_markdown_with_converter, HtmlToMarkdown};
use crate::errors::{Result, PackageError};
impl From<PackageError> for PyErr {
fn from(error: PackageError) -> Self {
match error {
PackageError::HTMLParseError(e) => PyValueError::new_err(e.to_string()),
PackageError::SelectorParseError(e) => PyValueError::new_err(e.to_string()),
PackageError::UnknownError(e) => PyRuntimeError::new_err(e.to_string()),
}
}
}
#[derive(IntoPyObject)]
pub enum PyXPathResult {
Node(PyHtmlNode),
String(String),
}
impl From<XPathResult> for PyXPathResult {
fn from(result: XPathResult) -> Self {
match result {
XPathResult::Node(node) => PyXPathResult::Node(PyHtmlNode { inner: node }),
XPathResult::String(string) => PyXPathResult::String(string),
}
}
}
#[pyclass(name = "HtmlDocument")]
#[derive(Clone)]
pub struct PyHtmlDocument {
inner: HtmlDocument,
}
#[pymethods]
impl PyHtmlDocument {
fn __repr__(&self) -> String {
format!("{:?}", self.inner)
}
#[classmethod]
#[allow(unused_variables)]
pub fn from_str(cls: &Bound<'_, PyType>, py: Python<'_>, raw: String) -> Self {
py.allow_threads(|| Self {
inner: HtmlDocument::from_str(raw),
})
}
#[getter]
pub fn raw(&self) -> &str {
self.inner.raw()
}
#[getter]
pub fn children(&self) -> Vec<PyHtmlNode> {
self.inner
.children()
.into_iter()
.map(|node| PyHtmlNode { inner: node })
.collect()
}
#[getter]
pub fn root(&self) -> PyHtmlNode {
PyHtmlNode { inner: self.inner.root() }
}
pub fn find_all(&self, selector: &str) -> Result<Vec<PyHtmlNode>> {
Ok(
self.inner
.find_all(selector)?
.into_iter()
.map(|node| PyHtmlNode { inner: node })
.collect()
)
}
pub fn find_all_xpath(&self, xpath: &str) -> Result<Vec<PyXPathResult>> {
Ok(
self.inner
.find_all_xpath(xpath)?
.into_iter()
.map(|result| result.into())
.collect()
)
}
pub fn find(&self, selector: &str) -> Result<Option<PyHtmlNode>> {
Ok(
self.inner
.find(selector)?
.map(|node| PyHtmlNode { inner: node })
)
}
pub fn find_xpath(&self, xpath: &str) -> Result<Option<PyXPathResult>> {
Ok(
self.inner
.find_xpath(xpath)?
.map(|result| result.into())
)
}
pub fn find_nth(&self, selector: &str, n: usize) -> Result<Option<PyHtmlNode>> {
Ok(
self.inner
.find_nth(selector, n)?
.map(|node| PyHtmlNode { inner: node })
)
}
pub fn find_nth_xpath(&self, xpath: &str, n: usize) -> Result<Option<PyXPathResult>> {
Ok(
self.inner
.find_nth_xpath(xpath, n)?
.map(|result| result.into())
)
}
}
#[pyclass(name = "HtmlNode")]
#[derive(Clone)]
pub struct PyHtmlNode {
inner: HtmlNode,
}
#[pymethods]
impl PyHtmlNode {
fn __repr__(&self) -> String {
format!("{:?}", self.inner)
}
#[getter]
fn text(&self) -> String {
self.inner.text()
}
#[getter]
fn inner_text(&self) -> String {
self.inner.inner_text()
}
#[getter]
fn inner_html(&self) -> String {
self.inner.inner_html()
}
#[getter]
fn outer_html(&self) -> String {
self.inner.outer_html()
}
#[getter]
fn tag_name(&self) -> String {
self.inner
.tag_name()
.to_string()
}
#[getter]
fn attributes(&self) -> HashMap<&str, Option<&str>> {
self.inner.attributes()
}
#[getter]
fn children(&self) -> Vec<PyHtmlNode> {
self.inner
.children()
.into_iter()
.map(|node| PyHtmlNode { inner: node })
.collect()
}
pub fn find_all(&self, selector: &str) -> Result<Vec<PyHtmlNode>> {
Ok(
self.inner
.find_all(selector)?
.into_iter()
.map(|node| PyHtmlNode { inner: node })
.collect()
)
}
pub fn find_all_xpath(&self, xpath: &str) -> Result<Vec<PyXPathResult>> {
Ok(
self.inner
.find_all_xpath(xpath)?
.into_iter()
.map(|result| result.into())
.collect()
)
}
pub fn find(&self, selector: &str) -> Result<Option<PyHtmlNode>> {
Ok(
self.inner
.find(selector)?
.map(|node| PyHtmlNode { inner: node })
)
}
pub fn find_xpath(&self, xpath: &str) -> Result<Option<PyXPathResult>> {
Ok(
self.inner
.find_xpath(xpath)?
.map(|result| result.into())
)
}
pub fn find_nth(&self, selector: &str, n: usize) -> Result<Option<PyHtmlNode>> {
Ok(
self.inner
.find_nth(selector, n)?
.map(|node| PyHtmlNode { inner: node })
)
}
pub fn find_nth_xpath(&self, xpath: &str, n: usize) -> Result<Option<PyXPathResult>> {
Ok(
self.inner
.find_nth_xpath(xpath, n)?
.map(|result| result.into())
)
}
fn get_attribute(&self, name: &str) -> Option<&str> {
self.inner.get_attribute(name)
}
}
#[pyfunction]
#[pyo3(signature = (html, skip_tags = vec!["script".to_string(), "style".to_string()]))]
fn html_to_markdown(html: String, skip_tags: Vec<String>) -> PyResult<String> {
Ok(
html_to_markdown_with_converter(
html,
HtmlToMarkdown::builder()
.skip_tags(skip_tags.iter().map(|v| v.as_str()).collect())
.build()
)
.map_err(|e| PyValueError::new_err(e.to_string()))?
)
}
#[pymodule]
#[pyo3(name = "_pickaxe")]
fn pickaxe(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyHtmlDocument>()?;
m.add_class::<PyHtmlNode>()?;
m.add_function(wrap_pyfunction!(html_to_markdown, py)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}