use sqlparser::ast::{
DuplicateTreatment, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
JoinOperator, LimitClause, OrderByKind, Query, Select, SelectItem, SetExpr, Statement,
TableFactor, TableWithJoins,
};
use crate::error::{Result, SQLRiteError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregateFn {
Count,
Sum,
Avg,
Min,
Max,
}
impl AggregateFn {
pub fn as_str(self) -> &'static str {
match self {
AggregateFn::Count => "COUNT",
AggregateFn::Sum => "SUM",
AggregateFn::Avg => "AVG",
AggregateFn::Min => "MIN",
AggregateFn::Max => "MAX",
}
}
fn from_name(name: &str) -> Option<Self> {
match name.to_ascii_lowercase().as_str() {
"count" => Some(AggregateFn::Count),
"sum" => Some(AggregateFn::Sum),
"avg" => Some(AggregateFn::Avg),
"min" => Some(AggregateFn::Min),
"max" => Some(AggregateFn::Max),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AggregateArg {
Star,
Column(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AggregateCall {
pub func: AggregateFn,
pub arg: AggregateArg,
pub distinct: bool,
}
impl AggregateCall {
pub fn display_name(&self) -> String {
let inner = match &self.arg {
AggregateArg::Star => "*".to_string(),
AggregateArg::Column(c) => {
if self.distinct {
format!("DISTINCT {c}")
} else {
c.clone()
}
}
};
format!("{}({inner})", self.func.as_str())
}
}
#[derive(Debug, Clone)]
pub struct ProjectionItem {
pub kind: ProjectionKind,
pub alias: Option<String>,
}
impl ProjectionItem {
pub fn output_name(&self) -> String {
if let Some(a) = &self.alias {
return a.clone();
}
match &self.kind {
ProjectionKind::Column { name, .. } => name.clone(),
ProjectionKind::Aggregate(a) => a.display_name(),
}
}
}
#[derive(Debug, Clone)]
pub enum ProjectionKind {
Column {
qualifier: Option<String>,
name: String,
},
Aggregate(AggregateCall),
}
#[derive(Debug, Clone)]
pub enum Projection {
All,
Items(Vec<ProjectionItem>),
}
#[derive(Debug, Clone)]
pub struct OrderByClause {
pub expr: Expr,
pub ascending: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
LeftOuter,
RightOuter,
FullOuter,
}
impl JoinType {
pub fn as_str(self) -> &'static str {
match self {
JoinType::Inner => "INNER",
JoinType::LeftOuter => "LEFT OUTER",
JoinType::RightOuter => "RIGHT OUTER",
JoinType::FullOuter => "FULL OUTER",
}
}
}
#[derive(Debug, Clone)]
pub struct JoinClause {
pub join_type: JoinType,
pub right_table: String,
pub right_alias: Option<String>,
pub on: Expr,
}
#[derive(Debug, Clone)]
pub struct SelectQuery {
pub table_name: String,
pub table_alias: Option<String>,
pub joins: Vec<JoinClause>,
pub projection: Projection,
pub selection: Option<Expr>,
pub order_by: Option<OrderByClause>,
pub limit: Option<usize>,
pub distinct: bool,
pub group_by: Vec<String>,
}
impl SelectQuery {
pub fn new(statement: &Statement) -> Result<Self> {
let Statement::Query(query) = statement else {
return Err(SQLRiteError::Internal(
"Error parsing SELECT: expected a Query statement".to_string(),
));
};
let Query {
body,
order_by,
limit_clause,
..
} = query.as_ref();
let SetExpr::Select(select) = body.as_ref() else {
return Err(SQLRiteError::NotImplemented(
"Only simple SELECT queries are supported (no UNION / VALUES / CTEs yet)"
.to_string(),
));
};
let Select {
projection,
from,
selection,
distinct,
group_by,
having,
..
} = select.as_ref();
let distinct_flag = match distinct {
None => false,
Some(sqlparser::ast::Distinct::Distinct) => true,
Some(sqlparser::ast::Distinct::All) => false,
Some(sqlparser::ast::Distinct::On(_)) => {
return Err(SQLRiteError::NotImplemented(
"SELECT DISTINCT ON (...) is not supported".to_string(),
));
}
};
if having.is_some() {
return Err(SQLRiteError::NotImplemented(
"HAVING is not supported yet".to_string(),
));
}
let group_by_cols: Vec<String> = match group_by {
sqlparser::ast::GroupByExpr::Expressions(exprs, _) => {
let mut out = Vec::with_capacity(exprs.len());
for e in exprs {
let col = match e {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => {
parts.last().map(|p| p.value.clone()).ok_or_else(|| {
SQLRiteError::Internal("empty compound identifier".to_string())
})?
}
other => {
return Err(SQLRiteError::NotImplemented(format!(
"GROUP BY only supports bare column references for now, got {other:?}"
)));
}
};
out.push(col);
}
out
}
_ => {
return Err(SQLRiteError::NotImplemented(
"GROUP BY ALL is not supported".to_string(),
));
}
};
let (table_name, table_alias, joins) = extract_from_clause(from)?;
let projection = parse_projection(projection)?;
let order_by = parse_order_by(order_by.as_ref())?;
let limit = parse_limit(limit_clause.as_ref())?;
if !group_by_cols.is_empty()
&& let Projection::Items(items) = &projection
{
for item in items {
if let ProjectionKind::Column { name: c, .. } = &item.kind
&& !group_by_cols.contains(c)
{
return Err(SQLRiteError::Internal(format!(
"column '{c}' must appear in GROUP BY or be used in an aggregate function"
)));
}
}
}
if !joins.is_empty() {
let has_agg = matches!(
&projection,
Projection::Items(items)
if items.iter().any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)))
);
if has_agg || !group_by_cols.is_empty() {
return Err(SQLRiteError::NotImplemented(
"GROUP BY / aggregate functions over JOIN results are not supported yet"
.to_string(),
));
}
if distinct_flag {
return Err(SQLRiteError::NotImplemented(
"SELECT DISTINCT over JOIN results is not supported yet".to_string(),
));
}
}
Ok(SelectQuery {
table_name,
table_alias,
joins,
projection,
selection: selection.clone(),
order_by,
limit,
distinct: distinct_flag,
group_by: group_by_cols,
})
}
}
fn extract_from_clause(
from: &[TableWithJoins],
) -> Result<(String, Option<String>, Vec<JoinClause>)> {
if from.is_empty() {
return Err(SQLRiteError::Internal(
"SELECT requires a FROM clause".to_string(),
));
}
if from.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"comma-separated FROM lists are not supported — use explicit JOIN syntax".to_string(),
));
}
let twj = &from[0];
let (table_name, table_alias) = extract_table_factor(&twj.relation)?;
let mut joins = Vec::with_capacity(twj.joins.len());
for j in &twj.joins {
let (right_table, right_alias) = extract_table_factor(&j.relation)?;
let (join_type, on_expr) = match &j.join_operator {
JoinOperator::Join(c) | JoinOperator::Inner(c) => (JoinType::Inner, parse_on(c)?),
JoinOperator::Left(c) | JoinOperator::LeftOuter(c) => {
(JoinType::LeftOuter, parse_on(c)?)
}
JoinOperator::Right(c) | JoinOperator::RightOuter(c) => {
(JoinType::RightOuter, parse_on(c)?)
}
JoinOperator::FullOuter(c) => (JoinType::FullOuter, parse_on(c)?),
other => {
return Err(SQLRiteError::NotImplemented(format!(
"join flavor {other:?} is not supported \
(only INNER / LEFT OUTER / RIGHT OUTER / FULL OUTER with ON)"
)));
}
};
joins.push(JoinClause {
join_type,
right_table,
right_alias,
on: on_expr,
});
}
Ok((table_name, table_alias, joins))
}
fn extract_table_factor(tf: &TableFactor) -> Result<(String, Option<String>)> {
match tf {
TableFactor::Table { name, alias, .. } => {
let table_name = name.to_string();
let alias_name = alias.as_ref().map(|a| a.name.value.clone());
if let Some(a) = alias.as_ref()
&& !a.columns.is_empty()
{
return Err(SQLRiteError::NotImplemented(
"table alias column lists are not supported".to_string(),
));
}
Ok((table_name, alias_name))
}
_ => Err(SQLRiteError::NotImplemented(
"only plain table references are supported in FROM / JOIN".to_string(),
)),
}
}
fn parse_on(constraint: &JoinConstraint) -> Result<Expr> {
match constraint {
JoinConstraint::On(expr) => Ok(expr.clone()),
JoinConstraint::Using(_) => Err(SQLRiteError::NotImplemented(
"JOIN ... USING (...) is not supported yet — use JOIN ... ON instead".to_string(),
)),
JoinConstraint::Natural => Err(SQLRiteError::NotImplemented(
"NATURAL JOIN is not supported".to_string(),
)),
JoinConstraint::None => Err(SQLRiteError::NotImplemented(
"JOIN without an ON condition is not supported (use INNER JOIN ... ON ...)".to_string(),
)),
}
}
fn parse_projection(items: &[SelectItem]) -> Result<Projection> {
if items.len() == 1
&& let SelectItem::Wildcard(_) = &items[0]
{
return Ok(Projection::All);
}
let mut out = Vec::with_capacity(items.len());
for item in items {
out.push(parse_select_item(item)?);
}
Ok(Projection::Items(out))
}
fn parse_select_item(item: &SelectItem) -> Result<ProjectionItem> {
match item {
SelectItem::UnnamedExpr(expr) => parse_projection_expr(expr, None),
SelectItem::ExprWithAlias { expr, alias } => {
parse_projection_expr(expr, Some(alias.value.clone()))
}
SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
Err(SQLRiteError::NotImplemented(
"Wildcard mixed with other columns is not supported".to_string(),
))
}
}
}
fn parse_projection_expr(expr: &Expr, alias: Option<String>) -> Result<ProjectionItem> {
match expr {
Expr::Identifier(ident) => Ok(ProjectionItem {
kind: ProjectionKind::Column {
qualifier: None,
name: ident.value.clone(),
},
alias,
}),
Expr::CompoundIdentifier(parts) => match parts.as_slice() {
[only] => Ok(ProjectionItem {
kind: ProjectionKind::Column {
qualifier: None,
name: only.value.clone(),
},
alias,
}),
[q, c] => Ok(ProjectionItem {
kind: ProjectionKind::Column {
qualifier: Some(q.value.clone()),
name: c.value.clone(),
},
alias,
}),
_ => Err(SQLRiteError::NotImplemented(format!(
"compound identifier with {} parts is not supported in projection",
parts.len()
))),
},
Expr::Function(func) => {
let call = parse_aggregate_call(func)?;
Ok(ProjectionItem {
kind: ProjectionKind::Aggregate(call),
alias,
})
}
other => Err(SQLRiteError::NotImplemented(format!(
"Only bare column references and aggregate functions are supported in the projection list (got {other:?})"
))),
}
}
fn parse_aggregate_call(func: &sqlparser::ast::Function) -> Result<AggregateCall> {
let name = match func.name.0.as_slice() {
[sqlparser::ast::ObjectNamePart::Identifier(ident)] => ident.value.clone(),
_ => {
return Err(SQLRiteError::NotImplemented(format!(
"qualified function names not supported: {:?}",
func.name
)));
}
};
let agg_fn = AggregateFn::from_name(&name).ok_or_else(|| {
SQLRiteError::NotImplemented(format!(
"function '{name}' is not supported in the projection list (only aggregate functions are: COUNT, SUM, AVG, MIN, MAX)"
))
})?;
let arg_list = match &func.args {
FunctionArguments::List(l) => l,
_ => {
return Err(SQLRiteError::NotImplemented(format!(
"{name}(...) — unsupported argument shape"
)));
}
};
let distinct = matches!(
arg_list.duplicate_treatment,
Some(DuplicateTreatment::Distinct)
);
if !arg_list.clauses.is_empty() {
return Err(SQLRiteError::NotImplemented(format!(
"{name}(...) — extra argument clauses (ORDER BY / LIMIT inside the call) are not supported"
)));
}
if func.over.is_some() {
return Err(SQLRiteError::NotImplemented(
"window functions (OVER (...)) are not supported".to_string(),
));
}
if func.filter.is_some() {
return Err(SQLRiteError::NotImplemented(
"FILTER (WHERE ...) on aggregates is not supported".to_string(),
));
}
if !func.within_group.is_empty() {
return Err(SQLRiteError::NotImplemented(
"WITHIN GROUP on aggregates is not supported".to_string(),
));
}
if arg_list.args.len() != 1 {
return Err(SQLRiteError::NotImplemented(format!(
"{name}(...) expects exactly one argument, got {}",
arg_list.args.len()
)));
}
let arg = match &arg_list.args[0] {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => AggregateArg::Star,
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(ident))) => {
AggregateArg::Column(ident.value.clone())
}
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
let c = parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
AggregateArg::Column(c)
}
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{name}(...) — argument must be `*` or a bare column reference (got {other:?})"
)));
}
};
if distinct && agg_fn != AggregateFn::Count {
return Err(SQLRiteError::NotImplemented(format!(
"DISTINCT is only supported on COUNT(...) for now, not {}",
agg_fn.as_str()
)));
}
if matches!(arg, AggregateArg::Star) && agg_fn != AggregateFn::Count {
return Err(SQLRiteError::NotImplemented(format!(
"{}(*) is not supported; use {}(<column>)",
agg_fn.as_str(),
agg_fn.as_str()
)));
}
Ok(AggregateCall {
func: agg_fn,
arg,
distinct,
})
}
fn parse_order_by(order_by: Option<&sqlparser::ast::OrderBy>) -> Result<Option<OrderByClause>> {
let Some(ob) = order_by else {
return Ok(None);
};
let exprs = match &ob.kind {
OrderByKind::Expressions(v) => v,
OrderByKind::All(_) => {
return Err(SQLRiteError::NotImplemented(
"ORDER BY ALL is not supported".to_string(),
));
}
};
if exprs.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"ORDER BY must have exactly one column for now".to_string(),
));
}
let obe = &exprs[0];
let expr = obe.expr.clone();
let ascending = obe.options.asc.unwrap_or(true);
Ok(Some(OrderByClause { expr, ascending }))
}
fn parse_limit(limit: Option<&LimitClause>) -> Result<Option<usize>> {
let Some(lc) = limit else {
return Ok(None);
};
let limit_expr = match lc {
LimitClause::LimitOffset { limit, offset, .. } => {
if offset.is_some() {
return Err(SQLRiteError::NotImplemented(
"OFFSET is not supported yet".to_string(),
));
}
limit.as_ref()
}
LimitClause::OffsetCommaLimit { .. } => {
return Err(SQLRiteError::NotImplemented(
"`LIMIT <offset>, <limit>` syntax is not supported yet".to_string(),
));
}
};
let Some(expr) = limit_expr else {
return Ok(None);
};
let n = eval_const_usize(expr)?;
Ok(Some(n))
}
fn eval_const_usize(expr: &Expr) -> Result<usize> {
match expr {
Expr::Value(v) => match &v.value {
sqlparser::ast::Value::Number(n, _) => n.parse::<usize>().map_err(|e| {
SQLRiteError::Internal(format!("LIMIT must be a non-negative integer: {e}"))
}),
_ => Err(SQLRiteError::Internal(
"LIMIT must be an integer literal".to_string(),
)),
},
_ => Err(SQLRiteError::NotImplemented(
"LIMIT expression must be a literal number".to_string(),
)),
}
}