Skip to main content

dbx_core/sql/optimizer/
constant_folding.rs

1//! Rule 3: Constant Folding
2//!
3//! 상수 표현식을 컴파일 타임에 평가 (1 + 2 → 3)
4
5use crate::error::DbxResult;
6use crate::sql::planner::{BinaryOperator, Expr, LogicalPlan};
7use crate::storage::columnar::ScalarValue;
8
9use super::OptimizationRule;
10
11/// 상수 표현식을 컴파일 타임에 평가 (1 + 2 → 3)
12pub struct ConstantFoldingRule;
13
14impl OptimizationRule for ConstantFoldingRule {
15    fn name(&self) -> &str {
16        "ConstantFolding"
17    }
18
19    fn apply(&self, plan: LogicalPlan) -> DbxResult<LogicalPlan> {
20        self.fold(plan)
21    }
22}
23
24impl ConstantFoldingRule {
25    fn fold(&self, plan: LogicalPlan) -> DbxResult<LogicalPlan> {
26        match plan {
27            LogicalPlan::Filter { input, predicate } => {
28                let folded_pred = self.fold_expr(predicate);
29                // If predicate folded to TRUE, eliminate filter entirely
30                if let Expr::Literal(ScalarValue::Boolean(true)) = &folded_pred {
31                    return self.fold(*input);
32                }
33                Ok(LogicalPlan::Filter {
34                    input: Box::new(self.fold(*input)?),
35                    predicate: folded_pred,
36                })
37            }
38            LogicalPlan::Project {
39                input,
40                projections: columns,
41            } => {
42                let folded_cols = columns
43                    .into_iter()
44                    .map(|(c, a)| (self.fold_expr(c), a))
45                    .collect();
46                Ok(LogicalPlan::Project {
47                    input: Box::new(self.fold(*input)?),
48                    projections: folded_cols,
49                })
50            }
51            LogicalPlan::Sort { input, order_by } => Ok(LogicalPlan::Sort {
52                input: Box::new(self.fold(*input)?),
53                order_by,
54            }),
55            LogicalPlan::Limit {
56                input,
57                count,
58                offset,
59            } => Ok(LogicalPlan::Limit {
60                input: Box::new(self.fold(*input)?),
61                count,
62                offset,
63            }),
64            LogicalPlan::Aggregate {
65                input,
66                group_by,
67                aggregates,
68            } => Ok(LogicalPlan::Aggregate {
69                input: Box::new(self.fold(*input)?),
70                group_by,
71                aggregates,
72            }),
73            LogicalPlan::Scan {
74                table,
75                columns,
76                filter,
77            } => Ok(LogicalPlan::Scan {
78                table,
79                columns,
80                filter: filter.map(|f| self.fold_expr(f)),
81            }),
82            other => Ok(other),
83        }
84    }
85
86    /// Fold constant expressions: Literal op Literal → Literal
87    fn fold_expr(&self, expr: Expr) -> Expr {
88        match expr {
89            Expr::BinaryOp { left, op, right } => {
90                let left = self.fold_expr(*left);
91                let right = self.fold_expr(*right);
92
93                // Both sides are literals → evaluate at plan time
94                if let (Expr::Literal(lv), Expr::Literal(rv)) = (&left, &right)
95                    && let Some(result) = self.eval_const(lv, &op, rv)
96                {
97                    return Expr::Literal(result);
98                }
99
100                Expr::BinaryOp {
101                    left: Box::new(left),
102                    op,
103                    right: Box::new(right),
104                }
105            }
106            other => other,
107        }
108    }
109
110    /// Evaluate constant binary operations.
111    fn eval_const(
112        &self,
113        left: &ScalarValue,
114        op: &BinaryOperator,
115        right: &ScalarValue,
116    ) -> Option<ScalarValue> {
117        match (left, op, right) {
118            // Integer arithmetic
119            (ScalarValue::Int32(a), BinaryOperator::Plus, ScalarValue::Int32(b)) => {
120                Some(ScalarValue::Int32(a + b))
121            }
122            (ScalarValue::Int32(a), BinaryOperator::Minus, ScalarValue::Int32(b)) => {
123                Some(ScalarValue::Int32(a - b))
124            }
125            (ScalarValue::Int32(a), BinaryOperator::Multiply, ScalarValue::Int32(b)) => {
126                Some(ScalarValue::Int32(a * b))
127            }
128            (ScalarValue::Int32(a), BinaryOperator::Divide, ScalarValue::Int32(b)) if *b != 0 => {
129                Some(ScalarValue::Int32(a / b))
130            }
131            // Integer comparison
132            (ScalarValue::Int32(a), BinaryOperator::Eq, ScalarValue::Int32(b)) => {
133                Some(ScalarValue::Boolean(a == b))
134            }
135            (ScalarValue::Int32(a), BinaryOperator::NotEq, ScalarValue::Int32(b)) => {
136                Some(ScalarValue::Boolean(a != b))
137            }
138            (ScalarValue::Int32(a), BinaryOperator::Lt, ScalarValue::Int32(b)) => {
139                Some(ScalarValue::Boolean(a < b))
140            }
141            (ScalarValue::Int32(a), BinaryOperator::Gt, ScalarValue::Int32(b)) => {
142                Some(ScalarValue::Boolean(a > b))
143            }
144            // Boolean logic
145            (ScalarValue::Boolean(a), BinaryOperator::And, ScalarValue::Boolean(b)) => {
146                Some(ScalarValue::Boolean(*a && *b))
147            }
148            (ScalarValue::Boolean(a), BinaryOperator::Or, ScalarValue::Boolean(b)) => {
149                Some(ScalarValue::Boolean(*a || *b))
150            }
151            // Float arithmetic
152            (ScalarValue::Float64(a), BinaryOperator::Plus, ScalarValue::Float64(b)) => {
153                Some(ScalarValue::Float64(a + b))
154            }
155            (ScalarValue::Float64(a), BinaryOperator::Minus, ScalarValue::Float64(b)) => {
156                Some(ScalarValue::Float64(a - b))
157            }
158            (ScalarValue::Float64(a), BinaryOperator::Multiply, ScalarValue::Float64(b)) => {
159                Some(ScalarValue::Float64(a * b))
160            }
161            _ => None,
162        }
163    }
164}