Skip to main content

dbx_core/sql/planner/logical/
select.rs

1//! SELECT statement planning - includes JOIN, Aggregate, and projection logic
2
3use crate::error::{DbxError, DbxResult};
4use crate::sql::planner::types::*;
5use crate::storage::columnar::ScalarValue;
6use sqlparser::ast::{
7    Expr as SqlExpr, GroupByExpr, JoinConstraint, JoinOperator, OrderByExpr as SqlOrderByExpr,
8    Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins,
9};
10
11use super::LogicalPlanner;
12use super::helpers::{convert_binary_op, match_scalar_function};
13
14impl LogicalPlanner {
15    /// Query → LogicalPlan 변환
16    pub(super) fn plan_query(&self, query: &Query) -> DbxResult<LogicalPlan> {
17        let mut plan = match query.body.as_ref() {
18            SetExpr::Select(select) => self.plan_select(select)?,
19            _ => {
20                return Err(DbxError::SqlNotSupported {
21                    feature: "Non-SELECT query body".to_string(),
22                    hint: "Only SELECT queries are supported".to_string(),
23                });
24            }
25        };
26
27        // ORDER BY
28        if let Some(order_by) = &query.order_by {
29            let sort_exprs: Vec<SortExpr> = order_by
30                .exprs
31                .iter()
32                .map(|ob| self.plan_order_by_expr(ob))
33                .collect::<DbxResult<Vec<_>>>()?;
34            plan = LogicalPlan::Sort {
35                input: Box::new(plan),
36                order_by: sort_exprs,
37            };
38        }
39
40        // LIMIT and OFFSET
41        if query.limit.is_some() || query.offset.is_some() {
42            let limit = if let Some(limit_expr) = &query.limit {
43                super::helpers::extract_usize(limit_expr)?
44            } else {
45                usize::MAX
46            };
47
48            let offset = if let Some(offset_struct) = &query.offset {
49                super::helpers::extract_usize(&offset_struct.value)?
50            } else {
51                0
52            };
53
54            plan = LogicalPlan::Limit {
55                input: Box::new(plan),
56                count: limit,
57                offset,
58            };
59        }
60
61        Ok(plan)
62    }
63
64    /// SELECT → LogicalPlan 변환
65    pub(super) fn plan_select(&self, select: &Select) -> DbxResult<LogicalPlan> {
66        // Clear alias map for new query
67        self.alias_map.write().unwrap().clear();
68
69        // 0. Pre-scan projections for aliases to support WHERE/ORDER BY
70        for item in &select.projection {
71            if let SelectItem::ExprWithAlias { expr, alias } = item {
72                let planned_expr = self.plan_expr(expr)?;
73                self.alias_map
74                    .write()
75                    .unwrap()
76                    .insert(alias.value.clone(), planned_expr);
77            }
78        }
79
80        // 1. FROM 절 → Scan
81        let mut plan = self.plan_from(&select.from)?;
82
83        // 2. WHERE 절 → Filter
84        if let Some(ref selection) = select.selection {
85            let predicate = self.plan_expr(selection)?;
86            plan = LogicalPlan::Filter {
87                input: Box::new(plan),
88                predicate,
89            };
90        }
91
92        // 3. GROUP BY 절 → Aggregate
93        let group_by_exprs = match &select.group_by {
94            GroupByExpr::Expressions(exprs, _) => exprs
95                .iter()
96                .map(|e| self.plan_expr(e))
97                .collect::<DbxResult<Vec<_>>>()?,
98            GroupByExpr::All(_) => vec![], // GROUP BY ALL — treat as empty
99        };
100
101        // Extract aggregate functions from SELECT items
102        let aggregates = self.extract_aggregates(&select.projection)?;
103        let has_aggregates = !group_by_exprs.is_empty() || !aggregates.is_empty();
104
105        // 4. SELECT 절 → Project
106        let projections = self.plan_projection(&select.projection)?;
107
108        // Skip Project node if it's a simple aggregate query (check before move)
109        let is_simple_agg = !aggregates.is_empty()
110            && group_by_exprs.is_empty()
111            && projections.len() == aggregates.len()
112            && projections
113                .iter()
114                .all(|(e, _)| matches!(e, Expr::Function { .. }));
115
116        if has_aggregates {
117            plan = LogicalPlan::Aggregate {
118                input: Box::new(plan),
119                group_by: group_by_exprs,
120                aggregates,
121            };
122        }
123
124        if !projections.is_empty() && !is_simple_agg {
125            plan = LogicalPlan::Project {
126                input: Box::new(plan),
127                projections,
128            };
129        }
130
131        Ok(plan)
132    }
133
134    /// Convert sqlparser OrderByExpr → our SortExpr
135    pub(super) fn plan_order_by_expr(&self, ob: &SqlOrderByExpr) -> DbxResult<SortExpr> {
136        let expr = self.plan_expr(&ob.expr)?;
137        Ok(SortExpr {
138            expr,
139            asc: ob.asc.unwrap_or(true),
140            nulls_first: ob.nulls_first.unwrap_or(true),
141        })
142    }
143
144    /// Extract aggregate function calls from SELECT items.
145    pub(super) fn extract_aggregates(
146        &self,
147        projection: &[SelectItem],
148    ) -> DbxResult<Vec<AggregateExpr>> {
149        let mut aggregates = Vec::new();
150        for item in projection {
151            match item {
152                SelectItem::UnnamedExpr(expr) => {
153                    if let Some(agg) = self.try_extract_aggregate(expr, None)? {
154                        aggregates.push(agg);
155                    }
156                }
157                SelectItem::ExprWithAlias { expr, alias } => {
158                    if let Some(agg) =
159                        self.try_extract_aggregate(expr, Some(alias.value.clone()))?
160                    {
161                        aggregates.push(agg);
162                    }
163                }
164                _ => {}
165            }
166        }
167        Ok(aggregates)
168    }
169
170    /// Try to extract an aggregate expression from a SQL expression.
171    pub(super) fn try_extract_aggregate(
172        &self,
173        expr: &SqlExpr,
174        alias: Option<String>,
175    ) -> DbxResult<Option<AggregateExpr>> {
176        match expr {
177            SqlExpr::Function(func) => {
178                let func_name = func.name.to_string().to_uppercase();
179                let agg_func = match func_name.as_str() {
180                    "COUNT" => Some(AggregateFunction::Count),
181                    "SUM" => Some(AggregateFunction::Sum),
182                    "AVG" => Some(AggregateFunction::Avg),
183                    "MIN" => Some(AggregateFunction::Min),
184                    "MAX" => Some(AggregateFunction::Max),
185                    _ => None,
186                };
187
188                if let Some(function) = agg_func {
189                    let arg_expr = match &func.args {
190                        sqlparser::ast::FunctionArguments::None => {
191                            // COUNT(*)
192                            Expr::Literal(ScalarValue::Int32(1))
193                        }
194                        _ => self.plan_function_arg(&func.args)?,
195                    };
196                    Ok(Some(AggregateExpr {
197                        function,
198                        expr: arg_expr,
199                        alias,
200                    }))
201                } else {
202                    Ok(None)
203                }
204            }
205            _ => Ok(None),
206        }
207    }
208
209    /// Plan function arguments (take first arg).
210    pub(super) fn plan_function_arg(
211        &self,
212        args: &sqlparser::ast::FunctionArguments,
213    ) -> DbxResult<Expr> {
214        match args {
215            sqlparser::ast::FunctionArguments::List(arg_list) => {
216                if arg_list.args.is_empty() {
217                    return Ok(Expr::Literal(ScalarValue::Int32(1))); // COUNT(*)
218                }
219                match &arg_list.args[0] {
220                    sqlparser::ast::FunctionArg::Unnamed(arg_expr) => {
221                        match arg_expr {
222                            sqlparser::ast::FunctionArgExpr::Expr(e) => self.plan_expr(e),
223                            sqlparser::ast::FunctionArgExpr::Wildcard => {
224                                Ok(Expr::Literal(ScalarValue::Int32(1))) // COUNT(*)
225                            }
226                            sqlparser::ast::FunctionArgExpr::QualifiedWildcard(_) => {
227                                Ok(Expr::Literal(ScalarValue::Int32(1)))
228                            }
229                        }
230                    }
231                    sqlparser::ast::FunctionArg::Named { arg, .. } => match arg {
232                        sqlparser::ast::FunctionArgExpr::Expr(e) => self.plan_expr(e),
233                        _ => Ok(Expr::Literal(ScalarValue::Int32(1))),
234                    },
235                }
236            }
237            sqlparser::ast::FunctionArguments::None => Ok(Expr::Literal(ScalarValue::Int32(1))),
238            sqlparser::ast::FunctionArguments::Subquery(_) => Err(DbxError::NotImplemented(
239                "Subquery function arguments".to_string(),
240            )),
241        }
242    }
243
244    /// FROM 절 → Scan (with JOIN support)
245    pub(super) fn plan_from(&self, from: &[TableWithJoins]) -> DbxResult<LogicalPlan> {
246        if from.is_empty() {
247            return Err(DbxError::Schema("FROM clause is required".to_string()));
248        }
249
250        if from.len() > 1 {
251            return Err(DbxError::SqlNotSupported {
252                feature: "Multiple tables in FROM clause".to_string(),
253                hint: "Use JOIN syntax or separate queries".to_string(),
254            });
255        }
256
257        let table_with_joins = &from[0];
258        let table_name = match &table_with_joins.relation {
259            TableFactor::Table { name, .. } => name.to_string(),
260            _ => {
261                return Err(DbxError::SqlNotSupported {
262                    feature: "Complex table expressions".to_string(),
263                    hint: "Use simple table names only".to_string(),
264                });
265            }
266        };
267
268        // Start with base table scan
269        let mut plan = LogicalPlan::Scan {
270            table: table_name,
271            columns: vec![], // All columns (optimized later by projection pushdown)
272            filter: None,
273        };
274
275        // Process JOINs
276        for join in &table_with_joins.joins {
277            let right_table = match &join.relation {
278                TableFactor::Table { name, .. } => name.to_string(),
279                _ => {
280                    return Err(DbxError::SqlNotSupported {
281                        feature: "Complex JOIN table expressions".to_string(),
282                        hint: "Use simple table names in JOIN clauses".to_string(),
283                    });
284                }
285            };
286
287            let right_plan = LogicalPlan::Scan {
288                table: right_table,
289                columns: vec![],
290                filter: None,
291            };
292
293            // Determine JOIN type
294            let join_type = match &join.join_operator {
295                JoinOperator::Inner(_) => JoinType::Inner,
296                JoinOperator::LeftOuter(_) => JoinType::Left,
297                JoinOperator::RightOuter(_) => JoinType::Right,
298                JoinOperator::CrossJoin => JoinType::Cross,
299                _ => {
300                    return Err(DbxError::SqlNotSupported {
301                        feature: format!("JOIN type: {:?}", join.join_operator),
302                        hint: "Supported: INNER, LEFT, RIGHT, CROSS JOIN".to_string(),
303                    });
304                }
305            };
306
307            // Extract JOIN condition
308            let on_expr = match &join.join_operator {
309                JoinOperator::Inner(constraint)
310                | JoinOperator::LeftOuter(constraint)
311                | JoinOperator::RightOuter(constraint) => match constraint {
312                    JoinConstraint::On(expr) => self.plan_expr(expr)?,
313                    JoinConstraint::Using(_) => {
314                        return Err(DbxError::SqlNotSupported {
315                            feature: "JOIN USING clause".to_string(),
316                            hint: "Use ON clause instead (e.g., ON a.id = b.id)".to_string(),
317                        });
318                    }
319                    JoinConstraint::Natural => {
320                        return Err(DbxError::SqlNotSupported {
321                            feature: "NATURAL JOIN".to_string(),
322                            hint: "Use explicit ON clause instead".to_string(),
323                        });
324                    }
325                    JoinConstraint::None => {
326                        return Err(DbxError::Schema("JOIN requires ON condition".to_string()));
327                    }
328                },
329                JoinOperator::CrossJoin => {
330                    // CROSS JOIN has no condition (Cartesian product)
331                    Expr::Literal(ScalarValue::Boolean(true))
332                }
333                _ => {
334                    return Err(DbxError::SqlNotSupported {
335                        feature: "Unsupported JOIN operator".to_string(),
336                        hint: "Use INNER, LEFT, RIGHT, or CROSS JOIN".to_string(),
337                    });
338                }
339            };
340
341            plan = LogicalPlan::Join {
342                left: Box::new(plan),
343                right: Box::new(right_plan),
344                join_type,
345                on: on_expr,
346            };
347        }
348
349        Ok(plan)
350    }
351
352    /// SELECT 절 → Vec<(Expr, Option<String>)>
353    pub(super) fn plan_projection(
354        &self,
355        projection: &[SelectItem],
356    ) -> DbxResult<Vec<(Expr, Option<String>)>> {
357        let mut projections = Vec::new();
358
359        for item in projection {
360            match item {
361                SelectItem::Wildcard(_) => {
362                    // SELECT * -> empty projections means all columns
363                }
364                SelectItem::UnnamedExpr(expr) => {
365                    let planned = self.plan_expr(expr)?;
366                    let alias = if let Expr::Column(name) = &planned {
367                        Some(name.clone())
368                    } else {
369                        None
370                    };
371                    projections.push((planned, alias));
372                }
373                SelectItem::ExprWithAlias { expr, alias } => {
374                    projections.push((self.plan_expr(expr)?, Some(alias.value.clone())));
375                }
376                _ => {
377                    return Err(DbxError::NotImplemented(format!(
378                        "Unsupported SELECT item: {:?}",
379                        item
380                    )));
381                }
382            }
383        }
384
385        Ok(projections)
386    }
387
388    /// SQL Expr → Logical Expr 변환
389    pub(super) fn plan_expr(&self, expr: &SqlExpr) -> DbxResult<Expr> {
390        match expr {
391            SqlExpr::Identifier(ident) => {
392                let name = ident.value.clone();
393                // Check if this identifier is an alias defined in SELECT
394                if let Some(aliased_expr) = self.alias_map.read().unwrap().get(&name) {
395                    return Ok(aliased_expr.clone());
396                }
397                Ok(Expr::Column(name))
398            }
399            SqlExpr::Value(value) => {
400                let scalar = match value {
401                    sqlparser::ast::Value::Number(n, _) => {
402                        if let Ok(i) = n.parse::<i32>() {
403                            ScalarValue::Int32(i)
404                        } else if let Ok(i) = n.parse::<i64>() {
405                            ScalarValue::Int64(i)
406                        } else if let Ok(f) = n.parse::<f64>() {
407                            ScalarValue::Float64(f)
408                        } else {
409                            return Err(DbxError::Schema(format!("Invalid number: {}", n)));
410                        }
411                    }
412                    sqlparser::ast::Value::SingleQuotedString(s) => ScalarValue::Utf8(s.clone()),
413                    sqlparser::ast::Value::Boolean(b) => ScalarValue::Boolean(*b),
414                    sqlparser::ast::Value::Null => ScalarValue::Null,
415                    _ => {
416                        return Err(DbxError::NotImplemented(format!(
417                            "Unsupported value: {:?}",
418                            value
419                        )));
420                    }
421                };
422                Ok(Expr::Literal(scalar))
423            }
424            SqlExpr::BinaryOp { left, op, right } => {
425                let left_expr = self.plan_expr(left)?;
426                let right_expr = self.plan_expr(right)?;
427                let binary_op = convert_binary_op(op)?;
428                Ok(Expr::BinaryOp {
429                    left: Box::new(left_expr),
430                    op: binary_op,
431                    right: Box::new(right_expr),
432                })
433            }
434            SqlExpr::IsNull(expr) => {
435                let inner = self.plan_expr(expr)?;
436                Ok(Expr::IsNull(Box::new(inner)))
437            }
438            SqlExpr::IsNotNull(expr) => {
439                let inner = self.plan_expr(expr)?;
440                Ok(Expr::IsNotNull(Box::new(inner)))
441            }
442            SqlExpr::Function(func) => {
443                let name = func.name.to_string().to_uppercase();
444                let args: Vec<Expr> = match &func.args {
445                    sqlparser::ast::FunctionArguments::List(arg_list) => {
446                        let mut planned_args = Vec::new();
447                        for arg in &arg_list.args {
448                            if let sqlparser::ast::FunctionArg::Unnamed(
449                                sqlparser::ast::FunctionArgExpr::Expr(e),
450                            ) = arg
451                            {
452                                planned_args.push(self.plan_expr(e)?)
453                            }
454                        }
455                        planned_args
456                    }
457                    _ => vec![],
458                };
459
460                // 스칼라 함수 매핑 시도
461                if let Some(scalar_func) = match_scalar_function(&name) {
462                    Ok(Expr::ScalarFunc {
463                        func: scalar_func,
464                        args,
465                    })
466                } else {
467                    // 집계 함수로 처리 (실제 집계 여부는 추후 Optimizer/Planner에서 검증)
468                    Ok(Expr::Function { name, args })
469                }
470            }
471            SqlExpr::Nested(expr) => self.plan_expr(expr),
472            SqlExpr::CompoundIdentifier(idents) => {
473                // table.column → just use the column name
474                let col_name = idents.last().map(|i| i.value.clone()).unwrap_or_default();
475                Ok(Expr::Column(col_name))
476            }
477            _ => Err(DbxError::NotImplemented(format!(
478                "Unsupported expression: {:?}",
479                expr
480            ))),
481        }
482    }
483}