1use rigsql_core::Segment;
2use rigsql_lexer::{Lexer, LexerConfig, LexerError};
3use thiserror::Error;
4
5use crate::context::{ParseContext, ParseDiagnostic};
6#[cfg(test)]
7use crate::grammar::TsqlGrammar;
8use crate::grammar::{AnsiGrammar, Grammar};
9
10#[derive(Debug, Error)]
11pub enum ParseError {
12 #[error("Lexer error: {0}")]
13 Lexer(#[from] LexerError),
14}
15
16pub struct ParseResult {
19 pub tree: Segment,
22 pub diagnostics: Vec<ParseDiagnostic>,
25}
26
27pub struct Parser {
29 lexer_config: LexerConfig,
30 grammar: Box<dyn Grammar>,
31}
32
33impl Parser {
34 pub fn new(lexer_config: LexerConfig, grammar: Box<dyn Grammar>) -> Self {
35 Self {
36 lexer_config,
37 grammar,
38 }
39 }
40
41 pub fn parse(&self, source: &str) -> Result<Segment, ParseError> {
43 self.parse_with_diagnostics(source).map(|r| r.tree)
44 }
45
46 pub fn parse_with_diagnostics(&self, source: &str) -> Result<ParseResult, ParseError> {
49 let mut lexer = Lexer::new(source, self.lexer_config.clone());
50 let tokens = lexer.tokenize()?;
51 let mut ctx = ParseContext::new(&tokens, source);
52 let tree = self.grammar.parse_file(&mut ctx);
53 let diagnostics = ctx.take_diagnostics();
54 Ok(ParseResult { tree, diagnostics })
55 }
56}
57
58impl Default for Parser {
59 fn default() -> Self {
60 Self::new(LexerConfig::ansi(), Box::new(AnsiGrammar))
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use rigsql_core::SegmentType;
68
69 fn parse(sql: &str) -> Segment {
70 Parser::default().parse(sql).unwrap()
71 }
72
73 fn parse_tsql(sql: &str) -> Segment {
74 Parser::new(LexerConfig::tsql(), Box::new(TsqlGrammar))
75 .parse(sql)
76 .unwrap()
77 }
78
79 fn assert_type(seg: &Segment, expected: SegmentType) {
80 assert_eq!(
81 seg.segment_type(),
82 expected,
83 "Expected {:?} but got {:?} for raw: {:?}",
84 expected,
85 seg.segment_type(),
86 seg.raw()
87 );
88 }
89
90 fn find_type(seg: &Segment, ty: SegmentType) -> Option<&Segment> {
91 let mut result = None;
92 seg.walk(&mut |s| {
93 if result.is_none() && s.segment_type() == ty {
94 result = Some(s as *const Segment);
95 }
96 });
97 result.map(|p| unsafe { &*p })
98 }
99
100 fn assert_no_unparsable(seg: &Segment) {
101 let mut unparsable = Vec::new();
102 seg.walk(&mut |s| {
103 if s.segment_type() == SegmentType::Unparsable {
104 unparsable.push(s.raw());
105 }
106 });
107 assert!(
108 unparsable.is_empty(),
109 "Found Unparsable segments: {:?}",
110 unparsable
111 );
112 }
113
114 #[test]
115 fn test_simple_select() {
116 let cst = parse("SELECT 1");
117 assert_type(&cst, SegmentType::File);
118 let stmt = &cst.children()[0];
119 assert_type(stmt, SegmentType::Statement);
120 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
121 }
122
123 #[test]
124 fn test_select_from_where() {
125 let cst = parse("SELECT name FROM users WHERE id = 1");
126 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
127 assert!(find_type(&cst, SegmentType::FromClause).is_some());
128 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
129 }
130
131 #[test]
132 fn test_join() {
133 let cst = parse("SELECT a.id FROM a INNER JOIN b ON a.id = b.id");
134 assert!(find_type(&cst, SegmentType::JoinClause).is_some());
135 assert!(find_type(&cst, SegmentType::OnClause).is_some());
136 }
137
138 #[test]
139 fn test_group_by_having_order_by() {
140 let cst = parse(
141 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC",
142 );
143 assert!(find_type(&cst, SegmentType::GroupByClause).is_some());
144 assert!(find_type(&cst, SegmentType::HavingClause).is_some());
145 assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
146 }
147
148 #[test]
149 fn test_insert_values() {
150 let cst = parse("INSERT INTO users (name, email) VALUES ('Alice', 'a@b.com')");
151 assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
152 assert!(find_type(&cst, SegmentType::ValuesClause).is_some());
153 }
154
155 #[test]
156 fn test_update_set_where() {
157 let cst = parse("UPDATE users SET name = 'Bob' WHERE id = 1");
158 assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
159 assert!(find_type(&cst, SegmentType::SetClause).is_some());
160 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
161 }
162
163 #[test]
164 fn test_delete() {
165 let cst = parse("DELETE FROM users WHERE id = 1");
166 assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
167 }
168
169 #[test]
170 fn test_create_table() {
171 let cst = parse("CREATE TABLE users (id INT, name VARCHAR(100))");
172 assert!(find_type(&cst, SegmentType::CreateTableStatement).is_some());
173 }
174
175 #[test]
176 fn test_with_cte() {
177 let cst =
178 parse("WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active");
179 assert!(find_type(&cst, SegmentType::WithClause).is_some());
180 assert!(find_type(&cst, SegmentType::CteDefinition).is_some());
181 }
182
183 #[test]
184 fn test_case_expression() {
185 let cst = parse("SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t");
186 assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
187 assert!(find_type(&cst, SegmentType::WhenClause).is_some());
188 assert!(find_type(&cst, SegmentType::ElseClause).is_some());
189 }
190
191 #[test]
192 fn test_subquery() {
193 let cst = parse("SELECT * FROM (SELECT 1) AS sub");
194 assert!(find_type(&cst, SegmentType::Subquery).is_some());
195 }
196
197 #[test]
198 fn test_function_call() {
199 let cst = parse("SELECT COUNT(*) FROM users");
200 assert!(find_type(&cst, SegmentType::FunctionCall).is_some());
201 }
202
203 #[test]
204 fn test_roundtrip() {
205 let sql = "SELECT a, b FROM t WHERE x = 1 ORDER BY a;";
206 let cst = parse(sql);
207 assert_eq!(
208 cst.raw(),
209 sql,
210 "CST roundtrip must preserve source text exactly"
211 );
212 }
213
214 #[test]
215 fn test_multiple_statements() {
216 let sql = "SELECT 1; SELECT 2;";
217 let cst = parse(sql);
218 let stmts: Vec<_> = cst
219 .children()
220 .iter()
221 .filter(|s| s.segment_type() == SegmentType::Statement)
222 .collect();
223 assert_eq!(stmts.len(), 2);
224 }
225
226 #[test]
227 fn test_roundtrip_complex() {
228 let sql = "WITH cte AS (\n SELECT id, name\n FROM users\n WHERE active = TRUE\n)\nSELECT cte.id, cte.name\nFROM cte\nINNER JOIN orders ON cte.id = orders.user_id\nWHERE orders.total > 100\nORDER BY cte.name ASC\nLIMIT 10;";
229 let cst = parse(sql);
230 assert_eq!(cst.raw(), sql);
231 }
232
233 #[test]
236 fn test_tsql_declare_variable() {
237 let cst = parse_tsql("DECLARE @id INT;");
238 assert_no_unparsable(&cst);
239 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
240 assert_eq!(cst.raw(), "DECLARE @id INT;");
241 }
242
243 #[test]
244 fn test_tsql_declare_with_default() {
245 let cst = parse_tsql("DECLARE @name VARCHAR(100) = 'test';");
246 assert_no_unparsable(&cst);
247 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
248 }
249
250 #[test]
251 fn test_tsql_declare_multiple() {
252 let cst = parse_tsql("DECLARE @a INT, @b VARCHAR(50);");
253 assert_no_unparsable(&cst);
254 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
255 assert_eq!(cst.raw(), "DECLARE @a INT, @b VARCHAR(50);");
256 }
257
258 #[test]
259 fn test_tsql_declare_table_variable() {
260 let cst = parse_tsql("DECLARE @t TABLE (id INT, name VARCHAR(100));");
261 assert_no_unparsable(&cst);
262 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
263 }
264
265 #[test]
266 fn test_tsql_declare_cursor() {
267 let cst = parse_tsql("DECLARE cur CURSOR FOR SELECT id FROM users;");
268 assert_no_unparsable(&cst);
269 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
270 assert!(find_type(&cst, SegmentType::SelectStatement).is_some());
271 }
272
273 #[test]
274 fn test_tsql_set_variable() {
275 let cst = parse_tsql("SET @id = 42;");
276 assert_no_unparsable(&cst);
277 assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
278 assert_eq!(cst.raw(), "SET @id = 42;");
279 }
280
281 #[test]
282 fn test_tsql_set_option() {
283 let cst = parse_tsql("SET NOCOUNT ON;");
284 assert_no_unparsable(&cst);
285 assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
286 }
287
288 #[test]
289 fn test_tsql_if_else() {
290 let sql = "IF @x > 0\n SELECT 1;\nELSE\n SELECT 2;";
291 let cst = parse_tsql(sql);
292 assert_no_unparsable(&cst);
293 assert!(find_type(&cst, SegmentType::IfStatement).is_some());
294 assert_eq!(cst.raw(), sql);
295 }
296
297 #[test]
298 fn test_tsql_if_begin_end() {
299 let sql = "IF @x > 0\nBEGIN\n SELECT 1;\n SELECT 2;\nEND";
300 let cst = parse_tsql(sql);
301 assert_no_unparsable(&cst);
302 assert!(find_type(&cst, SegmentType::IfStatement).is_some());
303 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
304 }
305
306 #[test]
307 fn test_tsql_begin_end() {
308 let sql = "BEGIN\n SELECT 1;\n SELECT 2;\nEND";
309 let cst = parse_tsql(sql);
310 assert_no_unparsable(&cst);
311 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
312 }
313
314 #[test]
315 fn test_tsql_while() {
316 let sql = "WHILE @i < 10\nBEGIN\n SET @i = @i + 1;\nEND";
317 let cst = parse_tsql(sql);
318 assert_no_unparsable(&cst);
319 assert!(find_type(&cst, SegmentType::WhileStatement).is_some());
320 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
321 }
322
323 #[test]
324 fn test_tsql_try_catch() {
325 let sql = "BEGIN TRY\n SELECT 1;\nEND TRY\nBEGIN CATCH\n SELECT 2;\nEND CATCH";
326 let cst = parse_tsql(sql);
327 assert_no_unparsable(&cst);
328 assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
329 assert_eq!(cst.raw(), sql);
330 }
331
332 #[test]
333 fn test_tsql_exec_simple() {
334 let cst = parse_tsql("EXEC sp_help;");
335 assert_no_unparsable(&cst);
336 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
337 }
338
339 #[test]
340 fn test_tsql_exec_with_params() {
341 let cst = parse_tsql("EXEC dbo.usp_GetUser @id = 1, @name = 'test';");
342 assert_no_unparsable(&cst);
343 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
344 }
345
346 #[test]
347 fn test_tsql_execute_keyword() {
348 let cst = parse_tsql("EXECUTE sp_help;");
349 assert_no_unparsable(&cst);
350 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
351 }
352
353 #[test]
354 fn test_tsql_return() {
355 let cst = parse_tsql("RETURN 0;");
356 assert_no_unparsable(&cst);
357 assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
358 }
359
360 #[test]
361 fn test_tsql_return_no_value() {
362 let cst = parse_tsql("RETURN;");
363 assert_no_unparsable(&cst);
364 assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
365 }
366
367 #[test]
368 fn test_tsql_print() {
369 let cst = parse_tsql("PRINT 'hello';");
370 assert_no_unparsable(&cst);
371 assert!(find_type(&cst, SegmentType::PrintStatement).is_some());
372 }
373
374 #[test]
375 fn test_tsql_throw() {
376 let cst = parse_tsql("THROW 50000, 'Error occurred', 1;");
377 assert_no_unparsable(&cst);
378 assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
379 }
380
381 #[test]
382 fn test_tsql_throw_rethrow() {
383 let cst = parse_tsql("THROW;");
384 assert_no_unparsable(&cst);
385 assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
386 }
387
388 #[test]
389 fn test_tsql_raiserror() {
390 let cst = parse_tsql("RAISERROR('Error', 16, 1);");
391 assert_no_unparsable(&cst);
392 assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
393 }
394
395 #[test]
396 fn test_tsql_raiserror_with_nowait() {
397 let cst = parse_tsql("RAISERROR('Error', 16, 1) WITH NOWAIT;");
398 assert_no_unparsable(&cst);
399 assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
400 }
401
402 #[test]
403 fn test_tsql_go() {
404 let cst = parse_tsql("SELECT 1;\nGO");
405 assert_no_unparsable(&cst);
406 assert!(find_type(&cst, SegmentType::GoStatement).is_some());
407 }
408
409 #[test]
410 fn test_tsql_go_with_count() {
411 let cst = parse_tsql("GO 5");
412 assert_no_unparsable(&cst);
413 assert!(find_type(&cst, SegmentType::GoStatement).is_some());
414 }
415
416 #[test]
417 fn test_tsql_simple_statements() {
418 let cst = parse_tsql("USE master;");
419 assert_no_unparsable(&cst);
420 assert_eq!(cst.raw(), "USE master;");
421 }
422
423 #[test]
424 fn test_tsql_roundtrip_complex() {
425 let sql = "SET NOCOUNT ON;\nDECLARE @id INT = 1;\nIF @id > 0\nBEGIN\n SELECT @id;\n PRINT 'done';\nEND";
426 let cst = parse_tsql(sql);
427 assert_no_unparsable(&cst);
428 assert_eq!(cst.raw(), sql);
429 }
430
431 #[test]
432 fn test_tsql_nested_begin_end() {
433 let sql = "BEGIN\n BEGIN\n SELECT 1;\n END\nEND";
434 let cst = parse_tsql(sql);
435 assert_no_unparsable(&cst);
436 assert_eq!(cst.raw(), sql);
437 }
438
439 #[test]
440 fn test_tsql_if_else_begin_end() {
441 let sql = "IF @x = 1\nBEGIN\n SELECT 1;\nEND\nELSE\nBEGIN\n SELECT 2;\nEND";
442 let cst = parse_tsql(sql);
443 assert_no_unparsable(&cst);
444 assert!(find_type(&cst, SegmentType::IfStatement).is_some());
445 }
446
447 #[test]
448 fn test_tsql_try_catch_with_throw() {
449 let sql = "BEGIN TRY\n SELECT 1;\nEND TRY\nBEGIN CATCH\n THROW;\nEND CATCH";
450 let cst = parse_tsql(sql);
451 assert_no_unparsable(&cst);
452 assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
453 assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
454 }
455
456 #[test]
457 fn test_tsql_case_inside_begin_end() {
458 let sql = "BEGIN\n SELECT CASE WHEN @x > 0 THEN 'pos' ELSE 'neg' END;\nEND";
459 let cst = parse_tsql(sql);
460 assert_no_unparsable(&cst);
461 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
462 assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
463 }
464
465 #[test]
466 fn test_tsql_exec_retval() {
467 let cst = parse_tsql("EXEC @result = dbo.usp_Calculate;");
468 assert_no_unparsable(&cst);
469 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
470 }
471
472 #[test]
473 fn test_tsql_multiple_set_options() {
474 let sql = "SET ANSI_NULLS ON;\nSET QUOTED_IDENTIFIER ON;";
475 let cst = parse_tsql(sql);
476 assert_no_unparsable(&cst);
477 assert_eq!(cst.raw(), sql);
478 }
479
480 fn count_unparsable(seg: &Segment) -> usize {
483 let mut count = 0;
484 seg.walk(&mut |s| {
485 if s.segment_type() == SegmentType::Unparsable {
486 count += 1;
487 }
488 });
489 count
490 }
491
492 #[test]
493 fn test_error_recovery_garbage_then_valid() {
494 let sql = "XYZZY FOOBAR; SELECT 1;";
496 let cst = parse(sql);
497 assert_eq!(cst.raw(), sql, "roundtrip must preserve source");
498 assert_eq!(count_unparsable(&cst), 1);
500 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
502 }
503
504 #[test]
505 fn test_error_recovery_garbage_between_statements() {
506 let sql = "SELECT 1; NOTAKEYWORD 123 'abc'; SELECT 2;";
508 let cst = parse(sql);
509 assert_eq!(cst.raw(), sql);
510 assert_eq!(count_unparsable(&cst), 1);
511 let stmts: Vec<_> = cst
512 .children()
513 .iter()
514 .filter(|s| s.segment_type() == SegmentType::Statement)
515 .collect();
516 assert_eq!(stmts.len(), 2);
517 }
518
519 #[test]
520 fn test_error_recovery_garbage_at_end() {
521 let sql = "SELECT 1; XYZZY";
522 let cst = parse(sql);
523 assert_eq!(cst.raw(), sql);
524 assert_eq!(count_unparsable(&cst), 1);
525 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
526 }
527
528 #[test]
529 fn test_error_recovery_skips_to_statement_keyword() {
530 let sql = "XYZZY SELECT 1;";
532 let cst = parse(sql);
533 assert_eq!(cst.raw(), sql);
534 assert_eq!(count_unparsable(&cst), 1);
535 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
536 }
537
538 #[test]
539 fn test_error_recovery_diagnostics() {
540 let parser = Parser::default();
541 let result = parser.parse_with_diagnostics("XYZZY; SELECT 1;").unwrap();
542 assert!(!result.diagnostics.is_empty());
543 assert!(result.diagnostics[0].message.contains("Unparsable"));
544 assert_eq!(result.diagnostics[0].offset, 0);
546 assert!(find_type(&result.tree, SegmentType::SelectClause).is_some());
548 }
549
550 #[test]
551 fn test_error_recovery_diagnostics_offset_mid_file() {
552 let parser = Parser::default();
553 let result = parser
555 .parse_with_diagnostics("SELECT 1; BADTOKEN;")
556 .unwrap();
557 assert_eq!(result.diagnostics.len(), 1);
558 assert_eq!(result.diagnostics[0].offset, 10);
560 }
561
562 #[test]
563 fn test_error_recovery_all_garbage() {
564 let sql = "NOTAKEYWORD 123 'hello'";
565 let cst = parse(sql);
566 assert_eq!(cst.raw(), sql);
567 assert!(count_unparsable(&cst) >= 1);
569 }
570
571 #[test]
572 fn test_error_recovery_preserves_valid_statements() {
573 let sql = "INSERT INTO t VALUES (1); BADTOKEN; DELETE FROM t WHERE id = 1;";
575 let cst = parse(sql);
576 assert_eq!(cst.raw(), sql);
577 assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
578 assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
579 assert_eq!(count_unparsable(&cst), 1);
580 }
581}