use std::collections::HashSet;
use crate::error::ParseError;
use crate::escape::{decode_entities, is_name_char, is_name_start};
use crate::types::{SourcePosition, SourceSpan};
const ATTR_DUP_SET_THRESHOLD: usize = 16;
#[inline]
fn offset_u32(n: usize) -> u32 {
u32::try_from(n).expect("offset within MAX_INPUT_BYTES — checked at parse entry")
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Token {
Open {
name: String,
attrs: Vec<(String, String)>,
span: SourceSpan,
body_start: usize,
},
Close {
name: String,
span: SourceSpan,
body_end: usize,
},
SelfClose {
name: String,
attrs: Vec<(String, String)>,
span: SourceSpan,
},
}
#[derive(Debug, Clone, Default)]
pub(crate) struct TokenStream {
pub tokens: Vec<Token>,
pub trivia: Vec<core::ops::Range<usize>>,
}
pub(crate) fn tokenize(input: &str) -> Result<TokenStream, ParseError> {
debug_assert!(
u32::try_from(input.len()).is_ok(),
"tokenize() requires input bounded by MAX_INPUT_BYTES"
);
let bytes = input.as_bytes();
let mut tokens = Vec::new();
let mut trivia: Vec<core::ops::Range<usize>> = Vec::new();
let mut i = 0;
let mut line: u32 = 1;
while i < bytes.len() {
if bytes[i] == b'<' {
if let Some((new_i, new_line)) = try_skip_comment(bytes, i, line)? {
trivia.push(i..new_i);
i = new_i;
line = new_line;
continue;
}
if let Some((new_i, new_line)) = try_skip_cdata(bytes, i, line)? {
let open_end = i + 9; let close_start = new_i - 3; trivia.push(i..open_end);
trivia.push(close_start..new_i);
i = new_i;
line = new_line;
continue;
}
if looks_like_tag_start(bytes, i + 1) {
let (token, new_i, new_line) = parse_tag(input, i, line)?;
tokens.push(token);
i = new_i;
line = new_line;
continue;
}
}
if bytes[i] == b'\n' {
line = line.saturating_add(1);
}
i += 1;
}
Ok(TokenStream { tokens, trivia })
}
fn looks_like_tag_start(bytes: &[u8], i: usize) -> bool {
if i >= bytes.len() {
return false;
}
if bytes[i] == b'/' {
i + 1 < bytes.len() && is_name_start(bytes[i + 1])
} else {
is_name_start(bytes[i])
}
}
fn parse_tag(
input: &str,
start: usize,
start_line: u32,
) -> Result<(Token, usize, u32), ParseError> {
let bytes = input.as_bytes();
let span_start = SourcePosition {
line: start_line,
offset: offset_u32(start),
};
if bytes.get(start + 1) == Some(&b'/') {
parse_end_tag(input, start, start_line, span_start)
} else {
parse_opening_tag(input, start, start_line, span_start)
}
}
fn parse_end_tag(
input: &str,
start: usize,
start_line: u32,
span_start: SourcePosition,
) -> Result<(Token, usize, u32), ParseError> {
let bytes = input.as_bytes();
let mut i = start + 2; let mut line = start_line;
let name = read_tag_name(input, &mut i);
skip_ws(bytes, &mut i, &mut line);
if i >= bytes.len() || bytes[i] != b'>' {
return Err(ParseError::MalformedTag {
reason: format!("expected '>' to close </{name}>"),
line: start_line,
});
}
i += 1;
let span_end = SourcePosition {
line,
offset: offset_u32(i),
};
Ok((
Token::Close {
name,
span: SourceSpan {
start: span_start,
end: span_end,
},
body_end: start,
},
i,
line,
))
}
enum TagForm {
Pair,
SelfClosing,
}
struct AttributeList {
attrs: Vec<(String, String)>,
form: TagForm,
end: usize,
line: u32,
}
fn parse_opening_tag(
input: &str,
start: usize,
start_line: u32,
span_start: SourcePosition,
) -> Result<(Token, usize, u32), ParseError> {
let mut i = start + 1; let mut line = start_line;
let name = read_tag_name(input, &mut i);
let list = parse_attribute_list(input, &name, start_line, i, line)?;
i = list.end;
line = list.line;
let span = SourceSpan {
start: span_start,
end: SourcePosition {
line,
offset: offset_u32(i),
},
};
let token = match list.form {
TagForm::Pair => Token::Open {
name,
attrs: list.attrs,
span,
body_start: i,
},
TagForm::SelfClosing => Token::SelfClose {
name,
attrs: list.attrs,
span,
},
};
Ok((token, i, line))
}
fn read_tag_name(input: &str, i: &mut usize) -> String {
let bytes = input.as_bytes();
let name_start = *i;
while *i < bytes.len() && is_name_char(bytes[*i]) {
*i += 1;
}
input[name_start..*i].to_string()
}
fn parse_attribute_list(
input: &str,
tag_name: &str,
tag_start_line: u32,
start: usize,
start_line: u32,
) -> Result<AttributeList, ParseError> {
let bytes = input.as_bytes();
let mut i = start;
let mut line = start_line;
let mut attrs: Vec<(String, String)> = Vec::new();
let mut seen_set: Option<HashSet<String>> = None;
loop {
skip_ws(bytes, &mut i, &mut line);
if i >= bytes.len() {
return Err(ParseError::MalformedTag {
reason: format!("<{tag_name}> not terminated"),
line: tag_start_line,
});
}
match bytes[i] {
b'>' => {
return Ok(AttributeList {
attrs,
form: TagForm::Pair,
end: i + 1,
line,
});
}
b'/' => {
i += 1;
if i >= bytes.len() || bytes[i] != b'>' {
return Err(ParseError::MalformedTag {
reason: format!("expected '>' after '/' in <{tag_name}/>"),
line: tag_start_line,
});
}
return Ok(AttributeList {
attrs,
form: TagForm::SelfClosing,
end: i + 1,
line,
});
}
_ => {
let parsed = parse_attribute(input, tag_name, i, line)?;
if seen_attribute(&attrs, seen_set.as_ref(), &parsed.key) {
return Err(ParseError::DuplicateAttr {
tag: tag_name.to_string(),
attr: parsed.key,
line,
});
}
record_seen_attr(&attrs, &parsed.key, &mut seen_set);
attrs.push((parsed.key, parsed.value));
i = parsed.end;
line = parsed.line;
}
}
}
}
fn seen_attribute(attrs: &[(String, String)], seen: Option<&HashSet<String>>, key: &str) -> bool {
if let Some(set) = seen {
set.contains(key)
} else {
attrs.iter().any(|(k, _)| k == key)
}
}
fn record_seen_attr(
attrs: &[(String, String)],
next_key: &str,
seen: &mut Option<HashSet<String>>,
) {
if let Some(set) = seen.as_mut() {
set.insert(next_key.to_string());
return;
}
if attrs.len() + 1 < ATTR_DUP_SET_THRESHOLD {
return;
}
let mut set: HashSet<String> = HashSet::with_capacity(attrs.len() + 1);
for (k, _) in attrs {
set.insert(k.clone());
}
set.insert(next_key.to_string());
*seen = Some(set);
}
struct ParsedAttribute {
key: String,
value: String,
end: usize,
line: u32,
}
fn parse_attribute(
input: &str,
tag_name: &str,
start: usize,
start_line: u32,
) -> Result<ParsedAttribute, ParseError> {
let bytes = input.as_bytes();
let mut i = start;
let mut line = start_line;
if i >= bytes.len() || !is_name_start(bytes[i]) {
return Err(ParseError::MalformedAttribute {
tag: tag_name.to_string(),
reason: format!(
"unexpected character {:?} at start of attribute name",
next_char_at(input, i).unwrap_or('\0')
),
line,
});
}
let name_start = i;
i += 1;
while i < bytes.len() && is_name_char(bytes[i]) {
i += 1;
}
let key = input[name_start..i].to_string();
if i >= bytes.len() || bytes[i] != b'=' {
return Err(ParseError::MalformedAttribute {
tag: tag_name.to_string(),
reason: format!("expected '=' after attribute {key}"),
line,
});
}
i += 1;
if i >= bytes.len() || bytes[i] != b'"' {
return Err(ParseError::MalformedAttribute {
tag: tag_name.to_string(),
reason: format!("expected '\"' to open value of {key}"),
line,
});
}
i += 1;
let value_start = i;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\n' {
line = line.saturating_add(1);
}
i += 1;
}
if i >= bytes.len() {
return Err(ParseError::MalformedAttribute {
tag: tag_name.to_string(),
reason: format!("unterminated value of {key}"),
line: start_line,
});
}
let value = decode_entities(&input[value_start..i]).into_owned();
i += 1;
Ok(ParsedAttribute {
key,
value,
end: i,
line,
})
}
fn try_skip_comment(
bytes: &[u8],
start: usize,
start_line: u32,
) -> Result<Option<(usize, u32)>, ParseError> {
if !bytes[start..].starts_with(b"<!--") {
return Ok(None);
}
scan_to_terminator(bytes, start + 4, start_line, b"-->")
.ok_or_else(|| ParseError::MalformedTag {
reason: "unterminated <!-- comment".to_string(),
line: start_line,
})
.map(Some)
}
fn try_skip_cdata(
bytes: &[u8],
start: usize,
start_line: u32,
) -> Result<Option<(usize, u32)>, ParseError> {
if !bytes[start..].starts_with(b"<![CDATA[") {
return Ok(None);
}
scan_to_terminator(bytes, start + 9, start_line, b"]]>")
.ok_or_else(|| ParseError::MalformedTag {
reason: "unterminated <![CDATA[ section".to_string(),
line: start_line,
})
.map(Some)
}
fn scan_to_terminator(
bytes: &[u8],
from: usize,
start_line: u32,
terminator: &[u8],
) -> Option<(usize, u32)> {
let term_len = terminator.len();
let mut i = from;
let mut line = start_line;
while i + term_len <= bytes.len() {
if &bytes[i..i + term_len] == terminator {
return Some((i + term_len, line));
}
if bytes[i] == b'\n' {
line = line.saturating_add(1);
}
i += 1;
}
None
}
fn next_char_at(input: &str, i: usize) -> Option<char> {
input.get(i..).and_then(|tail| tail.chars().next())
}
fn skip_ws(bytes: &[u8], i: &mut usize, line: &mut u32) {
while *i < bytes.len() && bytes[*i].is_ascii_whitespace() {
if bytes[*i] == b'\n' {
*line = line.saturating_add(1);
}
*i += 1;
}
}