pub(super) fn select_clause_uses_relation_compatibility(select: &str) -> bool {
scan_sql_until_top_level_boundary(
select,
0,
|_index, current, previous, next, _paren_depth| {
current == b':' && previous != Some(b':') && next != Some(b':')
},
)
.is_some()
}
pub(super) fn split_top_level_keyword(input: &str, keyword: &str) -> Vec<String> {
let mut parts = Vec::new();
let mut start = 0usize;
let mut remaining = input;
while let Some(index) = find_top_level_keyword(remaining, keyword, 0) {
let absolute = start + index;
parts.push(input[start..absolute].trim().to_string());
start = absolute + keyword.len();
remaining = &input[start..];
}
parts.push(input[start..].trim().to_string());
parts
}
pub(super) fn split_top_level_commas(input: &str) -> Vec<String> {
let mut parts = Vec::new();
let mut start = 0usize;
let mut next_start = 0usize;
while let Some(index) = scan_sql_until_top_level_boundary(
input,
next_start,
|_index, current, _previous, _next, paren_depth| paren_depth == 0 && current == b',',
) {
parts.push(input[start..index].trim().to_string());
start = index + 1;
next_start = start;
}
parts.push(input[start..].trim().to_string());
parts.into_iter().filter(|part| !part.is_empty()).collect()
}
pub(super) fn starts_with_keyword(input: &str, keyword: &str) -> bool {
starts_with_keyword_at(input, 0, keyword)
}
pub(super) fn starts_with_keyword_at(input: &str, start: usize, keyword: &str) -> bool {
if !slice_eq_ignore_ascii_case(input, start, keyword) {
return false;
}
let bytes = input.as_bytes();
let previous = start
.checked_sub(1)
.and_then(|position| bytes.get(position).copied());
let next = bytes.get(start + keyword.len()).copied();
is_keyword_boundary(previous) && is_keyword_boundary(next)
}
pub(super) fn starts_with_any_keyword_at(input: &str, start: usize, keywords: &[&str]) -> bool {
keywords
.iter()
.any(|keyword| starts_with_keyword_at(input, start, keyword))
}
pub(super) fn slice_eq_ignore_ascii_case(input: &str, start: usize, needle: &str) -> bool {
input
.get(start..start + needle.len())
.map(|candidate| candidate.eq_ignore_ascii_case(needle))
.unwrap_or(false)
}
pub(super) fn is_keyword_boundary(byte: Option<u8>) -> bool {
!matches!(byte, Some(value) if value.is_ascii_alphanumeric() || value == b'_')
}
pub(super) fn find_top_level_keyword(input: &str, keyword: &str, start: usize) -> Option<usize> {
let keyword_first = keyword.as_bytes()[0];
scan_sql_until_top_level_boundary(
input,
start,
|index, current, previous, _next, paren_depth| {
if paren_depth != 0 || !current.eq_ignore_ascii_case(&keyword_first) {
return false;
}
if !slice_eq_ignore_ascii_case(input, index, keyword) {
return false;
}
let after = input.as_bytes().get(index + keyword.len()).copied();
is_keyword_boundary(previous) && is_keyword_boundary(after)
},
)
}
pub(super) fn find_earliest_top_level_keyword<'a>(
input: &str,
keywords: &'a [&'a str],
) -> Option<(usize, &'a str)> {
keywords
.iter()
.filter_map(|keyword| {
find_top_level_keyword(input, keyword, 0).map(|index| (index, *keyword))
})
.min_by_key(|(index, _)| *index)
}
pub(super) fn scan_sql_until_top_level_boundary(
input: &str,
start: usize,
mut predicate: impl FnMut(usize, u8, Option<u8>, Option<u8>, usize) -> bool,
) -> Option<usize> {
let bytes = input.as_bytes();
let mut index = start;
let mut single_quoted = false;
let mut double_quoted = false;
let mut paren_depth = 0usize;
while index < bytes.len() {
let current = bytes[index];
let next = bytes.get(index + 1).copied();
if single_quoted {
if current == b'\'' && next == Some(b'\'') {
index += 2;
continue;
}
if current == b'\'' {
single_quoted = false;
}
index += 1;
continue;
}
if double_quoted {
if current == b'"' && next == Some(b'"') {
index += 2;
continue;
}
if current == b'"' {
double_quoted = false;
}
index += 1;
continue;
}
match current {
b'\'' => {
single_quoted = true;
index += 1;
continue;
}
b'"' => {
double_quoted = true;
index += 1;
continue;
}
b'(' => {
paren_depth += 1;
index += 1;
continue;
}
b')' => {
paren_depth = paren_depth.saturating_sub(1);
index += 1;
continue;
}
_ => {}
}
let previous = index
.checked_sub(1)
.and_then(|position| bytes.get(position).copied());
if predicate(index, current, previous, next, paren_depth) {
return Some(index);
}
index += 1;
}
None
}
pub(super) fn find_matching_closing_paren(input: &str, open_index: usize) -> Option<usize> {
let bytes = input.as_bytes();
let mut index = open_index;
let mut depth = 0usize;
let mut single_quoted = false;
let mut double_quoted = false;
while index < bytes.len() {
let current = bytes[index];
let next = bytes.get(index + 1).copied();
if single_quoted {
if current == b'\'' && next == Some(b'\'') {
index += 2;
continue;
}
if current == b'\'' {
single_quoted = false;
}
index += 1;
continue;
}
if double_quoted {
if current == b'"' && next == Some(b'"') {
index += 2;
continue;
}
if current == b'"' {
double_quoted = false;
}
index += 1;
continue;
}
match current {
b'\'' => single_quoted = true,
b'"' => double_quoted = true,
b'(' => depth += 1,
b')' => {
depth = depth.saturating_sub(1);
if depth == 0 {
return Some(index);
}
}
_ => {}
}
index += 1;
}
None
}
pub(super) fn skip_ascii_whitespace(input: &str, index: &mut usize) {
while *index < input.len() && input.as_bytes()[*index].is_ascii_whitespace() {
*index += 1;
}
}
pub(super) fn find_top_level_char(input: &str, target: u8) -> Option<usize> {
let bytes = input.as_bytes();
let mut index = 0usize;
let mut single_quoted = false;
let mut double_quoted = false;
let mut paren_depth = 0usize;
while index < bytes.len() {
let current = bytes[index];
let next = bytes.get(index + 1).copied();
if single_quoted {
if current == b'\'' && next == Some(b'\'') {
index += 2;
continue;
}
if current == b'\'' {
single_quoted = false;
}
index += 1;
continue;
}
if double_quoted {
if current == b'"' && next == Some(b'"') {
index += 2;
continue;
}
if current == b'"' {
double_quoted = false;
}
index += 1;
continue;
}
if paren_depth == 0 && current == target {
return Some(index);
}
match current {
b'\'' => single_quoted = true,
b'"' => double_quoted = true,
b'(' => paren_depth += 1,
b')' => paren_depth = paren_depth.saturating_sub(1),
_ => {}
}
index += 1;
}
None
}