use std::sync::LazyLock;
use rigsql_core::{NodeSegment, Segment, SegmentType, TokenKind};
use crate::context::ParseContext;
use super::ansi::ANSI_STATEMENT_KEYWORDS;
use super::{
any_token_segment, eat_trivia_segments, parse_comma_separated, parse_statement_list,
token_segment, Grammar,
};
pub struct TsqlGrammar;
const TSQL_EXTRA_KEYWORDS: &[&str] = &[
"BEGIN",
"DECLARE",
"EXEC",
"EXECUTE",
"GO",
"IF",
"PRINT",
"RAISERROR",
"RETURN",
"SET",
"THROW",
"WHILE",
];
static TSQL_STATEMENT_KEYWORDS: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
let mut kws: Vec<&str> = ANSI_STATEMENT_KEYWORDS
.iter()
.chain(TSQL_EXTRA_KEYWORDS.iter())
.copied()
.collect();
kws.sort_unstable();
kws.dedup();
kws
});
impl Grammar for TsqlGrammar {
fn statement_keywords(&self) -> &[&str] {
&TSQL_STATEMENT_KEYWORDS
}
fn dispatch_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
if ctx.peek_keyword("DECLARE") {
self.parse_declare_statement(ctx)
} else if ctx.peek_keyword("SET") {
self.parse_set_variable_statement(ctx)
} else if ctx.peek_keyword("IF") {
self.parse_if_statement(ctx)
} else if ctx.peek_keyword("BEGIN") {
self.parse_begin_block(ctx)
} else if ctx.peek_keyword("WHILE") {
self.parse_while_statement(ctx)
} else if ctx.peek_keyword("EXEC") || ctx.peek_keyword("EXECUTE") {
self.parse_exec_statement(ctx)
} else if ctx.peek_keyword("RETURN") {
self.parse_return_statement(ctx)
} else if ctx.peek_keyword("PRINT") {
self.parse_print_statement(ctx)
} else if ctx.peek_keyword("THROW") {
self.parse_throw_statement(ctx)
} else if ctx.peek_keyword("RAISERROR") {
self.parse_raiserror_statement(ctx)
} else if ctx.peek_keyword("GO") {
self.parse_go_statement(ctx)
} else {
self.dispatch_ansi_statement(ctx)
}
}
fn parse_table_hint(&self, ctx: &mut ParseContext) -> Option<Segment> {
if !ctx.peek_keyword("WITH") {
return None;
}
let save = ctx.save();
let with_kw = ctx.eat_keyword("WITH")?;
let trivia_after_with = eat_trivia_segments(ctx);
if ctx.peek_kind() != Some(TokenKind::LParen) {
ctx.restore(save);
return None;
}
let lparen = ctx.advance().unwrap();
let mut children = Vec::new();
children.push(token_segment(with_kw, SegmentType::Keyword));
children.extend(trivia_after_with);
children.push(token_segment(lparen, SegmentType::LParen));
children.extend(eat_trivia_segments(ctx));
let mut first = true;
while !ctx.at_eof() && ctx.peek_kind() != Some(TokenKind::RParen) {
if !first {
if let Some(comma) = ctx.eat_kind(TokenKind::Comma) {
children.push(token_segment(comma, SegmentType::Comma));
children.extend(eat_trivia_segments(ctx));
} else {
break;
}
}
first = false;
if ctx.peek_kind() == Some(TokenKind::Word) {
let hint = ctx.advance().unwrap();
children.push(token_segment(hint, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
} else {
break;
}
}
if let Some(rparen) = ctx.eat_kind(TokenKind::RParen) {
children.push(token_segment(rparen, SegmentType::RParen));
} else {
ctx.restore(save);
return None;
}
Some(Segment::Node(NodeSegment::new(
SegmentType::TableHint,
children,
)))
}
fn consume_until_end(&self, ctx: &mut ParseContext, children: &mut Vec<Segment>) {
let mut paren_depth = 0u32;
let mut begin_depth = 0u32;
let mut case_depth = 0u32;
while !ctx.at_eof() {
match ctx.peek_kind() {
Some(TokenKind::Semicolon) if paren_depth == 0 && begin_depth == 0 => break,
Some(TokenKind::LParen) => {
paren_depth += 1;
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
}
Some(TokenKind::RParen) => {
paren_depth = paren_depth.saturating_sub(1);
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
}
_ => {
let t = ctx.peek().unwrap();
if t.kind == TokenKind::Word {
if t.text.eq_ignore_ascii_case("BEGIN") {
begin_depth += 1;
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
continue;
} else if t.text.eq_ignore_ascii_case("CASE") {
case_depth += 1;
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
continue;
} else if t.text.eq_ignore_ascii_case("END") {
if case_depth > 0 {
case_depth -= 1;
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
continue;
}
if begin_depth > 0 {
begin_depth -= 1;
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
if begin_depth == 0 && paren_depth == 0 {
break;
}
continue;
}
} else if t.text.eq_ignore_ascii_case("GO")
&& paren_depth == 0
&& begin_depth == 0
{
break;
}
}
let token = ctx.advance().unwrap();
children.push(any_token_segment(token));
}
}
}
}
}
impl TsqlGrammar {
fn parse_declare_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("DECLARE")?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
self.parse_declare_variable(ctx, &mut children);
loop {
let save = ctx.save();
let trivia = eat_trivia_segments(ctx);
if let Some(comma) = ctx.eat_kind(TokenKind::Comma) {
children.extend(trivia);
children.push(token_segment(comma, SegmentType::Comma));
children.extend(eat_trivia_segments(ctx));
self.parse_declare_variable(ctx, &mut children);
} else {
ctx.restore(save);
break;
}
}
Some(Segment::Node(NodeSegment::new(
SegmentType::DeclareStatement,
children,
)))
}
fn parse_declare_variable(&self, ctx: &mut ParseContext, children: &mut Vec<Segment>) {
if ctx.peek_kind() == Some(TokenKind::AtSign) {
let at = ctx.advance().unwrap();
children.push(token_segment(at, SegmentType::Identifier));
children.extend(eat_trivia_segments(ctx));
} else if ctx.peek_kind() == Some(TokenKind::Word) {
let save = ctx.save();
let name = ctx.advance().unwrap();
let trivia = eat_trivia_segments(ctx);
if ctx.peek_keyword("CURSOR") {
children.push(token_segment(name, SegmentType::Identifier));
children.extend(trivia);
} else {
ctx.restore(save);
}
}
if ctx.peek_keyword("AS") {
let as_kw = ctx.advance().unwrap();
children.push(token_segment(as_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
}
if ctx.peek_keyword("CURSOR") {
let cursor_kw = ctx.advance().unwrap();
children.push(token_segment(cursor_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
while !ctx.at_eof() && !ctx.peek_keyword("FOR") {
if ctx.peek_kind() == Some(TokenKind::Semicolon) {
break;
}
if ctx.peek_kind() == Some(TokenKind::Word) {
let opt = ctx.advance().unwrap();
children.push(token_segment(opt, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
} else {
break;
}
}
if ctx.peek_keyword("FOR") {
let for_kw = ctx.advance().unwrap();
children.push(token_segment(for_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if let Some(sel) = self.parse_select_statement(ctx) {
children.push(sel);
}
}
return;
}
if ctx.peek_keyword("TABLE") {
let table_kw = ctx.advance().unwrap();
children.push(token_segment(table_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if ctx.peek_kind() == Some(TokenKind::LParen) {
if let Some(defs) = self.parse_paren_block(ctx) {
children.push(defs);
}
}
return;
}
if let Some(dt) = self.parse_data_type(ctx) {
children.push(dt);
children.extend(eat_trivia_segments(ctx));
}
if ctx.peek_kind() == Some(TokenKind::Eq) {
let eq = ctx.advance().unwrap();
children.push(token_segment(eq, SegmentType::ComparisonOperator));
children.extend(eat_trivia_segments(ctx));
if let Some(expr) = self.parse_expression(ctx) {
children.push(expr);
}
}
}
fn parse_set_variable_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let save = ctx.save();
let mut children = Vec::new();
let kw = ctx.eat_keyword("SET")?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if ctx.peek_kind() == Some(TokenKind::AtSign) {
let at = ctx.advance().unwrap();
children.push(token_segment(at, SegmentType::Identifier));
children.extend(eat_trivia_segments(ctx));
if let Some(kind) = ctx.peek_kind() {
if matches!(
kind,
TokenKind::Eq
| TokenKind::Plus
| TokenKind::Minus
| TokenKind::Star
| TokenKind::Slash
) {
let op = ctx.advance().unwrap();
children.push(token_segment(op, SegmentType::Operator));
if ctx.peek_kind() == Some(TokenKind::Eq) {
let eq = ctx.advance().unwrap();
children.push(token_segment(eq, SegmentType::Operator));
}
children.extend(eat_trivia_segments(ctx));
if let Some(expr) = self.parse_expression(ctx) {
children.push(expr);
}
}
}
return Some(Segment::Node(NodeSegment::new(
SegmentType::SetVariableStatement,
children,
)));
}
if ctx.peek_kind() == Some(TokenKind::Word) {
self.consume_until_statement_end(ctx, &mut children);
return Some(Segment::Node(NodeSegment::new(
SegmentType::SetVariableStatement,
children,
)));
}
ctx.restore(save);
None
}
fn parse_if_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("IF")?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if let Some(cond) = self.parse_expression(ctx) {
children.push(cond);
}
children.extend(eat_trivia_segments(ctx));
if let Some(stmt) = self.parse_statement(ctx) {
children.push(stmt);
}
children.extend(eat_trivia_segments(ctx));
if ctx.peek_keyword("ELSE") {
let else_kw = ctx.advance().unwrap();
children.push(token_segment(else_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if let Some(stmt) = self.parse_statement(ctx) {
children.push(stmt);
}
}
Some(Segment::Node(NodeSegment::new(
SegmentType::IfStatement,
children,
)))
}
fn parse_begin_block(&self, ctx: &mut ParseContext) -> Option<Segment> {
if ctx.peek_keywords(&["BEGIN", "TRY"]) {
return self.parse_try_catch_block(ctx);
}
let mut children = Vec::new();
let begin_kw = ctx.eat_keyword("BEGIN")?;
children.push(token_segment(begin_kw, SegmentType::Keyword));
parse_statement_list(self, ctx, &mut children, |c| c.peek_keyword("END"));
children.extend(eat_trivia_segments(ctx));
if let Some(end_kw) = ctx.eat_keyword("END") {
children.push(token_segment(end_kw, SegmentType::Keyword));
}
Some(Segment::Node(NodeSegment::new(
SegmentType::BeginEndBlock,
children,
)))
}
fn parse_try_catch_block(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let begin_kw = ctx.eat_keyword("BEGIN")?;
children.push(token_segment(begin_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
let try_kw = ctx.eat_keyword("TRY")?;
children.push(token_segment(try_kw, SegmentType::Keyword));
parse_statement_list(self, ctx, &mut children, |c| {
c.peek_keywords(&["END", "TRY"])
});
children.extend(eat_trivia_segments(ctx));
if let Some(end_kw) = ctx.eat_keyword("END") {
children.push(token_segment(end_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
}
if let Some(try_kw) = ctx.eat_keyword("TRY") {
children.push(token_segment(try_kw, SegmentType::Keyword));
}
children.extend(eat_trivia_segments(ctx));
if let Some(begin_kw) = ctx.eat_keyword("BEGIN") {
children.push(token_segment(begin_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if let Some(catch_kw) = ctx.eat_keyword("CATCH") {
children.push(token_segment(catch_kw, SegmentType::Keyword));
}
parse_statement_list(self, ctx, &mut children, |c| {
c.peek_keywords(&["END", "CATCH"])
});
children.extend(eat_trivia_segments(ctx));
if let Some(end_kw) = ctx.eat_keyword("END") {
children.push(token_segment(end_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
}
if let Some(catch_kw) = ctx.eat_keyword("CATCH") {
children.push(token_segment(catch_kw, SegmentType::Keyword));
}
}
Some(Segment::Node(NodeSegment::new(
SegmentType::TryCatchBlock,
children,
)))
}
fn parse_while_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("WHILE")?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if let Some(cond) = self.parse_expression(ctx) {
children.push(cond);
}
children.extend(eat_trivia_segments(ctx));
if let Some(stmt) = self.parse_statement(ctx) {
children.push(stmt);
}
Some(Segment::Node(NodeSegment::new(
SegmentType::WhileStatement,
children,
)))
}
fn parse_exec_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = if ctx.peek_keyword("EXEC") {
ctx.eat_keyword("EXEC")
} else {
ctx.eat_keyword("EXECUTE")
};
let kw = kw?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
let save = ctx.save();
if ctx.peek_kind() == Some(TokenKind::AtSign) {
let at = ctx.advance().unwrap();
let trivia = eat_trivia_segments(ctx);
if ctx.peek_kind() == Some(TokenKind::Eq) {
children.push(token_segment(at, SegmentType::Identifier));
children.extend(trivia);
let eq = ctx.advance().unwrap();
children.push(token_segment(eq, SegmentType::Operator));
children.extend(eat_trivia_segments(ctx));
} else {
ctx.restore(save);
}
}
if let Some(name) = self.parse_qualified_name(ctx) {
children.push(name);
}
children.extend(eat_trivia_segments(ctx));
self.parse_exec_params(ctx, &mut children);
Some(Segment::Node(NodeSegment::new(
SegmentType::ExecStatement,
children,
)))
}
fn parse_exec_params(&self, ctx: &mut ParseContext, children: &mut Vec<Segment>) {
if ctx.at_eof()
|| ctx.peek_kind() == Some(TokenKind::Semicolon)
|| self.peek_statement_start(ctx)
{
return;
}
parse_comma_separated(ctx, children, |c| self.parse_expression(c));
}
fn parse_return_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("RETURN")?;
children.push(token_segment(kw, SegmentType::Keyword));
let save = ctx.save();
let trivia = eat_trivia_segments(ctx);
if !ctx.at_eof()
&& ctx.peek_kind() != Some(TokenKind::Semicolon)
&& !self.peek_statement_start(ctx)
{
children.extend(trivia);
if let Some(expr) = self.parse_expression(ctx) {
children.push(expr);
}
} else {
ctx.restore(save);
}
Some(Segment::Node(NodeSegment::new(
SegmentType::ReturnStatement,
children,
)))
}
fn parse_print_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("PRINT")?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if let Some(expr) = self.parse_expression(ctx) {
children.push(expr);
}
Some(Segment::Node(NodeSegment::new(
SegmentType::PrintStatement,
children,
)))
}
fn parse_throw_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("THROW")?;
children.push(token_segment(kw, SegmentType::Keyword));
let save = ctx.save();
let trivia = eat_trivia_segments(ctx);
if ctx.at_eof()
|| ctx.peek_kind() == Some(TokenKind::Semicolon)
|| self.peek_statement_start(ctx)
{
ctx.restore(save);
return Some(Segment::Node(NodeSegment::new(
SegmentType::ThrowStatement,
children,
)));
}
children.extend(trivia);
if let Some(expr) = self.parse_expression(ctx) {
children.push(expr);
}
for _ in 0..2 {
let save2 = ctx.save();
let trivia2 = eat_trivia_segments(ctx);
if let Some(comma) = ctx.eat_kind(TokenKind::Comma) {
children.extend(trivia2);
children.push(token_segment(comma, SegmentType::Comma));
children.extend(eat_trivia_segments(ctx));
if let Some(expr) = self.parse_expression(ctx) {
children.push(expr);
}
} else {
ctx.restore(save2);
break;
}
}
Some(Segment::Node(NodeSegment::new(
SegmentType::ThrowStatement,
children,
)))
}
fn parse_raiserror_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("RAISERROR")?;
children.push(token_segment(kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
if ctx.peek_kind() == Some(TokenKind::LParen) {
if let Some(args) = self.parse_paren_block(ctx) {
children.push(args);
}
}
children.extend(eat_trivia_segments(ctx));
if ctx.peek_keyword("WITH") {
let with_kw = ctx.advance().unwrap();
children.push(token_segment(with_kw, SegmentType::Keyword));
children.extend(eat_trivia_segments(ctx));
while ctx.peek_kind() == Some(TokenKind::Word) {
let opt = ctx.advance().unwrap();
children.push(token_segment(opt, SegmentType::Keyword));
let save = ctx.save();
let trivia = eat_trivia_segments(ctx);
if ctx.peek_kind() == Some(TokenKind::Comma) {
children.extend(trivia);
let comma = ctx.advance().unwrap();
children.push(token_segment(comma, SegmentType::Comma));
children.extend(eat_trivia_segments(ctx));
} else {
ctx.restore(save);
break;
}
}
}
Some(Segment::Node(NodeSegment::new(
SegmentType::RaiserrorStatement,
children,
)))
}
fn parse_go_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
let mut children = Vec::new();
let kw = ctx.eat_keyword("GO")?;
children.push(token_segment(kw, SegmentType::Keyword));
let save = ctx.save();
let trivia = eat_trivia_segments(ctx);
if ctx.peek_kind() == Some(TokenKind::NumberLiteral) {
children.extend(trivia);
let num = ctx.advance().unwrap();
children.push(token_segment(num, SegmentType::NumericLiteral));
} else {
ctx.restore(save);
}
Some(Segment::Node(NodeSegment::new(
SegmentType::GoStatement,
children,
)))
}
}