Skip to main content

nodedb_sql/parser/preprocess/
lex.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared SQL lexer for preprocess passes.
4//!
5//! The lexer classifies a SQL string into non-overlapping segments, allowing
6//! preprocess passes to scan only plain `Text` segments — string literals,
7//! quoted identifiers, and comments are passed through opaquely and never
8//! matched against patterns.
9//!
10//! Supported SQL dialect features:
11//! - Single-quoted strings (`'...'`) with `''` escape and `E'...'` with
12//!   backslash escapes.
13//! - Double-quoted identifiers (`"..."`).
14//! - Line comments (`-- ...` to end-of-line).
15//! - Block comments (`/* ... */`, nestable per PostgreSQL).
16
17/// A classified segment of a SQL string.
18#[derive(Debug, PartialEq, Eq)]
19pub enum SqlSegment<'a> {
20    /// Unquoted SQL text.
21    Text(&'a str),
22    /// A single-quoted string literal, including the surrounding quotes.
23    SingleQuotedString(&'a str),
24    /// A double-quoted identifier, including the surrounding quotes.
25    QuotedIdent(&'a str),
26    /// A line comment starting with `--`, including the `--` prefix and
27    /// trailing newline (if present).
28    LineComment(&'a str),
29    /// A block comment delimited by `/* ... */` (nestable).
30    BlockComment(&'a str),
31}
32
33/// Segment a SQL string into classified [`SqlSegment`]s.
34///
35/// The entire input is covered exactly once (no bytes are skipped). Adjacent
36/// `Text` bytes are collected into a single segment.
37pub fn segments(sql: &str) -> Vec<SqlSegment<'_>> {
38    let mut out = Vec::new();
39    let bytes = sql.as_bytes();
40    let len = bytes.len();
41    let mut i = 0;
42    let mut text_start = 0;
43
44    macro_rules! flush_text {
45        () => {
46            if text_start < i {
47                out.push(SqlSegment::Text(&sql[text_start..i]));
48            }
49        };
50    }
51
52    while i < len {
53        // ── single-quoted string ──────────────────────────────────────────
54        // Optional `E` or `e` escape prefix before the opening quote.
55        let is_escape_prefix =
56            (bytes[i] == b'E' || bytes[i] == b'e') && i + 1 < len && bytes[i + 1] == b'\'';
57
58        if bytes[i] == b'\'' || is_escape_prefix {
59            flush_text!();
60            let start = i;
61            if is_escape_prefix {
62                i += 1; // skip `E`
63            }
64            i += 1; // skip opening `'`
65            let escape = is_escape_prefix;
66            while i < len {
67                match bytes[i] {
68                    b'\\' if escape => {
69                        // backslash escape: skip two chars
70                        i += 2;
71                    }
72                    b'\'' => {
73                        i += 1;
74                        // doubled-quote escape `''`
75                        if i < len && bytes[i] == b'\'' {
76                            i += 1;
77                        } else {
78                            break;
79                        }
80                    }
81                    _ => i += 1,
82                }
83            }
84            out.push(SqlSegment::SingleQuotedString(&sql[start..i]));
85            text_start = i;
86            continue;
87        }
88
89        // ── double-quoted identifier ──────────────────────────────────────
90        if bytes[i] == b'"' {
91            flush_text!();
92            let start = i;
93            i += 1; // skip opening `"`
94            while i < len {
95                match bytes[i] {
96                    b'"' => {
97                        i += 1;
98                        // doubled `""` escape inside quoted ident
99                        if i < len && bytes[i] == b'"' {
100                            i += 1;
101                        } else {
102                            break;
103                        }
104                    }
105                    _ => i += 1,
106                }
107            }
108            out.push(SqlSegment::QuotedIdent(&sql[start..i]));
109            text_start = i;
110            continue;
111        }
112
113        // ── line comment `-- ...` ─────────────────────────────────────────
114        if bytes[i] == b'-' && i + 1 < len && bytes[i + 1] == b'-' {
115            flush_text!();
116            let start = i;
117            while i < len && bytes[i] != b'\n' {
118                i += 1;
119            }
120            // include the newline if present
121            if i < len && bytes[i] == b'\n' {
122                i += 1;
123            }
124            out.push(SqlSegment::LineComment(&sql[start..i]));
125            text_start = i;
126            continue;
127        }
128
129        // ── block comment `/* ... */` (nestable) ─────────────────────────
130        if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
131            flush_text!();
132            let start = i;
133            i += 2; // skip `/*`
134            let mut depth: usize = 1;
135            while i < len && depth > 0 {
136                if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
137                    depth += 1;
138                    i += 2;
139                } else if bytes[i] == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
140                    depth -= 1;
141                    i += 2;
142                } else {
143                    i += 1;
144                }
145            }
146            out.push(SqlSegment::BlockComment(&sql[start..i]));
147            text_start = i;
148            continue;
149        }
150
151        // ── plain text byte ───────────────────────────────────────────────
152        i += 1;
153    }
154
155    // flush any trailing text
156    if text_start < len {
157        out.push(SqlSegment::Text(&sql[text_start..]));
158    }
159
160    out
161}
162
163/// Return the first SQL keyword/word in `sql`, skipping leading whitespace,
164/// line comments, and block comments. Returns `None` if the input is empty
165/// or contains only whitespace/comments.
166///
167/// The returned slice is a sub-slice of `sql` in its original case.
168pub fn first_sql_word(sql: &str) -> Option<&str> {
169    for seg in segments(sql) {
170        if let SqlSegment::Text(t) = seg {
171            let trimmed = t.trim_start();
172            if trimmed.is_empty() {
173                continue;
174            }
175            let end = trimmed
176                .find(|c: char| c.is_ascii_whitespace() || c == '(' || c == ';')
177                .unwrap_or(trimmed.len());
178            if end > 0 {
179                return Some(&trimmed[..end]);
180            }
181        }
182    }
183    None
184}
185
186/// Return the second SQL keyword/word in `sql`, skipping leading whitespace,
187/// line comments, and block comments, then skipping the first word. Returns
188/// `None` if there is no second word.
189///
190/// The returned slice is a sub-slice of `sql` in its original case.
191pub fn second_sql_word(sql: &str) -> Option<&str> {
192    let mut found_first = false;
193    for seg in segments(sql) {
194        if let SqlSegment::Text(t) = seg {
195            let mut remaining = t;
196            loop {
197                let trimmed = remaining.trim_start();
198                if trimmed.is_empty() {
199                    break;
200                }
201                let end = trimmed
202                    .find(|c: char| c.is_ascii_whitespace() || c == '(' || c == ';')
203                    .unwrap_or(trimmed.len());
204                if end == 0 {
205                    break;
206                }
207                if !found_first {
208                    found_first = true;
209                    // advance past this word
210                    remaining = &trimmed[end..];
211                } else {
212                    return Some(&trimmed[..end]);
213                }
214            }
215        }
216    }
217    None
218}
219
220/// Return `true` if `op` appears verbatim inside any `Text` segment of `sql`.
221/// The comparison is byte-exact (case-sensitive). Occurrences inside string
222/// literals, quoted identifiers, or comments are ignored.
223pub fn has_operator_outside_literals(sql: &str, op: &str) -> bool {
224    for seg in segments(sql) {
225        if let SqlSegment::Text(t) = seg
226            && t.contains(op)
227        {
228            return true;
229        }
230    }
231    false
232}
233
234/// Return the byte positions (relative to the start of `sql`) of every
235/// occurrence of `op` that falls inside a `Text` segment.
236pub fn find_operator_positions(sql: &str, op: &str) -> Vec<usize> {
237    let mut positions = Vec::new();
238    for seg in segments(sql) {
239        if let SqlSegment::Text(t) = seg {
240            // Safety: `t` is a sub-slice of `sql`; pointer arithmetic is valid.
241            let base = t.as_ptr() as usize - sql.as_ptr() as usize;
242            let mut search_from = 0;
243            while let Some(rel) = t[search_from..].find(op) {
244                let abs = base + search_from + rel;
245                positions.push(abs);
246                search_from += rel + op.len();
247            }
248        }
249    }
250    positions
251}
252
253/// Return `true` if `{` appears inside any `Text` segment of `sql`.
254pub fn has_brace_outside_literals(sql: &str) -> bool {
255    has_operator_outside_literals(sql, "{")
256}
257
258/// Return the byte position (relative to `sql`) of the first case-insensitive
259/// occurrence of the keyword `kw` that falls inside a `Text` segment. Returns
260/// `None` if not found.
261///
262/// The match is word-boundary-aware: the character immediately before the
263/// match (if any) must not be alphanumeric or `_`, and the character
264/// immediately after the match (if any) must not be alphanumeric or `_`.
265pub fn keyword_position_outside_literals(sql: &str, kw: &str) -> Option<usize> {
266    let kw_upper = kw.to_uppercase();
267    for seg in segments(sql) {
268        if let SqlSegment::Text(t) = seg {
269            let base = t.as_ptr() as usize - sql.as_ptr() as usize;
270            let upper = t.to_uppercase();
271            let mut search_from = 0;
272            while search_from < upper.len() {
273                let Some(rel) = upper[search_from..].find(&kw_upper) else {
274                    break;
275                };
276                let abs_rel = search_from + rel;
277                // word-boundary check
278                let before_ok = abs_rel == 0
279                    || !t[..abs_rel]
280                        .chars()
281                        .next_back()
282                        .map(|c| c.is_alphanumeric() || c == '_')
283                        .unwrap_or(false);
284                let after_start = abs_rel + kw.len();
285                let after_ok = after_start >= t.len()
286                    || !t[after_start..]
287                        .chars()
288                        .next()
289                        .map(|c| c.is_alphanumeric() || c == '_')
290                        .unwrap_or(false);
291                if before_ok && after_ok {
292                    return Some(base + abs_rel);
293                }
294                search_from = abs_rel + 1;
295            }
296        }
297    }
298    None
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    // ── segments ────────────────────────────────────────────────────────────
306
307    #[test]
308    fn plain_text_is_single_segment() {
309        let segs = segments("SELECT 1");
310        assert_eq!(segs, vec![SqlSegment::Text("SELECT 1")]);
311    }
312
313    #[test]
314    fn single_quoted_string_opaque() {
315        let segs = segments("SELECT '<->'");
316        assert_eq!(
317            segs,
318            vec![
319                SqlSegment::Text("SELECT "),
320                SqlSegment::SingleQuotedString("'<->'"),
321            ]
322        );
323    }
324
325    #[test]
326    fn quoted_ident_opaque() {
327        let segs = segments(r#"SELECT "col_<->""#);
328        assert_eq!(
329            segs,
330            vec![
331                SqlSegment::Text("SELECT "),
332                SqlSegment::QuotedIdent(r#""col_<->""#),
333            ]
334        );
335    }
336
337    #[test]
338    fn line_comment_opaque() {
339        let segs = segments("SELECT col -- has <-> in comment\nFROM t");
340        // The segment after the newline belongs to a new Text segment.
341        assert!(
342            segs.iter()
343                .any(|s| matches!(s, SqlSegment::LineComment(c) if c.contains("<->")))
344        );
345        assert!(
346            segs.iter()
347                .any(|s| matches!(s, SqlSegment::Text(t) if t.contains("FROM")))
348        );
349        // `<->` must not appear in any Text segment.
350        for s in &segs {
351            if let SqlSegment::Text(t) = s {
352                assert!(!t.contains("<->"), "unexpected <-> in Text: {t}");
353            }
354        }
355    }
356
357    #[test]
358    fn block_comment_opaque() {
359        let segs = segments("SELECT /* <-> */ x");
360        assert!(
361            segs.iter()
362                .any(|s| matches!(s, SqlSegment::BlockComment(c) if c.contains("<->")))
363        );
364        for s in &segs {
365            if let SqlSegment::Text(t) = s {
366                assert!(!t.contains("<->"), "unexpected <-> in Text: {t}");
367            }
368        }
369    }
370
371    #[test]
372    fn nested_block_comment() {
373        let segs = segments("SELECT /* /* nested */ <-> */ x");
374        // The outer `/* ... */` includes everything including `<->`.
375        assert!(
376            segs.iter()
377                .any(|s| matches!(s, SqlSegment::BlockComment(c) if c.contains("<->")))
378        );
379        for s in &segs {
380            if let SqlSegment::Text(t) = s {
381                assert!(!t.contains("<->"), "nested <-> leaked into Text: {t}");
382            }
383        }
384    }
385
386    #[test]
387    fn doubled_quote_escape_in_string() {
388        let segs = segments("SELECT 'it''s'");
389        assert_eq!(
390            segs,
391            vec![
392                SqlSegment::Text("SELECT "),
393                SqlSegment::SingleQuotedString("'it''s'"),
394            ]
395        );
396    }
397
398    #[test]
399    fn escape_string_prefix() {
400        let segs = segments("SELECT E'foo\\nbar'");
401        assert_eq!(
402            segs,
403            vec![
404                SqlSegment::Text("SELECT "),
405                SqlSegment::SingleQuotedString("E'foo\\nbar'"),
406            ]
407        );
408    }
409
410    // ── first_sql_word ───────────────────────────────────────────────────────
411
412    #[test]
413    fn first_word_simple() {
414        assert_eq!(first_sql_word("SELECT 1"), Some("SELECT"));
415    }
416
417    #[test]
418    fn first_word_skips_line_comment() {
419        assert_eq!(first_sql_word("-- INSERT INTO t\nSELECT 1"), Some("SELECT"));
420    }
421
422    #[test]
423    fn first_word_skips_block_comment() {
424        assert_eq!(
425            first_sql_word("/* hint */ INSERT INTO t VALUES (1)"),
426            Some("INSERT")
427        );
428    }
429
430    #[test]
431    fn first_word_upsert_with_comment() {
432        assert_eq!(
433            first_sql_word("/* hint */ UPSERT INTO t { name: 'a' }"),
434            Some("UPSERT")
435        );
436    }
437
438    #[test]
439    fn first_word_empty() {
440        assert_eq!(first_sql_word("   "), None);
441    }
442
443    // ── has_operator_outside_literals ───────────────────────────────────────
444
445    #[test]
446    fn operator_in_plain_text() {
447        assert!(has_operator_outside_literals("a <-> b", "<->"));
448    }
449
450    #[test]
451    fn operator_in_string_not_detected() {
452        assert!(!has_operator_outside_literals("SELECT '<->'", "<->"));
453    }
454
455    #[test]
456    fn operator_in_line_comment_not_detected() {
457        assert!(!has_operator_outside_literals(
458            "SELECT col -- has <-> in comment\nFROM t",
459            "<->"
460        ));
461    }
462
463    #[test]
464    fn operator_in_block_comment_not_detected() {
465        assert!(!has_operator_outside_literals("SELECT /* <-> */ x", "<->"));
466    }
467
468    #[test]
469    fn operator_in_quoted_ident_not_detected() {
470        assert!(!has_operator_outside_literals(r#"SELECT "col_<->""#, "<->"));
471    }
472
473    // ── has_brace_outside_literals ──────────────────────────────────────────
474
475    #[test]
476    fn brace_in_plain_text() {
477        assert!(has_brace_outside_literals("func({ foo })"));
478    }
479
480    #[test]
481    fn brace_in_string_not_detected() {
482        assert!(!has_brace_outside_literals("func('{ foo }')"));
483    }
484
485    #[test]
486    fn brace_concat_expr_not_detected() {
487        // `'{' || x || '}'` — braces only inside string literals
488        assert!(!has_brace_outside_literals("'{' || x || '}'"));
489    }
490
491    // ── keyword_position_outside_literals ───────────────────────────────────
492
493    #[test]
494    fn keyword_found_in_plain_text() {
495        let sql = "SELECT * FROM t FOR SYSTEM_TIME AS OF 100";
496        assert!(keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").is_some());
497    }
498
499    #[test]
500    fn keyword_in_string_not_found() {
501        let sql = "SELECT * FROM t WHERE name = 'FOR SYSTEM_TIME'";
502        assert!(keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").is_none());
503    }
504
505    #[test]
506    fn keyword_position_correct() {
507        let sql = "SELECT x FOR SYSTEM_TIME AS OF 100";
508        let pos = keyword_position_outside_literals(sql, "FOR SYSTEM_TIME").unwrap();
509        // verify the slice at that position matches the keyword (case-insensitively)
510        let found = &sql[pos..pos + "FOR SYSTEM_TIME".len()];
511        assert_eq!(found.to_uppercase(), "FOR SYSTEM_TIME");
512    }
513}