use crate::{TableId, Value};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SelectStatement {
pub select_clause: SelectClause,
pub from_clause: Option<FromClause>,
pub where_clause: Option<WhereExpression>,
pub group_by: Option<GroupByClause>,
pub having_clause: Option<WhereExpression>,
pub order_by: Option<OrderByClause>,
pub limit: Option<LimitClause>,
pub offset: Option<u64>,
pub allow_filtering: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SelectClause {
All,
Columns(Vec<SelectExpression>),
Distinct(Vec<SelectExpression>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SelectExpression {
Column(ColumnRef),
Aggregate(AggregateFunction),
Function(FunctionCall),
Literal(Value),
CollectionAccess(CollectionAccessExpression),
Arithmetic(ArithmeticExpression),
Aliased(Box<SelectExpression>, String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ColumnRef {
pub table: Option<String>,
pub column: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AggregateFunction {
pub function: AggregateType,
pub args: Vec<SelectExpression>,
pub distinct: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum AggregateType {
Count,
Sum,
Avg,
Min,
Max,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub args: Vec<SelectExpression>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CollectionAccessExpression {
ListIndex(ColumnRef, Box<SelectExpression>),
MapKey(ColumnRef, Box<SelectExpression>),
SetContains(ColumnRef, Box<SelectExpression>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ArithmeticExpression {
pub left: Box<SelectExpression>,
pub operator: ArithmeticOperator,
pub right: Box<SelectExpression>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ArithmeticOperator {
Add,
Subtract,
Multiply,
Divide,
Modulo,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum FromClause {
Table(TableId),
TableAlias(TableId, String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum WhereExpression {
Comparison(ComparisonExpression),
And(Vec<WhereExpression>),
Or(Vec<WhereExpression>),
Not(Box<WhereExpression>),
Parentheses(Box<WhereExpression>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ComparisonExpression {
pub left: SelectExpression,
pub operator: ComparisonOperator,
pub right: ComparisonRightSide,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ComparisonRightSide {
Value(SelectExpression),
ValueList(Vec<SelectExpression>),
Range(SelectExpression, SelectExpression),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ComparisonOperator {
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
In,
NotIn,
Like,
NotLike,
Between,
NotBetween,
IsNull,
IsNotNull,
Regex,
Contains,
ContainsKey,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GroupByClause {
pub columns: Vec<ColumnRef>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderByClause {
pub items: Vec<OrderByItem>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderByItem {
pub expression: SelectExpression,
pub direction: SortDirection,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SortDirection {
Ascending,
Descending,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LimitClause {
pub count: u64,
pub per_partition: bool,
}
impl SelectStatement {
pub fn select_all_from(table: TableId) -> Self {
Self {
select_clause: SelectClause::All,
from_clause: Some(FromClause::Table(table)),
where_clause: None,
group_by: None,
having_clause: None,
order_by: None,
limit: None,
offset: None,
allow_filtering: false,
}
}
pub fn requires_aggregation(&self) -> bool {
self.group_by.is_some() || self.has_aggregate_functions()
}
pub fn has_aggregate_functions(&self) -> bool {
match &self.select_clause {
SelectClause::Columns(exprs) | SelectClause::Distinct(exprs) => {
exprs.iter().any(|expr| expr.is_aggregate())
}
SelectClause::All => false,
}
}
pub fn get_referenced_columns(&self) -> Vec<ColumnRef> {
let mut columns = Vec::new();
if let SelectClause::Columns(exprs) | SelectClause::Distinct(exprs) = &self.select_clause {
for expr in exprs {
columns.extend(expr.get_column_refs());
}
}
if let Some(where_expr) = &self.where_clause {
columns.extend(where_expr.get_column_refs());
}
if let Some(group_by) = &self.group_by {
columns.extend(group_by.columns.iter().cloned());
}
if let Some(having) = &self.having_clause {
columns.extend(having.get_column_refs());
}
if let Some(order_by) = &self.order_by {
for item in &order_by.items {
columns.extend(item.expression.get_column_refs());
}
}
columns
}
}
impl SelectExpression {
pub fn is_aggregate(&self) -> bool {
matches!(self, SelectExpression::Aggregate(_))
}
pub fn get_column_refs(&self) -> Vec<ColumnRef> {
match self {
SelectExpression::Column(col_ref) => vec![col_ref.clone()],
SelectExpression::Aggregate(agg) => collect_refs(&agg.args),
SelectExpression::Function(func) => collect_refs(&func.args),
SelectExpression::CollectionAccess(access) => {
let (col_ref, sub_expr) = match access {
CollectionAccessExpression::ListIndex(c, e)
| CollectionAccessExpression::MapKey(c, e)
| CollectionAccessExpression::SetContains(c, e) => (c, e),
};
let mut refs = vec![col_ref.clone()];
refs.extend(sub_expr.get_column_refs());
refs
}
SelectExpression::Arithmetic(arith) => {
let mut refs = arith.left.get_column_refs();
refs.extend(arith.right.get_column_refs());
refs
}
SelectExpression::Aliased(expr, _) => expr.get_column_refs(),
SelectExpression::Literal(_) => Vec::new(),
}
}
}
fn collect_refs(exprs: &[SelectExpression]) -> Vec<ColumnRef> {
exprs
.iter()
.flat_map(SelectExpression::get_column_refs)
.collect()
}
impl WhereExpression {
pub fn get_column_refs(&self) -> Vec<ColumnRef> {
match self {
WhereExpression::Comparison(comp) => {
let mut refs = comp.left.get_column_refs();
match &comp.right {
ComparisonRightSide::Value(expr) => {
refs.extend(expr.get_column_refs());
}
ComparisonRightSide::ValueList(exprs) => {
refs.extend(collect_refs(exprs));
}
ComparisonRightSide::Range(start, end) => {
refs.extend(start.get_column_refs());
refs.extend(end.get_column_refs());
}
}
refs
}
WhereExpression::And(exprs) | WhereExpression::Or(exprs) => exprs
.iter()
.flat_map(WhereExpression::get_column_refs)
.collect(),
WhereExpression::Not(expr) | WhereExpression::Parentheses(expr) => {
expr.get_column_refs()
}
}
}
pub fn can_pushdown_to_sstable(&self) -> bool {
match self {
WhereExpression::Comparison(comp) => {
matches!(comp.left, SelectExpression::Column(_))
&& matches!(
comp.operator,
ComparisonOperator::Equal
| ComparisonOperator::LessThan
| ComparisonOperator::LessThanOrEqual
| ComparisonOperator::GreaterThan
| ComparisonOperator::GreaterThanOrEqual
| ComparisonOperator::In
| ComparisonOperator::Between
)
}
WhereExpression::And(exprs) => {
exprs.iter().all(WhereExpression::can_pushdown_to_sstable)
}
WhereExpression::Or(_) | WhereExpression::Not(_) => false,
WhereExpression::Parentheses(expr) => expr.can_pushdown_to_sstable(),
}
}
}
impl ColumnRef {
pub fn new(column: impl Into<String>) -> Self {
Self {
table: None,
column: column.into(),
}
}
pub fn qualified(table: impl Into<String>, column: impl Into<String>) -> Self {
Self {
table: Some(table.into()),
column: column.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_select_statement() {
let stmt = SelectStatement::select_all_from(TableId::new("users"));
assert_eq!(stmt.select_clause, SelectClause::All);
assert!(!stmt.requires_aggregation());
}
#[test]
fn test_aggregate_detection() {
let stmt = SelectStatement {
select_clause: SelectClause::Columns(vec![SelectExpression::Aggregate(
AggregateFunction {
function: AggregateType::Count,
args: vec![SelectExpression::Column(ColumnRef::new("id"))],
distinct: false,
},
)]),
from_clause: Some(FromClause::Table(TableId::new("users"))),
where_clause: None,
group_by: None,
having_clause: None,
order_by: None,
limit: None,
offset: None,
allow_filtering: false,
};
assert!(stmt.requires_aggregation());
assert!(stmt.has_aggregate_functions());
}
#[test]
fn test_column_references() {
let where_expr = WhereExpression::And(vec![
WhereExpression::Comparison(ComparisonExpression {
left: SelectExpression::Column(ColumnRef::new("age")),
operator: ComparisonOperator::GreaterThan,
right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Integer(21))),
}),
WhereExpression::Comparison(ComparisonExpression {
left: SelectExpression::Column(ColumnRef::new("city")),
operator: ComparisonOperator::Equal,
right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Text(
"NYC".to_string(),
))),
}),
]);
let column_refs = where_expr.get_column_refs();
assert_eq!(column_refs.len(), 2);
assert!(column_refs.iter().any(|col| col.column == "age"));
assert!(column_refs.iter().any(|col| col.column == "city"));
}
#[test]
fn test_pushdown_capability() {
let simple_comparison = WhereExpression::Comparison(ComparisonExpression {
left: SelectExpression::Column(ColumnRef::new("id")),
operator: ComparisonOperator::Equal,
right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Integer(123))),
});
assert!(simple_comparison.can_pushdown_to_sstable());
let complex_or =
WhereExpression::Or(vec![simple_comparison.clone(), simple_comparison.clone()]);
assert!(!complex_or.can_pushdown_to_sstable());
}
}