#[derive(Debug, PartialEq, Eq)]
pub enum SqlSegment<'a> {
Text(&'a str),
SingleQuotedString(&'a str),
QuotedIdent(&'a str),
LineComment(&'a str),
BlockComment(&'a str),
}
pub fn segments(sql: &str) -> Vec<SqlSegment<'_>> {
let mut out = Vec::new();
let bytes = sql.as_bytes();
let len = bytes.len();
let mut i = 0;
let mut text_start = 0;
macro_rules! flush_text {
() => {
if text_start < i {
out.push(SqlSegment::Text(&sql[text_start..i]));
}
};
}
while i < len {
let is_escape_prefix =
(bytes[i] == b'E' || bytes[i] == b'e') && i + 1 < len && bytes[i + 1] == b'\'';
if bytes[i] == b'\'' || is_escape_prefix {
flush_text!();
let start = i;
if is_escape_prefix {
i += 1; }
i += 1; let escape = is_escape_prefix;
while i < len {
match bytes[i] {
b'\\' if escape => {
i += 2;
}
b'\'' => {
i += 1;
if i < len && bytes[i] == b'\'' {
i += 1;
} else {
break;
}
}
_ => i += 1,
}
}
out.push(SqlSegment::SingleQuotedString(&sql[start..i]));
text_start = i;
continue;
}
if bytes[i] == b'"' {
flush_text!();
let start = i;
i += 1; while i < len {
match bytes[i] {
b'"' => {
i += 1;
if i < len && bytes[i] == b'"' {
i += 1;
} else {
break;
}
}
_ => i += 1,
}
}
out.push(SqlSegment::QuotedIdent(&sql[start..i]));
text_start = i;
continue;
}
if bytes[i] == b'-' && i + 1 < len && bytes[i + 1] == b'-' {
flush_text!();
let start = i;
while i < len && bytes[i] != b'\n' {
i += 1;
}
if i < len && bytes[i] == b'\n' {
i += 1;
}
out.push(SqlSegment::LineComment(&sql[start..i]));
text_start = i;
continue;
}
if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
flush_text!();
let start = i;
i += 2; let mut depth: usize = 1;
while i < len && depth > 0 {
if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
depth += 1;
i += 2;
} else if bytes[i] == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
depth -= 1;
i += 2;
} else {
i += 1;
}
}
out.push(SqlSegment::BlockComment(&sql[start..i]));
text_start = i;
continue;
}
i += 1;
}
if text_start < len {
out.push(SqlSegment::Text(&sql[text_start..]));
}
out
}
pub fn first_sql_word(sql: &str) -> Option<&str> {
for seg in segments(sql) {
if let SqlSegment::Text(t) = seg {
let trimmed = t.trim_start();
if trimmed.is_empty() {
continue;
}
let end = trimmed
.find(|c: char| c.is_ascii_whitespace() || c == '(' || c == ';')
.unwrap_or(trimmed.len());
if end > 0 {
return Some(&trimmed[..end]);
}
}
}
None
}
pub fn second_sql_word(sql: &str) -> Option<&str> {
let mut found_first = false;
for seg in segments(sql) {
if let SqlSegment::Text(t) = seg {
let mut remaining = t;
loop {
let trimmed = remaining.trim_start();
if trimmed.is_empty() {
break;
}
let end = trimmed
.find(|c: char| c.is_ascii_whitespace() || c == '(' || c == ';')
.unwrap_or(trimmed.len());
if end == 0 {
break;
}
if !found_first {
found_first = true;
remaining = &trimmed[end..];
} else {
return Some(&trimmed[..end]);
}
}
}
}
None
}
pub fn has_operator_outside_literals(sql: &str, op: &str) -> bool {
for seg in segments(sql) {
if let SqlSegment::Text(t) = seg
&& t.contains(op)
{
return true;
}
}
false
}
pub fn find_operator_positions(sql: &str, op: &str) -> Vec<usize> {
let mut positions = Vec::new();
for seg in segments(sql) {
if let SqlSegment::Text(t) = seg {
let base = t.as_ptr() as usize - sql.as_ptr() as usize;
let mut search_from = 0;
while let Some(rel) = t[search_from..].find(op) {
let abs = base + search_from + rel;
positions.push(abs);
search_from += rel + op.len();
}
}
}
positions
}
pub fn has_brace_outside_literals(sql: &str) -> bool {
has_operator_outside_literals(sql, "{")
}
pub fn keyword_position_outside_literals(sql: &str, kw: &str) -> Option<usize> {
let kw_upper = kw.to_uppercase();
for seg in segments(sql) {
if let SqlSegment::Text(t) = seg {
let base = t.as_ptr() as usize - sql.as_ptr() as usize;
let upper = t.to_uppercase();
let mut search_from = 0;
while search_from < upper.len() {
let Some(rel) = upper[search_from..].find(&kw_upper) else {
break;
};
let abs_rel = search_from + rel;
let before_ok = abs_rel == 0
|| !t[..abs_rel]
.chars()
.next_back()
.map(|c| c.is_alphanumeric() || c == '_')
.unwrap_or(false);
let after_start = abs_rel + kw.len();
let after_ok = after_start >= t.len()
|| !t[after_start..]
.chars()
.next()
.map(|c| c.is_alphanumeric() || c == '_')
.unwrap_or(false);
if before_ok && after_ok {
return Some(base + abs_rel);
}
search_from = abs_rel + 1;
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn plain_text_is_single_segment() {
let segs = segments("SELECT 1");
assert_eq!(segs, vec![SqlSegment::Text("SELECT 1")]);
}
#[test]
fn single_quoted_string_opaque() {
let segs = segments("SELECT '<->'");
assert_eq!(
segs,
vec![
SqlSegment::Text("SELECT "),
SqlSegment::SingleQuotedString("'<->'"),
]
);
}
#[test]
fn quoted_ident_opaque() {
let segs = segments(r#"SELECT "col_<->""#);
assert_eq!(
segs,
vec![
SqlSegment::Text("SELECT "),
SqlSegment::QuotedIdent(r#""col_<->""#),
]
);
}
#[test]
fn line_comment_opaque() {
let segs = segments("SELECT col -- has <-> in comment\nFROM t");
assert!(
segs.iter()
.any(|s| matches!(s, SqlSegment::LineComment(c) if c.contains("<->")))
);
assert!(
segs.iter()
.any(|s| matches!(s, SqlSegment::Text(t) if t.contains("FROM")))
);
for s in &segs {
if let SqlSegment::Text(t) = s {
assert!(!t.contains("<->"), "unexpected <-> in Text: {t}");
}
}
}
#[test]
fn block_comment_opaque() {
let segs = segments("SELECT /* <-> */ x");
assert!(
segs.iter()
.any(|s| matches!(s, SqlSegment::BlockComment(c) if c.contains("<->")))
);
for s in &segs {
if let SqlSegment::Text(t) = s {
assert!(!t.contains("<->"), "unexpected <-> in Text: {t}");
}
}
}
#[test]
fn nested_block_comment() {
let segs = segments("SELECT /* /* nested */ <-> */ x");
assert!(
segs.iter()
.any(|s| matches!(s, SqlSegment::BlockComment(c) if c.contains("<->")))
);
for s in &segs {
if let SqlSegment::Text(t) = s {
assert!(!t.contains("<->"), "nested <-> leaked into Text: {t}");
}
}
}
#[test]
fn doubled_quote_escape_in_string() {
let segs = segments("SELECT 'it''s'");
assert_eq!(
segs,
vec![
SqlSegment::Text("SELECT "),
SqlSegment::SingleQuotedString("'it''s'"),
]
);
}
#[test]
fn escape_string_prefix() {
let segs = segments("SELECT E'foo\\nbar'");
assert_eq!(
segs,
vec![
SqlSegment::Text("SELECT "),
SqlSegment::SingleQuotedString("E'foo\\nbar'"),
]
);
}
#[test]
fn first_word_simple() {
assert_eq!(first_sql_word("SELECT 1"), Some("SELECT"));
}
#[test]
fn first_word_skips_line_comment() {
assert_eq!(first_sql_word("-- INSERT INTO t\nSELECT 1"), Some("SELECT"));
}
#[test]
fn first_word_skips_block_comment() {
assert_eq!(
first_sql_word("/* hint */ INSERT INTO t VALUES (1)"),
Some("INSERT")
);
}
#[test]
fn first_word_upsert_with_comment() {
assert_eq!(
first_sql_word("/* hint */ UPSERT INTO t { name: 'a' }"),
Some("UPSERT")
);
}
#[test]
fn first_word_empty() {
assert_eq!(first_sql_word(" "), None);
}
#[test]
fn operator_in_plain_text() {
assert!(has_operator_outside_literals("a <-> b", "<->"));
}
#[test]
fn operator_in_string_not_detected() {
assert!(!has_operator_outside_literals("SELECT '<->'", "<->"));
}
#[test]
fn operator_in_line_comment_not_detected() {
assert!(!has_operator_outside_literals(
"SELECT col -- has <-> in comment\nFROM t",
"<->"
));
}
#[test]
fn operator_in_block_comment_not_detected() {
assert!(!has_operator_outside_literals("SELECT /* <-> */ x", "<->"));
}
#[test]
fn operator_in_quoted_ident_not_detected() {
assert!(!has_operator_outside_literals(r#"SELECT "col_<->""#, "<->"));
}
#[test]
fn brace_in_plain_text() {
assert!(has_brace_outside_literals("func({ foo })"));
}
#[test]
fn brace_in_string_not_detected() {
assert!(!has_brace_outside_literals("func('{ foo }')"));
}
#[test]
fn brace_concat_expr_not_detected() {
assert!(!has_brace_outside_literals("'{' || x || '}'"));
}
#[test]
fn keyword_found_in_plain_text() {
let sql = "SELECT * FROM t FOR SYSTEM_TIME AS OF 100";
assert!(keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").is_some());
}
#[test]
fn keyword_in_string_not_found() {
let sql = "SELECT * FROM t WHERE name = 'FOR SYSTEM_TIME'";
assert!(keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").is_none());
}
#[test]
fn keyword_position_correct() {
let sql = "SELECT x FOR SYSTEM_TIME AS OF 100";
let pos = keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").unwrap();
let found = &sql[pos..pos + "FOR SYSTEM_TIME".len()];
assert_eq!(found.to_uppercase(), "FOR SYSTEM_TIME");
}
}