Skip to main content

dbx_core/sql/planner/logical/
select.rs

1//! SELECT statement planning - includes JOIN, Aggregate, and projection logic
2
3use crate::error::{DbxError, DbxResult};
4use crate::sql::planner::types::*;
5use crate::storage::columnar::ScalarValue;
6use sqlparser::ast::{
7    Expr as SqlExpr, GroupByExpr, JoinConstraint, JoinOperator, OrderByExpr as SqlOrderByExpr,
8    Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins,
9};
10
11use super::LogicalPlanner;
12use super::helpers::{convert_binary_op, match_scalar_function};
13
14impl LogicalPlanner {
15    /// Query → LogicalPlan 변환
16    pub(super) fn plan_query(&self, query: &Query) -> DbxResult<LogicalPlan> {
17        let mut plan = match query.body.as_ref() {
18            SetExpr::Select(select) => self.plan_select(select)?,
19            _ => {
20                return Err(DbxError::SqlNotSupported {
21                    feature: "Non-SELECT query body".to_string(),
22                    hint: "Only SELECT queries are supported".to_string(),
23                });
24            }
25        };
26
27        // ORDER BY
28        if let Some(order_by) = &query.order_by {
29            let sort_exprs: Vec<SortExpr> = order_by
30                .exprs
31                .iter()
32                .map(|ob| self.plan_order_by_expr(ob))
33                .collect::<DbxResult<Vec<_>>>()?;
34            plan = LogicalPlan::Sort {
35                input: Box::new(plan),
36                order_by: sort_exprs,
37            };
38        }
39
40        // LIMIT and OFFSET
41        if query.limit.is_some() || query.offset.is_some() {
42            let limit = if let Some(limit_expr) = &query.limit {
43                super::helpers::extract_usize(limit_expr)?
44            } else {
45                usize::MAX
46            };
47
48            let offset = if let Some(offset_struct) = &query.offset {
49                super::helpers::extract_usize(&offset_struct.value)?
50            } else {
51                0
52            };
53
54            plan = LogicalPlan::Limit {
55                input: Box::new(plan),
56                count: limit,
57                offset,
58            };
59        }
60
61        Ok(plan)
62    }
63
64    /// SELECT → LogicalPlan 변환
65    pub(super) fn plan_select(&self, select: &Select) -> DbxResult<LogicalPlan> {
66        // Clear alias map for new query
67        self.alias_map.write().unwrap().clear();
68
69        // 0. Pre-scan projections for aliases to support WHERE/ORDER BY
70        for item in &select.projection {
71            if let SelectItem::ExprWithAlias { expr, alias } = item {
72                let planned_expr = self.plan_expr(expr)?;
73                self.alias_map
74                    .write()
75                    .unwrap()
76                    .insert(alias.value.clone(), planned_expr);
77            }
78        }
79
80        // 1. FROM 절 → Scan
81        let mut plan = self.plan_from(&select.from)?;
82
83        // 2. WHERE 절 → Filter
84        if let Some(ref selection) = select.selection {
85            let predicate = self.plan_expr(selection)?;
86            plan = LogicalPlan::Filter {
87                input: Box::new(plan),
88                predicate,
89            };
90        }
91
92        // 3. GROUP BY 절 → Aggregate
93        let group_by_exprs = match &select.group_by {
94            GroupByExpr::Expressions(exprs, _) => exprs
95                .iter()
96                .map(|e| self.plan_expr(e))
97                .collect::<DbxResult<Vec<_>>>()?,
98            GroupByExpr::All(_) => vec![], // GROUP BY ALL — treat as empty
99        };
100
101        // Extract aggregate functions from SELECT items
102        let aggregates = self.extract_aggregates(&select.projection)?;
103        let has_aggregates = !group_by_exprs.is_empty() || !aggregates.is_empty();
104
105        // 4. SELECT 절 → Project
106        let projections = self.plan_projection(&select.projection)?;
107
108        // Skip Project node if it's a simple aggregate query (check before move)
109        let is_simple_agg = !aggregates.is_empty()
110            && group_by_exprs.is_empty()
111            && projections.len() == aggregates.len()
112            && projections
113                .iter()
114                .all(|(e, _)| matches!(e, Expr::Function { .. }));
115
116        if has_aggregates {
117            plan = LogicalPlan::Aggregate {
118                input: Box::new(plan),
119                group_by: group_by_exprs,
120                aggregates,
121                mode: AggregateMode::Simple,
122            };
123        }
124
125        if !projections.is_empty() && !is_simple_agg {
126            plan = LogicalPlan::Project {
127                input: Box::new(plan),
128                projections,
129            };
130        }
131
132        Ok(plan)
133    }
134
135    /// Convert sqlparser OrderByExpr → our SortExpr
136    pub(super) fn plan_order_by_expr(&self, ob: &SqlOrderByExpr) -> DbxResult<SortExpr> {
137        let expr = self.plan_expr(&ob.expr)?;
138        Ok(SortExpr {
139            expr,
140            asc: ob.asc.unwrap_or(true),
141            nulls_first: ob.nulls_first.unwrap_or(true),
142        })
143    }
144
145    /// Extract aggregate function calls from SELECT items.
146    pub(super) fn extract_aggregates(
147        &self,
148        projection: &[SelectItem],
149    ) -> DbxResult<Vec<AggregateExpr>> {
150        let mut aggregates = Vec::new();
151        for item in projection {
152            match item {
153                SelectItem::UnnamedExpr(expr) => {
154                    if let Some(agg) = self.try_extract_aggregate(expr, None)? {
155                        aggregates.push(agg);
156                    }
157                }
158                SelectItem::ExprWithAlias { expr, alias } => {
159                    if let Some(agg) =
160                        self.try_extract_aggregate(expr, Some(alias.value.clone()))?
161                    {
162                        aggregates.push(agg);
163                    }
164                }
165                _ => {}
166            }
167        }
168        Ok(aggregates)
169    }
170
171    /// Try to extract an aggregate expression from a SQL expression.
172    pub(super) fn try_extract_aggregate(
173        &self,
174        expr: &SqlExpr,
175        alias: Option<String>,
176    ) -> DbxResult<Option<AggregateExpr>> {
177        match expr {
178            SqlExpr::Function(func) => {
179                let func_name = func.name.to_string().to_uppercase();
180                let agg_func = match func_name.as_str() {
181                    "COUNT" => Some(AggregateFunction::Count),
182                    "SUM" => Some(AggregateFunction::Sum),
183                    "AVG" => Some(AggregateFunction::Avg),
184                    "MIN" => Some(AggregateFunction::Min),
185                    "MAX" => Some(AggregateFunction::Max),
186                    _ => None,
187                };
188
189                if let Some(function) = agg_func {
190                    let arg_expr = match &func.args {
191                        sqlparser::ast::FunctionArguments::None => {
192                            // COUNT(*)
193                            Expr::Literal(ScalarValue::Int32(1))
194                        }
195                        _ => self.plan_function_arg(&func.args)?,
196                    };
197                    Ok(Some(AggregateExpr {
198                        function,
199                        expr: arg_expr,
200                        alias,
201                    }))
202                } else {
203                    Ok(None)
204                }
205            }
206            _ => Ok(None),
207        }
208    }
209
210    /// Plan function arguments (take first arg).
211    pub(super) fn plan_function_arg(
212        &self,
213        args: &sqlparser::ast::FunctionArguments,
214    ) -> DbxResult<Expr> {
215        match args {
216            sqlparser::ast::FunctionArguments::List(arg_list) => {
217                if arg_list.args.is_empty() {
218                    return Ok(Expr::Literal(ScalarValue::Int32(1))); // COUNT(*)
219                }
220                match &arg_list.args[0] {
221                    sqlparser::ast::FunctionArg::Unnamed(arg_expr) => {
222                        match arg_expr {
223                            sqlparser::ast::FunctionArgExpr::Expr(e) => self.plan_expr(e),
224                            sqlparser::ast::FunctionArgExpr::Wildcard => {
225                                Ok(Expr::Literal(ScalarValue::Int32(1))) // COUNT(*)
226                            }
227                            sqlparser::ast::FunctionArgExpr::QualifiedWildcard(_) => {
228                                Ok(Expr::Literal(ScalarValue::Int32(1)))
229                            }
230                        }
231                    }
232                    sqlparser::ast::FunctionArg::Named { arg, .. } => match arg {
233                        sqlparser::ast::FunctionArgExpr::Expr(e) => self.plan_expr(e),
234                        _ => Ok(Expr::Literal(ScalarValue::Int32(1))),
235                    },
236                }
237            }
238            sqlparser::ast::FunctionArguments::None => Ok(Expr::Literal(ScalarValue::Int32(1))),
239            sqlparser::ast::FunctionArguments::Subquery(_) => Err(DbxError::NotImplemented(
240                "Subquery function arguments".to_string(),
241            )),
242        }
243    }
244
245    /// FROM 절 → Scan (with JOIN support)
246    pub(super) fn plan_from(&self, from: &[TableWithJoins]) -> DbxResult<LogicalPlan> {
247        if from.is_empty() {
248            return Err(DbxError::Schema("FROM clause is required".to_string()));
249        }
250
251        if from.len() > 1 {
252            return Err(DbxError::SqlNotSupported {
253                feature: "Multiple tables in FROM clause".to_string(),
254                hint: "Use JOIN syntax or separate queries".to_string(),
255            });
256        }
257
258        let table_with_joins = &from[0];
259        let table_name = match &table_with_joins.relation {
260            TableFactor::Table { name, .. } => name.to_string(),
261            _ => {
262                return Err(DbxError::SqlNotSupported {
263                    feature: "Complex table expressions".to_string(),
264                    hint: "Use simple table names only".to_string(),
265                });
266            }
267        };
268
269        // Start with base table scan
270        let mut plan = LogicalPlan::Scan {
271            table: table_name,
272            columns: vec![], // All columns (optimized later by projection pushdown)
273            filter: None,
274            ros_files: vec![],
275        };
276
277        // Process JOINs
278        for join in &table_with_joins.joins {
279            let right_table = match &join.relation {
280                TableFactor::Table { name, .. } => name.to_string(),
281                _ => {
282                    return Err(DbxError::SqlNotSupported {
283                        feature: "Complex JOIN table expressions".to_string(),
284                        hint: "Use simple table names in JOIN clauses".to_string(),
285                    });
286                }
287            };
288
289            let right_plan = LogicalPlan::Scan {
290                table: right_table,
291                columns: vec![],
292                filter: None,
293                ros_files: vec![],
294            };
295
296            // Determine JOIN type
297            let join_type = match &join.join_operator {
298                JoinOperator::Inner(_) => JoinType::Inner,
299                JoinOperator::LeftOuter(_) => JoinType::Left,
300                JoinOperator::RightOuter(_) => JoinType::Right,
301                JoinOperator::CrossJoin => JoinType::Cross,
302                _ => {
303                    return Err(DbxError::SqlNotSupported {
304                        feature: format!("JOIN type: {:?}", join.join_operator),
305                        hint: "Supported: INNER, LEFT, RIGHT, CROSS JOIN".to_string(),
306                    });
307                }
308            };
309
310            // Extract JOIN condition
311            let on_expr = match &join.join_operator {
312                JoinOperator::Inner(constraint)
313                | JoinOperator::LeftOuter(constraint)
314                | JoinOperator::RightOuter(constraint) => match constraint {
315                    JoinConstraint::On(expr) => self.plan_expr(expr)?,
316                    JoinConstraint::Using(_) => {
317                        return Err(DbxError::SqlNotSupported {
318                            feature: "JOIN USING clause".to_string(),
319                            hint: "Use ON clause instead (e.g., ON a.id = b.id)".to_string(),
320                        });
321                    }
322                    JoinConstraint::Natural => {
323                        return Err(DbxError::SqlNotSupported {
324                            feature: "NATURAL JOIN".to_string(),
325                            hint: "Use explicit ON clause instead".to_string(),
326                        });
327                    }
328                    JoinConstraint::None => {
329                        return Err(DbxError::Schema("JOIN requires ON condition".to_string()));
330                    }
331                },
332                JoinOperator::CrossJoin => {
333                    // CROSS JOIN has no condition (Cartesian product)
334                    Expr::Literal(ScalarValue::Boolean(true))
335                }
336                _ => {
337                    return Err(DbxError::SqlNotSupported {
338                        feature: "Unsupported JOIN operator".to_string(),
339                        hint: "Use INNER, LEFT, RIGHT, or CROSS JOIN".to_string(),
340                    });
341                }
342            };
343
344            plan = LogicalPlan::Join {
345                left: Box::new(plan),
346                right: Box::new(right_plan),
347                join_type,
348                on: on_expr,
349            };
350        }
351
352        Ok(plan)
353    }
354
355    /// SELECT 절 → Vec<(Expr, Option<String>)>
356    pub(super) fn plan_projection(
357        &self,
358        projection: &[SelectItem],
359    ) -> DbxResult<Vec<(Expr, Option<String>)>> {
360        let mut projections = Vec::new();
361
362        for item in projection {
363            match item {
364                SelectItem::Wildcard(_) => {
365                    // SELECT * -> empty projections means all columns
366                }
367                SelectItem::UnnamedExpr(expr) => {
368                    let planned = self.plan_expr(expr)?;
369                    let alias = if let Expr::Column(name) = &planned {
370                        Some(name.clone())
371                    } else {
372                        None
373                    };
374                    projections.push((planned, alias));
375                }
376                SelectItem::ExprWithAlias { expr, alias } => {
377                    projections.push((self.plan_expr(expr)?, Some(alias.value.clone())));
378                }
379                _ => {
380                    return Err(DbxError::NotImplemented(format!(
381                        "Unsupported SELECT item: {:?}",
382                        item
383                    )));
384                }
385            }
386        }
387
388        Ok(projections)
389    }
390
391    /// SQL Expr → Logical Expr 변환
392    pub(super) fn plan_expr(&self, expr: &SqlExpr) -> DbxResult<Expr> {
393        match expr {
394            SqlExpr::Identifier(ident) => {
395                let name = ident.value.clone();
396                // Check if this identifier is an alias defined in SELECT
397                if let Some(aliased_expr) = self.alias_map.read().unwrap().get(&name) {
398                    return Ok(aliased_expr.clone());
399                }
400                Ok(Expr::Column(name))
401            }
402            SqlExpr::Value(value) => {
403                let scalar = match value {
404                    sqlparser::ast::Value::Number(n, _) => {
405                        if let Ok(i) = n.parse::<i32>() {
406                            ScalarValue::Int32(i)
407                        } else if let Ok(i) = n.parse::<i64>() {
408                            ScalarValue::Int64(i)
409                        } else if let Ok(f) = n.parse::<f64>() {
410                            ScalarValue::Float64(f)
411                        } else {
412                            return Err(DbxError::Schema(format!("Invalid number: {}", n)));
413                        }
414                    }
415                    sqlparser::ast::Value::SingleQuotedString(s) => ScalarValue::Utf8(s.clone()),
416                    sqlparser::ast::Value::Boolean(b) => ScalarValue::Boolean(*b),
417                    sqlparser::ast::Value::Null => ScalarValue::Null,
418                    _ => {
419                        return Err(DbxError::NotImplemented(format!(
420                            "Unsupported value: {:?}",
421                            value
422                        )));
423                    }
424                };
425                Ok(Expr::Literal(scalar))
426            }
427            SqlExpr::BinaryOp { left, op, right } => {
428                let left_expr = self.plan_expr(left)?;
429                let right_expr = self.plan_expr(right)?;
430                let binary_op = convert_binary_op(op)?;
431                Ok(Expr::BinaryOp {
432                    left: Box::new(left_expr),
433                    op: binary_op,
434                    right: Box::new(right_expr),
435                })
436            }
437            SqlExpr::IsNull(expr) => {
438                let inner = self.plan_expr(expr)?;
439                Ok(Expr::IsNull(Box::new(inner)))
440            }
441            SqlExpr::IsNotNull(expr) => {
442                let inner = self.plan_expr(expr)?;
443                Ok(Expr::IsNotNull(Box::new(inner)))
444            }
445            SqlExpr::Function(func) => {
446                let name = func.name.to_string().to_uppercase();
447                let args: Vec<Expr> = match &func.args {
448                    sqlparser::ast::FunctionArguments::List(arg_list) => {
449                        let mut planned_args = Vec::new();
450                        for arg in &arg_list.args {
451                            if let sqlparser::ast::FunctionArg::Unnamed(
452                                sqlparser::ast::FunctionArgExpr::Expr(e),
453                            ) = arg
454                            {
455                                planned_args.push(self.plan_expr(e)?)
456                            }
457                        }
458                        planned_args
459                    }
460                    _ => vec![],
461                };
462
463                // 스칼라 함수 매핑 시도
464                if let Some(scalar_func) = match_scalar_function(&name) {
465                    Ok(Expr::ScalarFunc {
466                        func: scalar_func,
467                        args,
468                    })
469                } else {
470                    // 집계 함수로 처리 (실제 집계 여부는 추후 Optimizer/Planner에서 검증)
471                    Ok(Expr::Function { name, args })
472                }
473            }
474            SqlExpr::Nested(expr) => self.plan_expr(expr),
475            SqlExpr::CompoundIdentifier(idents) => {
476                // table.column → just use the column name
477                let col_name = idents.last().map(|i| i.value.clone()).unwrap_or_default();
478                Ok(Expr::Column(col_name))
479            }
480            _ => Err(DbxError::NotImplemented(format!(
481                "Unsupported expression: {:?}",
482                expr
483            ))),
484        }
485    }
486}