mod delete;
mod insert;
mod select;
mod update;
use crate::db::sql::identifier::{identifier_last_segment, normalize_identifier_to_scope};
use crate::db::{
sql::parser::{
Parser, SqlAggregateCall, SqlAggregateInputExpr, SqlArithmeticProjectionCall,
SqlAssignment, SqlDeleteStatement, SqlDescribeStatement, SqlExplainMode,
SqlExplainStatement, SqlExplainTarget, SqlExpr, SqlOrderTerm, SqlProjection,
SqlProjectionOperand, SqlReturningProjection, SqlRoundProjectionCall,
SqlRoundProjectionInput, SqlSelectItem, SqlSelectStatement, SqlShowColumnsStatement,
SqlShowEntitiesStatement, SqlShowIndexesStatement, SqlStatement, SqlTextFunctionCall,
SqlUpdateStatement,
},
sql_shared::{Keyword, SqlParseError, TokenKind},
};
impl Parser {
pub(super) fn parse_statement(&mut self) -> Result<SqlStatement, SqlParseError> {
if self.eat_keyword(Keyword::Select) {
return Ok(SqlStatement::Select(self.parse_select_statement()?));
}
if self.eat_keyword(Keyword::Delete) {
return Ok(SqlStatement::Delete(self.parse_delete_statement()?));
}
if self.eat_keyword(Keyword::Insert) {
return Ok(SqlStatement::Insert(self.parse_insert_statement()?));
}
if self.eat_keyword(Keyword::Update) {
return Ok(SqlStatement::Update(self.parse_update_statement()?));
}
if self.eat_keyword(Keyword::Explain) {
return Ok(SqlStatement::Explain(self.parse_explain_statement()?));
}
if self.eat_keyword(Keyword::Describe) {
return Ok(SqlStatement::Describe(self.parse_describe_statement()?));
}
if self.eat_keyword(Keyword::Show) {
return self.parse_show_statement();
}
if let Some(feature) = self.peek_unsupported_feature() {
return Err(SqlParseError::unsupported_feature(feature));
}
Err(SqlParseError::expected(
"one of SELECT, DELETE, INSERT, UPDATE, EXPLAIN, DESCRIBE, SHOW",
self.peek_kind(),
))
}
pub(super) fn trailing_clause_order_error(
&self,
statement: &SqlStatement,
) -> Option<SqlParseError> {
match statement {
SqlStatement::Select(select) => self.select_clause_order_error(select),
SqlStatement::Delete(delete) => self.delete_clause_order_error(delete),
SqlStatement::Insert(_) => None,
SqlStatement::Update(update) => self.update_clause_order_error(update),
SqlStatement::Explain(explain) => match &explain.statement {
SqlExplainTarget::Select(select) => self.select_clause_order_error(select),
SqlExplainTarget::Delete(delete) => self.delete_clause_order_error(delete),
},
SqlStatement::Describe(_) => {
Some(SqlParseError::unsupported_feature("DESCRIBE modifiers"))
}
SqlStatement::ShowIndexes(_) => {
Some(SqlParseError::unsupported_feature("SHOW INDEXES modifiers"))
}
SqlStatement::ShowColumns(_) => {
Some(SqlParseError::unsupported_feature("SHOW COLUMNS modifiers"))
}
SqlStatement::ShowEntities(_) => Some(SqlParseError::unsupported_feature(
"SHOW ENTITIES modifiers",
)),
}
}
fn parse_show_statement(&mut self) -> Result<SqlStatement, SqlParseError> {
if self.eat_keyword(Keyword::Indexes) {
return Ok(SqlStatement::ShowIndexes(
self.parse_show_indexes_statement()?,
));
}
if self.eat_keyword(Keyword::Columns) {
return Ok(SqlStatement::ShowColumns(
self.parse_show_columns_statement()?,
));
}
if self.eat_keyword(Keyword::Entities) {
return Ok(SqlStatement::ShowEntities(SqlShowEntitiesStatement));
}
if self.eat_keyword(Keyword::Tables) {
return Ok(SqlStatement::ShowEntities(SqlShowEntitiesStatement));
}
Err(SqlParseError::unsupported_feature(
"SHOW commands beyond SHOW INDEXES/SHOW COLUMNS/SHOW ENTITIES/SHOW TABLES",
))
}
fn parse_explain_statement(&mut self) -> Result<SqlExplainStatement, SqlParseError> {
let mode = if self.eat_keyword(Keyword::Execution) {
SqlExplainMode::Execution
} else if self.eat_keyword(Keyword::Json) {
SqlExplainMode::Json
} else {
SqlExplainMode::Plan
};
let statement = if self.eat_keyword(Keyword::Select) {
SqlExplainTarget::Select(self.parse_select_statement()?)
} else if self.eat_keyword(Keyword::Delete) {
SqlExplainTarget::Delete(self.parse_delete_statement()?)
} else if let Some(feature) = self.peek_unsupported_feature() {
return Err(SqlParseError::unsupported_feature(feature));
} else {
return Err(SqlParseError::expected(
"one of SELECT, DELETE",
self.peek_kind(),
));
};
Ok(SqlExplainStatement { mode, statement })
}
fn select_clause_order_error(&self, statement: &SqlSelectStatement) -> Option<SqlParseError> {
if self.peek_keyword(Keyword::Order)
&& (statement.limit.is_some() || statement.offset.is_some())
{
return Some(SqlParseError::invalid_syntax(
"ORDER BY must appear before LIMIT/OFFSET",
));
}
None
}
fn delete_clause_order_error(&self, statement: &SqlDeleteStatement) -> Option<SqlParseError> {
if self.peek_keyword(Keyword::Order) && statement.limit.is_some() {
return Some(SqlParseError::invalid_syntax(
"ORDER BY must appear before LIMIT in DELETE",
));
}
None
}
fn update_clause_order_error(&self, statement: &SqlUpdateStatement) -> Option<SqlParseError> {
if self.peek_keyword(Keyword::Order)
&& (statement.limit.is_some() || statement.offset.is_some())
{
return Some(SqlParseError::invalid_syntax(
"ORDER BY must appear before LIMIT/OFFSET in UPDATE",
));
}
if self.peek_keyword(Keyword::Limit) && statement.offset.is_some() {
return Some(SqlParseError::invalid_syntax(
"LIMIT must appear before OFFSET in UPDATE",
));
}
None
}
pub(super) fn parse_optional_table_alias(&mut self) -> Result<Option<String>, SqlParseError> {
if self.eat_keyword(Keyword::As) {
return self.expect_identifier().map(Some);
}
if matches!(self.peek_kind(), Some(TokenKind::Identifier(_))) {
let Some(TokenKind::Identifier(value)) = self.peek_kind() else {
unreachable!();
};
if matches!(
value.as_str().to_ascii_uppercase().as_str(),
"SET" | "VALUES"
) {
return Ok(None);
}
return self.expect_identifier().map(Some);
}
Ok(None)
}
fn parse_describe_statement(&mut self) -> Result<SqlDescribeStatement, SqlParseError> {
let entity = self.expect_identifier()?;
Ok(SqlDescribeStatement { entity })
}
fn parse_show_indexes_statement(&mut self) -> Result<SqlShowIndexesStatement, SqlParseError> {
let entity = self.expect_identifier()?;
Ok(SqlShowIndexesStatement { entity })
}
fn parse_show_columns_statement(&mut self) -> Result<SqlShowColumnsStatement, SqlParseError> {
let entity = self.expect_identifier()?;
Ok(SqlShowColumnsStatement { entity })
}
}
pub(super) fn normalize_returning_projection_for_table_alias(
projection: SqlReturningProjection,
entity: &str,
alias: &str,
) -> SqlReturningProjection {
match projection {
SqlReturningProjection::All => SqlReturningProjection::All,
SqlReturningProjection::Fields(fields) => SqlReturningProjection::Fields(
normalize_identifier_list_for_table_alias(fields, entity, alias),
),
}
}
fn table_alias_scope(entity: &str, alias: &str) -> Vec<String> {
let mut scope = vec![entity.to_string(), alias.to_string()];
if let Some(last) = identifier_last_segment(entity) {
scope.push(last.to_string());
}
scope
}
pub(super) fn normalize_projection_for_table_alias(
projection: SqlProjection,
entity: &str,
alias: &str,
) -> SqlProjection {
let scope = table_alias_scope(entity, alias);
match projection {
SqlProjection::All => SqlProjection::All,
SqlProjection::Items(items) => SqlProjection::Items(
items
.into_iter()
.map(|item| match item {
SqlSelectItem::Field(field) => {
SqlSelectItem::Field(normalize_identifier_to_scope(field, scope.as_slice()))
}
SqlSelectItem::Aggregate(aggregate) => SqlSelectItem::Aggregate(
normalize_aggregate_call_for_table_alias(aggregate, scope.as_slice()),
),
SqlSelectItem::TextFunction(call) => SqlSelectItem::TextFunction(
normalize_text_function_call_for_table_alias(call, scope.as_slice()),
),
SqlSelectItem::Arithmetic(call) => SqlSelectItem::Arithmetic(
normalize_arithmetic_projection_call_for_table_alias(
call,
scope.as_slice(),
),
),
SqlSelectItem::Round(call) => SqlSelectItem::Round(
normalize_round_projection_call_for_table_alias(call, scope.as_slice()),
),
SqlSelectItem::Expr(expr) => SqlSelectItem::Expr(
normalize_sql_expr_for_table_alias(expr, scope.as_slice()),
),
})
.collect(),
),
}
}
fn normalize_round_projection_call_for_table_alias(
call: SqlRoundProjectionCall,
scope: &[String],
) -> SqlRoundProjectionCall {
SqlRoundProjectionCall {
input: match call.input {
SqlRoundProjectionInput::Operand(operand) => SqlRoundProjectionInput::Operand(
normalize_projection_operand_for_table_alias(operand, scope),
),
SqlRoundProjectionInput::Arithmetic(call) => SqlRoundProjectionInput::Arithmetic(
normalize_arithmetic_projection_call_for_table_alias(call, scope),
),
},
scale: call.scale,
}
}
fn normalize_arithmetic_projection_call_for_table_alias(
call: SqlArithmeticProjectionCall,
scope: &[String],
) -> SqlArithmeticProjectionCall {
SqlArithmeticProjectionCall {
left: normalize_projection_operand_for_table_alias(call.left, scope),
op: call.op,
right: normalize_projection_operand_for_table_alias(call.right, scope),
}
}
fn normalize_projection_operand_for_table_alias(
operand: SqlProjectionOperand,
scope: &[String],
) -> SqlProjectionOperand {
match operand {
SqlProjectionOperand::Field(field) => {
SqlProjectionOperand::Field(normalize_identifier_to_scope(field, scope))
}
SqlProjectionOperand::Aggregate(aggregate) => SqlProjectionOperand::Aggregate(
normalize_aggregate_call_for_table_alias(aggregate, scope),
),
SqlProjectionOperand::Literal(literal) => SqlProjectionOperand::Literal(literal),
SqlProjectionOperand::Arithmetic(call) => SqlProjectionOperand::Arithmetic(Box::new(
normalize_arithmetic_projection_call_for_table_alias(*call, scope),
)),
}
}
pub(super) fn normalize_sql_expr_for_entity_alias(
expr: SqlExpr,
entity: &str,
alias: &str,
) -> SqlExpr {
let scope = table_alias_scope(entity, alias);
normalize_sql_expr_for_table_alias(expr, scope.as_slice())
}
pub(super) fn normalize_sql_exprs_for_entity_alias(
exprs: Vec<SqlExpr>,
entity: &str,
alias: &str,
) -> Vec<SqlExpr> {
let scope = table_alias_scope(entity, alias);
exprs
.into_iter()
.map(|expr| normalize_sql_expr_for_table_alias(expr, scope.as_slice()))
.collect()
}
pub(super) fn normalize_identifier_list_for_table_alias(
fields: Vec<String>,
entity: &str,
alias: &str,
) -> Vec<String> {
let scope = table_alias_scope(entity, alias);
fields
.into_iter()
.map(|field| normalize_identifier_to_scope(field, scope.as_slice()))
.collect()
}
pub(super) fn normalize_assignments_for_table_alias(
assignments: Vec<SqlAssignment>,
entity: &str,
alias: &str,
) -> Vec<SqlAssignment> {
let scope = table_alias_scope(entity, alias);
assignments
.into_iter()
.map(|assignment| SqlAssignment {
field: normalize_identifier_to_scope(assignment.field, scope.as_slice()),
value: assignment.value,
})
.collect()
}
pub(super) fn normalize_order_terms_for_table_alias(
terms: Vec<SqlOrderTerm>,
entity: &str,
alias: &str,
) -> Vec<SqlOrderTerm> {
let scope = table_alias_scope(entity, alias);
terms
.into_iter()
.map(|term| SqlOrderTerm {
field: normalize_sql_expr_for_table_alias(term.field, scope.as_slice()),
direction: term.direction,
})
.collect()
}
fn normalize_aggregate_call_for_table_alias(
aggregate: SqlAggregateCall,
scope: &[String],
) -> SqlAggregateCall {
SqlAggregateCall {
kind: aggregate.kind,
input: aggregate.input.map(|input| {
Box::new(normalize_aggregate_input_expr_for_table_alias(
*input, scope,
))
}),
filter_expr: aggregate
.filter_expr
.map(|expr| Box::new(normalize_sql_expr_for_table_alias(*expr, scope))),
distinct: aggregate.distinct,
}
}
fn normalize_aggregate_input_expr_for_table_alias(
expr: SqlAggregateInputExpr,
scope: &[String],
) -> SqlAggregateInputExpr {
match expr {
SqlAggregateInputExpr::Field(field) => {
SqlAggregateInputExpr::Field(normalize_identifier_to_scope(field, scope))
}
SqlAggregateInputExpr::Literal(literal) => SqlAggregateInputExpr::Literal(literal),
SqlAggregateInputExpr::Arithmetic(call) => SqlAggregateInputExpr::Arithmetic(
normalize_arithmetic_projection_call_for_table_alias(call, scope),
),
SqlAggregateInputExpr::Round(call) => SqlAggregateInputExpr::Round(
normalize_round_projection_call_for_table_alias(call, scope),
),
SqlAggregateInputExpr::Expr(expr) => {
SqlAggregateInputExpr::Expr(normalize_sql_expr_for_table_alias(expr, scope))
}
}
}
pub(super) fn normalize_sql_expr_for_table_alias(expr: SqlExpr, scope: &[String]) -> SqlExpr {
match expr {
SqlExpr::Field(field) => SqlExpr::Field(normalize_identifier_to_scope(field, scope)),
SqlExpr::Aggregate(aggregate) => {
SqlExpr::Aggregate(normalize_aggregate_call_for_table_alias(aggregate, scope))
}
SqlExpr::Literal(value) => SqlExpr::Literal(value),
SqlExpr::TextFunction(call) => {
SqlExpr::TextFunction(normalize_text_function_call_for_table_alias(call, scope))
}
SqlExpr::Membership {
expr,
values,
negated,
} => SqlExpr::Membership {
expr: Box::new(normalize_sql_expr_for_table_alias(*expr, scope)),
values,
negated,
},
SqlExpr::NullTest { expr, negated } => SqlExpr::NullTest {
expr: Box::new(normalize_sql_expr_for_table_alias(*expr, scope)),
negated,
},
SqlExpr::FunctionCall { function, args } => SqlExpr::FunctionCall {
function,
args: args
.into_iter()
.map(|arg| normalize_sql_expr_for_table_alias(arg, scope))
.collect(),
},
SqlExpr::Round(call) => {
SqlExpr::Round(normalize_round_projection_call_for_table_alias(call, scope))
}
SqlExpr::Unary { op, expr } => SqlExpr::Unary {
op,
expr: Box::new(normalize_sql_expr_for_table_alias(*expr, scope)),
},
SqlExpr::Binary { op, left, right } => SqlExpr::Binary {
op,
left: Box::new(normalize_sql_expr_for_table_alias(*left, scope)),
right: Box::new(normalize_sql_expr_for_table_alias(*right, scope)),
},
SqlExpr::Case { arms, else_expr } => SqlExpr::Case {
arms: arms
.into_iter()
.map(|arm| crate::db::sql::parser::SqlCaseArm {
condition: normalize_sql_expr_for_table_alias(arm.condition, scope),
result: normalize_sql_expr_for_table_alias(arm.result, scope),
})
.collect(),
else_expr: else_expr
.map(|else_expr| Box::new(normalize_sql_expr_for_table_alias(*else_expr, scope))),
},
}
}
fn normalize_text_function_call_for_table_alias(
call: SqlTextFunctionCall,
scope: &[String],
) -> SqlTextFunctionCall {
SqlTextFunctionCall {
function: call.function,
field: normalize_identifier_to_scope(call.field, scope),
literal: call.literal,
literal2: call.literal2,
literal3: call.literal3,
}
}