Skip to main content

dbx_core/sql/optimizer/
constant_folding.rs

1//! Rule 3: Constant Folding
2//!
3//! 상수 표현식을 컴파일 타임에 평가 (1 + 2 → 3)
4
5use crate::error::DbxResult;
6use crate::sql::planner::{BinaryOperator, Expr, LogicalPlan};
7use crate::storage::columnar::ScalarValue;
8
9use super::OptimizationRule;
10
11/// 상수 표현식을 컴파일 타임에 평가 (1 + 2 → 3)
12pub struct ConstantFoldingRule;
13
14impl OptimizationRule for ConstantFoldingRule {
15    fn name(&self) -> &str {
16        "ConstantFolding"
17    }
18
19    fn apply(&self, plan: LogicalPlan) -> DbxResult<LogicalPlan> {
20        self.fold(plan)
21    }
22}
23
24impl ConstantFoldingRule {
25    fn fold(&self, plan: LogicalPlan) -> DbxResult<LogicalPlan> {
26        match plan {
27            LogicalPlan::Filter { input, predicate } => {
28                let folded_pred = self.fold_expr(predicate);
29                // If predicate folded to TRUE, eliminate filter entirely
30                if let Expr::Literal(ScalarValue::Boolean(true)) = &folded_pred {
31                    return self.fold(*input);
32                }
33                Ok(LogicalPlan::Filter {
34                    input: Box::new(self.fold(*input)?),
35                    predicate: folded_pred,
36                })
37            }
38            LogicalPlan::Project {
39                input,
40                projections: columns,
41            } => {
42                let folded_cols = columns
43                    .into_iter()
44                    .map(|(c, a)| (self.fold_expr(c), a))
45                    .collect();
46                Ok(LogicalPlan::Project {
47                    input: Box::new(self.fold(*input)?),
48                    projections: folded_cols,
49                })
50            }
51            LogicalPlan::Sort { input, order_by } => Ok(LogicalPlan::Sort {
52                input: Box::new(self.fold(*input)?),
53                order_by,
54            }),
55            LogicalPlan::Limit {
56                input,
57                count,
58                offset,
59            } => Ok(LogicalPlan::Limit {
60                input: Box::new(self.fold(*input)?),
61                count,
62                offset,
63            }),
64            LogicalPlan::Aggregate {
65                input,
66                group_by,
67                aggregates,
68                mode,
69            } => Ok(LogicalPlan::Aggregate {
70                input: Box::new(self.fold(*input)?),
71                group_by,
72                aggregates,
73                mode,
74            }),
75            LogicalPlan::Scan {
76                table,
77                columns,
78                filter,
79                ros_files,
80            } => Ok(LogicalPlan::Scan {
81                table,
82                columns,
83                filter: filter.map(|f| self.fold_expr(f)),
84                ros_files,
85            }),
86            other => Ok(other),
87        }
88    }
89
90    /// Fold constant expressions: Literal op Literal → Literal
91    fn fold_expr(&self, expr: Expr) -> Expr {
92        match expr {
93            Expr::BinaryOp { left, op, right } => {
94                let left = self.fold_expr(*left);
95                let right = self.fold_expr(*right);
96
97                // Both sides are literals → evaluate at plan time
98                if let (Expr::Literal(lv), Expr::Literal(rv)) = (&left, &right)
99                    && let Some(result) = self.eval_const(lv, &op, rv)
100                {
101                    return Expr::Literal(result);
102                }
103
104                Expr::BinaryOp {
105                    left: Box::new(left),
106                    op,
107                    right: Box::new(right),
108                }
109            }
110            other => other,
111        }
112    }
113
114    /// Evaluate constant binary operations.
115    fn eval_const(
116        &self,
117        left: &ScalarValue,
118        op: &BinaryOperator,
119        right: &ScalarValue,
120    ) -> Option<ScalarValue> {
121        match (left, op, right) {
122            // Integer arithmetic
123            (ScalarValue::Int32(a), BinaryOperator::Plus, ScalarValue::Int32(b)) => {
124                Some(ScalarValue::Int32(a + b))
125            }
126            (ScalarValue::Int32(a), BinaryOperator::Minus, ScalarValue::Int32(b)) => {
127                Some(ScalarValue::Int32(a - b))
128            }
129            (ScalarValue::Int32(a), BinaryOperator::Multiply, ScalarValue::Int32(b)) => {
130                Some(ScalarValue::Int32(a * b))
131            }
132            (ScalarValue::Int32(a), BinaryOperator::Divide, ScalarValue::Int32(b)) if *b != 0 => {
133                Some(ScalarValue::Int32(a / b))
134            }
135            // Integer comparison
136            (ScalarValue::Int32(a), BinaryOperator::Eq, ScalarValue::Int32(b)) => {
137                Some(ScalarValue::Boolean(a == b))
138            }
139            (ScalarValue::Int32(a), BinaryOperator::NotEq, ScalarValue::Int32(b)) => {
140                Some(ScalarValue::Boolean(a != b))
141            }
142            (ScalarValue::Int32(a), BinaryOperator::Lt, ScalarValue::Int32(b)) => {
143                Some(ScalarValue::Boolean(a < b))
144            }
145            (ScalarValue::Int32(a), BinaryOperator::Gt, ScalarValue::Int32(b)) => {
146                Some(ScalarValue::Boolean(a > b))
147            }
148            // Boolean logic
149            (ScalarValue::Boolean(a), BinaryOperator::And, ScalarValue::Boolean(b)) => {
150                Some(ScalarValue::Boolean(*a && *b))
151            }
152            (ScalarValue::Boolean(a), BinaryOperator::Or, ScalarValue::Boolean(b)) => {
153                Some(ScalarValue::Boolean(*a || *b))
154            }
155            // Float arithmetic
156            (ScalarValue::Float64(a), BinaryOperator::Plus, ScalarValue::Float64(b)) => {
157                Some(ScalarValue::Float64(a + b))
158            }
159            (ScalarValue::Float64(a), BinaryOperator::Minus, ScalarValue::Float64(b)) => {
160                Some(ScalarValue::Float64(a - b))
161            }
162            (ScalarValue::Float64(a), BinaryOperator::Multiply, ScalarValue::Float64(b)) => {
163                Some(ScalarValue::Float64(a * b))
164            }
165            _ => None,
166        }
167    }
168}