use crate::velesql::error::{ParseError, ParseErrorKind};
const MAX_NESTING_DEPTH: usize = 64;
#[derive(Clone, Copy, PartialEq, Eq)]
enum ScanState {
Normal,
SingleQuote,
DoubleQuote,
Backtick,
}
struct Depth {
brackets: usize,
prefix_run: usize,
max: usize,
}
impl Depth {
fn new() -> Self {
Self {
brackets: 0,
prefix_run: 0,
max: 0,
}
}
fn observe(&mut self) {
let effective = self.brackets + self.prefix_run;
if effective > self.max {
self.max = effective;
}
}
}
pub(super) fn prescan(input: &str, max_query_length: usize) -> Result<(), ParseError> {
if input.len() > max_query_length {
return Err(length_error(input, max_query_length));
}
check_nesting_depth(input)
}
fn check_nesting_depth(input: &str) -> Result<(), ParseError> {
let bytes = input.as_bytes();
let mut state = ScanState::Normal;
let mut depth = Depth::new();
let mut i = 0;
while i < bytes.len() {
i += step(bytes, i, &mut state, &mut depth);
}
if depth.max > MAX_NESTING_DEPTH {
return Err(depth_error(bytes, bytes.len().saturating_sub(1)));
}
Ok(())
}
fn step(bytes: &[u8], i: usize, state: &mut ScanState, depth: &mut Depth) -> usize {
match *state {
ScanState::Normal => step_normal(bytes, i, state, depth),
ScanState::SingleQuote => step_quoted(bytes, i, bytes[i], b'\'', state),
ScanState::DoubleQuote => step_quoted(bytes, i, bytes[i], b'"', state),
ScanState::Backtick => {
if bytes[i] == b'`' {
*state = ScanState::Normal;
}
1
}
}
}
fn step_normal(bytes: &[u8], i: usize, state: &mut ScanState, depth: &mut Depth) -> usize {
let b = bytes[i];
if starts_line_comment(bytes, i) {
return skip_line_comment(bytes, i);
}
match b {
b'(' | b'[' => open_bracket(depth),
b')' | b']' => depth.brackets = depth.brackets.saturating_sub(1),
b'\'' => *state = ScanState::SingleQuote,
b'"' => *state = ScanState::DoubleQuote,
b'`' => *state = ScanState::Backtick,
_ if is_word_byte(b) => return word_token(bytes, i, depth),
_ => {}
}
1
}
fn starts_line_comment(bytes: &[u8], i: usize) -> bool {
bytes[i] == b'-' && bytes.get(i + 1) == Some(&b'-')
}
fn open_bracket(depth: &mut Depth) {
depth.brackets += depth.prefix_run + 1;
depth.prefix_run = 0;
depth.observe();
}
fn word_token(bytes: &[u8], i: usize, depth: &mut Depth) -> usize {
let mut end = i;
while end < bytes.len() && is_word_byte(bytes[end]) {
end += 1;
}
if is_not_keyword(&bytes[i..end]) {
depth.prefix_run += 1;
depth.observe();
} else {
depth.prefix_run = 0;
}
end - i
}
fn skip_line_comment(bytes: &[u8], i: usize) -> usize {
let mut end = i + 2;
while end < bytes.len() && bytes[end] != b'\n' {
end += 1;
}
end - i
}
fn is_word_byte(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
fn is_not_keyword(token: &[u8]) -> bool {
token.len() == 3
&& token[0].eq_ignore_ascii_case(&b'N')
&& token[1].eq_ignore_ascii_case(&b'O')
&& token[2].eq_ignore_ascii_case(&b'T')
}
fn step_quoted(bytes: &[u8], i: usize, b: u8, quote: u8, state: &mut ScanState) -> usize {
if b != quote {
return 1;
}
if bytes.get(i + 1) == Some("e) {
return 2;
}
*state = ScanState::Normal;
1
}
fn length_error(input: &str, max: usize) -> ParseError {
ParseError::new(
ParseErrorKind::ComplexityLimit,
max,
input.chars().take(128).collect::<String>(),
format!("Query length exceeded: max={max}, actual={}", input.len()),
)
}
fn depth_error(bytes: &[u8], position: usize) -> ParseError {
let start = position.saturating_sub(32);
let end = position.min(bytes.len().saturating_sub(1));
let fragment = String::from_utf8_lossy(&bytes[start..=end]).into_owned();
ParseError::new(
ParseErrorKind::ComplexityLimit,
position,
fragment,
format!("Query nesting too deep: max={MAX_NESTING_DEPTH} levels of recursion"),
)
}