use anyhow::{Result, anyhow, bail, ensure};
use ordered_float::OrderedFloat;
use sqlparser::ast::{
AlterTableOperation, BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments,
Ident, ObjectName, ObjectType, OrderByExpr, Query, Select, SelectItem, SetExpr,
ShowStatementOptions, Statement, TableFactor, Value as SqlValue,
};
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
use crate::types::{DataType, Schema, Value};
#[derive(Debug)]
pub enum Command {
CreateTable(Schema),
CreateTableIfNotExists(Schema),
AlterTable(AlterTablePlan),
DropTable(DropTablePlan),
ShowTables,
DescribeTable(String),
Insert(InsertPlan),
Select(SelectPlan),
Eval(EvalPlan),
}
#[derive(Debug)]
pub struct AlterTablePlan {
pub table_name: String,
pub if_exists: bool,
pub operation: AlterTableOperationPlan,
}
#[derive(Debug)]
pub enum AlterTableOperationPlan {
RenameTable {
new_table_name: String,
},
RenameColumn {
old_column_name: String,
new_column_name: String,
},
}
#[derive(Debug)]
pub struct DropTablePlan {
pub table_name: String,
pub if_exists: bool,
}
#[derive(Debug)]
pub struct InsertPlan {
pub table_name: String,
pub rows: Vec<Vec<Value>>,
}
#[derive(Debug, Clone, Copy)]
pub enum CompareOp {
Eq,
NotEq,
Lt,
Lte,
Gt,
Gte,
IsNull,
IsNotNull,
}
#[derive(Debug, Clone)]
pub struct Filter {
pub column: String,
pub op: CompareOp,
pub value: Option<Value>,
}
#[derive(Debug, Clone, Copy)]
pub enum AggregateKind {
Count,
Sum,
Min,
Max,
Avg,
}
#[derive(Debug, Clone)]
pub enum ProjectionExpr {
Column(String),
Scalar(ScalarExpr),
Aggregate {
kind: AggregateKind,
column: Option<String>,
},
}
#[derive(Debug, Clone)]
pub struct Projection {
pub expr: ProjectionExpr,
pub alias: String,
}
#[derive(Debug, Clone)]
pub struct OrderByPlan {
pub field: String,
pub descending: bool,
}
#[derive(Debug, Clone)]
pub struct SelectPlan {
pub table_name: String,
pub filters: Vec<Filter>,
pub projections: Vec<Projection>,
pub group_by: Vec<String>,
pub order_by: Option<OrderByPlan>,
pub limit: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
pub enum ArithmeticOp {
Add,
Sub,
Mul,
Div,
Mod,
}
#[derive(Debug, Clone)]
pub enum ScalarExpr {
Literal(Value),
ColumnRef(String),
BinaryOp {
left: Box<ScalarExpr>,
op: ArithmeticOp,
right: Box<ScalarExpr>,
},
UnaryMinus(Box<ScalarExpr>),
}
#[derive(Debug)]
pub struct EvalPlan {
pub exprs: Vec<(ScalarExpr, String)>,
}
pub fn parse_sql(sql: &str) -> Result<Vec<Command>> {
let dialect = SQLiteDialect {};
let statements = Parser::parse_sql(&dialect, sql)?;
statements.into_iter().map(parse_statement).collect()
}
fn parse_statement(statement: Statement) -> Result<Command> {
match statement {
Statement::CreateTable(create) => parse_create_table(create),
Statement::AlterTable {
name,
if_exists,
operations,
..
} => parse_alter_table(name, if_exists, operations),
Statement::Drop {
object_type,
if_exists,
names,
..
} => parse_drop(object_type, if_exists, names),
Statement::ShowTables {
terse,
history,
extended,
full,
external,
show_options,
} => parse_show_tables(terse, history, extended, full, external, show_options),
Statement::ExplainTable { table_name, .. } => Ok(Command::DescribeTable(
object_name_to_string(&table_name.0)?,
)),
Statement::Insert(insert) => parse_insert(insert),
Statement::Query(query) => parse_query(*query),
other => bail!("unsupported statement: {other}"),
}
}
fn parse_show_tables(
terse: bool,
history: bool,
extended: bool,
full: bool,
external: bool,
show_options: ShowStatementOptions,
) -> Result<Command> {
ensure!(
!terse && !history && !extended && !full && !external,
"SHOW TABLES options are not supported yet"
);
ensure!(
show_options.show_in.is_none(),
"SHOW TABLES IN is not supported yet"
);
ensure!(
show_options.starts_with.is_none(),
"SHOW TABLES STARTS WITH is not supported yet"
);
ensure!(
show_options.limit.is_none(),
"SHOW TABLES LIMIT is not supported yet"
);
ensure!(
show_options.limit_from.is_none(),
"SHOW TABLES LIMIT ... FROM is not supported yet"
);
ensure!(
show_options.filter_position.is_none(),
"SHOW TABLES LIKE/WHERE is not supported yet"
);
Ok(Command::ShowTables)
}
fn parse_alter_table(
name: ObjectName,
if_exists: bool,
operations: Vec<AlterTableOperation>,
) -> Result<Command> {
ensure!(
operations.len() == 1,
"ALTER TABLE supports exactly one operation"
);
let table_name = object_name_to_string(&name.0)?;
let operation = match operations.into_iter().next().unwrap() {
AlterTableOperation::RenameTable { table_name } => AlterTableOperationPlan::RenameTable {
new_table_name: object_name_to_string(&table_name.0)?,
},
AlterTableOperation::RenameColumn {
old_column_name,
new_column_name,
} => AlterTableOperationPlan::RenameColumn {
old_column_name: old_column_name.value,
new_column_name: new_column_name.value,
},
other => {
bail!("unsupported ALTER TABLE operation: {other}. Supported: RENAME TO, RENAME COLUMN")
}
};
Ok(Command::AlterTable(AlterTablePlan {
table_name,
if_exists,
operation,
}))
}
fn parse_drop(object_type: ObjectType, if_exists: bool, names: Vec<ObjectName>) -> Result<Command> {
ensure!(
object_type == ObjectType::Table,
"only DROP TABLE is supported"
);
ensure!(
names.len() == 1,
"DROP TABLE supports exactly one table at a time"
);
Ok(Command::DropTable(DropTablePlan {
table_name: object_name_to_string(&names[0].0)?,
if_exists,
}))
}
fn parse_create_table(statement: sqlparser::ast::CreateTable) -> Result<Command> {
let table_name = object_name_to_string(&statement.name.0)?;
let if_not_exists = statement.if_not_exists;
let mut columns = Vec::with_capacity(statement.columns.len());
for column in statement.columns {
columns.push(crate::types::ColumnDef {
name: column.name.value,
data_type: DataType::from_sql_name(&column.data_type.to_string())?,
});
}
let schema = Schema {
table_name,
columns,
};
if if_not_exists {
Ok(Command::CreateTableIfNotExists(schema))
} else {
Ok(Command::CreateTable(schema))
}
}
fn parse_insert(statement: sqlparser::ast::Insert) -> Result<Command> {
ensure!(
statement.columns.is_empty(),
"explicit insert column lists are not supported yet"
);
let table_name = object_name_to_string(&statement.table_name.0)?;
let Some(source) = statement.source else {
bail!("INSERT source is required")
};
let rows = match *source.body {
SetExpr::Values(values) => values
.rows
.into_iter()
.map(|row| row.into_iter().map(parse_sql_value_expr).collect())
.collect::<Result<Vec<Vec<Value>>>>()?,
other => bail!("unsupported INSERT source: {other}"),
};
Ok(Command::Insert(InsertPlan { table_name, rows }))
}
fn parse_query(query: Query) -> Result<Command> {
let select = match *query.body {
SetExpr::Select(select) => select,
other => bail!("unsupported query body: {other}"),
};
if select.from.is_empty() {
return parse_eval(*select);
}
let plan = parse_select(*select, query.order_by, query.limit)?;
Ok(Command::Select(plan))
}
fn parse_select(
select: Select,
order_by: Option<sqlparser::ast::OrderBy>,
limit: Option<Expr>,
) -> Result<SelectPlan> {
ensure!(select.from.len() == 1, "exactly one table is supported");
let table_name = match &select.from[0].relation {
TableFactor::Table { name, .. } => object_name_to_string(&name.0)?,
other => bail!("unsupported table source: {other}"),
};
let filters = match select.selection {
Some(expr) => parse_filters(expr)?,
None => Vec::new(),
};
let group_by = match select.group_by {
sqlparser::ast::GroupByExpr::Expressions(exprs, _) => exprs
.into_iter()
.map(parse_identifier_expr)
.collect::<Result<Vec<String>>>()?,
sqlparser::ast::GroupByExpr::All(_) => bail!("GROUP BY ALL is not supported"),
};
let projections = parse_projections(select.projection, &group_by)?;
let order_by = parse_order_by(order_by)?;
let limit = match limit {
Some(expr) => Some(parse_usize_literal(expr)?),
None => None,
};
Ok(SelectPlan {
table_name,
filters,
projections,
group_by,
order_by,
limit,
})
}
fn parse_filters(expr: Expr) -> Result<Vec<Filter>> {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let mut filters = parse_filters(*left)?;
filters.extend(parse_filters(*right)?);
Ok(filters)
}
Expr::BinaryOp { left, op, right } => {
let column = parse_identifier_expr(*left)?;
let value = parse_sql_value_expr(*right)?;
let op = match op {
BinaryOperator::Eq => CompareOp::Eq,
BinaryOperator::NotEq => CompareOp::NotEq,
BinaryOperator::Lt => CompareOp::Lt,
BinaryOperator::LtEq => CompareOp::Lte,
BinaryOperator::Gt => CompareOp::Gt,
BinaryOperator::GtEq => CompareOp::Gte,
other => bail!("unsupported filter operator: {other}"),
};
Ok(vec![Filter {
column,
op,
value: Some(value),
}])
}
Expr::IsNull(expr) => {
let column = parse_identifier_expr(*expr)?;
Ok(vec![Filter {
column,
op: CompareOp::IsNull,
value: None,
}])
}
Expr::IsNotNull(expr) => {
let column = parse_identifier_expr(*expr)?;
Ok(vec![Filter {
column,
op: CompareOp::IsNotNull,
value: None,
}])
}
other => bail!("unsupported WHERE expression: {other}"),
}
}
fn parse_projections(items: Vec<SelectItem>, group_by: &[String]) -> Result<Vec<Projection>> {
let mut projections = Vec::new();
for item in items {
match item {
SelectItem::Wildcard(_) => projections.push(Projection {
expr: ProjectionExpr::Column("*".to_string()),
alias: "*".to_string(),
}),
SelectItem::UnnamedExpr(expr) => {
projections.push(parse_projection_expr(expr, group_by, None)?)
}
SelectItem::ExprWithAlias { expr, alias } => {
projections.push(parse_projection_expr(expr, group_by, Some(alias.value))?)
}
other => bail!("unsupported SELECT projection: {other}"),
}
}
Ok(projections)
}
fn parse_projection_expr(
expr: Expr,
group_by: &[String],
alias_override: Option<String>,
) -> Result<Projection> {
match expr {
Expr::Identifier(ident) => {
let alias = alias_override.unwrap_or_else(|| ident.value.clone());
if !group_by.is_empty() && !group_by.iter().any(|column| column == &ident.value) {
bail!(
"non-aggregated projection {} must appear in GROUP BY",
ident.value
);
}
Ok(Projection {
expr: ProjectionExpr::Column(ident.value),
alias,
})
}
Expr::CompoundIdentifier(parts) => {
let name = compound_identifier_to_string(&parts)?;
let alias = alias_override.unwrap_or_else(|| name.clone());
if !group_by.is_empty() && !group_by.iter().any(|column| column == &name) {
bail!("non-aggregated projection {name} must appear in GROUP BY");
}
Ok(Projection {
expr: ProjectionExpr::Column(name),
alias,
})
}
Expr::Function(function) => {
let name = function.name.to_string().to_ascii_uppercase();
let kind = match name.as_str() {
"COUNT" => AggregateKind::Count,
"SUM" => AggregateKind::Sum,
"MIN" => AggregateKind::Min,
"MAX" => AggregateKind::Max,
"AVG" => AggregateKind::Avg,
_ => bail!("unsupported aggregate: {name}"),
};
let column = parse_function_column(function.args)?;
let alias = alias_override.unwrap_or_else(|| {
if let Some(column) = &column {
format!("{}_{}", name.to_ascii_lowercase(), column)
} else {
name.to_ascii_lowercase()
}
});
Ok(Projection {
expr: ProjectionExpr::Aggregate { kind, column },
alias,
})
}
other => {
let alias = alias_override.unwrap_or_else(|| other.to_string());
let scalar = parse_scalar_expr(other)?;
Ok(Projection {
expr: ProjectionExpr::Scalar(scalar),
alias,
})
}
}
}
fn parse_function_column(arguments: FunctionArguments) -> Result<Option<String>> {
match arguments {
FunctionArguments::None => Ok(None),
FunctionArguments::Subquery(_) => bail!("subquery function arguments are not supported"),
FunctionArguments::List(list) => {
if list.args.len() == 1 {
match &list.args[0] {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(None),
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => {
Ok(Some(parse_identifier_expr(expr.clone())?))
}
other => bail!("unsupported function argument: {other}"),
}
} else {
bail!("only single-argument aggregates are supported")
}
}
}
}
fn parse_order_by(order_by: Option<sqlparser::ast::OrderBy>) -> Result<Option<OrderByPlan>> {
let Some(order_by) = order_by else {
return Ok(None);
};
ensure!(
order_by.exprs.len() == 1,
"only one ORDER BY expression is supported"
);
let OrderByExpr { expr, asc, .. } = order_by.exprs.into_iter().next().unwrap();
Ok(Some(OrderByPlan {
field: parse_identifier_expr(expr)?,
descending: matches!(asc, Some(false)),
}))
}
fn parse_identifier_expr(expr: Expr) -> Result<String> {
match expr {
Expr::Identifier(ident) => Ok(ident.value),
Expr::CompoundIdentifier(parts) => compound_identifier_to_string(&parts),
other => bail!("expected column identifier, got {other}"),
}
}
fn compound_identifier_to_string(parts: &[Ident]) -> Result<String> {
let Some(last) = parts.last() else {
bail!("empty identifier")
};
Ok(last.value.clone())
}
fn parse_usize_literal(expr: Expr) -> Result<usize> {
match parse_sql_value_expr(expr)? {
Value::Int64(value) => Ok(value as usize),
other => bail!("expected integer literal, got {other:?}"),
}
}
fn parse_sql_value_expr(expr: Expr) -> Result<Value> {
match expr {
Expr::Value(value) => parse_sql_value(value),
Expr::UnaryOp { op, expr } if op.to_string() == "-" => match parse_sql_value_expr(*expr)? {
Value::Int64(value) => Ok(Value::Int64(-value)),
Value::Float64(value) => Ok(Value::Float64(OrderedFloat(-value.into_inner()))),
other => bail!("unsupported unary negation for {other:?}"),
},
Expr::IsNull(_) => {
bail!("IS NULL expressions must appear in WHERE clauses")
}
other => bail!("expected literal value, got {other}"),
}
}
fn parse_sql_value(value: sqlparser::ast::Value) -> Result<Value> {
match value {
SqlValue::Number(number, _) => {
if number.contains('.') {
Ok(Value::Float64(OrderedFloat(number.parse::<f64>()?)))
} else {
Ok(Value::Int64(number.parse::<i64>()?))
}
}
SqlValue::SingleQuotedString(value) | SqlValue::DoubleQuotedString(value) => {
Ok(Value::String(value))
}
SqlValue::Boolean(value) => Ok(Value::Bool(value)),
SqlValue::Null => Ok(Value::Null),
other => Err(anyhow!("unsupported literal value: {other}")),
}
}
fn object_name_to_string(parts: &[Ident]) -> Result<String> {
let Some(last) = parts.last() else {
bail!("empty object name")
};
Ok(last.value.clone())
}
fn parse_eval(select: Select) -> Result<Command> {
let mut exprs = Vec::new();
for item in select.projection {
match item {
SelectItem::UnnamedExpr(expr) => {
let alias = expr.to_string();
exprs.push((parse_scalar_expr(expr)?, alias));
}
SelectItem::ExprWithAlias { expr, alias } => {
exprs.push((parse_scalar_expr(expr)?, alias.value));
}
other => bail!("unsupported expression in SELECT: {other}"),
}
}
Ok(Command::Eval(EvalPlan { exprs }))
}
fn parse_scalar_expr(expr: Expr) -> Result<ScalarExpr> {
match expr {
Expr::Identifier(ident) => Ok(ScalarExpr::ColumnRef(ident.value)),
Expr::CompoundIdentifier(parts) => Ok(ScalarExpr::ColumnRef(
compound_identifier_to_string(&parts)?,
)),
Expr::Value(value) => Ok(ScalarExpr::Literal(parse_sql_value(value)?)),
Expr::BinaryOp { left, op, right } => {
let arith_op = match op {
BinaryOperator::Plus => ArithmeticOp::Add,
BinaryOperator::Minus => ArithmeticOp::Sub,
BinaryOperator::Multiply => ArithmeticOp::Mul,
BinaryOperator::Divide => ArithmeticOp::Div,
BinaryOperator::Modulo => ArithmeticOp::Mod,
other => bail!("unsupported operator in expression: {other}"),
};
Ok(ScalarExpr::BinaryOp {
left: Box::new(parse_scalar_expr(*left)?),
op: arith_op,
right: Box::new(parse_scalar_expr(*right)?),
})
}
Expr::UnaryOp { op, expr } if op.to_string() == "-" => {
Ok(ScalarExpr::UnaryMinus(Box::new(parse_scalar_expr(*expr)?)))
}
Expr::Nested(inner) => parse_scalar_expr(*inner),
other => bail!("unsupported expression: {other}"),
}
}