Skip to main content

rigsql_parser/
parser.rs

1use rigsql_core::Segment;
2use rigsql_lexer::{Lexer, LexerConfig, LexerError};
3use thiserror::Error;
4
5use crate::context::{ParseContext, ParseDiagnostic};
6#[cfg(test)]
7use crate::grammar::TsqlGrammar;
8use crate::grammar::{AnsiGrammar, Grammar};
9
10#[derive(Debug, Error)]
11pub enum ParseError {
12    #[error("Lexer error: {0}")]
13    Lexer(#[from] LexerError),
14}
15
16/// Result of parsing: a CST (always produced) plus any diagnostics
17/// collected during error-recovery passes.
18pub struct ParseResult {
19    /// The concrete syntax tree.  Always present — unparsable regions
20    /// are wrapped in `SegmentType::Unparsable` nodes.
21    pub tree: Segment,
22    /// Diagnostics emitted by the parser when it encountered
23    /// unrecognised tokens and had to skip ahead.
24    pub diagnostics: Vec<ParseDiagnostic>,
25}
26
27/// High-level SQL parser: source text → CST.
28pub struct Parser {
29    lexer_config: LexerConfig,
30    grammar: Box<dyn Grammar>,
31}
32
33impl Parser {
34    pub fn new(lexer_config: LexerConfig, grammar: Box<dyn Grammar>) -> Self {
35        Self {
36            lexer_config,
37            grammar,
38        }
39    }
40
41    /// Parse SQL source into a CST rooted at a File segment.
42    pub fn parse(&self, source: &str) -> Result<Segment, ParseError> {
43        self.parse_with_diagnostics(source).map(|r| r.tree)
44    }
45
46    /// Parse SQL source, returning both the CST and any diagnostics
47    /// produced during error recovery.
48    pub fn parse_with_diagnostics(&self, source: &str) -> Result<ParseResult, ParseError> {
49        let mut lexer = Lexer::new(source, self.lexer_config.clone());
50        let tokens = lexer.tokenize()?;
51        let mut ctx = ParseContext::new(&tokens, source);
52        let tree = self.grammar.parse_file(&mut ctx);
53        let diagnostics = ctx.take_diagnostics();
54        Ok(ParseResult { tree, diagnostics })
55    }
56}
57
58impl Default for Parser {
59    fn default() -> Self {
60        Self::new(LexerConfig::ansi(), Box::new(AnsiGrammar))
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use rigsql_core::SegmentType;
68
69    fn parse(sql: &str) -> Segment {
70        Parser::default().parse(sql).unwrap()
71    }
72
73    fn parse_tsql(sql: &str) -> Segment {
74        Parser::new(LexerConfig::tsql(), Box::new(TsqlGrammar))
75            .parse(sql)
76            .unwrap()
77    }
78
79    fn assert_type(seg: &Segment, expected: SegmentType) {
80        assert_eq!(
81            seg.segment_type(),
82            expected,
83            "Expected {:?} but got {:?} for raw: {:?}",
84            expected,
85            seg.segment_type(),
86            seg.raw()
87        );
88    }
89
90    fn find_type(seg: &Segment, ty: SegmentType) -> Option<&Segment> {
91        let mut result = None;
92        seg.walk(&mut |s| {
93            if result.is_none() && s.segment_type() == ty {
94                result = Some(s as *const Segment);
95            }
96        });
97        result.map(|p| unsafe { &*p })
98    }
99
100    fn assert_no_unparsable(seg: &Segment) {
101        let mut unparsable = Vec::new();
102        seg.walk(&mut |s| {
103            if s.segment_type() == SegmentType::Unparsable {
104                unparsable.push(s.raw());
105            }
106        });
107        assert!(
108            unparsable.is_empty(),
109            "Found Unparsable segments: {:?}",
110            unparsable
111        );
112    }
113
114    #[test]
115    fn test_simple_select() {
116        let cst = parse("SELECT 1");
117        assert_type(&cst, SegmentType::File);
118        let stmt = &cst.children()[0];
119        assert_type(stmt, SegmentType::Statement);
120        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
121    }
122
123    #[test]
124    fn test_select_from_where() {
125        let cst = parse("SELECT name FROM users WHERE id = 1");
126        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
127        assert!(find_type(&cst, SegmentType::FromClause).is_some());
128        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
129    }
130
131    #[test]
132    fn test_join() {
133        let cst = parse("SELECT a.id FROM a INNER JOIN b ON a.id = b.id");
134        assert!(find_type(&cst, SegmentType::JoinClause).is_some());
135        assert!(find_type(&cst, SegmentType::OnClause).is_some());
136    }
137
138    #[test]
139    fn test_group_by_having_order_by() {
140        let cst = parse(
141            "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC",
142        );
143        assert!(find_type(&cst, SegmentType::GroupByClause).is_some());
144        assert!(find_type(&cst, SegmentType::HavingClause).is_some());
145        assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
146    }
147
148    #[test]
149    fn test_insert_values() {
150        let cst = parse("INSERT INTO users (name, email) VALUES ('Alice', 'a@b.com')");
151        assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
152        assert!(find_type(&cst, SegmentType::ValuesClause).is_some());
153    }
154
155    #[test]
156    fn test_update_set_where() {
157        let cst = parse("UPDATE users SET name = 'Bob' WHERE id = 1");
158        assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
159        assert!(find_type(&cst, SegmentType::SetClause).is_some());
160        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
161    }
162
163    #[test]
164    fn test_delete() {
165        let cst = parse("DELETE FROM users WHERE id = 1");
166        assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
167    }
168
169    #[test]
170    fn test_create_table() {
171        let cst = parse("CREATE TABLE users (id INT, name VARCHAR(100))");
172        assert!(find_type(&cst, SegmentType::CreateTableStatement).is_some());
173    }
174
175    #[test]
176    fn test_with_cte() {
177        let cst =
178            parse("WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active");
179        assert!(find_type(&cst, SegmentType::WithClause).is_some());
180        assert!(find_type(&cst, SegmentType::CteDefinition).is_some());
181    }
182
183    #[test]
184    fn test_case_expression() {
185        let cst = parse("SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t");
186        assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
187        assert!(find_type(&cst, SegmentType::WhenClause).is_some());
188        assert!(find_type(&cst, SegmentType::ElseClause).is_some());
189    }
190
191    #[test]
192    fn test_subquery() {
193        let cst = parse("SELECT * FROM (SELECT 1) AS sub");
194        assert!(find_type(&cst, SegmentType::Subquery).is_some());
195    }
196
197    #[test]
198    fn test_function_call() {
199        let cst = parse("SELECT COUNT(*) FROM users");
200        assert!(find_type(&cst, SegmentType::FunctionCall).is_some());
201    }
202
203    #[test]
204    fn test_roundtrip() {
205        let sql = "SELECT a, b FROM t WHERE x = 1 ORDER BY a;";
206        let cst = parse(sql);
207        assert_eq!(
208            cst.raw(),
209            sql,
210            "CST roundtrip must preserve source text exactly"
211        );
212    }
213
214    #[test]
215    fn test_multiple_statements() {
216        let sql = "SELECT 1; SELECT 2;";
217        let cst = parse(sql);
218        let stmts: Vec<_> = cst
219            .children()
220            .iter()
221            .filter(|s| s.segment_type() == SegmentType::Statement)
222            .collect();
223        assert_eq!(stmts.len(), 2);
224    }
225
226    #[test]
227    fn test_roundtrip_complex() {
228        let sql = "WITH cte AS (\n  SELECT id, name\n  FROM users\n  WHERE active = TRUE\n)\nSELECT cte.id, cte.name\nFROM cte\nINNER JOIN orders ON cte.id = orders.user_id\nWHERE orders.total > 100\nORDER BY cte.name ASC\nLIMIT 10;";
229        let cst = parse(sql);
230        assert_eq!(cst.raw(), sql);
231    }
232
233    // ── TSQL Tests ──────────────────────────────────────────────
234
235    #[test]
236    fn test_tsql_declare_variable() {
237        let cst = parse_tsql("DECLARE @id INT;");
238        assert_no_unparsable(&cst);
239        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
240        assert_eq!(cst.raw(), "DECLARE @id INT;");
241    }
242
243    #[test]
244    fn test_tsql_declare_with_default() {
245        let cst = parse_tsql("DECLARE @name VARCHAR(100) = 'test';");
246        assert_no_unparsable(&cst);
247        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
248    }
249
250    #[test]
251    fn test_tsql_declare_multiple() {
252        let cst = parse_tsql("DECLARE @a INT, @b VARCHAR(50);");
253        assert_no_unparsable(&cst);
254        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
255        assert_eq!(cst.raw(), "DECLARE @a INT, @b VARCHAR(50);");
256    }
257
258    #[test]
259    fn test_tsql_declare_table_variable() {
260        let cst = parse_tsql("DECLARE @t TABLE (id INT, name VARCHAR(100));");
261        assert_no_unparsable(&cst);
262        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
263    }
264
265    #[test]
266    fn test_tsql_declare_cursor() {
267        let cst = parse_tsql("DECLARE cur CURSOR FOR SELECT id FROM users;");
268        assert_no_unparsable(&cst);
269        assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
270        assert!(find_type(&cst, SegmentType::SelectStatement).is_some());
271    }
272
273    #[test]
274    fn test_tsql_set_variable() {
275        let cst = parse_tsql("SET @id = 42;");
276        assert_no_unparsable(&cst);
277        assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
278        assert_eq!(cst.raw(), "SET @id = 42;");
279    }
280
281    #[test]
282    fn test_tsql_set_option() {
283        let cst = parse_tsql("SET NOCOUNT ON;");
284        assert_no_unparsable(&cst);
285        assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
286    }
287
288    #[test]
289    fn test_tsql_if_else() {
290        let sql = "IF @x > 0\n    SELECT 1;\nELSE\n    SELECT 2;";
291        let cst = parse_tsql(sql);
292        assert_no_unparsable(&cst);
293        assert!(find_type(&cst, SegmentType::IfStatement).is_some());
294        assert_eq!(cst.raw(), sql);
295    }
296
297    #[test]
298    fn test_tsql_if_begin_end() {
299        let sql = "IF @x > 0\nBEGIN\n    SELECT 1;\n    SELECT 2;\nEND";
300        let cst = parse_tsql(sql);
301        assert_no_unparsable(&cst);
302        assert!(find_type(&cst, SegmentType::IfStatement).is_some());
303        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
304    }
305
306    #[test]
307    fn test_tsql_begin_end() {
308        let sql = "BEGIN\n    SELECT 1;\n    SELECT 2;\nEND";
309        let cst = parse_tsql(sql);
310        assert_no_unparsable(&cst);
311        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
312    }
313
314    #[test]
315    fn test_tsql_while() {
316        let sql = "WHILE @i < 10\nBEGIN\n    SET @i = @i + 1;\nEND";
317        let cst = parse_tsql(sql);
318        assert_no_unparsable(&cst);
319        assert!(find_type(&cst, SegmentType::WhileStatement).is_some());
320        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
321    }
322
323    #[test]
324    fn test_tsql_try_catch() {
325        let sql = "BEGIN TRY\n    SELECT 1;\nEND TRY\nBEGIN CATCH\n    SELECT 2;\nEND CATCH";
326        let cst = parse_tsql(sql);
327        assert_no_unparsable(&cst);
328        assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
329        assert_eq!(cst.raw(), sql);
330    }
331
332    #[test]
333    fn test_tsql_exec_simple() {
334        let cst = parse_tsql("EXEC sp_help;");
335        assert_no_unparsable(&cst);
336        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
337    }
338
339    #[test]
340    fn test_tsql_exec_with_params() {
341        let cst = parse_tsql("EXEC dbo.usp_GetUser @id = 1, @name = 'test';");
342        assert_no_unparsable(&cst);
343        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
344    }
345
346    #[test]
347    fn test_tsql_execute_keyword() {
348        let cst = parse_tsql("EXECUTE sp_help;");
349        assert_no_unparsable(&cst);
350        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
351    }
352
353    #[test]
354    fn test_tsql_return() {
355        let cst = parse_tsql("RETURN 0;");
356        assert_no_unparsable(&cst);
357        assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
358    }
359
360    #[test]
361    fn test_tsql_return_no_value() {
362        let cst = parse_tsql("RETURN;");
363        assert_no_unparsable(&cst);
364        assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
365    }
366
367    #[test]
368    fn test_tsql_print() {
369        let cst = parse_tsql("PRINT 'hello';");
370        assert_no_unparsable(&cst);
371        assert!(find_type(&cst, SegmentType::PrintStatement).is_some());
372    }
373
374    #[test]
375    fn test_tsql_throw() {
376        let cst = parse_tsql("THROW 50000, 'Error occurred', 1;");
377        assert_no_unparsable(&cst);
378        assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
379    }
380
381    #[test]
382    fn test_tsql_throw_rethrow() {
383        let cst = parse_tsql("THROW;");
384        assert_no_unparsable(&cst);
385        assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
386    }
387
388    #[test]
389    fn test_tsql_raiserror() {
390        let cst = parse_tsql("RAISERROR('Error', 16, 1);");
391        assert_no_unparsable(&cst);
392        assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
393    }
394
395    #[test]
396    fn test_tsql_raiserror_with_nowait() {
397        let cst = parse_tsql("RAISERROR('Error', 16, 1) WITH NOWAIT;");
398        assert_no_unparsable(&cst);
399        assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
400    }
401
402    #[test]
403    fn test_tsql_go() {
404        let cst = parse_tsql("SELECT 1;\nGO");
405        assert_no_unparsable(&cst);
406        assert!(find_type(&cst, SegmentType::GoStatement).is_some());
407    }
408
409    #[test]
410    fn test_tsql_go_with_count() {
411        let cst = parse_tsql("GO 5");
412        assert_no_unparsable(&cst);
413        assert!(find_type(&cst, SegmentType::GoStatement).is_some());
414    }
415
416    #[test]
417    fn test_tsql_simple_statements() {
418        let cst = parse_tsql("USE master;");
419        assert_no_unparsable(&cst);
420        assert_eq!(cst.raw(), "USE master;");
421    }
422
423    #[test]
424    fn test_tsql_roundtrip_complex() {
425        let sql = "SET NOCOUNT ON;\nDECLARE @id INT = 1;\nIF @id > 0\nBEGIN\n    SELECT @id;\n    PRINT 'done';\nEND";
426        let cst = parse_tsql(sql);
427        assert_no_unparsable(&cst);
428        assert_eq!(cst.raw(), sql);
429    }
430
431    #[test]
432    fn test_tsql_nested_begin_end() {
433        let sql = "BEGIN\n    BEGIN\n        SELECT 1;\n    END\nEND";
434        let cst = parse_tsql(sql);
435        assert_no_unparsable(&cst);
436        assert_eq!(cst.raw(), sql);
437    }
438
439    #[test]
440    fn test_tsql_if_else_begin_end() {
441        let sql = "IF @x = 1\nBEGIN\n    SELECT 1;\nEND\nELSE\nBEGIN\n    SELECT 2;\nEND";
442        let cst = parse_tsql(sql);
443        assert_no_unparsable(&cst);
444        assert!(find_type(&cst, SegmentType::IfStatement).is_some());
445    }
446
447    #[test]
448    fn test_tsql_try_catch_with_throw() {
449        let sql = "BEGIN TRY\n    SELECT 1;\nEND TRY\nBEGIN CATCH\n    THROW;\nEND CATCH";
450        let cst = parse_tsql(sql);
451        assert_no_unparsable(&cst);
452        assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
453        assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
454    }
455
456    #[test]
457    fn test_tsql_case_inside_begin_end() {
458        let sql = "BEGIN\n    SELECT CASE WHEN @x > 0 THEN 'pos' ELSE 'neg' END;\nEND";
459        let cst = parse_tsql(sql);
460        assert_no_unparsable(&cst);
461        assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
462        assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
463    }
464
465    #[test]
466    fn test_tsql_exec_retval() {
467        let cst = parse_tsql("EXEC @result = dbo.usp_Calculate;");
468        assert_no_unparsable(&cst);
469        assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
470    }
471
472    #[test]
473    fn test_tsql_multiple_set_options() {
474        let sql = "SET ANSI_NULLS ON;\nSET QUOTED_IDENTIFIER ON;";
475        let cst = parse_tsql(sql);
476        assert_no_unparsable(&cst);
477        assert_eq!(cst.raw(), sql);
478    }
479
480    // ── Error Recovery Tests ──────────────────────────────────────
481
482    fn count_unparsable(seg: &Segment) -> usize {
483        let mut count = 0;
484        seg.walk(&mut |s| {
485            if s.segment_type() == SegmentType::Unparsable {
486                count += 1;
487            }
488        });
489        count
490    }
491
492    #[test]
493    fn test_error_recovery_garbage_then_valid() {
494        // Garbage tokens followed by a valid statement
495        let sql = "XYZZY FOOBAR; SELECT 1;";
496        let cst = parse(sql);
497        assert_eq!(cst.raw(), sql, "roundtrip must preserve source");
498        // The garbage should be in one Unparsable node
499        assert_eq!(count_unparsable(&cst), 1);
500        // The valid SELECT should still parse
501        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
502    }
503
504    #[test]
505    fn test_error_recovery_garbage_between_statements() {
506        // Valid, garbage, valid
507        let sql = "SELECT 1; NOTAKEYWORD 123 'abc'; SELECT 2;";
508        let cst = parse(sql);
509        assert_eq!(cst.raw(), sql);
510        assert_eq!(count_unparsable(&cst), 1);
511        let stmts: Vec<_> = cst
512            .children()
513            .iter()
514            .filter(|s| s.segment_type() == SegmentType::Statement)
515            .collect();
516        assert_eq!(stmts.len(), 2);
517    }
518
519    #[test]
520    fn test_error_recovery_garbage_at_end() {
521        let sql = "SELECT 1; XYZZY";
522        let cst = parse(sql);
523        assert_eq!(cst.raw(), sql);
524        assert_eq!(count_unparsable(&cst), 1);
525        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
526    }
527
528    #[test]
529    fn test_error_recovery_skips_to_statement_keyword() {
530        // Garbage followed directly by SELECT (no semicolon separator)
531        let sql = "XYZZY SELECT 1;";
532        let cst = parse(sql);
533        assert_eq!(cst.raw(), sql);
534        assert_eq!(count_unparsable(&cst), 1);
535        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
536    }
537
538    #[test]
539    fn test_error_recovery_diagnostics() {
540        let parser = Parser::default();
541        let result = parser.parse_with_diagnostics("XYZZY; SELECT 1;").unwrap();
542        assert!(!result.diagnostics.is_empty());
543        assert!(result.diagnostics[0].message.contains("Unparsable"));
544        // Offset should point to the start of the unparsable region (byte 0 = 'X')
545        assert_eq!(result.diagnostics[0].offset, 0);
546        // CST still produced
547        assert!(find_type(&result.tree, SegmentType::SelectClause).is_some());
548    }
549
550    #[test]
551    fn test_error_recovery_diagnostics_offset_mid_file() {
552        let parser = Parser::default();
553        // "SELECT 1; " = 10 bytes, then garbage starts
554        let result = parser
555            .parse_with_diagnostics("SELECT 1; BADTOKEN;")
556            .unwrap();
557        assert_eq!(result.diagnostics.len(), 1);
558        // Offset should point to 'B' in BADTOKEN, not to ';' or beyond
559        assert_eq!(result.diagnostics[0].offset, 10);
560    }
561
562    #[test]
563    fn test_error_recovery_all_garbage() {
564        let sql = "NOTAKEYWORD 123 'hello'";
565        let cst = parse(sql);
566        assert_eq!(cst.raw(), sql);
567        // Everything should be unparsable but still present
568        assert!(count_unparsable(&cst) >= 1);
569    }
570
571    #[test]
572    fn test_error_recovery_preserves_valid_statements() {
573        // Multiple valid statements with garbage in the middle
574        let sql = "INSERT INTO t VALUES (1); BADTOKEN; DELETE FROM t WHERE id = 1;";
575        let cst = parse(sql);
576        assert_eq!(cst.raw(), sql);
577        assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
578        assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
579        assert_eq!(count_unparsable(&cst), 1);
580    }
581}