1use std::sync::LazyLock;
2
3use rigsql_core::{NodeSegment, Segment, SegmentType, TokenKind};
4
5use crate::context::ParseContext;
6
7use super::ansi::ANSI_STATEMENT_KEYWORDS;
8use super::{
9 any_token_segment, eat_trivia_segments, parse_comma_separated, parse_statement_list,
10 token_segment, Grammar,
11};
12
13pub struct TsqlGrammar;
15
16const TSQL_EXTRA_KEYWORDS: &[&str] = &[
18 "BEGIN",
19 "DECLARE",
20 "EXEC",
21 "EXECUTE",
22 "GO",
23 "IF",
24 "PRINT",
25 "RAISERROR",
26 "RETURN",
27 "SET",
28 "THROW",
29 "WHILE",
30];
31
32static TSQL_STATEMENT_KEYWORDS: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
34 let mut kws: Vec<&str> = ANSI_STATEMENT_KEYWORDS
35 .iter()
36 .chain(TSQL_EXTRA_KEYWORDS.iter())
37 .copied()
38 .collect();
39 kws.sort_unstable();
40 kws.dedup();
41 kws
42});
43
44impl Grammar for TsqlGrammar {
45 fn statement_keywords(&self) -> &[&str] {
46 &TSQL_STATEMENT_KEYWORDS
47 }
48
49 fn dispatch_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
50 if ctx.peek_keyword("DECLARE") {
52 self.parse_declare_statement(ctx)
53 } else if ctx.peek_keyword("SET") {
54 self.parse_set_variable_statement(ctx)
55 } else if ctx.peek_keyword("IF") {
56 self.parse_if_statement(ctx)
57 } else if ctx.peek_keyword("BEGIN") {
58 self.parse_begin_block(ctx)
59 } else if ctx.peek_keyword("WHILE") {
60 self.parse_while_statement(ctx)
61 } else if ctx.peek_keyword("EXEC") || ctx.peek_keyword("EXECUTE") {
62 self.parse_exec_statement(ctx)
63 } else if ctx.peek_keyword("RETURN") {
64 self.parse_return_statement(ctx)
65 } else if ctx.peek_keyword("PRINT") {
66 self.parse_print_statement(ctx)
67 } else if ctx.peek_keyword("THROW") {
68 self.parse_throw_statement(ctx)
69 } else if ctx.peek_keyword("RAISERROR") {
70 self.parse_raiserror_statement(ctx)
71 } else if ctx.peek_keyword("GO") {
72 self.parse_go_statement(ctx)
73 } else {
74 self.dispatch_ansi_statement(ctx)
76 }
77 }
78
79 fn consume_until_end(&self, ctx: &mut ParseContext, children: &mut Vec<Segment>) {
82 let mut paren_depth = 0u32;
83 let mut begin_depth = 0u32;
84 let mut case_depth = 0u32;
85 while !ctx.at_eof() {
86 match ctx.peek_kind() {
87 Some(TokenKind::Semicolon) if paren_depth == 0 && begin_depth == 0 => break,
88 Some(TokenKind::LParen) => {
89 paren_depth += 1;
90 let token = ctx.advance().unwrap();
91 children.push(any_token_segment(token));
92 }
93 Some(TokenKind::RParen) => {
94 paren_depth = paren_depth.saturating_sub(1);
95 let token = ctx.advance().unwrap();
96 children.push(any_token_segment(token));
97 }
98 _ => {
99 let t = ctx.peek().unwrap();
100 if t.kind == TokenKind::Word {
101 if t.text.eq_ignore_ascii_case("BEGIN") {
102 begin_depth += 1;
103 let token = ctx.advance().unwrap();
104 children.push(any_token_segment(token));
105 continue;
106 } else if t.text.eq_ignore_ascii_case("CASE") {
107 case_depth += 1;
108 let token = ctx.advance().unwrap();
109 children.push(any_token_segment(token));
110 continue;
111 } else if t.text.eq_ignore_ascii_case("END") {
112 if case_depth > 0 {
113 case_depth -= 1;
114 let token = ctx.advance().unwrap();
115 children.push(any_token_segment(token));
116 continue;
117 }
118 if begin_depth > 0 {
119 begin_depth -= 1;
120 let token = ctx.advance().unwrap();
121 children.push(any_token_segment(token));
122 if begin_depth == 0 && paren_depth == 0 {
123 break;
124 }
125 continue;
126 }
127 } else if t.text.eq_ignore_ascii_case("GO")
128 && paren_depth == 0
129 && begin_depth == 0
130 {
131 break;
132 }
133 }
134 let token = ctx.advance().unwrap();
135 children.push(any_token_segment(token));
136 }
137 }
138 }
139 }
140}
141
142impl TsqlGrammar {
145 fn parse_declare_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
147 let mut children = Vec::new();
148 let kw = ctx.eat_keyword("DECLARE")?;
149 children.push(token_segment(kw, SegmentType::Keyword));
150 children.extend(eat_trivia_segments(ctx));
151
152 self.parse_declare_variable(ctx, &mut children);
154
155 loop {
157 let save = ctx.save();
158 let trivia = eat_trivia_segments(ctx);
159 if let Some(comma) = ctx.eat_kind(TokenKind::Comma) {
160 children.extend(trivia);
161 children.push(token_segment(comma, SegmentType::Comma));
162 children.extend(eat_trivia_segments(ctx));
163 self.parse_declare_variable(ctx, &mut children);
164 } else {
165 ctx.restore(save);
166 break;
167 }
168 }
169
170 Some(Segment::Node(NodeSegment::new(
171 SegmentType::DeclareStatement,
172 children,
173 )))
174 }
175
176 fn parse_declare_variable(&self, ctx: &mut ParseContext, children: &mut Vec<Segment>) {
178 if ctx.peek_kind() == Some(TokenKind::AtSign) {
180 let at = ctx.advance().unwrap();
181 children.push(token_segment(at, SegmentType::Identifier));
182 children.extend(eat_trivia_segments(ctx));
183 } else if ctx.peek_kind() == Some(TokenKind::Word) {
184 let save = ctx.save();
186 let name = ctx.advance().unwrap();
187 let trivia = eat_trivia_segments(ctx);
188 if ctx.peek_keyword("CURSOR") {
189 children.push(token_segment(name, SegmentType::Identifier));
190 children.extend(trivia);
191 } else {
192 ctx.restore(save);
193 }
194 }
195
196 if ctx.peek_keyword("AS") {
198 let as_kw = ctx.advance().unwrap();
199 children.push(token_segment(as_kw, SegmentType::Keyword));
200 children.extend(eat_trivia_segments(ctx));
201 }
202
203 if ctx.peek_keyword("CURSOR") {
207 let cursor_kw = ctx.advance().unwrap();
208 children.push(token_segment(cursor_kw, SegmentType::Keyword));
209 children.extend(eat_trivia_segments(ctx));
210
211 while !ctx.at_eof() && !ctx.peek_keyword("FOR") {
213 if ctx.peek_kind() == Some(TokenKind::Semicolon) {
214 break;
215 }
216 if ctx.peek_kind() == Some(TokenKind::Word) {
217 let opt = ctx.advance().unwrap();
218 children.push(token_segment(opt, SegmentType::Keyword));
219 children.extend(eat_trivia_segments(ctx));
220 } else {
221 break;
222 }
223 }
224
225 if ctx.peek_keyword("FOR") {
227 let for_kw = ctx.advance().unwrap();
228 children.push(token_segment(for_kw, SegmentType::Keyword));
229 children.extend(eat_trivia_segments(ctx));
230 if let Some(sel) = self.parse_select_statement(ctx) {
231 children.push(sel);
232 }
233 }
234 return;
235 }
236
237 if ctx.peek_keyword("TABLE") {
239 let table_kw = ctx.advance().unwrap();
240 children.push(token_segment(table_kw, SegmentType::Keyword));
241 children.extend(eat_trivia_segments(ctx));
242 if ctx.peek_kind() == Some(TokenKind::LParen) {
243 if let Some(defs) = self.parse_paren_block(ctx) {
244 children.push(defs);
245 }
246 }
247 return;
248 }
249
250 if let Some(dt) = self.parse_data_type(ctx) {
252 children.push(dt);
253 children.extend(eat_trivia_segments(ctx));
254 }
255
256 if ctx.peek_kind() == Some(TokenKind::Eq) {
258 let eq = ctx.advance().unwrap();
259 children.push(token_segment(eq, SegmentType::ComparisonOperator));
260 children.extend(eat_trivia_segments(ctx));
261 if let Some(expr) = self.parse_expression(ctx) {
262 children.push(expr);
263 }
264 }
265 }
266
267 fn parse_set_variable_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
269 let save = ctx.save();
270 let mut children = Vec::new();
271 let kw = ctx.eat_keyword("SET")?;
272 children.push(token_segment(kw, SegmentType::Keyword));
273 children.extend(eat_trivia_segments(ctx));
274
275 if ctx.peek_kind() == Some(TokenKind::AtSign) {
277 let at = ctx.advance().unwrap();
278 children.push(token_segment(at, SegmentType::Identifier));
279 children.extend(eat_trivia_segments(ctx));
280
281 if let Some(kind) = ctx.peek_kind() {
283 if matches!(
284 kind,
285 TokenKind::Eq
286 | TokenKind::Plus
287 | TokenKind::Minus
288 | TokenKind::Star
289 | TokenKind::Slash
290 ) {
291 let op = ctx.advance().unwrap();
292 children.push(token_segment(op, SegmentType::Operator));
293 if ctx.peek_kind() == Some(TokenKind::Eq) {
295 let eq = ctx.advance().unwrap();
296 children.push(token_segment(eq, SegmentType::Operator));
297 }
298 children.extend(eat_trivia_segments(ctx));
299 if let Some(expr) = self.parse_expression(ctx) {
300 children.push(expr);
301 }
302 }
303 }
304
305 return Some(Segment::Node(NodeSegment::new(
306 SegmentType::SetVariableStatement,
307 children,
308 )));
309 }
310
311 if ctx.peek_kind() == Some(TokenKind::Word) {
313 self.consume_until_statement_end(ctx, &mut children);
314 return Some(Segment::Node(NodeSegment::new(
315 SegmentType::SetVariableStatement,
316 children,
317 )));
318 }
319
320 ctx.restore(save);
321 None
322 }
323
324 fn parse_if_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
326 let mut children = Vec::new();
327 let kw = ctx.eat_keyword("IF")?;
328 children.push(token_segment(kw, SegmentType::Keyword));
329 children.extend(eat_trivia_segments(ctx));
330
331 if let Some(cond) = self.parse_expression(ctx) {
333 children.push(cond);
334 }
335 children.extend(eat_trivia_segments(ctx));
336
337 if let Some(stmt) = self.parse_statement(ctx) {
339 children.push(stmt);
340 }
341
342 children.extend(eat_trivia_segments(ctx));
344 if ctx.peek_keyword("ELSE") {
345 let else_kw = ctx.advance().unwrap();
346 children.push(token_segment(else_kw, SegmentType::Keyword));
347 children.extend(eat_trivia_segments(ctx));
348
349 if let Some(stmt) = self.parse_statement(ctx) {
350 children.push(stmt);
351 }
352 }
353
354 Some(Segment::Node(NodeSegment::new(
355 SegmentType::IfStatement,
356 children,
357 )))
358 }
359
360 fn parse_begin_block(&self, ctx: &mut ParseContext) -> Option<Segment> {
362 if ctx.peek_keywords(&["BEGIN", "TRY"]) {
364 return self.parse_try_catch_block(ctx);
365 }
366
367 let mut children = Vec::new();
368 let begin_kw = ctx.eat_keyword("BEGIN")?;
369 children.push(token_segment(begin_kw, SegmentType::Keyword));
370
371 parse_statement_list(self, ctx, &mut children, |c| c.peek_keyword("END"));
372
373 children.extend(eat_trivia_segments(ctx));
375 if let Some(end_kw) = ctx.eat_keyword("END") {
376 children.push(token_segment(end_kw, SegmentType::Keyword));
377 }
378
379 Some(Segment::Node(NodeSegment::new(
380 SegmentType::BeginEndBlock,
381 children,
382 )))
383 }
384
385 fn parse_try_catch_block(&self, ctx: &mut ParseContext) -> Option<Segment> {
387 let mut children = Vec::new();
388
389 let begin_kw = ctx.eat_keyword("BEGIN")?;
391 children.push(token_segment(begin_kw, SegmentType::Keyword));
392 children.extend(eat_trivia_segments(ctx));
393 let try_kw = ctx.eat_keyword("TRY")?;
394 children.push(token_segment(try_kw, SegmentType::Keyword));
395
396 parse_statement_list(self, ctx, &mut children, |c| {
397 c.peek_keywords(&["END", "TRY"])
398 });
399
400 children.extend(eat_trivia_segments(ctx));
402 if let Some(end_kw) = ctx.eat_keyword("END") {
403 children.push(token_segment(end_kw, SegmentType::Keyword));
404 children.extend(eat_trivia_segments(ctx));
405 }
406 if let Some(try_kw) = ctx.eat_keyword("TRY") {
407 children.push(token_segment(try_kw, SegmentType::Keyword));
408 }
409
410 children.extend(eat_trivia_segments(ctx));
412 if let Some(begin_kw) = ctx.eat_keyword("BEGIN") {
413 children.push(token_segment(begin_kw, SegmentType::Keyword));
414 children.extend(eat_trivia_segments(ctx));
415 if let Some(catch_kw) = ctx.eat_keyword("CATCH") {
416 children.push(token_segment(catch_kw, SegmentType::Keyword));
417 }
418
419 parse_statement_list(self, ctx, &mut children, |c| {
420 c.peek_keywords(&["END", "CATCH"])
421 });
422
423 children.extend(eat_trivia_segments(ctx));
425 if let Some(end_kw) = ctx.eat_keyword("END") {
426 children.push(token_segment(end_kw, SegmentType::Keyword));
427 children.extend(eat_trivia_segments(ctx));
428 }
429 if let Some(catch_kw) = ctx.eat_keyword("CATCH") {
430 children.push(token_segment(catch_kw, SegmentType::Keyword));
431 }
432 }
433
434 Some(Segment::Node(NodeSegment::new(
435 SegmentType::TryCatchBlock,
436 children,
437 )))
438 }
439
440 fn parse_while_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
442 let mut children = Vec::new();
443 let kw = ctx.eat_keyword("WHILE")?;
444 children.push(token_segment(kw, SegmentType::Keyword));
445 children.extend(eat_trivia_segments(ctx));
446
447 if let Some(cond) = self.parse_expression(ctx) {
449 children.push(cond);
450 }
451 children.extend(eat_trivia_segments(ctx));
452
453 if let Some(stmt) = self.parse_statement(ctx) {
455 children.push(stmt);
456 }
457
458 Some(Segment::Node(NodeSegment::new(
459 SegmentType::WhileStatement,
460 children,
461 )))
462 }
463
464 fn parse_exec_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
466 let mut children = Vec::new();
467 let kw = if ctx.peek_keyword("EXEC") {
469 ctx.eat_keyword("EXEC")
470 } else {
471 ctx.eat_keyword("EXECUTE")
472 };
473 let kw = kw?;
474 children.push(token_segment(kw, SegmentType::Keyword));
475 children.extend(eat_trivia_segments(ctx));
476
477 let save = ctx.save();
479 if ctx.peek_kind() == Some(TokenKind::AtSign) {
480 let at = ctx.advance().unwrap();
481 let trivia = eat_trivia_segments(ctx);
482 if ctx.peek_kind() == Some(TokenKind::Eq) {
483 children.push(token_segment(at, SegmentType::Identifier));
484 children.extend(trivia);
485 let eq = ctx.advance().unwrap();
486 children.push(token_segment(eq, SegmentType::Operator));
487 children.extend(eat_trivia_segments(ctx));
488 } else {
489 ctx.restore(save);
490 }
491 }
492
493 if let Some(name) = self.parse_qualified_name(ctx) {
495 children.push(name);
496 }
497 children.extend(eat_trivia_segments(ctx));
498
499 self.parse_exec_params(ctx, &mut children);
501
502 Some(Segment::Node(NodeSegment::new(
503 SegmentType::ExecStatement,
504 children,
505 )))
506 }
507
508 fn parse_exec_params(&self, ctx: &mut ParseContext, children: &mut Vec<Segment>) {
510 if ctx.at_eof()
511 || ctx.peek_kind() == Some(TokenKind::Semicolon)
512 || self.peek_statement_start(ctx)
513 {
514 return;
515 }
516 parse_comma_separated(ctx, children, |c| self.parse_expression(c));
517 }
518
519 fn parse_return_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
521 let mut children = Vec::new();
522 let kw = ctx.eat_keyword("RETURN")?;
523 children.push(token_segment(kw, SegmentType::Keyword));
524
525 let save = ctx.save();
527 let trivia = eat_trivia_segments(ctx);
528 if !ctx.at_eof()
529 && ctx.peek_kind() != Some(TokenKind::Semicolon)
530 && !self.peek_statement_start(ctx)
531 {
532 children.extend(trivia);
533 if let Some(expr) = self.parse_expression(ctx) {
534 children.push(expr);
535 }
536 } else {
537 ctx.restore(save);
538 }
539
540 Some(Segment::Node(NodeSegment::new(
541 SegmentType::ReturnStatement,
542 children,
543 )))
544 }
545
546 fn parse_print_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
548 let mut children = Vec::new();
549 let kw = ctx.eat_keyword("PRINT")?;
550 children.push(token_segment(kw, SegmentType::Keyword));
551 children.extend(eat_trivia_segments(ctx));
552
553 if let Some(expr) = self.parse_expression(ctx) {
554 children.push(expr);
555 }
556
557 Some(Segment::Node(NodeSegment::new(
558 SegmentType::PrintStatement,
559 children,
560 )))
561 }
562
563 fn parse_throw_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
565 let mut children = Vec::new();
566 let kw = ctx.eat_keyword("THROW")?;
567 children.push(token_segment(kw, SegmentType::Keyword));
568
569 let save = ctx.save();
571 let trivia = eat_trivia_segments(ctx);
572 if ctx.at_eof()
573 || ctx.peek_kind() == Some(TokenKind::Semicolon)
574 || self.peek_statement_start(ctx)
575 {
576 ctx.restore(save);
577 return Some(Segment::Node(NodeSegment::new(
578 SegmentType::ThrowStatement,
579 children,
580 )));
581 }
582
583 children.extend(trivia);
585 if let Some(expr) = self.parse_expression(ctx) {
586 children.push(expr);
587 }
588 for _ in 0..2 {
590 let save2 = ctx.save();
591 let trivia2 = eat_trivia_segments(ctx);
592 if let Some(comma) = ctx.eat_kind(TokenKind::Comma) {
593 children.extend(trivia2);
594 children.push(token_segment(comma, SegmentType::Comma));
595 children.extend(eat_trivia_segments(ctx));
596 if let Some(expr) = self.parse_expression(ctx) {
597 children.push(expr);
598 }
599 } else {
600 ctx.restore(save2);
601 break;
602 }
603 }
604
605 Some(Segment::Node(NodeSegment::new(
606 SegmentType::ThrowStatement,
607 children,
608 )))
609 }
610
611 fn parse_raiserror_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
613 let mut children = Vec::new();
614 let kw = ctx.eat_keyword("RAISERROR")?;
615 children.push(token_segment(kw, SegmentType::Keyword));
616 children.extend(eat_trivia_segments(ctx));
617
618 if ctx.peek_kind() == Some(TokenKind::LParen) {
620 if let Some(args) = self.parse_paren_block(ctx) {
621 children.push(args);
622 }
623 }
624
625 children.extend(eat_trivia_segments(ctx));
627 if ctx.peek_keyword("WITH") {
628 let with_kw = ctx.advance().unwrap();
629 children.push(token_segment(with_kw, SegmentType::Keyword));
630 children.extend(eat_trivia_segments(ctx));
631 while ctx.peek_kind() == Some(TokenKind::Word) {
633 let opt = ctx.advance().unwrap();
634 children.push(token_segment(opt, SegmentType::Keyword));
635 let save = ctx.save();
636 let trivia = eat_trivia_segments(ctx);
637 if ctx.peek_kind() == Some(TokenKind::Comma) {
638 children.extend(trivia);
639 let comma = ctx.advance().unwrap();
640 children.push(token_segment(comma, SegmentType::Comma));
641 children.extend(eat_trivia_segments(ctx));
642 } else {
643 ctx.restore(save);
644 break;
645 }
646 }
647 }
648
649 Some(Segment::Node(NodeSegment::new(
650 SegmentType::RaiserrorStatement,
651 children,
652 )))
653 }
654
655 fn parse_go_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
657 let mut children = Vec::new();
658 let kw = ctx.eat_keyword("GO")?;
659 children.push(token_segment(kw, SegmentType::Keyword));
660
661 let save = ctx.save();
663 let trivia = eat_trivia_segments(ctx);
664 if ctx.peek_kind() == Some(TokenKind::NumberLiteral) {
665 children.extend(trivia);
666 let num = ctx.advance().unwrap();
667 children.push(token_segment(num, SegmentType::NumericLiteral));
668 } else {
669 ctx.restore(save);
670 }
671
672 Some(Segment::Node(NodeSegment::new(
673 SegmentType::GoStatement,
674 children,
675 )))
676 }
677}