1use crate::error::ParseError;
2use crate::types::Dialect;
3use sqlparser::ast::Statement;
4use sqlparser::dialect::PostgreSqlDialect;
5use sqlparser::parser::Parser;
6
7pub struct ParseSqlOutput {
9 pub statements: Vec<Statement>,
10 pub parser_fallback_used: bool,
11}
12
13pub fn parse_sql_with_dialect(sql: &str, dialect: Dialect) -> Result<Vec<Statement>, ParseError> {
15 parse_sql_with_dialect_output(sql, dialect).map(|output| output.statements)
16}
17
18pub fn parse_sql_with_dialect_output(
20 sql: &str,
21 dialect: Dialect,
22) -> Result<ParseSqlOutput, ParseError> {
23 let sqlparser_dialect = dialect.to_sqlparser_dialect();
24 match Parser::parse_sql(sqlparser_dialect.as_ref(), sql) {
25 Ok(statements) => Ok(ParseSqlOutput {
26 statements,
27 parser_fallback_used: false,
28 }),
29 Err(primary_err) => {
30 if let Some(sanitized_sql) = sanitize_escaped_identifiers_for_dialect(sql, dialect) {
31 if let Ok(statements) =
32 Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
33 {
34 return Ok(ParseSqlOutput {
35 statements,
36 parser_fallback_used: true,
37 });
38 }
39 }
40
41 if let Some(sanitized_sql) = sanitize_trailing_comma_before_from(sql) {
42 if let Ok(statements) =
43 Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
44 {
45 return Ok(ParseSqlOutput {
46 statements,
47 parser_fallback_used: true,
48 });
49 }
50 }
51
52 if matches!(dialect, Dialect::Ansi) {
53 if let Some(sanitized_sql) = sanitize_ansi_national_literal_spacing(sql) {
54 if let Ok(statements) =
55 Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
56 {
57 return Ok(ParseSqlOutput {
58 statements,
59 parser_fallback_used: true,
60 });
61 }
62 }
63 }
64
65 if matches!(dialect, Dialect::Bigquery) {
66 if let Some(sanitized_sql) = sanitize_bigquery_raw_double_quoted_literals(sql) {
67 if let Ok(statements) =
68 Parser::parse_sql(sqlparser_dialect.as_ref(), &sanitized_sql)
69 {
70 return Ok(ParseSqlOutput {
71 statements,
72 parser_fallback_used: true,
73 });
74 }
75 }
76 }
77
78 if matches!(dialect, Dialect::Generic) && looks_like_postgres_syntax(sql) {
81 let postgres = PostgreSqlDialect {};
82 if let Ok(statements) = Parser::parse_sql(&postgres, sql) {
83 return Ok(ParseSqlOutput {
84 statements,
85 parser_fallback_used: true,
86 });
87 }
88 }
89 Err(primary_err.into())
90 }
91 }
92}
93
94fn looks_like_postgres_syntax(sql: &str) -> bool {
95 sql.contains("::")
96 || sql.contains("->")
97 || sql.contains("?|")
98 || sql.contains("?&")
99 || sql.contains(" ? ")
100 || sql.contains(" ?\n")
101 || sql.contains("? '")
102 || sql.contains("?\t")
103}
104
105fn sanitize_escaped_identifiers_for_dialect(sql: &str, dialect: Dialect) -> Option<String> {
106 let delimiters: &[u8] = match dialect {
107 Dialect::Bigquery => b"`",
108 Dialect::Clickhouse => b"`\"",
109 _ => return None,
110 };
111
112 if !sql.as_bytes().contains(&b'\\') {
113 return None;
114 }
115
116 let mut rewritten = rewrite_escaped_quoted_identifiers(sql, delimiters);
117
118 if matches!(dialect, Dialect::Clickhouse) {
119 rewritten = remove_trailing_comma_before_from(&rewritten);
120 }
121
122 (rewritten != sql).then_some(rewritten)
123}
124
125fn sanitize_trailing_comma_before_from(sql: &str) -> Option<String> {
126 let rewritten = remove_trailing_comma_before_from(sql);
127 (rewritten != sql).then_some(rewritten)
128}
129
130fn push_current_char(sql: &str, i: &mut usize, out: &mut String) {
131 if let Some(ch) = sql[*i..].chars().next() {
132 out.push(ch);
133 *i += ch.len_utf8();
134 }
135}
136
137fn sanitize_ansi_national_literal_spacing(sql: &str) -> Option<String> {
138 #[derive(Clone, Copy, PartialEq, Eq)]
139 enum ScanMode {
140 Outside,
141 SingleQuote,
142 DoubleQuote,
143 BacktickQuote,
144 BracketQuote,
145 LineComment,
146 BlockComment,
147 }
148
149 fn identifier_tail(byte: u8) -> bool {
150 byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'$')
151 }
152
153 let bytes = sql.as_bytes();
154 let mut out = String::with_capacity(sql.len());
155 let mut mode = ScanMode::Outside;
156 let mut i = 0usize;
157 let mut changed = false;
158
159 while i < bytes.len() {
160 let b = bytes[i];
161 let next = bytes.get(i + 1).copied();
162
163 match mode {
164 ScanMode::Outside => {
165 if b == b'\'' {
166 mode = ScanMode::SingleQuote;
167 out.push('\'');
168 i += 1;
169 continue;
170 }
171 if b == b'"' {
172 mode = ScanMode::DoubleQuote;
173 out.push('"');
174 i += 1;
175 continue;
176 }
177 if b == b'`' {
178 mode = ScanMode::BacktickQuote;
179 out.push('`');
180 i += 1;
181 continue;
182 }
183 if b == b'[' {
184 mode = ScanMode::BracketQuote;
185 out.push('[');
186 i += 1;
187 continue;
188 }
189 if b == b'-' && next == Some(b'-') {
190 mode = ScanMode::LineComment;
191 out.push('-');
192 out.push('-');
193 i += 2;
194 continue;
195 }
196 if b == b'/' && next == Some(b'*') {
197 mode = ScanMode::BlockComment;
198 out.push('/');
199 out.push('*');
200 i += 2;
201 continue;
202 }
203
204 if matches!(b, b'N' | b'n') {
205 let prev = i.checked_sub(1).and_then(|idx| bytes.get(idx).copied());
206 if !prev.is_some_and(identifier_tail) {
207 let mut j = i + 1;
208 while j < bytes.len() && bytes[j].is_ascii_whitespace() {
209 j += 1;
210 }
211 if j > i + 1 && bytes.get(j).copied() == Some(b'\'') {
212 out.push(b as char);
213 i += 1;
214 while i < j {
215 changed = true;
216 i += 1;
217 }
218 continue;
219 }
220 }
221 }
222
223 push_current_char(sql, &mut i, &mut out);
224 }
225 ScanMode::SingleQuote => {
226 push_current_char(sql, &mut i, &mut out);
227 if b == b'\'' {
228 if next == Some(b'\'') {
229 out.push('\'');
230 i += 1;
231 } else {
232 mode = ScanMode::Outside;
233 }
234 }
235 }
236 ScanMode::DoubleQuote => {
237 push_current_char(sql, &mut i, &mut out);
238 if b == b'"' {
239 mode = ScanMode::Outside;
240 }
241 }
242 ScanMode::BacktickQuote => {
243 push_current_char(sql, &mut i, &mut out);
244 if b == b'`' {
245 mode = ScanMode::Outside;
246 }
247 }
248 ScanMode::BracketQuote => {
249 push_current_char(sql, &mut i, &mut out);
250 if b == b']' {
251 mode = ScanMode::Outside;
252 }
253 }
254 ScanMode::LineComment => {
255 push_current_char(sql, &mut i, &mut out);
256 if b == b'\n' || b == b'\r' {
257 mode = ScanMode::Outside;
258 }
259 }
260 ScanMode::BlockComment => {
261 push_current_char(sql, &mut i, &mut out);
262 if b == b'*' && next == Some(b'/') {
263 out.push('/');
264 i += 1;
265 mode = ScanMode::Outside;
266 }
267 }
268 }
269 }
270
271 changed.then_some(out)
272}
273
274fn sanitize_bigquery_raw_double_quoted_literals(sql: &str) -> Option<String> {
275 let bytes = sql.as_bytes();
276 let mut out = String::with_capacity(sql.len());
277 let mut i = 0usize;
278 let mut changed = false;
279
280 while i < bytes.len() {
281 let start = i;
282 while i < bytes.len() && bytes[i].is_ascii_alphabetic() {
283 i += 1;
284 }
285
286 let prefix = &sql[start..i];
287 let is_raw_prefix = prefix.eq_ignore_ascii_case("r")
288 || prefix.eq_ignore_ascii_case("br")
289 || prefix.eq_ignore_ascii_case("rb");
290
291 if !is_raw_prefix || i >= bytes.len() || bytes[i] != b'"' {
292 if start < i {
293 out.push_str(prefix);
294 } else if i < bytes.len() {
295 push_current_char(sql, &mut i, &mut out);
296 }
297 continue;
298 }
299
300 let quote_start = i;
301 i += 1;
302 let mut body = String::new();
303 let mut closed = false;
304 while i < bytes.len() {
305 if bytes[i] == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'"' {
306 body.push('\\');
307 body.push('"');
308 i += 2;
309 continue;
310 }
311 if bytes[i] == b'"' {
312 closed = true;
313 i += 1;
314 break;
315 }
316 push_current_char(sql, &mut i, &mut body);
317 }
318
319 if !closed {
320 out.push_str(&sql[start..quote_start]);
321 out.push('"');
322 out.push_str(&body);
323 break;
324 }
325
326 changed = true;
327 out.push_str(prefix);
328 out.push('\'');
329 for ch in body.chars() {
330 if ch == '\'' {
331 out.push('\'');
332 }
333 out.push(ch);
334 }
335 out.push('\'');
336 }
337
338 changed.then_some(out)
339}
340
341fn rewrite_escaped_quoted_identifiers(sql: &str, delimiters: &[u8]) -> String {
342 let bytes = sql.as_bytes();
343 let mut out = String::with_capacity(sql.len());
344 let mut i = 0usize;
345 let len = bytes.len();
346
347 while i < len {
348 if bytes[i] == b'\'' {
349 let start = i;
350 i += 1;
351 while i < len {
352 if bytes[i] == b'\'' {
353 if i + 1 < len && bytes[i + 1] == b'\'' {
354 i += 2;
355 } else {
356 i += 1;
357 break;
358 }
359 } else {
360 i += 1;
361 }
362 }
363 out.push_str(&sql[start..i]);
364 continue;
365 }
366
367 if bytes[i] == b'-' && i + 1 < len && bytes[i + 1] == b'-' {
368 let start = i;
369 i += 2;
370 while i < len && bytes[i] != b'\n' && bytes[i] != b'\r' {
371 i += 1;
372 }
373 out.push_str(&sql[start..i]);
374 continue;
375 }
376
377 if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
378 let start = i;
379 i += 2;
380 while i + 1 < len {
381 if bytes[i] == b'*' && bytes[i + 1] == b'/' {
382 i += 2;
383 break;
384 }
385 i += 1;
386 }
387 out.push_str(&sql[start..i.min(len)]);
388 continue;
389 }
390
391 if delimiters.contains(&bytes[i]) {
392 let delimiter = bytes[i];
393 let start = i;
394 i += 1;
395 let mut content = String::new();
396 let mut had_escape = false;
397 let mut closed = false;
398
399 while i < len {
400 if bytes[i] == b'\\' && i + 1 < len && bytes[i + 1] == delimiter {
401 had_escape = true;
402 content.push('_');
403 i += 2;
404 continue;
405 }
406
407 if bytes[i] == delimiter {
408 if i + 1 < len && bytes[i + 1] == delimiter {
409 had_escape = true;
410 content.push('_');
411 i += 2;
412 continue;
413 }
414 i += 1;
415 closed = true;
416 break;
417 }
418
419 push_current_char(sql, &mut i, &mut content);
420 }
421
422 if !closed {
423 out.push_str(&sql[start..len]);
424 break;
425 }
426
427 if had_escape {
428 let normalized = normalize_identifier_content(&content);
429 out.push(delimiter as char);
430 out.push_str(&normalized);
431 out.push(delimiter as char);
432 } else {
433 out.push_str(&sql[start..i]);
434 }
435 continue;
436 }
437
438 push_current_char(sql, &mut i, &mut out);
439 }
440
441 out
442}
443
444fn normalize_identifier_content(content: &str) -> String {
445 let mut normalized = String::with_capacity(content.len());
446 for ch in content.chars() {
447 if ch.is_ascii_alphanumeric() || ch == '_' {
448 normalized.push(ch.to_ascii_lowercase());
449 } else {
450 normalized.push('_');
451 }
452 }
453
454 if normalized.is_empty() || normalized.chars().all(|ch| ch == '_') {
455 "escaped_identifier".to_string()
456 } else {
457 normalized
458 }
459}
460
461fn remove_trailing_comma_before_from(sql: &str) -> String {
462 let bytes = sql.as_bytes();
463 let mut out = String::with_capacity(sql.len());
464 let mut i = 0usize;
465 let len = bytes.len();
466
467 while i < len {
468 if bytes[i] == b',' {
469 let mut j = i + 1;
470 while j < len && matches!(bytes[j], b' ' | b'\t' | b'\n' | b'\r') {
471 j += 1;
472 }
473
474 if j + 4 <= len
475 && bytes[j..j + 4].eq_ignore_ascii_case(b"FROM")
476 && (j + 4 == len || !bytes[j + 4].is_ascii_alphanumeric())
477 {
478 i += 1;
479 continue;
480 }
481 }
482
483 push_current_char(sql, &mut i, &mut out);
484 }
485
486 out
487}
488
489pub fn parse_sql(sql: &str) -> Result<Vec<Statement>, ParseError> {
491 parse_sql_with_dialect(sql, Dialect::Generic)
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_parse_valid_select() {
500 let sql = "SELECT * FROM users";
501 let result = parse_sql(sql);
502 assert!(result.is_ok());
503 let statements = result.unwrap();
504 assert_eq!(statements.len(), 1);
505 }
506
507 #[test]
508 fn test_parse_invalid_sql() {
509 let sql = "SELECT * FROM";
510 let result = parse_sql(sql);
511 assert!(result.is_err());
512 }
513
514 #[test]
515 fn test_parse_multiple_statements() {
516 let sql = "SELECT * FROM users; SELECT * FROM orders;";
517 let result = parse_sql(sql);
518 assert!(result.is_ok());
519 let statements = result.unwrap();
520 assert_eq!(statements.len(), 2);
521 }
522
523 #[test]
524 fn test_parse_with_postgres_dialect() {
525 let sql = "SELECT * FROM users WHERE name ILIKE '%test%'";
526 let result = parse_sql_with_dialect(sql, Dialect::Postgres);
527 assert!(result.is_ok());
528 }
529
530 #[test]
531 fn test_parse_with_snowflake_dialect() {
532 let sql = "SELECT * FROM db.schema.table";
533 let result = parse_sql_with_dialect(sql, Dialect::Snowflake);
534 assert!(result.is_ok());
535 }
536
537 #[test]
538 fn test_parse_with_bigquery_dialect() {
539 let sql = "SELECT * FROM `project.dataset.table`";
540 let result = parse_sql_with_dialect(sql, Dialect::Bigquery);
541 assert!(result.is_ok());
542 }
543
544 #[test]
545 fn test_parse_cte() {
546 let sql = r#"
547 WITH active_users AS (
548 SELECT * FROM users WHERE active = true
549 )
550 SELECT * FROM active_users
551 "#;
552 let result = parse_sql(sql);
553 assert!(result.is_ok());
554 }
555
556 #[test]
557 fn test_parse_insert_select() {
558 let sql = "INSERT INTO archive SELECT * FROM users WHERE deleted = true";
559 let result = parse_sql(sql);
560 assert!(result.is_ok());
561 }
562
563 #[test]
564 fn test_parse_create_table_as() {
565 let sql = "CREATE TABLE users_backup AS SELECT * FROM users";
566 let result = parse_sql(sql);
567 assert!(result.is_ok());
568 }
569
570 #[test]
571 fn test_parse_union() {
572 let sql = "SELECT id FROM users UNION ALL SELECT id FROM admins";
573 let result = parse_sql(sql);
574 assert!(result.is_ok());
575 }
576
577 #[test]
578 fn test_parse_generic_falls_back_for_postgres_json_operator() {
579 let sql = "SELECT usage_metadata ? 'pipeline_id' FROM ledger.usage_line_item";
580 let result = parse_sql(sql);
581 assert!(result.is_ok());
582 }
583
584 #[test]
585 fn test_parse_generic_falls_back_for_postgres_cast_operator() {
586 let sql = "SELECT workspace_id::text FROM ledger.usage_line_item";
587 let result = parse_sql(sql);
588 assert!(result.is_ok());
589 }
590
591 #[test]
592 fn test_parse_output_marks_parser_fallback_usage() {
593 let generic = sqlparser::dialect::GenericDialect {};
594 let sql = [
595 "SELECT usage_metadata ? 'pipeline_id' FROM ledger.usage_line_item",
596 "SELECT workspace_id::text FROM ledger.usage_line_item",
597 "SELECT payload->>'id' FROM ledger.usage_line_item",
598 ]
599 .into_iter()
600 .find(|candidate| Parser::parse_sql(&generic, candidate).is_err())
601 .expect("expected at least one postgres-only candidate to fail in generic parser");
602
603 let output = parse_sql_with_dialect_output(sql, Dialect::Generic).expect("parse");
604 assert!(output.parser_fallback_used);
605 assert_eq!(output.statements.len(), 1);
606 }
607
608 #[test]
609 fn test_parse_output_bigquery_escaped_identifier_fallback_usage() {
610 let sql = "SELECT `\\`a`.col1 FROM tab1 as `\\`A`";
611 let output = parse_sql_with_dialect_output(sql, Dialect::Bigquery).expect("parse");
612 assert!(output.parser_fallback_used);
613 assert_eq!(output.statements.len(), 1);
614 }
615
616 #[test]
617 fn test_parse_output_clickhouse_escaped_identifier_fallback_usage() {
618 let sql = "SELECT \"\\\"`a`\"\"\".col1,\nFROM tab1 as `\"\\`a``\"`";
619 let output = parse_sql_with_dialect_output(sql, Dialect::Clickhouse).expect("parse");
620 assert!(output.parser_fallback_used);
621 assert_eq!(output.statements.len(), 1);
622 }
623
624 #[test]
625 fn test_parse_output_trailing_comma_before_from_fallback_usage() {
626 let sql = "SELECT widget.id,\nwidget.name,\nFROM widget";
627 let output = parse_sql_with_dialect_output(sql, Dialect::Ansi).expect("parse");
628 assert!(output.parser_fallback_used);
629 assert_eq!(output.statements.len(), 1);
630 }
631
632 #[test]
633 fn test_remove_trailing_comma_before_from_preserves_utf8() {
634 let sql = "SELECT café,\nFROM résumé";
635 let rewritten = remove_trailing_comma_before_from(sql);
636 assert_eq!(rewritten, "SELECT café\nFROM résumé");
637 }
638
639 #[test]
640 fn test_sanitize_escaped_identifiers_preserves_utf8() {
641 let sql = "SELECT naïve, `\\`id` FROM café";
642 let rewritten =
643 sanitize_escaped_identifiers_for_dialect(sql, Dialect::Bigquery).expect("rewrite");
644 assert_eq!(rewritten, "SELECT naïve, `_id` FROM café");
645 }
646
647 #[test]
648 fn test_parse_output_ansi_national_literal_spacing_fallback_usage() {
649 let sql = "SELECT a + N 'b' + N 'c' FROM tbl;";
650 let output = parse_sql_with_dialect_output(sql, Dialect::Ansi).expect("parse");
651 assert!(output.parser_fallback_used);
652 assert_eq!(output.statements.len(), 1);
653 }
654
655 #[test]
656 fn test_parse_output_bigquery_raw_double_quoted_literal_fallback_usage() {
657 let sql = r#"SELECT r'Tricky "quote', r"Not-so-tricky \"quote""#;
658 let output = parse_sql_with_dialect_output(sql, Dialect::Bigquery).expect("parse");
659 assert!(output.parser_fallback_used);
660 assert_eq!(output.statements.len(), 1);
661 }
662
663 #[test]
664 fn test_parse_output_without_fallback() {
665 let sql = "SELECT 1";
666 let output = parse_sql_with_dialect_output(sql, Dialect::Generic).expect("parse");
667 assert!(!output.parser_fallback_used);
668 assert_eq!(output.statements.len(), 1);
669 }
670}