sciforge 0.0.3

A comprehensive scientific computing library in pure Rust with zero dependencies
Documentation
use super::entity::validate_entity;
use super::error::{HtmlError, HtmlErrorKind};
use super::lexer::Cursor;
use super::value::HtmlValue;

const VOID_ELEMENTS: &[&str] = &[
    "area", "base", "br", "col", "embed", "hr", "img", "input", "link", "meta", "param", "source",
    "track", "wbr",
];

const RAW_TEXT_ELEMENTS: &[&str] = &["script", "style"];

pub const DEFAULT_MAX_HTML_DEPTH: usize = 128;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct HtmlLimits {
    pub max_depth: usize,
    pub max_node_count: usize,
    pub max_attribute_count: usize,
    pub max_attribute_value_len: usize,
}

pub const DEFAULT_HTML_LIMITS: HtmlLimits = HtmlLimits {
    max_depth: DEFAULT_MAX_HTML_DEPTH,
    max_node_count: 256 * 1024,
    max_attribute_count: 256,
    max_attribute_value_len: 64 * 1024,
};

pub struct HtmlParser<'a> {
    cursor: Cursor<'a>,
    limits: HtmlLimits,
    nodes_seen: usize,
}

impl<'a> HtmlParser<'a> {
    pub const fn new(bytes: &'a [u8]) -> Self {
        Self {
            cursor: Cursor::new(bytes),
            limits: DEFAULT_HTML_LIMITS,
            nodes_seen: 0,
        }
    }

    pub const fn with_limits(mut self, limits: HtmlLimits) -> Self {
        self.limits = limits;
        self
    }

    pub const fn with_max_depth(mut self, max_depth: usize) -> Self {
        self.limits.max_depth = max_depth;
        self
    }

    pub fn parse(mut self) -> Result<HtmlValue<'a>, HtmlError> {
        self.parse_nodes(0)?;
        Ok(HtmlValue::Document)
    }

    pub fn validate(mut self) -> Result<(), HtmlError> {
        self.parse_nodes(0)?;
        Ok(())
    }

    fn tick_node(&mut self) -> Result<(), HtmlError> {
        self.nodes_seen = self.nodes_seen.saturating_add(1);
        if self.nodes_seen > self.limits.max_node_count {
            return Err(HtmlError::new(
                HtmlErrorKind::MaxNodeCountExceeded,
                self.cursor.position(),
            ));
        }
        Ok(())
    }

    fn parse_nodes(&mut self, depth: usize) -> Result<(), HtmlError> {
        if depth > self.limits.max_depth {
            return Err(HtmlError::new(
                HtmlErrorKind::MaxDepthExceeded,
                self.cursor.position(),
            ));
        }

        while !self.cursor.is_eof() {
            if self.cursor.peek() == Some(b'<') {
                if self.cursor.starts_with(b"</") {
                    return Ok(());
                }
                if self.cursor.starts_with(b"<!--") {
                    self.parse_comment()?;
                } else if self.cursor.starts_with(b"<!") {
                    self.parse_doctype()?;
                } else {
                    self.parse_element(depth)?;
                }
            } else {
                self.parse_text()?;
            }
        }

        Ok(())
    }

    fn parse_comment(&mut self) -> Result<(), HtmlError> {
        self.tick_node()?;
        let start = self.cursor.position();
        self.cursor.advance(4);

        loop {
            if self.cursor.is_eof() {
                return Err(HtmlError::new(HtmlErrorKind::UnterminatedComment, start));
            }
            if self.cursor.starts_with(b"-->") {
                self.cursor.advance(3);
                return Ok(());
            }
            self.cursor.advance(1);
        }
    }

    fn parse_doctype(&mut self) -> Result<(), HtmlError> {
        self.tick_node()?;
        let start = self.cursor.position();
        self.cursor.advance(2);

        loop {
            if self.cursor.is_eof() {
                return Err(HtmlError::new(HtmlErrorKind::UnterminatedDoctype, start));
            }
            if self.cursor.peek() == Some(b'>') {
                self.cursor.advance(1);
                return Ok(());
            }
            self.cursor.advance(1);
        }
    }

    fn parse_text(&mut self) -> Result<(), HtmlError> {
        self.tick_node()?;
        while let Some(b) = self.cursor.peek() {
            match b {
                b'<' => break,
                b'&' => validate_entity(&mut self.cursor)?,
                _ => self.cursor.advance(1),
            }
        }
        Ok(())
    }

    fn parse_element(&mut self, depth: usize) -> Result<(), HtmlError> {
        self.tick_node()?;
        let tag_start = self.cursor.position();
        self.cursor.advance(1);

        let tag_name = self.cursor.read_tag_name()?;

        self.parse_attributes()?;
        self.cursor.skip_ws();

        let self_closing = self.cursor.peek() == Some(b'/');
        if self_closing {
            self.cursor.advance(1);
        }

        if self.cursor.peek() != Some(b'>') {
            return Err(HtmlError::new(HtmlErrorKind::UnterminatedTag, tag_start));
        }
        self.cursor.advance(1);

        if self_closing || is_void_element(tag_name) {
            return Ok(());
        }

        if is_raw_text_element(tag_name) {
            return self.skip_raw_text(tag_name, tag_start);
        }

        self.parse_nodes(depth + 1)?;
        self.parse_closing_tag(tag_name, tag_start)
    }

    fn parse_attributes(&mut self) -> Result<(), HtmlError> {
        let mut count = 0usize;

        loop {
            self.cursor.skip_ws();
            match self.cursor.peek() {
                Some(b'>') | Some(b'/') | None => return Ok(()),
                _ => {}
            }

            self.parse_attribute()?;
            count = count.saturating_add(1);
            if count > self.limits.max_attribute_count {
                return Err(HtmlError::new(
                    HtmlErrorKind::MaxAttributeCountExceeded,
                    self.cursor.position(),
                ));
            }
        }
    }

    fn parse_attribute(&mut self) -> Result<(), HtmlError> {
        self.cursor.read_while(|b| {
            b.is_ascii_alphanumeric() || b == b'-' || b == b'_' || b == b':' || b == b'.'
        });

        self.cursor.skip_ws();

        if self.cursor.peek() != Some(b'=') {
            return Ok(());
        }
        self.cursor.advance(1);
        self.cursor.skip_ws();

        match self.cursor.peek() {
            Some(b'"') => self.parse_quoted_value(b'"'),
            Some(b'\'') => self.parse_quoted_value(b'\''),
            _ => {
                self.cursor.read_while(|b| {
                    !matches!(
                        b,
                        b' ' | b'\t' | b'\n' | b'\r' | b'>' | b'/' | b'"' | b'\'' | b'='
                    )
                });
                Ok(())
            }
        }
    }

    fn parse_quoted_value(&mut self, quote: u8) -> Result<(), HtmlError> {
        let start = self.cursor.position();
        self.cursor.advance(1);
        let content_start = self.cursor.position();

        loop {
            match self.cursor.peek() {
                None => return Err(HtmlError::new(HtmlErrorKind::UnterminatedAttribute, start)),
                Some(b) if b == quote => {
                    let len = self.cursor.position() - content_start;
                    if len > self.limits.max_attribute_value_len {
                        return Err(HtmlError::new(
                            HtmlErrorKind::MaxAttributeValueLengthExceeded,
                            content_start,
                        ));
                    }
                    self.cursor.advance(1);
                    return Ok(());
                }
                Some(b'&') => validate_entity(&mut self.cursor)?,
                _ => self.cursor.advance(1),
            }
        }
    }

    fn parse_closing_tag(&mut self, expected: &str, open_offset: usize) -> Result<(), HtmlError> {
        if !self.cursor.starts_with(b"</") {
            return Err(HtmlError::new(
                HtmlErrorKind::MismatchedClosingTag,
                self.cursor.position(),
            ));
        }
        self.cursor.advance(2);

        let close_name = self.cursor.read_tag_name()?;

        if !eq_ignore_ascii_case(close_name, expected) {
            return Err(HtmlError::new(
                HtmlErrorKind::MismatchedClosingTag,
                open_offset,
            ));
        }

        self.cursor.skip_ws();

        if self.cursor.peek() != Some(b'>') {
            return Err(HtmlError::new(
                HtmlErrorKind::UnterminatedTag,
                self.cursor.position(),
            ));
        }
        self.cursor.advance(1);
        Ok(())
    }

    fn skip_raw_text(&mut self, tag_name: &str, open_offset: usize) -> Result<(), HtmlError> {
        loop {
            if self.cursor.is_eof() {
                return Err(HtmlError::new(HtmlErrorKind::UnterminatedTag, open_offset));
            }
            if self.cursor.starts_with(b"</") {
                let saved = self.cursor.position();
                self.cursor.advance(2);
                if let Ok(name) = self.cursor.read_tag_name()
                    && eq_ignore_ascii_case(name, tag_name)
                {
                    self.cursor.skip_ws();
                    if self.cursor.peek() == Some(b'>') {
                        self.cursor.advance(1);
                        return Ok(());
                    }
                }
                self.cursor.advance(0);
                if self.cursor.position() == saved + 2 {
                    self.cursor.advance(1);
                }
                continue;
            }
            self.cursor.advance(1);
        }
    }
}

fn is_void_element(name: &str) -> bool {
    VOID_ELEMENTS.iter().any(|&v| eq_ignore_ascii_case(v, name))
}

fn is_raw_text_element(name: &str) -> bool {
    RAW_TEXT_ELEMENTS
        .iter()
        .any(|&v| eq_ignore_ascii_case(v, name))
}

fn eq_ignore_ascii_case(a: &str, b: &str) -> bool {
    if a.len() != b.len() {
        return false;
    }
    a.bytes()
        .zip(b.bytes())
        .all(|(x, y)| x.eq_ignore_ascii_case(&y))
}

pub fn parse_html(bytes: &[u8]) -> Result<HtmlValue<'_>, HtmlError> {
    HtmlParser::new(bytes).parse()
}

pub fn parse_html_with_max_depth(
    bytes: &[u8],
    max_depth: usize,
) -> Result<HtmlValue<'_>, HtmlError> {
    HtmlParser::new(bytes).with_max_depth(max_depth).parse()
}

pub fn parse_html_with_limits(
    bytes: &[u8],
    limits: HtmlLimits,
) -> Result<HtmlValue<'_>, HtmlError> {
    HtmlParser::new(bytes).with_limits(limits).parse()
}

pub fn validate_html(bytes: &[u8]) -> Result<(), HtmlError> {
    HtmlParser::new(bytes).validate()
}