nodedb_sql/parser/preprocess/
lex.rs1#[derive(Debug, PartialEq, Eq)]
19pub enum SqlSegment<'a> {
20 Text(&'a str),
22 SingleQuotedString(&'a str),
24 QuotedIdent(&'a str),
26 LineComment(&'a str),
29 BlockComment(&'a str),
31}
32
33pub fn segments(sql: &str) -> Vec<SqlSegment<'_>> {
38 let mut out = Vec::new();
39 let bytes = sql.as_bytes();
40 let len = bytes.len();
41 let mut i = 0;
42 let mut text_start = 0;
43
44 macro_rules! flush_text {
45 () => {
46 if text_start < i {
47 out.push(SqlSegment::Text(&sql[text_start..i]));
48 }
49 };
50 }
51
52 while i < len {
53 let is_escape_prefix =
56 (bytes[i] == b'E' || bytes[i] == b'e') && i + 1 < len && bytes[i + 1] == b'\'';
57
58 if bytes[i] == b'\'' || is_escape_prefix {
59 flush_text!();
60 let start = i;
61 if is_escape_prefix {
62 i += 1; }
64 i += 1; let escape = is_escape_prefix;
66 while i < len {
67 match bytes[i] {
68 b'\\' if escape => {
69 i += 2;
71 }
72 b'\'' => {
73 i += 1;
74 if i < len && bytes[i] == b'\'' {
76 i += 1;
77 } else {
78 break;
79 }
80 }
81 _ => i += 1,
82 }
83 }
84 out.push(SqlSegment::SingleQuotedString(&sql[start..i]));
85 text_start = i;
86 continue;
87 }
88
89 if bytes[i] == b'"' {
91 flush_text!();
92 let start = i;
93 i += 1; while i < len {
95 match bytes[i] {
96 b'"' => {
97 i += 1;
98 if i < len && bytes[i] == b'"' {
100 i += 1;
101 } else {
102 break;
103 }
104 }
105 _ => i += 1,
106 }
107 }
108 out.push(SqlSegment::QuotedIdent(&sql[start..i]));
109 text_start = i;
110 continue;
111 }
112
113 if bytes[i] == b'-' && i + 1 < len && bytes[i + 1] == b'-' {
115 flush_text!();
116 let start = i;
117 while i < len && bytes[i] != b'\n' {
118 i += 1;
119 }
120 if i < len && bytes[i] == b'\n' {
122 i += 1;
123 }
124 out.push(SqlSegment::LineComment(&sql[start..i]));
125 text_start = i;
126 continue;
127 }
128
129 if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
131 flush_text!();
132 let start = i;
133 i += 2; let mut depth: usize = 1;
135 while i < len && depth > 0 {
136 if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
137 depth += 1;
138 i += 2;
139 } else if bytes[i] == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
140 depth -= 1;
141 i += 2;
142 } else {
143 i += 1;
144 }
145 }
146 out.push(SqlSegment::BlockComment(&sql[start..i]));
147 text_start = i;
148 continue;
149 }
150
151 i += 1;
153 }
154
155 if text_start < len {
157 out.push(SqlSegment::Text(&sql[text_start..]));
158 }
159
160 out
161}
162
163pub fn first_sql_word(sql: &str) -> Option<&str> {
169 for seg in segments(sql) {
170 if let SqlSegment::Text(t) = seg {
171 let trimmed = t.trim_start();
172 if trimmed.is_empty() {
173 continue;
174 }
175 let end = trimmed
176 .find(|c: char| c.is_ascii_whitespace() || c == '(' || c == ';')
177 .unwrap_or(trimmed.len());
178 if end > 0 {
179 return Some(&trimmed[..end]);
180 }
181 }
182 }
183 None
184}
185
186pub fn second_sql_word(sql: &str) -> Option<&str> {
192 let mut found_first = false;
193 for seg in segments(sql) {
194 if let SqlSegment::Text(t) = seg {
195 let mut remaining = t;
196 loop {
197 let trimmed = remaining.trim_start();
198 if trimmed.is_empty() {
199 break;
200 }
201 let end = trimmed
202 .find(|c: char| c.is_ascii_whitespace() || c == '(' || c == ';')
203 .unwrap_or(trimmed.len());
204 if end == 0 {
205 break;
206 }
207 if !found_first {
208 found_first = true;
209 remaining = &trimmed[end..];
211 } else {
212 return Some(&trimmed[..end]);
213 }
214 }
215 }
216 }
217 None
218}
219
220pub fn has_operator_outside_literals(sql: &str, op: &str) -> bool {
224 for seg in segments(sql) {
225 if let SqlSegment::Text(t) = seg
226 && t.contains(op)
227 {
228 return true;
229 }
230 }
231 false
232}
233
234pub fn find_operator_positions(sql: &str, op: &str) -> Vec<usize> {
237 let mut positions = Vec::new();
238 for seg in segments(sql) {
239 if let SqlSegment::Text(t) = seg {
240 let base = t.as_ptr() as usize - sql.as_ptr() as usize;
242 let mut search_from = 0;
243 while let Some(rel) = t[search_from..].find(op) {
244 let abs = base + search_from + rel;
245 positions.push(abs);
246 search_from += rel + op.len();
247 }
248 }
249 }
250 positions
251}
252
253pub fn has_brace_outside_literals(sql: &str) -> bool {
255 has_operator_outside_literals(sql, "{")
256}
257
258pub fn keyword_position_outside_literals(sql: &str, kw: &str) -> Option<usize> {
266 let kw_upper = kw.to_uppercase();
267 for seg in segments(sql) {
268 if let SqlSegment::Text(t) = seg {
269 let base = t.as_ptr() as usize - sql.as_ptr() as usize;
270 let upper = t.to_uppercase();
271 let mut search_from = 0;
272 while search_from < upper.len() {
273 let Some(rel) = upper[search_from..].find(&kw_upper) else {
274 break;
275 };
276 let abs_rel = search_from + rel;
277 let before_ok = abs_rel == 0
279 || !t[..abs_rel]
280 .chars()
281 .next_back()
282 .map(|c| c.is_alphanumeric() || c == '_')
283 .unwrap_or(false);
284 let after_start = abs_rel + kw.len();
285 let after_ok = after_start >= t.len()
286 || !t[after_start..]
287 .chars()
288 .next()
289 .map(|c| c.is_alphanumeric() || c == '_')
290 .unwrap_or(false);
291 if before_ok && after_ok {
292 return Some(base + abs_rel);
293 }
294 search_from = abs_rel + 1;
295 }
296 }
297 }
298 None
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
308 fn plain_text_is_single_segment() {
309 let segs = segments("SELECT 1");
310 assert_eq!(segs, vec![SqlSegment::Text("SELECT 1")]);
311 }
312
313 #[test]
314 fn single_quoted_string_opaque() {
315 let segs = segments("SELECT '<->'");
316 assert_eq!(
317 segs,
318 vec![
319 SqlSegment::Text("SELECT "),
320 SqlSegment::SingleQuotedString("'<->'"),
321 ]
322 );
323 }
324
325 #[test]
326 fn quoted_ident_opaque() {
327 let segs = segments(r#"SELECT "col_<->""#);
328 assert_eq!(
329 segs,
330 vec![
331 SqlSegment::Text("SELECT "),
332 SqlSegment::QuotedIdent(r#""col_<->""#),
333 ]
334 );
335 }
336
337 #[test]
338 fn line_comment_opaque() {
339 let segs = segments("SELECT col -- has <-> in comment\nFROM t");
340 assert!(
342 segs.iter()
343 .any(|s| matches!(s, SqlSegment::LineComment(c) if c.contains("<->")))
344 );
345 assert!(
346 segs.iter()
347 .any(|s| matches!(s, SqlSegment::Text(t) if t.contains("FROM")))
348 );
349 for s in &segs {
351 if let SqlSegment::Text(t) = s {
352 assert!(!t.contains("<->"), "unexpected <-> in Text: {t}");
353 }
354 }
355 }
356
357 #[test]
358 fn block_comment_opaque() {
359 let segs = segments("SELECT /* <-> */ x");
360 assert!(
361 segs.iter()
362 .any(|s| matches!(s, SqlSegment::BlockComment(c) if c.contains("<->")))
363 );
364 for s in &segs {
365 if let SqlSegment::Text(t) = s {
366 assert!(!t.contains("<->"), "unexpected <-> in Text: {t}");
367 }
368 }
369 }
370
371 #[test]
372 fn nested_block_comment() {
373 let segs = segments("SELECT /* /* nested */ <-> */ x");
374 assert!(
376 segs.iter()
377 .any(|s| matches!(s, SqlSegment::BlockComment(c) if c.contains("<->")))
378 );
379 for s in &segs {
380 if let SqlSegment::Text(t) = s {
381 assert!(!t.contains("<->"), "nested <-> leaked into Text: {t}");
382 }
383 }
384 }
385
386 #[test]
387 fn doubled_quote_escape_in_string() {
388 let segs = segments("SELECT 'it''s'");
389 assert_eq!(
390 segs,
391 vec![
392 SqlSegment::Text("SELECT "),
393 SqlSegment::SingleQuotedString("'it''s'"),
394 ]
395 );
396 }
397
398 #[test]
399 fn escape_string_prefix() {
400 let segs = segments("SELECT E'foo\\nbar'");
401 assert_eq!(
402 segs,
403 vec![
404 SqlSegment::Text("SELECT "),
405 SqlSegment::SingleQuotedString("E'foo\\nbar'"),
406 ]
407 );
408 }
409
410 #[test]
413 fn first_word_simple() {
414 assert_eq!(first_sql_word("SELECT 1"), Some("SELECT"));
415 }
416
417 #[test]
418 fn first_word_skips_line_comment() {
419 assert_eq!(first_sql_word("-- INSERT INTO t\nSELECT 1"), Some("SELECT"));
420 }
421
422 #[test]
423 fn first_word_skips_block_comment() {
424 assert_eq!(
425 first_sql_word("/* hint */ INSERT INTO t VALUES (1)"),
426 Some("INSERT")
427 );
428 }
429
430 #[test]
431 fn first_word_upsert_with_comment() {
432 assert_eq!(
433 first_sql_word("/* hint */ UPSERT INTO t { name: 'a' }"),
434 Some("UPSERT")
435 );
436 }
437
438 #[test]
439 fn first_word_empty() {
440 assert_eq!(first_sql_word(" "), None);
441 }
442
443 #[test]
446 fn operator_in_plain_text() {
447 assert!(has_operator_outside_literals("a <-> b", "<->"));
448 }
449
450 #[test]
451 fn operator_in_string_not_detected() {
452 assert!(!has_operator_outside_literals("SELECT '<->'", "<->"));
453 }
454
455 #[test]
456 fn operator_in_line_comment_not_detected() {
457 assert!(!has_operator_outside_literals(
458 "SELECT col -- has <-> in comment\nFROM t",
459 "<->"
460 ));
461 }
462
463 #[test]
464 fn operator_in_block_comment_not_detected() {
465 assert!(!has_operator_outside_literals("SELECT /* <-> */ x", "<->"));
466 }
467
468 #[test]
469 fn operator_in_quoted_ident_not_detected() {
470 assert!(!has_operator_outside_literals(r#"SELECT "col_<->""#, "<->"));
471 }
472
473 #[test]
476 fn brace_in_plain_text() {
477 assert!(has_brace_outside_literals("func({ foo })"));
478 }
479
480 #[test]
481 fn brace_in_string_not_detected() {
482 assert!(!has_brace_outside_literals("func('{ foo }')"));
483 }
484
485 #[test]
486 fn brace_concat_expr_not_detected() {
487 assert!(!has_brace_outside_literals("'{' || x || '}'"));
489 }
490
491 #[test]
494 fn keyword_found_in_plain_text() {
495 let sql = "SELECT * FROM t FOR SYSTEM_TIME AS OF 100";
496 assert!(keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").is_some());
497 }
498
499 #[test]
500 fn keyword_in_string_not_found() {
501 let sql = "SELECT * FROM t WHERE name = 'FOR SYSTEM_TIME'";
502 assert!(keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").is_none());
503 }
504
505 #[test]
506 fn keyword_position_correct() {
507 let sql = "SELECT x FOR SYSTEM_TIME AS OF 100";
508 let pos = keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").unwrap();
509 let found = &sql[pos..pos + "FOR SYSTEM_TIME".len()];
511 assert_eq!(found.to_uppercase(), "FOR SYSTEM_TIME");
512 }
513}