Skip to main content

flowscope_core/parser/
mod.rs

1use crate::error::ParseError;
2use crate::types::Dialect;
3use sqlparser::ast::Statement;
4use sqlparser::dialect::PostgreSqlDialect;
5use sqlparser::parser::Parser;
6
7/// Result of parsing SQL with fallback metadata.
8pub struct ParseSqlOutput {
9    pub statements: Vec<Statement>,
10    pub parser_fallback_used: bool,
11}
12
13/// Parse SQL using the specified dialect
14pub fn parse_sql_with_dialect(sql: &str, dialect: Dialect) -> Result<Vec<Statement>, ParseError> {
15    parse_sql_with_dialect_output(sql, dialect).map(|output| output.statements)
16}
17
18/// Parse SQL using the specified dialect and report whether parser fallback was used.
19pub fn parse_sql_with_dialect_output(
20    sql: &str,
21    dialect: Dialect,
22) -> Result<ParseSqlOutput, ParseError> {
23    let sqlparser_dialect = dialect.to_sqlparser_dialect();
24    match Parser::parse_sql(sqlparser_dialect.as_ref(), sql) {
25        Ok(statements) => Ok(ParseSqlOutput {
26            statements,
27            parser_fallback_used: false,
28        }),
29        Err(primary_err) => {
30            if let Some(sanitized_sql) = sanitize_escaped_identifiers_for_dialect(sql, dialect) {
31                if let Ok(statements) =
32                    Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
33                {
34                    return Ok(ParseSqlOutput {
35                        statements,
36                        parser_fallback_used: true,
37                    });
38                }
39            }
40
41            if let Some(sanitized_sql) = sanitize_trailing_comma_before_from(sql) {
42                if let Ok(statements) =
43                    Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
44                {
45                    return Ok(ParseSqlOutput {
46                        statements,
47                        parser_fallback_used: true,
48                    });
49                }
50            }
51
52            if matches!(dialect, Dialect::Ansi) {
53                if let Some(sanitized_sql) = sanitize_ansi_national_literal_spacing(sql) {
54                    if let Ok(statements) =
55                        Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
56                    {
57                        return Ok(ParseSqlOutput {
58                            statements,
59                            parser_fallback_used: true,
60                        });
61                    }
62                }
63            }
64
65            if matches!(dialect, Dialect::Bigquery) {
66                if let Some(sanitized_sql) = sanitize_bigquery_raw_double_quoted_literals(sql) {
67                    if let Ok(statements) =
68                        Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
69                    {
70                        return Ok(ParseSqlOutput {
71                            statements,
72                            parser_fallback_used: true,
73                        });
74                    }
75                }
76            }
77
78            // Parity fallback: Generic dialect frequently fails on Postgres-specific
79            // operators (`?`, `->>`, `::`) commonly used in warehouse SQL.
80            if matches!(dialect, Dialect::Generic) && looks_like_postgres_syntax(sql) {
81                let postgres = PostgreSqlDialect {};
82                if let Ok(statements) = Parser::parse_sql(&postgres, sql) {
83                    return Ok(ParseSqlOutput {
84                        statements,
85                        parser_fallback_used: true,
86                    });
87                }
88            }
89            Err(primary_err.into())
90        }
91    }
92}
93
94fn looks_like_postgres_syntax(sql: &str) -> bool {
95    sql.contains("::")
96        || sql.contains("->")
97        || sql.contains("?|")
98        || sql.contains("?&")
99        || sql.contains(" ? ")
100        || sql.contains(" ?\n")
101        || sql.contains("? '")
102        || sql.contains("?\t")
103}
104
105fn sanitize_escaped_identifiers_for_dialect(sql: &str, dialect: Dialect) -> Option<String> {
106    let delimiters: &[u8] = match dialect {
107        Dialect::Bigquery => b"`",
108        Dialect::Clickhouse => b"`\"",
109        _ => return None,
110    };
111
112    if !sql.as_bytes().contains(&b'\\') {
113        return None;
114    }
115
116    let mut rewritten = rewrite_escaped_quoted_identifiers(sql, delimiters);
117
118    if matches!(dialect, Dialect::Clickhouse) {
119        rewritten = remove_trailing_comma_before_from(&rewritten);
120    }
121
122    (rewritten != sql).then_some(rewritten)
123}
124
125fn sanitize_trailing_comma_before_from(sql: &str) -> Option<String> {
126    let rewritten = remove_trailing_comma_before_from(sql);
127    (rewritten != sql).then_some(rewritten)
128}
129
130fn push_current_char(sql: &str, i: &mut usize, out: &mut String) {
131    if let Some(ch) = sql[*i..].chars().next() {
132        out.push(ch);
133        *i += ch.len_utf8();
134    }
135}
136
137fn sanitize_ansi_national_literal_spacing(sql: &str) -> Option<String> {
138    #[derive(Clone, Copy, PartialEq, Eq)]
139    enum ScanMode {
140        Outside,
141        SingleQuote,
142        DoubleQuote,
143        BacktickQuote,
144        BracketQuote,
145        LineComment,
146        BlockComment,
147    }
148
149    fn identifier_tail(byte: u8) -> bool {
150        byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'$')
151    }
152
153    let bytes = sql.as_bytes();
154    let mut out = String::with_capacity(sql.len());
155    let mut mode = ScanMode::Outside;
156    let mut i = 0usize;
157    let mut changed = false;
158
159    while i < bytes.len() {
160        let b = bytes[i];
161        let next = bytes.get(i + 1).copied();
162
163        match mode {
164            ScanMode::Outside => {
165                if b == b'\'' {
166                    mode = ScanMode::SingleQuote;
167                    out.push('\'');
168                    i += 1;
169                    continue;
170                }
171                if b == b'"' {
172                    mode = ScanMode::DoubleQuote;
173                    out.push('"');
174                    i += 1;
175                    continue;
176                }
177                if b == b'`' {
178                    mode = ScanMode::BacktickQuote;
179                    out.push('`');
180                    i += 1;
181                    continue;
182                }
183                if b == b'[' {
184                    mode = ScanMode::BracketQuote;
185                    out.push('[');
186                    i += 1;
187                    continue;
188                }
189                if b == b'-' && next == Some(b'-') {
190                    mode = ScanMode::LineComment;
191                    out.push('-');
192                    out.push('-');
193                    i += 2;
194                    continue;
195                }
196                if b == b'/' && next == Some(b'*') {
197                    mode = ScanMode::BlockComment;
198                    out.push('/');
199                    out.push('*');
200                    i += 2;
201                    continue;
202                }
203
204                if matches!(b, b'N' | b'n') {
205                    let prev = i.checked_sub(1).and_then(|idx| bytes.get(idx).copied());
206                    if !prev.is_some_and(identifier_tail) {
207                        let mut j = i + 1;
208                        while j < bytes.len() && bytes[j].is_ascii_whitespace() {
209                            j += 1;
210                        }
211                        if j > i + 1 && bytes.get(j).copied() == Some(b'\'') {
212                            out.push(b as char);
213                            i += 1;
214                            while i < j {
215                                changed = true;
216                                i += 1;
217                            }
218                            continue;
219                        }
220                    }
221                }
222
223                push_current_char(sql, &mut i, &mut out);
224            }
225            ScanMode::SingleQuote => {
226                push_current_char(sql, &mut i, &mut out);
227                if b == b'\'' {
228                    if next == Some(b'\'') {
229                        out.push('\'');
230                        i += 1;
231                    } else {
232                        mode = ScanMode::Outside;
233                    }
234                }
235            }
236            ScanMode::DoubleQuote => {
237                push_current_char(sql, &mut i, &mut out);
238                if b == b'"' {
239                    mode = ScanMode::Outside;
240                }
241            }
242            ScanMode::BacktickQuote => {
243                push_current_char(sql, &mut i, &mut out);
244                if b == b'`' {
245                    mode = ScanMode::Outside;
246                }
247            }
248            ScanMode::BracketQuote => {
249                push_current_char(sql, &mut i, &mut out);
250                if b == b']' {
251                    mode = ScanMode::Outside;
252                }
253            }
254            ScanMode::LineComment => {
255                push_current_char(sql, &mut i, &mut out);
256                if b == b'\n' || b == b'\r' {
257                    mode = ScanMode::Outside;
258                }
259            }
260            ScanMode::BlockComment => {
261                push_current_char(sql, &mut i, &mut out);
262                if b == b'*' && next == Some(b'/') {
263                    out.push('/');
264                    i += 1;
265                    mode = ScanMode::Outside;
266                }
267            }
268        }
269    }
270
271    changed.then_some(out)
272}
273
274fn sanitize_bigquery_raw_double_quoted_literals(sql: &str) -> Option<String> {
275    let bytes = sql.as_bytes();
276    let mut out = String::with_capacity(sql.len());
277    let mut i = 0usize;
278    let mut changed = false;
279
280    while i < bytes.len() {
281        let start = i;
282        while i < bytes.len() && bytes[i].is_ascii_alphabetic() {
283            i += 1;
284        }
285
286        let prefix = &sql[start..i];
287        let is_raw_prefix = prefix.eq_ignore_ascii_case("r")
288            || prefix.eq_ignore_ascii_case("br")
289            || prefix.eq_ignore_ascii_case("rb");
290
291        if !is_raw_prefix || i >= bytes.len() || bytes[i] != b'"' {
292            if start < i {
293                out.push_str(prefix);
294            } else if i < bytes.len() {
295                push_current_char(sql, &mut i, &mut out);
296            }
297            continue;
298        }
299
300        let quote_start = i;
301        i += 1;
302        let mut body = String::new();
303        let mut closed = false;
304        while i < bytes.len() {
305            if bytes[i] == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'"' {
306                body.push('\\');
307                body.push('"');
308                i += 2;
309                continue;
310            }
311            if bytes[i] == b'"' {
312                closed = true;
313                i += 1;
314                break;
315            }
316            push_current_char(sql, &mut i, &mut body);
317        }
318
319        if !closed {
320            out.push_str(&sql[start..quote_start]);
321            out.push('"');
322            out.push_str(&body);
323            break;
324        }
325
326        changed = true;
327        out.push_str(prefix);
328        out.push('\'');
329        for ch in body.chars() {
330            if ch == '\'' {
331                out.push('\'');
332            }
333            out.push(ch);
334        }
335        out.push('\'');
336    }
337
338    changed.then_some(out)
339}
340
341fn rewrite_escaped_quoted_identifiers(sql: &str, delimiters: &[u8]) -> String {
342    let bytes = sql.as_bytes();
343    let mut out = String::with_capacity(sql.len());
344    let mut i = 0usize;
345    let len = bytes.len();
346
347    while i < len {
348        if bytes[i] == b'\'' {
349            let start = i;
350            i += 1;
351            while i < len {
352                if bytes[i] == b'\'' {
353                    if i + 1 < len && bytes[i + 1] == b'\'' {
354                        i += 2;
355                    } else {
356                        i += 1;
357                        break;
358                    }
359                } else {
360                    i += 1;
361                }
362            }
363            out.push_str(&sql[start..i]);
364            continue;
365        }
366
367        if bytes[i] == b'-' && i + 1 < len && bytes[i + 1] == b'-' {
368            let start = i;
369            i += 2;
370            while i < len && bytes[i] != b'\n' && bytes[i] != b'\r' {
371                i += 1;
372            }
373            out.push_str(&sql[start..i]);
374            continue;
375        }
376
377        if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
378            let start = i;
379            i += 2;
380            while i + 1 < len {
381                if bytes[i] == b'*' && bytes[i + 1] == b'/' {
382                    i += 2;
383                    break;
384                }
385                i += 1;
386            }
387            out.push_str(&sql[start..i.min(len)]);
388            continue;
389        }
390
391        if delimiters.contains(&bytes[i]) {
392            let delimiter = bytes[i];
393            let start = i;
394            i += 1;
395            let mut content = String::new();
396            let mut had_escape = false;
397            let mut closed = false;
398
399            while i < len {
400                if bytes[i] == b'\\' && i + 1 < len && bytes[i + 1] == delimiter {
401                    had_escape = true;
402                    content.push('_');
403                    i += 2;
404                    continue;
405                }
406
407                if bytes[i] == delimiter {
408                    if i + 1 < len && bytes[i + 1] == delimiter {
409                        had_escape = true;
410                        content.push('_');
411                        i += 2;
412                        continue;
413                    }
414                    i += 1;
415                    closed = true;
416                    break;
417                }
418
419                push_current_char(sql, &mut i, &mut content);
420            }
421
422            if !closed {
423                out.push_str(&sql[start..len]);
424                break;
425            }
426
427            if had_escape {
428                let normalized = normalize_identifier_content(&content);
429                out.push(delimiter as char);
430                out.push_str(&normalized);
431                out.push(delimiter as char);
432            } else {
433                out.push_str(&sql[start..i]);
434            }
435            continue;
436        }
437
438        push_current_char(sql, &mut i, &mut out);
439    }
440
441    out
442}
443
444fn normalize_identifier_content(content: &str) -> String {
445    let mut normalized = String::with_capacity(content.len());
446    for ch in content.chars() {
447        if ch.is_ascii_alphanumeric() || ch == '_' {
448            normalized.push(ch.to_ascii_lowercase());
449        } else {
450            normalized.push('_');
451        }
452    }
453
454    if normalized.is_empty() || normalized.chars().all(|ch| ch == '_') {
455        "escaped_identifier".to_string()
456    } else {
457        normalized
458    }
459}
460
461fn remove_trailing_comma_before_from(sql: &str) -> String {
462    let bytes = sql.as_bytes();
463    let mut out = String::with_capacity(sql.len());
464    let mut i = 0usize;
465    let len = bytes.len();
466
467    while i < len {
468        if bytes[i] == b',' {
469            let mut j = i + 1;
470            while j < len && matches!(bytes[j], b' ' | b'\t' | b'\n' | b'\r') {
471                j += 1;
472            }
473
474            if j + 4 <= len
475                && bytes[j..j + 4].eq_ignore_ascii_case(b"FROM")
476                && (j + 4 == len || !bytes[j + 4].is_ascii_alphanumeric())
477            {
478                i += 1;
479                continue;
480            }
481        }
482
483        push_current_char(sql, &mut i, &mut out);
484    }
485
486    out
487}
488
489/// Parse SQL using the generic dialect (legacy compatibility)
490pub fn parse_sql(sql: &str) -> Result<Vec<Statement>, ParseError> {
491    parse_sql_with_dialect(sql, Dialect::Generic)
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_parse_valid_select() {
500        let sql = "SELECT * FROM users";
501        let result = parse_sql(sql);
502        assert!(result.is_ok());
503        let statements = result.unwrap();
504        assert_eq!(statements.len(), 1);
505    }
506
507    #[test]
508    fn test_parse_invalid_sql() {
509        let sql = "SELECT * FROM";
510        let result = parse_sql(sql);
511        assert!(result.is_err());
512    }
513
514    #[test]
515    fn test_parse_multiple_statements() {
516        let sql = "SELECT * FROM users; SELECT * FROM orders;";
517        let result = parse_sql(sql);
518        assert!(result.is_ok());
519        let statements = result.unwrap();
520        assert_eq!(statements.len(), 2);
521    }
522
523    #[test]
524    fn test_parse_with_postgres_dialect() {
525        let sql = "SELECT * FROM users WHERE name ILIKE '%test%'";
526        let result = parse_sql_with_dialect(sql, Dialect::Postgres);
527        assert!(result.is_ok());
528    }
529
530    #[test]
531    fn test_parse_with_snowflake_dialect() {
532        let sql = "SELECT * FROM db.schema.table";
533        let result = parse_sql_with_dialect(sql, Dialect::Snowflake);
534        assert!(result.is_ok());
535    }
536
537    #[test]
538    fn test_parse_with_bigquery_dialect() {
539        let sql = "SELECT * FROM `project.dataset.table`";
540        let result = parse_sql_with_dialect(sql, Dialect::Bigquery);
541        assert!(result.is_ok());
542    }
543
544    #[test]
545    fn test_parse_cte() {
546        let sql = r#"
547            WITH active_users AS (
548                SELECT * FROM users WHERE active = true
549            )
550            SELECT * FROM active_users
551        "#;
552        let result = parse_sql(sql);
553        assert!(result.is_ok());
554    }
555
556    #[test]
557    fn test_parse_insert_select() {
558        let sql = "INSERT INTO archive SELECT * FROM users WHERE deleted = true";
559        let result = parse_sql(sql);
560        assert!(result.is_ok());
561    }
562
563    #[test]
564    fn test_parse_create_table_as() {
565        let sql = "CREATE TABLE users_backup AS SELECT * FROM users";
566        let result = parse_sql(sql);
567        assert!(result.is_ok());
568    }
569
570    #[test]
571    fn test_parse_union() {
572        let sql = "SELECT id FROM users UNION ALL SELECT id FROM admins";
573        let result = parse_sql(sql);
574        assert!(result.is_ok());
575    }
576
577    #[test]
578    fn test_parse_generic_falls_back_for_postgres_json_operator() {
579        let sql = "SELECT usage_metadata ? 'pipeline_id' FROM ledger.usage_line_item";
580        let result = parse_sql(sql);
581        assert!(result.is_ok());
582    }
583
584    #[test]
585    fn test_parse_generic_falls_back_for_postgres_cast_operator() {
586        let sql = "SELECT workspace_id::text FROM ledger.usage_line_item";
587        let result = parse_sql(sql);
588        assert!(result.is_ok());
589    }
590
591    #[test]
592    fn test_parse_output_marks_parser_fallback_usage() {
593        let generic = sqlparser::dialect::GenericDialect {};
594        let sql = [
595            "SELECT usage_metadata ? 'pipeline_id' FROM ledger.usage_line_item",
596            "SELECT workspace_id::text FROM ledger.usage_line_item",
597            "SELECT payload->>'id' FROM ledger.usage_line_item",
598        ]
599        .into_iter()
600        .find(|candidate| Parser::parse_sql(&generic, candidate).is_err())
601        .expect("expected at least one postgres-only candidate to fail in generic parser");
602
603        let output = parse_sql_with_dialect_output(sql, Dialect::Generic).expect("parse");
604        assert!(output.parser_fallback_used);
605        assert_eq!(output.statements.len(), 1);
606    }
607
608    #[test]
609    fn test_parse_output_bigquery_escaped_identifier_fallback_usage() {
610        let sql = "SELECT `\\`a`.col1 FROM tab1 as `\\`A`";
611        let output = parse_sql_with_dialect_output(sql, Dialect::Bigquery).expect("parse");
612        assert!(output.parser_fallback_used);
613        assert_eq!(output.statements.len(), 1);
614    }
615
616    #[test]
617    fn test_parse_output_clickhouse_escaped_identifier_fallback_usage() {
618        let sql = "SELECT \"\\\"`a`\"\"\".col1,\nFROM tab1 as `\"\\`a``\"`";
619        let output = parse_sql_with_dialect_output(sql, Dialect::Clickhouse).expect("parse");
620        assert!(output.parser_fallback_used);
621        assert_eq!(output.statements.len(), 1);
622    }
623
624    #[test]
625    fn test_parse_output_trailing_comma_before_from_fallback_usage() {
626        let sql = "SELECT widget.id,\nwidget.name,\nFROM widget";
627        let output = parse_sql_with_dialect_output(sql, Dialect::Ansi).expect("parse");
628        assert!(output.parser_fallback_used);
629        assert_eq!(output.statements.len(), 1);
630    }
631
632    #[test]
633    fn test_remove_trailing_comma_before_from_preserves_utf8() {
634        let sql = "SELECT café,\nFROM résumé";
635        let rewritten = remove_trailing_comma_before_from(sql);
636        assert_eq!(rewritten, "SELECT café\nFROM résumé");
637    }
638
639    #[test]
640    fn test_sanitize_escaped_identifiers_preserves_utf8() {
641        let sql = "SELECT naïve, `\\`id` FROM café";
642        let rewritten =
643            sanitize_escaped_identifiers_for_dialect(sql, Dialect::Bigquery).expect("rewrite");
644        assert_eq!(rewritten, "SELECT naïve, `_id` FROM café");
645    }
646
647    #[test]
648    fn test_parse_output_ansi_national_literal_spacing_fallback_usage() {
649        let sql = "SELECT a + N 'b' + N 'c' FROM tbl;";
650        let output = parse_sql_with_dialect_output(sql, Dialect::Ansi).expect("parse");
651        assert!(output.parser_fallback_used);
652        assert_eq!(output.statements.len(), 1);
653    }
654
655    #[test]
656    fn test_parse_output_bigquery_raw_double_quoted_literal_fallback_usage() {
657        let sql = r#"SELECT r'Tricky "quote', r"Not-so-tricky \"quote""#;
658        let output = parse_sql_with_dialect_output(sql, Dialect::Bigquery).expect("parse");
659        assert!(output.parser_fallback_used);
660        assert_eq!(output.statements.len(), 1);
661    }
662
663    #[test]
664    fn test_parse_output_without_fallback() {
665        let sql = "SELECT 1";
666        let output = parse_sql_with_dialect_output(sql, Dialect::Generic).expect("parse");
667        assert!(!output.parser_fallback_used);
668        assert_eq!(output.statements.len(), 1);
669    }
670}