use crate::{SourceSpan, error::ParserError};
pub(crate) const MAX_NESTING_DEPTH: u32 = 64;
pub(crate) const MAX_LIST_NESTING_DEPTH: u32 = 32;
pub(crate) const MAX_RECURSION_DEPTH: u32 = 256;
pub(super) fn validate(source: &str) -> Result<(), ParserError> {
let bytes = source.as_bytes();
let last_single_quote = bytes.iter().rposition(|byte| *byte == b'\'');
let last_double_quote = bytes.iter().rposition(|byte| *byte == b'"');
let last_backtick = bytes.iter().rposition(|byte| *byte == b'`');
let mut index = 0;
let mut depth = 0_u32;
let mut list_depth = 0_u32;
let mut sign_run = 0_u32;
let mut not_run = 0_u32;
let mut case_depth = 0_u32;
let mut prev_sig_byte: Option<u8> = None;
let mut prev_word = PrevWord::Other;
while index < bytes.len() {
match bytes[index] {
b'@' if next_is(bytes, index, b'\'') => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'\'');
index = skip_no_escape_quoted(bytes, index + 2, b'\'');
}
b'@' if next_is(bytes, index, b'"') => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'"');
index = skip_no_escape_quoted(bytes, index + 2, b'"');
}
b'@' if next_is(bytes, index, b'`') => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'`');
index = skip_no_escape_quoted(bytes, index + 2, b'`');
}
b'\'' => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'\'');
index = skip_single_quoted(bytes, index + 1, last_single_quote);
}
b'"' => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'"');
index = skip_double_quoted(bytes, index + 1, last_double_quote);
}
b'`' => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'`');
index = skip_backtick_quoted(bytes, index + 1, last_backtick);
}
b'/' if next_is(bytes, index, b'/') => {
index = skip_line_comment(bytes, index + 2);
continue;
}
b'/' if next_is(bytes, index, b'*') => {
index = skip_block_comment(bytes, index + 2);
index += 1;
continue;
}
b'(' | b'{' => {
depth += 1;
prev_word = PrevWord::Other;
prev_sig_byte = Some(bytes[index]);
if depth > MAX_NESTING_DEPTH {
return Err(ParserError::NestingLimitExceeded {
limit: MAX_NESTING_DEPTH,
span: point_span(index),
});
}
if exceeds_recursion_budget(depth, sign_run, not_run, case_depth) {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_RECURSION_DEPTH,
span: point_span(index),
});
}
}
b'[' => {
list_depth += 1;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b'[');
if list_depth > MAX_LIST_NESTING_DEPTH {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_LIST_NESTING_DEPTH,
span: point_span(index),
});
}
depth += 1;
if depth > MAX_NESTING_DEPTH {
return Err(ParserError::NestingLimitExceeded {
limit: MAX_NESTING_DEPTH,
span: point_span(index),
});
}
if exceeds_recursion_budget(depth, sign_run, not_run, case_depth) {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_RECURSION_DEPTH,
span: point_span(index),
});
}
}
b')' | b'}' => {
depth = depth.saturating_sub(1);
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(bytes[index]);
}
b']' => {
depth = depth.saturating_sub(1);
list_depth = list_depth.saturating_sub(1);
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(b']');
}
b'+' | b'-' => {
sign_run += 1;
prev_word = PrevWord::Other;
prev_sig_byte = Some(bytes[index]);
if exceeds_recursion_budget(depth, sign_run, not_run, case_depth) {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_RECURSION_DEPTH,
span: point_span(index),
});
}
}
b' ' | b'\t' | b'\r' | b'\n' => {}
byte if is_word_byte_start(byte) => {
let word_end = scan_word_chars(source, index);
if word_end == index {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(byte);
index += char_len_at(source, index);
continue;
}
let in_ident_pos = matches!(prev_sig_byte, Some(b'.') | Some(b'$'))
|| matches!(prev_word, PrevWord::As | PrevWord::Yield)
|| next_sig_is_colon(bytes, word_end);
let class = classify_word(&source[index..word_end]);
match class {
WordClass::Not if !in_ident_pos => {
not_run += 1;
if exceeds_recursion_budget(depth, sign_run, not_run, case_depth) {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_RECURSION_DEPTH,
span: point_span(index),
});
}
}
WordClass::Case if !in_ident_pos => {
case_depth = case_depth
.saturating_add(1)
.saturating_add(sign_run)
.saturating_add(not_run);
sign_run = 0;
not_run = 0;
if exceeds_recursion_budget(depth, sign_run, not_run, case_depth) {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_RECURSION_DEPTH,
span: point_span(index),
});
}
}
_ => {
sign_run = 0;
not_run = 0;
}
}
prev_word = if in_ident_pos {
PrevWord::Other
} else {
match class {
WordClass::As => PrevWord::As,
WordClass::Yield => PrevWord::Yield,
_ => PrevWord::Other,
}
};
prev_sig_byte = Some(b'w');
index = word_end;
continue;
}
other => {
sign_run = 0;
not_run = 0;
prev_word = PrevWord::Other;
prev_sig_byte = Some(other);
}
}
index += 1;
}
Ok(())
}
fn exceeds_recursion_budget(depth: u32, sign_run: u32, not_run: u32, case_depth: u32) -> bool {
depth
.saturating_add(sign_run)
.saturating_add(not_run)
.saturating_add(case_depth)
> MAX_RECURSION_DEPTH
}
fn next_is(bytes: &[u8], index: usize, expected: u8) -> bool {
bytes.get(index + 1).is_some_and(|value| *value == expected)
}
enum WordClass {
Not,
Case,
End,
As,
Yield,
Other,
}
enum PrevWord {
As,
Yield,
Other,
}
fn is_word_byte_start(byte: u8) -> bool {
byte.is_ascii_alphabetic() || byte == b'_' || byte >= 0x80
}
fn char_len_at(source: &str, index: usize) -> usize {
source[index..].chars().next().map_or(1, char::len_utf8)
}
fn scan_word_chars(source: &str, start: usize) -> usize {
let mut end = start;
for (offset, ch) in source[start..].char_indices() {
if ch.is_alphanumeric() || ch == '_' {
end = start + offset + ch.len_utf8();
} else {
break;
}
}
end
}
fn classify_word(word: &str) -> WordClass {
if word.eq_ignore_ascii_case("NOT") {
WordClass::Not
} else if word.eq_ignore_ascii_case("CASE") {
WordClass::Case
} else if word.eq_ignore_ascii_case("END") {
WordClass::End
} else if word.eq_ignore_ascii_case("AS") {
WordClass::As
} else if word.eq_ignore_ascii_case("YIELD") {
WordClass::Yield
} else {
WordClass::Other
}
}
fn next_sig_is_colon(bytes: &[u8], from: usize) -> bool {
let mut index = from;
while index < bytes.len() {
match bytes[index] {
b' ' | b'\t' | b'\r' | b'\n' => index += 1,
b'/' if next_is(bytes, index, b'/') => index = skip_line_comment(bytes, index + 2),
b'/' if next_is(bytes, index, b'*') => index = skip_block_comment(bytes, index + 2) + 1,
b':' => return true,
_ => return false,
}
}
false
}
fn skip_single_quoted(bytes: &[u8], mut index: usize, last_quote: Option<usize>) -> usize {
while index < bytes.len() {
match bytes[index] {
b'\\' if bytes.get(index + 1) == Some(&b'\'') && Some(index + 1) == last_quote => {
return index + 1;
}
b'\\' => index += 2,
b'\'' if next_is(bytes, index, b'\'') => index += 2,
b'\'' => return index,
_ => index += 1,
}
}
bytes.len()
}
fn skip_double_quoted(bytes: &[u8], mut index: usize, last_quote: Option<usize>) -> usize {
while index < bytes.len() {
match bytes[index] {
b'\\' if bytes.get(index + 1) == Some(&b'"') && Some(index + 1) == last_quote => {
return index + 1;
}
b'\\' => index += 2,
b'"' if next_is(bytes, index, b'"') => index += 2,
b'"' => return index,
_ => index += 1,
}
}
bytes.len()
}
fn skip_no_escape_quoted(bytes: &[u8], mut index: usize, delimiter: u8) -> usize {
while index < bytes.len() {
if bytes[index] == delimiter {
return index;
}
index += 1;
}
bytes.len()
}
fn skip_backtick_quoted(bytes: &[u8], mut index: usize, last_backtick: Option<usize>) -> usize {
while index < bytes.len() {
match bytes[index] {
b'\\' if bytes.get(index + 1) == Some(&b'`') && Some(index + 1) == last_backtick => {
return index + 1;
}
b'\\' => index += 2,
b'`' if next_is(bytes, index, b'`') => index += 2,
b'`' => return index,
_ => index += 1,
}
}
bytes.len()
}
fn skip_line_comment(bytes: &[u8], mut index: usize) -> usize {
while index < bytes.len() {
if bytes[index] == b'\n' {
return index;
}
index += 1;
}
bytes.len()
}
fn skip_block_comment(bytes: &[u8], mut index: usize) -> usize {
while index + 1 < bytes.len() {
if bytes[index] == b'*' && bytes[index + 1] == b'/' {
return index + 1;
}
index += 1;
}
bytes.len()
}
fn point_span(offset: usize) -> SourceSpan {
SourceSpan::new(u32::try_from(offset).unwrap_or(u32::MAX), 1)
}