1use crate::{TableId, Value};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct SelectStatement {
13 pub select_clause: SelectClause,
15 pub from_clause: Option<FromClause>,
17 pub where_clause: Option<WhereExpression>,
19 pub group_by: Option<GroupByClause>,
21 pub having_clause: Option<WhereExpression>,
23 pub order_by: Option<OrderByClause>,
25 pub limit: Option<LimitClause>,
27 pub offset: Option<u64>,
29 pub allow_filtering: bool,
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum SelectClause {
36 All,
38 Columns(Vec<SelectExpression>),
40 Distinct(Vec<SelectExpression>),
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
46pub enum SelectExpression {
47 Column(ColumnRef),
49 Aggregate(AggregateFunction),
51 Function(FunctionCall),
53 Literal(Value),
55 CollectionAccess(CollectionAccessExpression),
57 Arithmetic(ArithmeticExpression),
59 Aliased(Box<SelectExpression>, String),
61}
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct ColumnRef {
66 pub table: Option<String>,
68 pub column: String,
70}
71
72#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub struct AggregateFunction {
75 pub function: AggregateType,
77 pub args: Vec<SelectExpression>,
79 pub distinct: bool,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub enum AggregateType {
86 Count,
87 Sum,
88 Avg,
89 Min,
90 Max,
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct FunctionCall {
96 pub name: String,
98 pub args: Vec<SelectExpression>,
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub enum CollectionAccessExpression {
105 ListIndex(ColumnRef, Box<SelectExpression>),
107 MapKey(ColumnRef, Box<SelectExpression>),
109 SetContains(ColumnRef, Box<SelectExpression>),
111}
112
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
115pub struct ArithmeticExpression {
116 pub left: Box<SelectExpression>,
118 pub operator: ArithmeticOperator,
120 pub right: Box<SelectExpression>,
122}
123
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
126pub enum ArithmeticOperator {
127 Add,
128 Subtract,
129 Multiply,
130 Divide,
131 Modulo,
132}
133
134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
136pub enum FromClause {
137 Table(TableId),
139 TableAlias(TableId, String),
141}
142
143#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
145#[allow(clippy::large_enum_variant)]
146pub enum WhereExpression {
147 Comparison(ComparisonExpression),
149 And(Vec<WhereExpression>),
151 Or(Vec<WhereExpression>),
153 Not(Box<WhereExpression>),
155 Parentheses(Box<WhereExpression>),
157}
158
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
161pub struct ComparisonExpression {
162 pub left: SelectExpression,
164 pub operator: ComparisonOperator,
166 pub right: ComparisonRightSide,
168}
169
170#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
172pub enum ComparisonRightSide {
173 Value(SelectExpression),
175 ValueList(Vec<SelectExpression>),
177 Range(SelectExpression, SelectExpression),
179}
180
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub enum ComparisonOperator {
184 Equal,
186 NotEqual,
188 LessThan,
190 LessThanOrEqual,
192 GreaterThan,
194 GreaterThanOrEqual,
196 In,
198 NotIn,
200 Like,
202 NotLike,
204 Between,
206 NotBetween,
208 IsNull,
210 IsNotNull,
212 Regex,
214 Contains,
216 ContainsKey,
218}
219
220#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
222pub struct GroupByClause {
223 pub columns: Vec<ColumnRef>,
225}
226
227#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
229pub struct OrderByClause {
230 pub items: Vec<OrderByItem>,
232}
233
234#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
236pub struct OrderByItem {
237 pub expression: SelectExpression,
239 pub direction: SortDirection,
241}
242
243#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245pub enum SortDirection {
246 Ascending,
247 Descending,
248}
249
250#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252pub struct LimitClause {
253 pub count: u64,
255 pub per_partition: bool,
257}
258
259impl SelectStatement {
260 pub fn select_all_from(table: TableId) -> Self {
262 Self {
263 select_clause: SelectClause::All,
264 from_clause: Some(FromClause::Table(table)),
265 where_clause: None,
266 group_by: None,
267 having_clause: None,
268 order_by: None,
269 limit: None,
270 offset: None,
271 allow_filtering: false,
272 }
273 }
274
275 pub fn requires_aggregation(&self) -> bool {
277 self.group_by.is_some() || self.has_aggregate_functions()
278 }
279
280 pub fn has_aggregate_functions(&self) -> bool {
282 match &self.select_clause {
283 SelectClause::Columns(exprs) | SelectClause::Distinct(exprs) => {
284 exprs.iter().any(|expr| expr.is_aggregate())
285 }
286 SelectClause::All => false,
287 }
288 }
289
290 pub fn get_referenced_columns(&self) -> Vec<ColumnRef> {
295 let mut columns = Vec::new();
296
297 if let SelectClause::Columns(exprs) | SelectClause::Distinct(exprs) = &self.select_clause {
298 for expr in exprs {
299 columns.extend(expr.get_column_refs());
300 }
301 }
302
303 if let Some(where_expr) = &self.where_clause {
304 columns.extend(where_expr.get_column_refs());
305 }
306
307 if let Some(group_by) = &self.group_by {
308 columns.extend(group_by.columns.iter().cloned());
309 }
310
311 if let Some(having) = &self.having_clause {
312 columns.extend(having.get_column_refs());
313 }
314
315 if let Some(order_by) = &self.order_by {
316 for item in &order_by.items {
317 columns.extend(item.expression.get_column_refs());
318 }
319 }
320
321 columns
322 }
323}
324
325impl SelectExpression {
326 pub fn is_aggregate(&self) -> bool {
328 matches!(self, SelectExpression::Aggregate(_))
329 }
330
331 pub fn get_column_refs(&self) -> Vec<ColumnRef> {
333 match self {
334 SelectExpression::Column(col_ref) => vec![col_ref.clone()],
335 SelectExpression::Aggregate(agg) => collect_refs(&agg.args),
336 SelectExpression::Function(func) => collect_refs(&func.args),
337 SelectExpression::CollectionAccess(access) => {
338 let (col_ref, sub_expr) = match access {
339 CollectionAccessExpression::ListIndex(c, e)
340 | CollectionAccessExpression::MapKey(c, e)
341 | CollectionAccessExpression::SetContains(c, e) => (c, e),
342 };
343 let mut refs = vec![col_ref.clone()];
344 refs.extend(sub_expr.get_column_refs());
345 refs
346 }
347 SelectExpression::Arithmetic(arith) => {
348 let mut refs = arith.left.get_column_refs();
349 refs.extend(arith.right.get_column_refs());
350 refs
351 }
352 SelectExpression::Aliased(expr, _) => expr.get_column_refs(),
353 SelectExpression::Literal(_) => Vec::new(),
354 }
355 }
356}
357
358fn collect_refs(exprs: &[SelectExpression]) -> Vec<ColumnRef> {
360 exprs
361 .iter()
362 .flat_map(SelectExpression::get_column_refs)
363 .collect()
364}
365
366impl WhereExpression {
367 pub fn get_column_refs(&self) -> Vec<ColumnRef> {
369 match self {
370 WhereExpression::Comparison(comp) => {
371 let mut refs = comp.left.get_column_refs();
372 match &comp.right {
373 ComparisonRightSide::Value(expr) => {
374 refs.extend(expr.get_column_refs());
375 }
376 ComparisonRightSide::ValueList(exprs) => {
377 refs.extend(collect_refs(exprs));
378 }
379 ComparisonRightSide::Range(start, end) => {
380 refs.extend(start.get_column_refs());
381 refs.extend(end.get_column_refs());
382 }
383 }
384 refs
385 }
386 WhereExpression::And(exprs) | WhereExpression::Or(exprs) => exprs
387 .iter()
388 .flat_map(WhereExpression::get_column_refs)
389 .collect(),
390 WhereExpression::Not(expr) | WhereExpression::Parentheses(expr) => {
391 expr.get_column_refs()
392 }
393 }
394 }
395
396 pub fn can_pushdown_to_sstable(&self) -> bool {
401 match self {
402 WhereExpression::Comparison(comp) => {
403 matches!(comp.left, SelectExpression::Column(_))
404 && matches!(
405 comp.operator,
406 ComparisonOperator::Equal
407 | ComparisonOperator::LessThan
408 | ComparisonOperator::LessThanOrEqual
409 | ComparisonOperator::GreaterThan
410 | ComparisonOperator::GreaterThanOrEqual
411 | ComparisonOperator::In
412 | ComparisonOperator::Between
413 )
414 }
415 WhereExpression::And(exprs) => {
416 exprs.iter().all(WhereExpression::can_pushdown_to_sstable)
417 }
418 WhereExpression::Or(_) | WhereExpression::Not(_) => false,
419 WhereExpression::Parentheses(expr) => expr.can_pushdown_to_sstable(),
420 }
421 }
422}
423
424impl ColumnRef {
425 pub fn new(column: impl Into<String>) -> Self {
427 Self {
428 table: None,
429 column: column.into(),
430 }
431 }
432
433 pub fn qualified(table: impl Into<String>, column: impl Into<String>) -> Self {
435 Self {
436 table: Some(table.into()),
437 column: column.into(),
438 }
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_simple_select_statement() {
448 let stmt = SelectStatement::select_all_from(TableId::new("users"));
449 assert_eq!(stmt.select_clause, SelectClause::All);
450 assert!(!stmt.requires_aggregation());
451 }
452
453 #[test]
454 fn test_aggregate_detection() {
455 let stmt = SelectStatement {
456 select_clause: SelectClause::Columns(vec![SelectExpression::Aggregate(
457 AggregateFunction {
458 function: AggregateType::Count,
459 args: vec![SelectExpression::Column(ColumnRef::new("id"))],
460 distinct: false,
461 },
462 )]),
463 from_clause: Some(FromClause::Table(TableId::new("users"))),
464 where_clause: None,
465 group_by: None,
466 having_clause: None,
467 order_by: None,
468 limit: None,
469 offset: None,
470 allow_filtering: false,
471 };
472
473 assert!(stmt.requires_aggregation());
474 assert!(stmt.has_aggregate_functions());
475 }
476
477 #[test]
478 fn test_column_references() {
479 let where_expr = WhereExpression::And(vec![
480 WhereExpression::Comparison(ComparisonExpression {
481 left: SelectExpression::Column(ColumnRef::new("age")),
482 operator: ComparisonOperator::GreaterThan,
483 right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Integer(21))),
484 }),
485 WhereExpression::Comparison(ComparisonExpression {
486 left: SelectExpression::Column(ColumnRef::new("city")),
487 operator: ComparisonOperator::Equal,
488 right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Text(
489 "NYC".to_string(),
490 ))),
491 }),
492 ]);
493
494 let column_refs = where_expr.get_column_refs();
495 assert_eq!(column_refs.len(), 2);
496 assert!(column_refs.iter().any(|col| col.column == "age"));
497 assert!(column_refs.iter().any(|col| col.column == "city"));
498 }
499
500 #[test]
501 fn test_pushdown_capability() {
502 let simple_comparison = WhereExpression::Comparison(ComparisonExpression {
503 left: SelectExpression::Column(ColumnRef::new("id")),
504 operator: ComparisonOperator::Equal,
505 right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Integer(123))),
506 });
507
508 assert!(simple_comparison.can_pushdown_to_sstable());
509
510 let complex_or =
511 WhereExpression::Or(vec![simple_comparison.clone(), simple_comparison.clone()]);
512
513 assert!(!complex_or.can_pushdown_to_sstable());
514 }
515}