1use rigsql_core::Segment;
2use rigsql_lexer::{Lexer, LexerConfig, LexerError};
3use thiserror::Error;
4
5use crate::context::{ParseContext, ParseDiagnostic};
6use crate::grammar::{AnsiGrammar, Grammar};
7#[cfg(test)]
8use crate::grammar::{PostgresGrammar, TsqlGrammar};
9
10#[derive(Debug, Error)]
11pub enum ParseError {
12 #[error("Lexer error: {0}")]
13 Lexer(#[from] LexerError),
14}
15
16pub struct ParseResult {
19 pub tree: Segment,
22 pub diagnostics: Vec<ParseDiagnostic>,
25}
26
27pub struct Parser {
29 lexer_config: LexerConfig,
30 grammar: Box<dyn Grammar>,
31}
32
33impl Parser {
34 pub fn new(lexer_config: LexerConfig, grammar: Box<dyn Grammar>) -> Self {
35 Self {
36 lexer_config,
37 grammar,
38 }
39 }
40
41 pub fn parse(&self, source: &str) -> Result<Segment, ParseError> {
43 self.parse_with_diagnostics(source).map(|r| r.tree)
44 }
45
46 pub fn parse_with_diagnostics(&self, source: &str) -> Result<ParseResult, ParseError> {
49 let mut lexer = Lexer::new(source, self.lexer_config.clone());
50 let tokens = lexer.tokenize()?;
51 let mut ctx = ParseContext::new(&tokens, source);
52 let tree = self.grammar.parse_file(&mut ctx);
53 let diagnostics = ctx.take_diagnostics();
54 Ok(ParseResult { tree, diagnostics })
55 }
56}
57
58impl Default for Parser {
59 fn default() -> Self {
60 Self::new(LexerConfig::ansi(), Box::new(AnsiGrammar))
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use rigsql_core::SegmentType;
68
69 fn parse(sql: &str) -> Segment {
70 Parser::default().parse(sql).unwrap()
71 }
72
73 fn parse_tsql(sql: &str) -> Segment {
74 Parser::new(LexerConfig::tsql(), Box::new(TsqlGrammar))
75 .parse(sql)
76 .unwrap()
77 }
78
79 fn assert_type(seg: &Segment, expected: SegmentType) {
80 assert_eq!(
81 seg.segment_type(),
82 expected,
83 "Expected {:?} but got {:?} for raw: {:?}",
84 expected,
85 seg.segment_type(),
86 seg.raw()
87 );
88 }
89
90 fn find_type(seg: &Segment, ty: SegmentType) -> Option<&Segment> {
91 let mut result = None;
92 seg.walk(&mut |s| {
93 if result.is_none() && s.segment_type() == ty {
94 result = Some(s as *const Segment);
95 }
96 });
97 result.map(|p| unsafe { &*p })
98 }
99
100 fn assert_no_unparsable(seg: &Segment) {
101 let mut unparsable = Vec::new();
102 seg.walk(&mut |s| {
103 if s.segment_type() == SegmentType::Unparsable {
104 unparsable.push(s.raw());
105 }
106 });
107 assert!(
108 unparsable.is_empty(),
109 "Found Unparsable segments: {:?}",
110 unparsable
111 );
112 }
113
114 #[test]
115 fn test_simple_select() {
116 let cst = parse("SELECT 1");
117 assert_type(&cst, SegmentType::File);
118 let stmt = &cst.children()[0];
119 assert_type(stmt, SegmentType::Statement);
120 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
121 }
122
123 #[test]
124 fn test_select_from_where() {
125 let cst = parse("SELECT name FROM users WHERE id = 1");
126 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
127 assert!(find_type(&cst, SegmentType::FromClause).is_some());
128 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
129 }
130
131 #[test]
132 fn test_join() {
133 let cst = parse("SELECT a.id FROM a INNER JOIN b ON a.id = b.id");
134 assert!(find_type(&cst, SegmentType::JoinClause).is_some());
135 assert!(find_type(&cst, SegmentType::OnClause).is_some());
136 }
137
138 #[test]
139 fn test_group_by_having_order_by() {
140 let cst = parse(
141 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC",
142 );
143 assert!(find_type(&cst, SegmentType::GroupByClause).is_some());
144 assert!(find_type(&cst, SegmentType::HavingClause).is_some());
145 assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
146 }
147
148 #[test]
149 fn test_insert_values() {
150 let cst = parse("INSERT INTO users (name, email) VALUES ('Alice', 'a@b.com')");
151 assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
152 assert!(find_type(&cst, SegmentType::ValuesClause).is_some());
153 }
154
155 #[test]
156 fn test_update_set_where() {
157 let cst = parse("UPDATE users SET name = 'Bob' WHERE id = 1");
158 assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
159 assert!(find_type(&cst, SegmentType::SetClause).is_some());
160 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
161 }
162
163 #[test]
164 fn test_delete() {
165 let cst = parse("DELETE FROM users WHERE id = 1");
166 assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
167 }
168
169 #[test]
170 fn test_create_table() {
171 let cst = parse("CREATE TABLE users (id INT, name VARCHAR(100))");
172 assert!(find_type(&cst, SegmentType::CreateTableStatement).is_some());
173 }
174
175 #[test]
176 fn test_with_cte() {
177 let cst =
178 parse("WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active");
179 assert!(find_type(&cst, SegmentType::WithClause).is_some());
180 assert!(find_type(&cst, SegmentType::CteDefinition).is_some());
181 }
182
183 #[test]
184 fn test_case_expression() {
185 let cst = parse("SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t");
186 assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
187 assert!(find_type(&cst, SegmentType::WhenClause).is_some());
188 assert!(find_type(&cst, SegmentType::ElseClause).is_some());
189 }
190
191 #[test]
192 fn test_subquery() {
193 let cst = parse("SELECT * FROM (SELECT 1) AS sub");
194 assert!(find_type(&cst, SegmentType::Subquery).is_some());
195 }
196
197 #[test]
198 fn test_function_call() {
199 let cst = parse("SELECT COUNT(*) FROM users");
200 assert!(find_type(&cst, SegmentType::FunctionCall).is_some());
201 }
202
203 #[test]
204 fn test_roundtrip() {
205 let sql = "SELECT a, b FROM t WHERE x = 1 ORDER BY a;";
206 let cst = parse(sql);
207 assert_eq!(
208 cst.raw(),
209 sql,
210 "CST roundtrip must preserve source text exactly"
211 );
212 }
213
214 #[test]
215 fn test_multiple_statements() {
216 let sql = "SELECT 1; SELECT 2;";
217 let cst = parse(sql);
218 let stmts: Vec<_> = cst
219 .children()
220 .iter()
221 .filter(|s| s.segment_type() == SegmentType::Statement)
222 .collect();
223 assert_eq!(stmts.len(), 2);
224 }
225
226 #[test]
227 fn test_roundtrip_complex() {
228 let sql = "WITH cte AS (\n SELECT id, name\n FROM users\n WHERE active = TRUE\n)\nSELECT cte.id, cte.name\nFROM cte\nINNER JOIN orders ON cte.id = orders.user_id\nWHERE orders.total > 100\nORDER BY cte.name ASC\nLIMIT 10;";
229 let cst = parse(sql);
230 assert_eq!(cst.raw(), sql);
231 }
232
233 #[test]
236 fn test_tsql_declare_variable() {
237 let cst = parse_tsql("DECLARE @id INT;");
238 assert_no_unparsable(&cst);
239 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
240 assert_eq!(cst.raw(), "DECLARE @id INT;");
241 }
242
243 #[test]
244 fn test_tsql_declare_with_default() {
245 let cst = parse_tsql("DECLARE @name VARCHAR(100) = 'test';");
246 assert_no_unparsable(&cst);
247 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
248 }
249
250 #[test]
251 fn test_tsql_declare_multiple() {
252 let cst = parse_tsql("DECLARE @a INT, @b VARCHAR(50);");
253 assert_no_unparsable(&cst);
254 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
255 assert_eq!(cst.raw(), "DECLARE @a INT, @b VARCHAR(50);");
256 }
257
258 #[test]
259 fn test_tsql_declare_table_variable() {
260 let cst = parse_tsql("DECLARE @t TABLE (id INT, name VARCHAR(100));");
261 assert_no_unparsable(&cst);
262 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
263 }
264
265 #[test]
266 fn test_tsql_declare_cursor() {
267 let cst = parse_tsql("DECLARE cur CURSOR FOR SELECT id FROM users;");
268 assert_no_unparsable(&cst);
269 assert!(find_type(&cst, SegmentType::DeclareStatement).is_some());
270 assert!(find_type(&cst, SegmentType::SelectStatement).is_some());
271 }
272
273 #[test]
274 fn test_tsql_set_variable() {
275 let cst = parse_tsql("SET @id = 42;");
276 assert_no_unparsable(&cst);
277 assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
278 assert_eq!(cst.raw(), "SET @id = 42;");
279 }
280
281 #[test]
282 fn test_tsql_set_option() {
283 let cst = parse_tsql("SET NOCOUNT ON;");
284 assert_no_unparsable(&cst);
285 assert!(find_type(&cst, SegmentType::SetVariableStatement).is_some());
286 }
287
288 #[test]
289 fn test_tsql_if_else() {
290 let sql = "IF @x > 0\n SELECT 1;\nELSE\n SELECT 2;";
291 let cst = parse_tsql(sql);
292 assert_no_unparsable(&cst);
293 assert!(find_type(&cst, SegmentType::IfStatement).is_some());
294 assert_eq!(cst.raw(), sql);
295 }
296
297 #[test]
298 fn test_tsql_if_begin_end() {
299 let sql = "IF @x > 0\nBEGIN\n SELECT 1;\n SELECT 2;\nEND";
300 let cst = parse_tsql(sql);
301 assert_no_unparsable(&cst);
302 assert!(find_type(&cst, SegmentType::IfStatement).is_some());
303 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
304 }
305
306 #[test]
307 fn test_tsql_begin_end() {
308 let sql = "BEGIN\n SELECT 1;\n SELECT 2;\nEND";
309 let cst = parse_tsql(sql);
310 assert_no_unparsable(&cst);
311 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
312 }
313
314 #[test]
315 fn test_tsql_while() {
316 let sql = "WHILE @i < 10\nBEGIN\n SET @i = @i + 1;\nEND";
317 let cst = parse_tsql(sql);
318 assert_no_unparsable(&cst);
319 assert!(find_type(&cst, SegmentType::WhileStatement).is_some());
320 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
321 }
322
323 #[test]
324 fn test_tsql_try_catch() {
325 let sql = "BEGIN TRY\n SELECT 1;\nEND TRY\nBEGIN CATCH\n SELECT 2;\nEND CATCH";
326 let cst = parse_tsql(sql);
327 assert_no_unparsable(&cst);
328 assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
329 assert_eq!(cst.raw(), sql);
330 }
331
332 #[test]
333 fn test_tsql_exec_simple() {
334 let cst = parse_tsql("EXEC sp_help;");
335 assert_no_unparsable(&cst);
336 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
337 }
338
339 #[test]
340 fn test_tsql_exec_with_params() {
341 let cst = parse_tsql("EXEC dbo.usp_GetUser @id = 1, @name = 'test';");
342 assert_no_unparsable(&cst);
343 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
344 }
345
346 #[test]
347 fn test_tsql_execute_keyword() {
348 let cst = parse_tsql("EXECUTE sp_help;");
349 assert_no_unparsable(&cst);
350 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
351 }
352
353 #[test]
354 fn test_tsql_return() {
355 let cst = parse_tsql("RETURN 0;");
356 assert_no_unparsable(&cst);
357 assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
358 }
359
360 #[test]
361 fn test_tsql_return_no_value() {
362 let cst = parse_tsql("RETURN;");
363 assert_no_unparsable(&cst);
364 assert!(find_type(&cst, SegmentType::ReturnStatement).is_some());
365 }
366
367 #[test]
368 fn test_tsql_print() {
369 let cst = parse_tsql("PRINT 'hello';");
370 assert_no_unparsable(&cst);
371 assert!(find_type(&cst, SegmentType::PrintStatement).is_some());
372 }
373
374 #[test]
375 fn test_tsql_throw() {
376 let cst = parse_tsql("THROW 50000, 'Error occurred', 1;");
377 assert_no_unparsable(&cst);
378 assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
379 }
380
381 #[test]
382 fn test_tsql_throw_rethrow() {
383 let cst = parse_tsql("THROW;");
384 assert_no_unparsable(&cst);
385 assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
386 }
387
388 #[test]
389 fn test_tsql_raiserror() {
390 let cst = parse_tsql("RAISERROR('Error', 16, 1);");
391 assert_no_unparsable(&cst);
392 assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
393 }
394
395 #[test]
396 fn test_tsql_raiserror_with_nowait() {
397 let cst = parse_tsql("RAISERROR('Error', 16, 1) WITH NOWAIT;");
398 assert_no_unparsable(&cst);
399 assert!(find_type(&cst, SegmentType::RaiserrorStatement).is_some());
400 }
401
402 #[test]
403 fn test_tsql_go() {
404 let cst = parse_tsql("SELECT 1;\nGO");
405 assert_no_unparsable(&cst);
406 assert!(find_type(&cst, SegmentType::GoStatement).is_some());
407 }
408
409 #[test]
410 fn test_tsql_go_with_count() {
411 let cst = parse_tsql("GO 5");
412 assert_no_unparsable(&cst);
413 assert!(find_type(&cst, SegmentType::GoStatement).is_some());
414 }
415
416 #[test]
417 fn test_tsql_simple_statements() {
418 let cst = parse_tsql("USE master;");
419 assert_no_unparsable(&cst);
420 assert_eq!(cst.raw(), "USE master;");
421 }
422
423 #[test]
424 fn test_tsql_roundtrip_complex() {
425 let sql = "SET NOCOUNT ON;\nDECLARE @id INT = 1;\nIF @id > 0\nBEGIN\n SELECT @id;\n PRINT 'done';\nEND";
426 let cst = parse_tsql(sql);
427 assert_no_unparsable(&cst);
428 assert_eq!(cst.raw(), sql);
429 }
430
431 #[test]
432 fn test_tsql_nested_begin_end() {
433 let sql = "BEGIN\n BEGIN\n SELECT 1;\n END\nEND";
434 let cst = parse_tsql(sql);
435 assert_no_unparsable(&cst);
436 assert_eq!(cst.raw(), sql);
437 }
438
439 #[test]
440 fn test_tsql_if_else_begin_end() {
441 let sql = "IF @x = 1\nBEGIN\n SELECT 1;\nEND\nELSE\nBEGIN\n SELECT 2;\nEND";
442 let cst = parse_tsql(sql);
443 assert_no_unparsable(&cst);
444 assert!(find_type(&cst, SegmentType::IfStatement).is_some());
445 }
446
447 #[test]
448 fn test_tsql_try_catch_with_throw() {
449 let sql = "BEGIN TRY\n SELECT 1;\nEND TRY\nBEGIN CATCH\n THROW;\nEND CATCH";
450 let cst = parse_tsql(sql);
451 assert_no_unparsable(&cst);
452 assert!(find_type(&cst, SegmentType::TryCatchBlock).is_some());
453 assert!(find_type(&cst, SegmentType::ThrowStatement).is_some());
454 }
455
456 #[test]
457 fn test_tsql_case_inside_begin_end() {
458 let sql = "BEGIN\n SELECT CASE WHEN @x > 0 THEN 'pos' ELSE 'neg' END;\nEND";
459 let cst = parse_tsql(sql);
460 assert_no_unparsable(&cst);
461 assert!(find_type(&cst, SegmentType::BeginEndBlock).is_some());
462 assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
463 }
464
465 #[test]
466 fn test_tsql_exec_retval() {
467 let cst = parse_tsql("EXEC @result = dbo.usp_Calculate;");
468 assert_no_unparsable(&cst);
469 assert!(find_type(&cst, SegmentType::ExecStatement).is_some());
470 }
471
472 #[test]
473 fn test_tsql_multiple_set_options() {
474 let sql = "SET ANSI_NULLS ON;\nSET QUOTED_IDENTIFIER ON;";
475 let cst = parse_tsql(sql);
476 assert_no_unparsable(&cst);
477 assert_eq!(cst.raw(), sql);
478 }
479
480 #[test]
483 fn test_tsql_with_nolock() {
484 let sql = "SELECT * FROM orders WITH(NOLOCK) WHERE id = 1";
485 let cst = parse_tsql(sql);
486 assert_no_unparsable(&cst);
487 assert!(find_type(&cst, SegmentType::TableHint).is_some());
488 assert!(find_type(&cst, SegmentType::FromClause).is_some());
489 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
490 assert_eq!(cst.raw(), sql);
491 }
492
493 #[test]
494 fn test_tsql_with_nolock_alias() {
495 let sql = "SELECT o.id FROM orders o WITH(NOLOCK)";
496 let cst = parse_tsql(sql);
497 assert_no_unparsable(&cst);
498 assert!(find_type(&cst, SegmentType::TableHint).is_some());
499 assert!(find_type(&cst, SegmentType::AliasExpression).is_some());
500 assert_eq!(cst.raw(), sql);
501 }
502
503 #[test]
504 fn test_tsql_with_nolock_join() {
505 let sql = "SELECT a.id FROM orders a WITH(NOLOCK) INNER JOIN items b WITH(READUNCOMMITTED) ON a.id = b.order_id";
506 let cst = parse_tsql(sql);
507 assert_no_unparsable(&cst);
508 let mut hint_count = 0;
510 cst.walk(&mut |s| {
511 if s.segment_type() == SegmentType::TableHint {
512 hint_count += 1;
513 }
514 });
515 assert_eq!(hint_count, 2);
516 assert!(find_type(&cst, SegmentType::JoinClause).is_some());
517 assert_eq!(cst.raw(), sql);
518 }
519
520 #[test]
521 fn test_tsql_with_multiple_hints() {
522 let sql = "SELECT * FROM orders WITH(NOLOCK, NOWAIT)";
523 let cst = parse_tsql(sql);
524 assert_no_unparsable(&cst);
525 assert!(find_type(&cst, SegmentType::TableHint).is_some());
526 assert_eq!(cst.raw(), sql);
527 }
528
529 #[test]
530 fn test_tsql_with_nolock_roundtrip() {
531 let sql = "SELECT o.id, o.total\nFROM orders o WITH(NOLOCK)\nINNER JOIN customers c WITH(NOLOCK) ON o.customer_id = c.id\nWHERE c.active = 1\nORDER BY o.id";
532 let cst = parse_tsql(sql);
533 assert_no_unparsable(&cst);
534 assert_eq!(cst.raw(), sql);
535 }
536
537 fn count_unparsable(seg: &Segment) -> usize {
540 let mut count = 0;
541 seg.walk(&mut |s| {
542 if s.segment_type() == SegmentType::Unparsable {
543 count += 1;
544 }
545 });
546 count
547 }
548
549 #[test]
550 fn test_error_recovery_garbage_then_valid() {
551 let sql = "XYZZY FOOBAR; SELECT 1;";
553 let cst = parse(sql);
554 assert_eq!(cst.raw(), sql, "roundtrip must preserve source");
555 assert_eq!(count_unparsable(&cst), 1);
557 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
559 }
560
561 #[test]
562 fn test_error_recovery_garbage_between_statements() {
563 let sql = "SELECT 1; NOTAKEYWORD 123 'abc'; SELECT 2;";
565 let cst = parse(sql);
566 assert_eq!(cst.raw(), sql);
567 assert_eq!(count_unparsable(&cst), 1);
568 let stmts: Vec<_> = cst
569 .children()
570 .iter()
571 .filter(|s| s.segment_type() == SegmentType::Statement)
572 .collect();
573 assert_eq!(stmts.len(), 2);
574 }
575
576 #[test]
577 fn test_error_recovery_garbage_at_end() {
578 let sql = "SELECT 1; XYZZY";
579 let cst = parse(sql);
580 assert_eq!(cst.raw(), sql);
581 assert_eq!(count_unparsable(&cst), 1);
582 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
583 }
584
585 #[test]
586 fn test_error_recovery_skips_to_statement_keyword() {
587 let sql = "XYZZY SELECT 1;";
589 let cst = parse(sql);
590 assert_eq!(cst.raw(), sql);
591 assert_eq!(count_unparsable(&cst), 1);
592 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
593 }
594
595 #[test]
596 fn test_error_recovery_diagnostics() {
597 let parser = Parser::default();
598 let result = parser.parse_with_diagnostics("XYZZY; SELECT 1;").unwrap();
599 assert!(!result.diagnostics.is_empty());
600 assert!(result.diagnostics[0].message.contains("Unparsable"));
601 assert_eq!(result.diagnostics[0].offset, 0);
603 assert!(find_type(&result.tree, SegmentType::SelectClause).is_some());
605 }
606
607 #[test]
608 fn test_error_recovery_diagnostics_offset_mid_file() {
609 let parser = Parser::default();
610 let result = parser
612 .parse_with_diagnostics("SELECT 1; BADTOKEN;")
613 .unwrap();
614 assert_eq!(result.diagnostics.len(), 1);
615 assert_eq!(result.diagnostics[0].offset, 10);
617 }
618
619 #[test]
620 fn test_error_recovery_all_garbage() {
621 let sql = "NOTAKEYWORD 123 'hello'";
622 let cst = parse(sql);
623 assert_eq!(cst.raw(), sql);
624 assert!(count_unparsable(&cst) >= 1);
626 }
627
628 #[test]
629 fn test_error_recovery_preserves_valid_statements() {
630 let sql = "INSERT INTO t VALUES (1); BADTOKEN; DELETE FROM t WHERE id = 1;";
632 let cst = parse(sql);
633 assert_eq!(cst.raw(), sql);
634 assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
635 assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
636 assert_eq!(count_unparsable(&cst), 1);
637 }
638
639 fn parse_pg(sql: &str) -> Segment {
642 Parser::new(LexerConfig::postgres(), Box::new(PostgresGrammar))
643 .parse(sql)
644 .unwrap()
645 }
646
647 #[test]
648 fn test_pg_double_colon_cast() {
649 let cst = parse_pg("SELECT col::int FROM t");
650 assert_no_unparsable(&cst);
651 assert!(find_type(&cst, SegmentType::TypeCastExpression).is_some());
652 assert_eq!(cst.raw(), "SELECT col::int FROM t");
653 }
654
655 #[test]
656 fn test_pg_chained_cast() {
657 let cst = parse_pg("SELECT '2024-01-01'::date::text FROM t");
658 assert_no_unparsable(&cst);
659 let mut count = 0;
661 cst.walk(&mut |s| {
662 if s.segment_type() == SegmentType::TypeCastExpression {
663 count += 1;
664 }
665 });
666 assert_eq!(
667 count, 2,
668 "Expected 2 TypeCastExpression nodes for chained cast"
669 );
670 }
671
672 #[test]
673 fn test_pg_cast_with_precision() {
674 let cst = parse_pg("SELECT col::numeric(10, 2) FROM t");
675 assert_no_unparsable(&cst);
676 assert!(find_type(&cst, SegmentType::TypeCastExpression).is_some());
677 assert!(find_type(&cst, SegmentType::DataType).is_some());
678 }
679
680 #[test]
681 fn test_pg_array_subscript() {
682 let cst = parse_pg("SELECT arr[1] FROM t");
683 assert_no_unparsable(&cst);
684 assert!(find_type(&cst, SegmentType::ArrayAccessExpression).is_some());
685 }
686
687 #[test]
688 fn test_pg_array_cast_chain() {
689 let cst = parse_pg("SELECT arr[1]::text FROM t");
690 assert_no_unparsable(&cst);
691 assert!(find_type(&cst, SegmentType::ArrayAccessExpression).is_some());
692 assert!(find_type(&cst, SegmentType::TypeCastExpression).is_some());
693 }
694
695 #[test]
696 fn test_pg_insert_returning() {
697 let cst = parse_pg("INSERT INTO users (name) VALUES ('Alice') RETURNING id, name");
698 assert_no_unparsable(&cst);
699 assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
700 assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
701 }
702
703 #[test]
704 fn test_pg_update_returning() {
705 let cst = parse_pg("UPDATE users SET name = 'Bob' WHERE id = 1 RETURNING *");
706 assert_no_unparsable(&cst);
707 assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
708 assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
709 }
710
711 #[test]
712 fn test_pg_delete_returning() {
713 let cst = parse_pg("DELETE FROM users WHERE id = 1 RETURNING id");
714 assert_no_unparsable(&cst);
715 assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
716 assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
717 }
718
719 #[test]
720 fn test_pg_on_conflict_do_nothing() {
721 let cst = parse_pg(
722 "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT (id) DO NOTHING",
723 );
724 assert_no_unparsable(&cst);
725 assert!(find_type(&cst, SegmentType::OnConflictClause).is_some());
726 }
727
728 #[test]
729 fn test_pg_on_conflict_do_update() {
730 let cst = parse_pg(
731 "INSERT INTO users (id, name) VALUES (1, 'Alice') \
732 ON CONFLICT (id) DO UPDATE SET name = 'Alice'",
733 );
734 assert_no_unparsable(&cst);
735 assert!(find_type(&cst, SegmentType::OnConflictClause).is_some());
736 assert!(find_type(&cst, SegmentType::SetClause).is_some());
737 }
738
739 #[test]
740 fn test_pg_upsert_returning() {
741 let cst = parse_pg(
742 "INSERT INTO users (id, name) VALUES (1, 'Alice') \
743 ON CONFLICT (id) DO UPDATE SET name = 'Alice' RETURNING *",
744 );
745 assert_no_unparsable(&cst);
746 assert!(find_type(&cst, SegmentType::OnConflictClause).is_some());
747 assert!(find_type(&cst, SegmentType::ReturningClause).is_some());
748 }
749
750 #[test]
751 fn test_pg_distinct_on() {
752 let cst = parse_pg(
753 "SELECT DISTINCT ON (dept) name, salary FROM employees ORDER BY dept, salary DESC",
754 );
755 assert_no_unparsable(&cst);
756 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
757 assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
758 }
759
760 #[test]
761 fn test_pg_dollar_quoted_string() {
762 let cst = parse_pg("SELECT $$hello world$$");
763 assert_no_unparsable(&cst);
764 assert_eq!(cst.raw(), "SELECT $$hello world$$");
765 }
766
767 #[test]
768 fn test_pg_ilike() {
769 let cst = parse_pg("SELECT * FROM users WHERE name ILIKE '%alice%'");
770 assert_no_unparsable(&cst);
771 assert!(find_type(&cst, SegmentType::LikeExpression).is_some());
772 }
773
774 #[test]
775 fn test_pg_roundtrip_complex() {
776 let sql = "INSERT INTO orders (user_id, total) \
777 VALUES (1, 99.99) \
778 ON CONFLICT (user_id) DO UPDATE SET total = orders.total + 99.99 \
779 RETURNING id, total::numeric(10, 2)";
780 let cst = parse_pg(sql);
781 assert_eq!(
782 cst.raw(),
783 sql,
784 "CST roundtrip must preserve source text exactly"
785 );
786 assert_no_unparsable(&cst);
787 }
788}