Skip to main content

cqlite_core/query/
select_ast.rs

1//! CQL SELECT Abstract Syntax Tree.
2//!
3//! AST types for SELECT statements executed directly against SSTable files.
4//! Covers projections, WHERE expressions, aggregates, GROUP BY/HAVING,
5//! ORDER BY, LIMIT/OFFSET, collection access, and arithmetic expressions.
6
7use crate::{TableId, Value};
8use serde::{Deserialize, Serialize};
9
10/// Complete SELECT statement AST
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct SelectStatement {
13    /// SELECT clause - what to return
14    pub select_clause: SelectClause,
15    /// FROM clause - which table(s) to query (optional for constant expressions)
16    pub from_clause: Option<FromClause>,
17    /// WHERE clause - filtering conditions
18    pub where_clause: Option<WhereExpression>,
19    /// GROUP BY clause - grouping columns
20    pub group_by: Option<GroupByClause>,
21    /// HAVING clause - filtering after grouping
22    pub having_clause: Option<WhereExpression>,
23    /// ORDER BY clause - sorting specification
24    pub order_by: Option<OrderByClause>,
25    /// LIMIT clause - result size limitation
26    pub limit: Option<LimitClause>,
27    /// OFFSET clause - result pagination
28    pub offset: Option<u64>,
29    /// Allow filtering flag (for non-indexed queries)
30    pub allow_filtering: bool,
31}
32
33/// SELECT clause - defines what columns/expressions to return
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum SelectClause {
36    /// SELECT * - all columns
37    All,
38    /// SELECT column1, column2, ... - specific columns
39    Columns(Vec<SelectExpression>),
40    /// SELECT DISTINCT column1, column2, ... - unique values only
41    Distinct(Vec<SelectExpression>),
42}
43
44/// Expression in SELECT clause
45#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
46pub enum SelectExpression {
47    /// Simple column reference
48    Column(ColumnRef),
49    /// Aggregate function
50    Aggregate(AggregateFunction),
51    /// Scalar function
52    Function(FunctionCall),
53    /// Literal value
54    Literal(Value),
55    /// Collection access (list[0], map['key'])
56    CollectionAccess(CollectionAccessExpression),
57    /// Arithmetic expression
58    Arithmetic(ArithmeticExpression),
59    /// Aliased expression (expr AS alias)
60    Aliased(Box<SelectExpression>, String),
61}
62
63/// Column reference with optional table qualifier
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct ColumnRef {
66    /// Table name (optional for simple queries)
67    pub table: Option<String>,
68    /// Column name
69    pub column: String,
70}
71
72/// Aggregate function call
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub struct AggregateFunction {
75    /// Function name (COUNT, SUM, AVG, MIN, MAX)
76    pub function: AggregateType,
77    /// Arguments (usually column references)
78    pub args: Vec<SelectExpression>,
79    /// DISTINCT modifier
80    pub distinct: bool,
81}
82
83/// Types of aggregate functions
84#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub enum AggregateType {
86    Count,
87    Sum,
88    Avg,
89    Min,
90    Max,
91}
92
93/// Scalar function call
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct FunctionCall {
96    /// Function name
97    pub name: String,
98    /// Arguments
99    pub args: Vec<SelectExpression>,
100}
101
102/// Collection access operations
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub enum CollectionAccessExpression {
105    /// List element access: list[index]
106    ListIndex(ColumnRef, Box<SelectExpression>),
107    /// Map value access: map['key']
108    MapKey(ColumnRef, Box<SelectExpression>),
109    /// Set membership test: value IN set_column
110    SetContains(ColumnRef, Box<SelectExpression>),
111}
112
113/// Arithmetic expressions
114#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
115pub struct ArithmeticExpression {
116    /// Left operand
117    pub left: Box<SelectExpression>,
118    /// Operator
119    pub operator: ArithmeticOperator,
120    /// Right operand
121    pub right: Box<SelectExpression>,
122}
123
124/// Arithmetic operators
125#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
126pub enum ArithmeticOperator {
127    Add,
128    Subtract,
129    Multiply,
130    Divide,
131    Modulo,
132}
133
134/// FROM clause. Cassandra CQL only supports single-table queries (no JOINs).
135#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
136pub enum FromClause {
137    /// Single table
138    Table(TableId),
139    /// Table with alias (Cassandra CQL supports table aliases)
140    TableAlias(TableId, String),
141}
142
143/// Advanced WHERE expression tree
144#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
145#[allow(clippy::large_enum_variant)]
146pub enum WhereExpression {
147    /// Simple comparison
148    Comparison(ComparisonExpression),
149    /// Logical AND
150    And(Vec<WhereExpression>),
151    /// Logical OR  
152    Or(Vec<WhereExpression>),
153    /// Logical NOT
154    Not(Box<WhereExpression>),
155    /// Parenthesized expression
156    Parentheses(Box<WhereExpression>),
157}
158
159/// Comparison expression
160#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
161pub struct ComparisonExpression {
162    /// Left side (usually column)
163    pub left: SelectExpression,
164    /// Comparison operator
165    pub operator: ComparisonOperator,
166    /// Right side (value, column, or expression)
167    pub right: ComparisonRightSide,
168}
169
170/// Right side of comparison
171#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
172pub enum ComparisonRightSide {
173    /// Single value
174    Value(SelectExpression),
175    /// List of values for IN/NOT IN
176    ValueList(Vec<SelectExpression>),
177    /// Range for BETWEEN
178    Range(SelectExpression, SelectExpression),
179}
180
181/// Comparison operators
182#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub enum ComparisonOperator {
184    /// Equality
185    Equal,
186    /// Inequality
187    NotEqual,
188    /// Less than
189    LessThan,
190    /// Less than or equal
191    LessThanOrEqual,
192    /// Greater than
193    GreaterThan,
194    /// Greater than or equal
195    GreaterThanOrEqual,
196    /// IN operator
197    In,
198    /// NOT IN operator
199    NotIn,
200    /// LIKE operator (pattern matching)
201    Like,
202    /// NOT LIKE operator
203    NotLike,
204    /// BETWEEN operator
205    Between,
206    /// NOT BETWEEN operator
207    NotBetween,
208    /// IS NULL
209    IsNull,
210    /// IS NOT NULL
211    IsNotNull,
212    /// Regular expression matching
213    Regex,
214    /// Collection CONTAINS
215    Contains,
216    /// Collection CONTAINS KEY
217    ContainsKey,
218}
219
220/// GROUP BY clause
221#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
222pub struct GroupByClause {
223    /// Columns to group by
224    pub columns: Vec<ColumnRef>,
225}
226
227/// ORDER BY clause
228#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
229pub struct OrderByClause {
230    /// Order specifications
231    pub items: Vec<OrderByItem>,
232}
233
234/// Individual ORDER BY item
235#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
236pub struct OrderByItem {
237    /// Expression to order by
238    pub expression: SelectExpression,
239    /// Sort direction
240    pub direction: SortDirection,
241}
242
243/// Sort direction
244#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245pub enum SortDirection {
246    Ascending,
247    Descending,
248}
249
250/// LIMIT clause
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252pub struct LimitClause {
253    /// Maximum number of rows
254    pub count: u64,
255    /// Per-partition limit (Cassandra-specific)
256    pub per_partition: bool,
257}
258
259impl SelectStatement {
260    /// Create a simple SELECT * FROM table statement
261    pub fn select_all_from(table: TableId) -> Self {
262        Self {
263            select_clause: SelectClause::All,
264            from_clause: Some(FromClause::Table(table)),
265            where_clause: None,
266            group_by: None,
267            having_clause: None,
268            order_by: None,
269            limit: None,
270            offset: None,
271            allow_filtering: false,
272        }
273    }
274
275    /// Check if this query requires aggregation
276    pub fn requires_aggregation(&self) -> bool {
277        self.group_by.is_some() || self.has_aggregate_functions()
278    }
279
280    /// Check if this query has aggregate functions
281    pub fn has_aggregate_functions(&self) -> bool {
282        match &self.select_clause {
283            SelectClause::Columns(exprs) | SelectClause::Distinct(exprs) => {
284                exprs.iter().any(|expr| expr.is_aggregate())
285            }
286            SelectClause::All => false,
287        }
288    }
289
290    /// Get all referenced columns (for query planning).
291    ///
292    /// `SELECT *` contributes nothing here; the projection is resolved later
293    /// against the schema during planning.
294    pub fn get_referenced_columns(&self) -> Vec<ColumnRef> {
295        let mut columns = Vec::new();
296
297        if let SelectClause::Columns(exprs) | SelectClause::Distinct(exprs) = &self.select_clause {
298            for expr in exprs {
299                columns.extend(expr.get_column_refs());
300            }
301        }
302
303        if let Some(where_expr) = &self.where_clause {
304            columns.extend(where_expr.get_column_refs());
305        }
306
307        if let Some(group_by) = &self.group_by {
308            columns.extend(group_by.columns.iter().cloned());
309        }
310
311        if let Some(having) = &self.having_clause {
312            columns.extend(having.get_column_refs());
313        }
314
315        if let Some(order_by) = &self.order_by {
316            for item in &order_by.items {
317                columns.extend(item.expression.get_column_refs());
318            }
319        }
320
321        columns
322    }
323}
324
325impl SelectExpression {
326    /// Check if this expression is an aggregate function
327    pub fn is_aggregate(&self) -> bool {
328        matches!(self, SelectExpression::Aggregate(_))
329    }
330
331    /// Get all column references in this expression
332    pub fn get_column_refs(&self) -> Vec<ColumnRef> {
333        match self {
334            SelectExpression::Column(col_ref) => vec![col_ref.clone()],
335            SelectExpression::Aggregate(agg) => collect_refs(&agg.args),
336            SelectExpression::Function(func) => collect_refs(&func.args),
337            SelectExpression::CollectionAccess(access) => {
338                let (col_ref, sub_expr) = match access {
339                    CollectionAccessExpression::ListIndex(c, e)
340                    | CollectionAccessExpression::MapKey(c, e)
341                    | CollectionAccessExpression::SetContains(c, e) => (c, e),
342                };
343                let mut refs = vec![col_ref.clone()];
344                refs.extend(sub_expr.get_column_refs());
345                refs
346            }
347            SelectExpression::Arithmetic(arith) => {
348                let mut refs = arith.left.get_column_refs();
349                refs.extend(arith.right.get_column_refs());
350                refs
351            }
352            SelectExpression::Aliased(expr, _) => expr.get_column_refs(),
353            SelectExpression::Literal(_) => Vec::new(),
354        }
355    }
356}
357
358/// Collect column refs from each expression in `exprs`, in order.
359fn collect_refs(exprs: &[SelectExpression]) -> Vec<ColumnRef> {
360    exprs
361        .iter()
362        .flat_map(SelectExpression::get_column_refs)
363        .collect()
364}
365
366impl WhereExpression {
367    /// Get all column references in this WHERE expression
368    pub fn get_column_refs(&self) -> Vec<ColumnRef> {
369        match self {
370            WhereExpression::Comparison(comp) => {
371                let mut refs = comp.left.get_column_refs();
372                match &comp.right {
373                    ComparisonRightSide::Value(expr) => {
374                        refs.extend(expr.get_column_refs());
375                    }
376                    ComparisonRightSide::ValueList(exprs) => {
377                        refs.extend(collect_refs(exprs));
378                    }
379                    ComparisonRightSide::Range(start, end) => {
380                        refs.extend(start.get_column_refs());
381                        refs.extend(end.get_column_refs());
382                    }
383                }
384                refs
385            }
386            WhereExpression::And(exprs) | WhereExpression::Or(exprs) => exprs
387                .iter()
388                .flat_map(WhereExpression::get_column_refs)
389                .collect(),
390            WhereExpression::Not(expr) | WhereExpression::Parentheses(expr) => {
391                expr.get_column_refs()
392            }
393        }
394    }
395
396    /// Check if this WHERE expression can be pushed down to SSTable level.
397    ///
398    /// OR and NOT are excluded: efficient pushdown of those would require
399    /// index intersection / negative scans we don't currently support.
400    pub fn can_pushdown_to_sstable(&self) -> bool {
401        match self {
402            WhereExpression::Comparison(comp) => {
403                matches!(comp.left, SelectExpression::Column(_))
404                    && matches!(
405                        comp.operator,
406                        ComparisonOperator::Equal
407                            | ComparisonOperator::LessThan
408                            | ComparisonOperator::LessThanOrEqual
409                            | ComparisonOperator::GreaterThan
410                            | ComparisonOperator::GreaterThanOrEqual
411                            | ComparisonOperator::In
412                            | ComparisonOperator::Between
413                    )
414            }
415            WhereExpression::And(exprs) => {
416                exprs.iter().all(WhereExpression::can_pushdown_to_sstable)
417            }
418            WhereExpression::Or(_) | WhereExpression::Not(_) => false,
419            WhereExpression::Parentheses(expr) => expr.can_pushdown_to_sstable(),
420        }
421    }
422}
423
424impl ColumnRef {
425    /// Create a simple column reference
426    pub fn new(column: impl Into<String>) -> Self {
427        Self {
428            table: None,
429            column: column.into(),
430        }
431    }
432
433    /// Create a qualified column reference
434    pub fn qualified(table: impl Into<String>, column: impl Into<String>) -> Self {
435        Self {
436            table: Some(table.into()),
437            column: column.into(),
438        }
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_simple_select_statement() {
448        let stmt = SelectStatement::select_all_from(TableId::new("users"));
449        assert_eq!(stmt.select_clause, SelectClause::All);
450        assert!(!stmt.requires_aggregation());
451    }
452
453    #[test]
454    fn test_aggregate_detection() {
455        let stmt = SelectStatement {
456            select_clause: SelectClause::Columns(vec![SelectExpression::Aggregate(
457                AggregateFunction {
458                    function: AggregateType::Count,
459                    args: vec![SelectExpression::Column(ColumnRef::new("id"))],
460                    distinct: false,
461                },
462            )]),
463            from_clause: Some(FromClause::Table(TableId::new("users"))),
464            where_clause: None,
465            group_by: None,
466            having_clause: None,
467            order_by: None,
468            limit: None,
469            offset: None,
470            allow_filtering: false,
471        };
472
473        assert!(stmt.requires_aggregation());
474        assert!(stmt.has_aggregate_functions());
475    }
476
477    #[test]
478    fn test_column_references() {
479        let where_expr = WhereExpression::And(vec![
480            WhereExpression::Comparison(ComparisonExpression {
481                left: SelectExpression::Column(ColumnRef::new("age")),
482                operator: ComparisonOperator::GreaterThan,
483                right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Integer(21))),
484            }),
485            WhereExpression::Comparison(ComparisonExpression {
486                left: SelectExpression::Column(ColumnRef::new("city")),
487                operator: ComparisonOperator::Equal,
488                right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Text(
489                    "NYC".to_string(),
490                ))),
491            }),
492        ]);
493
494        let column_refs = where_expr.get_column_refs();
495        assert_eq!(column_refs.len(), 2);
496        assert!(column_refs.iter().any(|col| col.column == "age"));
497        assert!(column_refs.iter().any(|col| col.column == "city"));
498    }
499
500    #[test]
501    fn test_pushdown_capability() {
502        let simple_comparison = WhereExpression::Comparison(ComparisonExpression {
503            left: SelectExpression::Column(ColumnRef::new("id")),
504            operator: ComparisonOperator::Equal,
505            right: ComparisonRightSide::Value(SelectExpression::Literal(Value::Integer(123))),
506        });
507
508        assert!(simple_comparison.can_pushdown_to_sstable());
509
510        let complex_or =
511            WhereExpression::Or(vec![simple_comparison.clone(), simple_comparison.clone()]);
512
513        assert!(!complex_or.can_pushdown_to_sstable());
514    }
515}