Skip to main content

rigsql_parser/
parser.rs

1use rigsql_core::Segment;
2use rigsql_lexer::{Lexer, LexerConfig, LexerError};
3use thiserror::Error;
4
5use crate::context::{ParseContext, ParseDiagnostic};
6use crate::grammar::{AnsiGrammar, Grammar};
7#[cfg(test)]
8use crate::grammar::{PostgresGrammar, TsqlGrammar};
9
10#[derive(Debug, Error)]
11pub enum ParseError {
12    #[error("Lexer error: {0}")]
13    Lexer(#[from] LexerError),
14}
15
16/// Result of parsing: a CST (always produced) plus any diagnostics
17/// collected during error-recovery passes.
18pub struct ParseResult {
19    /// The concrete syntax tree.  Always present — unparsable regions
20    /// are wrapped in `SegmentType::Unparsable` nodes.
21    pub tree: Segment,
22    /// Diagnostics emitted by the parser when it encountered
23    /// unrecognised tokens and had to skip ahead.
24    pub diagnostics: Vec<ParseDiagnostic>,
25}
26
27/// High-level SQL parser: source text → CST.
28pub struct Parser {
29    lexer_config: LexerConfig,
30    grammar: Box<dyn Grammar>,
31}
32
33impl Parser {
34    pub fn new(lexer_config: LexerConfig, grammar: Box<dyn Grammar>) -> Self {
35        Self {
36            lexer_config,
37            grammar,
38        }
39    }
40
41    /// Parse SQL source into a CST rooted at a File segment.
42    pub fn parse(&self, source: &str) -> Result<Segment, ParseError> {
43        self.parse_with_diagnostics(source).map(|r| r.tree)
44    }
45
46    /// Parse SQL source, returning both the CST and any diagnostics
47    /// produced during error recovery.
48    pub fn parse_with_diagnostics(&self, source: &str) -> Result<ParseResult, ParseError> {
49        let mut lexer = Lexer::new(source, self.lexer_config.clone());
50        let tokens = lexer.tokenize()?;
51        let mut ctx = ParseContext::new(&tokens, source);
52        let tree = self.grammar.parse_file(&mut ctx);
53        let diagnostics = ctx.take_diagnostics();
54        Ok(ParseResult { tree, diagnostics })
55    }
56}
57
58impl Default for Parser {
59    fn default() -> Self {
60        Self::new(LexerConfig::ansi(), Box::new(AnsiGrammar))
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use rigsql_core::SegmentType;
68
69    fn parse(sql: &str) -> Segment {
70        Parser::default().parse(sql).unwrap()
71    }
72
73    fn parse_tsql(sql: &str) -> Segment {
74        Parser::new(LexerConfig::tsql(), Box::new(TsqlGrammar))
75            .parse(sql)
76            .unwrap()
77    }
78
79    fn assert_type(seg: &Segment, expected: SegmentType) {
80        assert_eq!(
81            seg.segment_type(),
82            expected,
83            "Expected {:?} but got {:?} for raw: {:?}",
84            expected,
85            seg.segment_type(),
86            seg.raw()
87        );
88    }
89
90    fn find_type(seg: &Segment, ty: SegmentType) -> Option<&Segment> {
91        let mut result = None;
92        seg.walk(&mut |s| {
93            if result.is_none() && s.segment_type() == ty {
94                result = Some(s as *const Segment);
95            }
96        });
97        result.map(|p| unsafe { &*p })
98    }
99
100    fn assert_no_unparsable(seg: &Segment) {
101        let mut unparsable = Vec::new();
102        seg.walk(&mut |s| {
103            if s.segment_type() == SegmentType::Unparsable {
104                unparsable.push(s.raw());
105            }
106        });
107        assert!(
108            unparsable.is_empty(),
109            "Found Unparsable segments: {:?}",
110            unparsable
111        );
112    }
113
114    #[test]
115    fn test_simple_select() {
116        let cst = parse("SELECT 1");
117        assert_type(&cst, SegmentType::File);
118        let stmt = &cst.children()[0];
119        assert_type(stmt, SegmentType::Statement);
120        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
121    }
122
123    #[test]
124    fn test_select_from_where() {
125        let cst = parse("SELECT name FROM users WHERE id = 1");
126        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
127        assert!(find_type(&cst, SegmentType::FromClause).is_some());
128        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
129    }
130
131    #[test]
132    fn test_join() {
133        let cst = parse("SELECT a.id FROM a INNER JOIN b ON a.id = b.id");
134        assert!(find_type(&cst, SegmentType::JoinClause).is_some());
135        assert!(find_type(&cst, SegmentType::OnClause).is_some());
136    }
137
138    #[test]
139    fn test_group_by_having_order_by() {
140        let cst = parse(
141            "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC",
142        );
143        assert!(find_type(&cst, SegmentType::GroupByClause).is_some());
144        assert!(find_type(&cst, SegmentType::HavingClause).is_some());
145        assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
146    }
147
148    #[test]
149    fn test_insert_values() {
150        let cst = parse("INSERT INTO users (name, email) VALUES ('Alice', 'a@b.com')");
151        assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
152        assert!(find_type(&cst, SegmentType::ValuesClause).is_some());
153    }
154
155    #[test]
156    fn test_update_set_where() {
157        let cst = parse("UPDATE users SET name = 'Bob' WHERE id = 1");
158        assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
159        assert!(find_type(&cst, SegmentType::SetClause).is_some());
160        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
161    }
162
163    #[test]
164    fn test_delete() {
165        let cst = parse("DELETE FROM users WHERE id = 1");
166        assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
167    }
168
169    #[test]
170    fn test_create_table() {
171        let cst = parse("CREATE TABLE users (id INT, name VARCHAR(100))");
172        assert!(find_type(&cst, SegmentType::CreateTableStatement).is_some());
173    }
174
175    #[test]
176    fn test_with_cte() {
177        let cst =
178            parse("WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active");
179        assert!(find_type(&cst, SegmentType::WithClause).is_some());
180        assert!(find_type(&cst, SegmentType::CteDefinition).is_some());
181    }
182
183    #[test]
184    fn test_case_expression() {
185        let cst = parse("SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t");
186        assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
187        assert!(find_type(&cst, SegmentType::WhenClause).is_some());
188        assert!(find_type(&cst, SegmentType::ElseClause).is_some());
189    }
190
191    #[test]
192    fn test_subquery() {
193        let cst = parse("SELECT * FROM (SELECT 1) AS sub");
194        assert!(find_type(&cst, SegmentType::Subquery).is_some());
195    }
196
197    #[test]
198    fn test_function_call() {
199        let cst = parse("SELECT COUNT(*) FROM users");
200        assert!(find_type(&cst, SegmentType::FunctionCall).is_some());
201    }
202
203    #[test]
204    fn test_roundtrip() {
205        let sql = "SELECT a, b FROM t WHERE x = 1 ORDER BY a;";
206        let cst = parse(sql);
207        assert_eq!(
208            cst.raw(),
209            sql,
210            "CST roundtrip must preserve source text exactly"
211        );
212    }
213
214    #[test]
215    fn test_multiple_statements() {
216        let sql = "SELECT 1; SELECT 2;";
217        let cst = parse(sql);
218        let stmts: Vec<_> = cst
219            .children()
220            .iter()
221            .filter(|s| s.segment_type() == SegmentType::Statement)
222            .collect();
223        assert_eq!(stmts.len(), 2);
224    }
225
226    #[test]
227    fn test_roundtrip_complex() {
228        let sql = "WITH cte AS (\n  SELECT id, name\n  FROM users\n  WHERE active = TRUE\n)\nSELECT cte.id, cte.name\nFROM cte\nINNER JOIN orders ON cte.id = orders.user_id\nWHERE orders.total > 100\nORDER BY cte.name ASC\nLIMIT 10;";
229        let cst = parse(sql);
230        assert_eq!(cst.raw(), sql);
231    }
232
233    // ── TSQL Tests ──────────────────────────────────────────────
234
235    #[test]
236    fn test_tsql_declare_variable() {
237        let cst = parse_tsql("DECLARE @id INT;");
238        assert_no_unparsable(&cst);
239        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
240        assert_eq!(cst.raw(), "DECLARE @id INT;");
241    }
242
243    #[test]
244    fn test_tsql_declare_with_default() {
245        let cst = parse_tsql("DECLARE @name VARCHAR(100) = 'test';");
246        assert_no_unparsable(&cst);
247        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
248    }
249
250    #[test]
251    fn test_tsql_declare_multiple() {
252        let cst = parse_tsql("DECLARE @a INT, @b VARCHAR(50);");
253        assert_no_unparsable(&cst);
254        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
255        assert_eq!(cst.raw(), "DECLARE @a INT, @b VARCHAR(50);");
256    }
257
258    #[test]
259    fn test_tsql_declare_table_variable() {
260        let cst = parse_tsql("DECLARE @t TABLE (id INT, name VARCHAR(100));");
261        assert_no_unparsable(&cst);
262        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
263    }
264
265    #[test]
266    fn test_tsql_declare_cursor() {
267        let cst = parse_tsql("DECLARE cur CURSOR FOR SELECT id FROM users;");
268        assert_no_unparsable(&cst);
269        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
270        assert!(find_type(&cst, SegmentType::SelectStatement).is_some());
271    }
272
273    #[test]
274    fn test_tsql_set_variable() {
275        let cst = parse_tsql("SET @id = 42;");
276        assert_no_unparsable(&cst);
277        assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
278        assert_eq!(cst.raw(), "SET @id = 42;");
279    }
280
281    #[test]
282    fn test_tsql_set_option() {
283        let cst = parse_tsql("SET NOCOUNT ON;");
284        assert_no_unparsable(&cst);
285        assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
286    }
287
288    #[test]
289    fn test_tsql_if_else() {
290        let sql = "IF @x > 0\n    SELECT 1;\nELSE\n    SELECT 2;";
291        let cst = parse_tsql(sql);
292        assert_no_unparsable(&cst);
293        assert!(find_type(&cst, SegmentType::IfStatement).is_some());
294        assert_eq!(cst.raw(), sql);
295    }
296
297    #[test]
298    fn test_tsql_if_begin_end() {
299        let sql = "IF @x > 0\nBEGIN\n    SELECT 1;\n    SELECT 2;\nEND";
300        let cst = parse_tsql(sql);
301        assert_no_unparsable(&cst);
302        assert!(find_type(&cst, SegmentType::IfStatement).is_some());
303        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
304    }
305
306    #[test]
307    fn test_tsql_begin_end() {
308        let sql = "BEGIN\n    SELECT 1;\n    SELECT 2;\nEND";
309        let cst = parse_tsql(sql);
310        assert_no_unparsable(&cst);
311        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
312    }
313
314    #[test]
315    fn test_tsql_while() {
316        let sql = "WHILE @i < 10\nBEGIN\n    SET @i = @i + 1;\nEND";
317        let cst = parse_tsql(sql);
318        assert_no_unparsable(&cst);
319        assert!(find_type(&cst, SegmentType::WhileStatement).is_some());
320        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
321    }
322
323    #[test]
324    fn test_tsql_try_catch() {
325        let sql = "BEGIN TRY\n    SELECT 1;\nEND TRY\nBEGIN CATCH\n    SELECT 2;\nEND CATCH";
326        let cst = parse_tsql(sql);
327        assert_no_unparsable(&cst);
328        assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
329        assert_eq!(cst.raw(), sql);
330    }
331
332    #[test]
333    fn test_tsql_exec_simple() {
334        let cst = parse_tsql("EXEC sp_help;");
335        assert_no_unparsable(&cst);
336        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
337    }
338
339    #[test]
340    fn test_tsql_exec_with_params() {
341        let cst = parse_tsql("EXEC dbo.usp_GetUser @id = 1, @name = 'test';");
342        assert_no_unparsable(&cst);
343        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
344    }
345
346    #[test]
347    fn test_tsql_execute_keyword() {
348        let cst = parse_tsql("EXECUTE sp_help;");
349        assert_no_unparsable(&cst);
350        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
351    }
352
353    #[test]
354    fn test_tsql_return() {
355        let cst = parse_tsql("RETURN 0;");
356        assert_no_unparsable(&cst);
357        assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
358    }
359
360    #[test]
361    fn test_tsql_return_no_value() {
362        let cst = parse_tsql("RETURN;");
363        assert_no_unparsable(&cst);
364        assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
365    }
366
367    #[test]
368    fn test_tsql_print() {
369        let cst = parse_tsql("PRINT 'hello';");
370        assert_no_unparsable(&cst);
371        assert!(find_type(&cst, SegmentType::PrintStatement).is_some());
372    }
373
374    #[test]
375    fn test_tsql_throw() {
376        let cst = parse_tsql("THROW 50000, 'Error occurred', 1;");
377        assert_no_unparsable(&cst);
378        assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
379    }
380
381    #[test]
382    fn test_tsql_throw_rethrow() {
383        let cst = parse_tsql("THROW;");
384        assert_no_unparsable(&cst);
385        assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
386    }
387
388    #[test]
389    fn test_tsql_raiserror() {
390        let cst = parse_tsql("RAISERROR('Error', 16, 1);");
391        assert_no_unparsable(&cst);
392        assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
393    }
394
395    #[test]
396    fn test_tsql_raiserror_with_nowait() {
397        let cst = parse_tsql("RAISERROR('Error', 16, 1) WITH NOWAIT;");
398        assert_no_unparsable(&cst);
399        assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
400    }
401
402    #[test]
403    fn test_tsql_go() {
404        let cst = parse_tsql("SELECT 1;\nGO");
405        assert_no_unparsable(&cst);
406        assert!(find_type(&cst, SegmentType::GoStatement).is_some());
407    }
408
409    #[test]
410    fn test_tsql_go_with_count() {
411        let cst = parse_tsql("GO 5");
412        assert_no_unparsable(&cst);
413        assert!(find_type(&cst, SegmentType::GoStatement).is_some());
414    }
415
416    #[test]
417    fn test_tsql_simple_statements() {
418        let cst = parse_tsql("USE master;");
419        assert_no_unparsable(&cst);
420        assert_eq!(cst.raw(), "USE master;");
421    }
422
423    #[test]
424    fn test_tsql_roundtrip_complex() {
425        let sql = "SET NOCOUNT ON;\nDECLARE @id INT = 1;\nIF @id > 0\nBEGIN\n    SELECT @id;\n    PRINT 'done';\nEND";
426        let cst = parse_tsql(sql);
427        assert_no_unparsable(&cst);
428        assert_eq!(cst.raw(), sql);
429    }
430
431    #[test]
432    fn test_tsql_nested_begin_end() {
433        let sql = "BEGIN\n    BEGIN\n        SELECT 1;\n    END\nEND";
434        let cst = parse_tsql(sql);
435        assert_no_unparsable(&cst);
436        assert_eq!(cst.raw(), sql);
437    }
438
439    #[test]
440    fn test_tsql_if_else_begin_end() {
441        let sql = "IF @x = 1\nBEGIN\n    SELECT 1;\nEND\nELSE\nBEGIN\n    SELECT 2;\nEND";
442        let cst = parse_tsql(sql);
443        assert_no_unparsable(&cst);
444        assert!(find_type(&cst, SegmentType::IfStatement).is_some());
445    }
446
447    #[test]
448    fn test_tsql_try_catch_with_throw() {
449        let sql = "BEGIN TRY\n    SELECT 1;\nEND TRY\nBEGIN CATCH\n    THROW;\nEND CATCH";
450        let cst = parse_tsql(sql);
451        assert_no_unparsable(&cst);
452        assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
453        assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
454    }
455
456    #[test]
457    fn test_tsql_case_inside_begin_end() {
458        let sql = "BEGIN\n    SELECT CASE WHEN @x > 0 THEN 'pos' ELSE 'neg' END;\nEND";
459        let cst = parse_tsql(sql);
460        assert_no_unparsable(&cst);
461        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
462        assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
463    }
464
465    #[test]
466    fn test_tsql_exec_retval() {
467        let cst = parse_tsql("EXEC @result = dbo.usp_Calculate;");
468        assert_no_unparsable(&cst);
469        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
470    }
471
472    #[test]
473    fn test_tsql_multiple_set_options() {
474        let sql = "SET ANSI_NULLS ON;\nSET QUOTED_IDENTIFIER ON;";
475        let cst = parse_tsql(sql);
476        assert_no_unparsable(&cst);
477        assert_eq!(cst.raw(), sql);
478    }
479
480    // ── TSQL Table Hint Tests ────────────────────────────────────
481
482    #[test]
483    fn test_tsql_with_nolock() {
484        let sql = "SELECT * FROM orders WITH(NOLOCK) WHERE id = 1";
485        let cst = parse_tsql(sql);
486        assert_no_unparsable(&cst);
487        assert!(find_type(&cst, SegmentType::TableHint).is_some());
488        assert!(find_type(&cst, SegmentType::FromClause).is_some());
489        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
490        assert_eq!(cst.raw(), sql);
491    }
492
493    #[test]
494    fn test_tsql_with_nolock_alias() {
495        let sql = "SELECT o.id FROM orders o WITH(NOLOCK)";
496        let cst = parse_tsql(sql);
497        assert_no_unparsable(&cst);
498        assert!(find_type(&cst, SegmentType::TableHint).is_some());
499        assert!(find_type(&cst, SegmentType::AliasExpression).is_some());
500        assert_eq!(cst.raw(), sql);
501    }
502
503    #[test]
504    fn test_tsql_with_nolock_join() {
505        let sql = "SELECT a.id FROM orders a WITH(NOLOCK) INNER JOIN items b WITH(READUNCOMMITTED) ON a.id = b.order_id";
506        let cst = parse_tsql(sql);
507        assert_no_unparsable(&cst);
508        // Two table hints
509        let mut hint_count = 0;
510        cst.walk(&mut |s| {
511            if s.segment_type() == SegmentType::TableHint {
512                hint_count += 1;
513            }
514        });
515        assert_eq!(hint_count, 2);
516        assert!(find_type(&cst, SegmentType::JoinClause).is_some());
517        assert_eq!(cst.raw(), sql);
518    }
519
520    #[test]
521    fn test_tsql_with_multiple_hints() {
522        let sql = "SELECT * FROM orders WITH(NOLOCK, NOWAIT)";
523        let cst = parse_tsql(sql);
524        assert_no_unparsable(&cst);
525        assert!(find_type(&cst, SegmentType::TableHint).is_some());
526        assert_eq!(cst.raw(), sql);
527    }
528
529    #[test]
530    fn test_tsql_with_nolock_roundtrip() {
531        let sql = "SELECT o.id, o.total\nFROM orders o WITH(NOLOCK)\nINNER JOIN customers c WITH(NOLOCK) ON o.customer_id = c.id\nWHERE c.active = 1\nORDER BY o.id";
532        let cst = parse_tsql(sql);
533        assert_no_unparsable(&cst);
534        assert_eq!(cst.raw(), sql);
535    }
536
537    // ── Error Recovery Tests ──────────────────────────────────────
538
539    fn count_unparsable(seg: &Segment) -> usize {
540        let mut count = 0;
541        seg.walk(&mut |s| {
542            if s.segment_type() == SegmentType::Unparsable {
543                count += 1;
544            }
545        });
546        count
547    }
548
549    #[test]
550    fn test_error_recovery_garbage_then_valid() {
551        // Garbage tokens followed by a valid statement
552        let sql = "XYZZY FOOBAR; SELECT 1;";
553        let cst = parse(sql);
554        assert_eq!(cst.raw(), sql, "roundtrip must preserve source");
555        // The garbage should be in one Unparsable node
556        assert_eq!(count_unparsable(&cst), 1);
557        // The valid SELECT should still parse
558        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
559    }
560
561    #[test]
562    fn test_error_recovery_garbage_between_statements() {
563        // Valid, garbage, valid
564        let sql = "SELECT 1; NOTAKEYWORD 123 'abc'; SELECT 2;";
565        let cst = parse(sql);
566        assert_eq!(cst.raw(), sql);
567        assert_eq!(count_unparsable(&cst), 1);
568        let stmts: Vec<_> = cst
569            .children()
570            .iter()
571            .filter(|s| s.segment_type() == SegmentType::Statement)
572            .collect();
573        assert_eq!(stmts.len(), 2);
574    }
575
576    #[test]
577    fn test_error_recovery_garbage_at_end() {
578        let sql = "SELECT 1; XYZZY";
579        let cst = parse(sql);
580        assert_eq!(cst.raw(), sql);
581        assert_eq!(count_unparsable(&cst), 1);
582        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
583    }
584
585    #[test]
586    fn test_error_recovery_skips_to_statement_keyword() {
587        // Garbage followed directly by SELECT (no semicolon separator)
588        let sql = "XYZZY SELECT 1;";
589        let cst = parse(sql);
590        assert_eq!(cst.raw(), sql);
591        assert_eq!(count_unparsable(&cst), 1);
592        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
593    }
594
595    #[test]
596    fn test_error_recovery_diagnostics() {
597        let parser = Parser::default();
598        let result = parser.parse_with_diagnostics("XYZZY; SELECT 1;").unwrap();
599        assert!(!result.diagnostics.is_empty());
600        assert!(result.diagnostics[0].message.contains("Unparsable"));
601        // Offset should point to the start of the unparsable region (byte 0 = 'X')
602        assert_eq!(result.diagnostics[0].offset, 0);
603        // CST still produced
604        assert!(find_type(&result.tree, SegmentType::SelectClause).is_some());
605    }
606
607    #[test]
608    fn test_error_recovery_diagnostics_offset_mid_file() {
609        let parser = Parser::default();
610        // "SELECT 1; " = 10 bytes, then garbage starts
611        let result = parser
612            .parse_with_diagnostics("SELECT 1; BADTOKEN;")
613            .unwrap();
614        assert_eq!(result.diagnostics.len(), 1);
615        // Offset should point to 'B' in BADTOKEN, not to ';' or beyond
616        assert_eq!(result.diagnostics[0].offset, 10);
617    }
618
619    #[test]
620    fn test_error_recovery_all_garbage() {
621        let sql = "NOTAKEYWORD 123 'hello'";
622        let cst = parse(sql);
623        assert_eq!(cst.raw(), sql);
624        // Everything should be unparsable but still present
625        assert!(count_unparsable(&cst) >= 1);
626    }
627
628    #[test]
629    fn test_error_recovery_preserves_valid_statements() {
630        // Multiple valid statements with garbage in the middle
631        let sql = "INSERT INTO t VALUES (1); BADTOKEN; DELETE FROM t WHERE id = 1;";
632        let cst = parse(sql);
633        assert_eq!(cst.raw(), sql);
634        assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
635        assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
636        assert_eq!(count_unparsable(&cst), 1);
637    }
638
639    // ── PostgreSQL tests ────────────────────────────────────────────
640
641    fn parse_pg(sql: &str) -> Segment {
642        Parser::new(LexerConfig::postgres(), Box::new(PostgresGrammar))
643            .parse(sql)
644            .unwrap()
645    }
646
647    #[test]
648    fn test_pg_double_colon_cast() {
649        let cst = parse_pg("SELECT col::int FROM t");
650        assert_no_unparsable(&cst);
651        assert!(find_type(&cst, SegmentType::TypeCastExpression).is_some());
652        assert_eq!(cst.raw(), "SELECT col::int FROM t");
653    }
654
655    #[test]
656    fn test_pg_chained_cast() {
657        let cst = parse_pg("SELECT '2024-01-01'::date::text FROM t");
658        assert_no_unparsable(&cst);
659        // Two nested TypeCastExpression
660        let mut count = 0;
661        cst.walk(&mut |s| {
662            if s.segment_type() == SegmentType::TypeCastExpression {
663                count += 1;
664            }
665        });
666        assert_eq!(
667            count, 2,
668            "Expected 2 TypeCastExpression nodes for chained cast"
669        );
670    }
671
672    #[test]
673    fn test_pg_cast_with_precision() {
674        let cst = parse_pg("SELECT col::numeric(10, 2) FROM t");
675        assert_no_unparsable(&cst);
676        assert!(find_type(&cst, SegmentType::TypeCastExpression).is_some());
677        assert!(find_type(&cst, SegmentType::DataType).is_some());
678    }
679
680    #[test]
681    fn test_pg_array_subscript() {
682        let cst = parse_pg("SELECT arr[1] FROM t");
683        assert_no_unparsable(&cst);
684        assert!(find_type(&cst, SegmentType::ArrayAccessExpression).is_some());
685    }
686
687    #[test]
688    fn test_pg_array_cast_chain() {
689        let cst = parse_pg("SELECT arr[1]::text FROM t");
690        assert_no_unparsable(&cst);
691        assert!(find_type(&cst, SegmentType::ArrayAccessExpression).is_some());
692        assert!(find_type(&cst, SegmentType::TypeCastExpression).is_some());
693    }
694
695    #[test]
696    fn test_pg_insert_returning() {
697        let cst = parse_pg("INSERT INTO users (name) VALUES ('Alice') RETURNING id, name");
698        assert_no_unparsable(&cst);
699        assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
700        assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
701    }
702
703    #[test]
704    fn test_pg_update_returning() {
705        let cst = parse_pg("UPDATE users SET name = 'Bob' WHERE id = 1 RETURNING *");
706        assert_no_unparsable(&cst);
707        assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
708        assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
709    }
710
711    #[test]
712    fn test_pg_delete_returning() {
713        let cst = parse_pg("DELETE FROM users WHERE id = 1 RETURNING id");
714        assert_no_unparsable(&cst);
715        assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
716        assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
717    }
718
719    #[test]
720    fn test_pg_on_conflict_do_nothing() {
721        let cst = parse_pg(
722            "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT (id) DO NOTHING",
723        );
724        assert_no_unparsable(&cst);
725        assert!(find_type(&cst, SegmentType::OnConflictClause).is_some());
726    }
727
728    #[test]
729    fn test_pg_on_conflict_do_update() {
730        let cst = parse_pg(
731            "INSERT INTO users (id, name) VALUES (1, 'Alice') \
732             ON CONFLICT (id) DO UPDATE SET name = 'Alice'",
733        );
734        assert_no_unparsable(&cst);
735        assert!(find_type(&cst, SegmentType::OnConflictClause).is_some());
736        assert!(find_type(&cst, SegmentType::SetClause).is_some());
737    }
738
739    #[test]
740    fn test_pg_upsert_returning() {
741        let cst = parse_pg(
742            "INSERT INTO users (id, name) VALUES (1, 'Alice') \
743             ON CONFLICT (id) DO UPDATE SET name = 'Alice' RETURNING *",
744        );
745        assert_no_unparsable(&cst);
746        assert!(find_type(&cst, SegmentType::OnConflictClause).is_some());
747        assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
748    }
749
750    #[test]
751    fn test_pg_distinct_on() {
752        let cst = parse_pg(
753            "SELECT DISTINCT ON (dept) name, salary FROM employees ORDER BY dept, salary DESC",
754        );
755        assert_no_unparsable(&cst);
756        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
757        assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
758    }
759
760    #[test]
761    fn test_pg_dollar_quoted_string() {
762        let cst = parse_pg("SELECT $$hello world$$");
763        assert_no_unparsable(&cst);
764        assert_eq!(cst.raw(), "SELECT $$hello world$$");
765    }
766
767    #[test]
768    fn test_pg_ilike() {
769        let cst = parse_pg("SELECT * FROM users WHERE name ILIKE '%alice%'");
770        assert_no_unparsable(&cst);
771        assert!(find_type(&cst, SegmentType::LikeExpression).is_some());
772    }
773
774    #[test]
775    fn test_pg_roundtrip_complex() {
776        let sql = "INSERT INTO orders (user_id, total) \
777                   VALUES (1, 99.99) \
778                   ON CONFLICT (user_id) DO UPDATE SET total = orders.total + 99.99 \
779                   RETURNING id, total::numeric(10, 2)";
780        let cst = parse_pg(sql);
781        assert_eq!(
782            cst.raw(),
783            sql,
784            "CST roundtrip must preserve source text exactly"
785        );
786        assert_no_unparsable(&cst);
787    }
788}