1use crate::error::DbxResult;
6use crate::sql::planner::{BinaryOperator, Expr, LogicalPlan};
7use crate::storage::columnar::ScalarValue;
8
9use super::OptimizationRule;
10
11pub struct ConstantFoldingRule;
13
14impl OptimizationRule for ConstantFoldingRule {
15 fn name(&self) -> &str {
16 "ConstantFolding"
17 }
18
19 fn apply(&self, plan: LogicalPlan) -> DbxResult<LogicalPlan> {
20 self.fold(plan)
21 }
22}
23
24impl ConstantFoldingRule {
25 fn fold(&self, plan: LogicalPlan) -> DbxResult<LogicalPlan> {
26 match plan {
27 LogicalPlan::Filter { input, predicate } => {
28 let folded_pred = self.fold_expr(predicate);
29 if let Expr::Literal(ScalarValue::Boolean(true)) = &folded_pred {
31 return self.fold(*input);
32 }
33 Ok(LogicalPlan::Filter {
34 input: Box::new(self.fold(*input)?),
35 predicate: folded_pred,
36 })
37 }
38 LogicalPlan::Project {
39 input,
40 projections: columns,
41 } => {
42 let folded_cols = columns
43 .into_iter()
44 .map(|(c, a)| (self.fold_expr(c), a))
45 .collect();
46 Ok(LogicalPlan::Project {
47 input: Box::new(self.fold(*input)?),
48 projections: folded_cols,
49 })
50 }
51 LogicalPlan::Sort { input, order_by } => Ok(LogicalPlan::Sort {
52 input: Box::new(self.fold(*input)?),
53 order_by,
54 }),
55 LogicalPlan::Limit {
56 input,
57 count,
58 offset,
59 } => Ok(LogicalPlan::Limit {
60 input: Box::new(self.fold(*input)?),
61 count,
62 offset,
63 }),
64 LogicalPlan::Aggregate {
65 input,
66 group_by,
67 aggregates,
68 mode,
69 } => Ok(LogicalPlan::Aggregate {
70 input: Box::new(self.fold(*input)?),
71 group_by,
72 aggregates,
73 mode,
74 }),
75 LogicalPlan::Scan {
76 table,
77 columns,
78 filter,
79 ros_files,
80 } => Ok(LogicalPlan::Scan {
81 table,
82 columns,
83 filter: filter.map(|f| self.fold_expr(f)),
84 ros_files,
85 }),
86 other => Ok(other),
87 }
88 }
89
90 fn fold_expr(&self, expr: Expr) -> Expr {
92 match expr {
93 Expr::BinaryOp { left, op, right } => {
94 let left = self.fold_expr(*left);
95 let right = self.fold_expr(*right);
96
97 if let (Expr::Literal(lv), Expr::Literal(rv)) = (&left, &right)
99 && let Some(result) = self.eval_const(lv, &op, rv)
100 {
101 return Expr::Literal(result);
102 }
103
104 Expr::BinaryOp {
105 left: Box::new(left),
106 op,
107 right: Box::new(right),
108 }
109 }
110 other => other,
111 }
112 }
113
114 fn eval_const(
116 &self,
117 left: &ScalarValue,
118 op: &BinaryOperator,
119 right: &ScalarValue,
120 ) -> Option<ScalarValue> {
121 match (left, op, right) {
122 (ScalarValue::Int32(a), BinaryOperator::Plus, ScalarValue::Int32(b)) => {
124 Some(ScalarValue::Int32(a + b))
125 }
126 (ScalarValue::Int32(a), BinaryOperator::Minus, ScalarValue::Int32(b)) => {
127 Some(ScalarValue::Int32(a - b))
128 }
129 (ScalarValue::Int32(a), BinaryOperator::Multiply, ScalarValue::Int32(b)) => {
130 Some(ScalarValue::Int32(a * b))
131 }
132 (ScalarValue::Int32(a), BinaryOperator::Divide, ScalarValue::Int32(b)) if *b != 0 => {
133 Some(ScalarValue::Int32(a / b))
134 }
135 (ScalarValue::Int32(a), BinaryOperator::Eq, ScalarValue::Int32(b)) => {
137 Some(ScalarValue::Boolean(a == b))
138 }
139 (ScalarValue::Int32(a), BinaryOperator::NotEq, ScalarValue::Int32(b)) => {
140 Some(ScalarValue::Boolean(a != b))
141 }
142 (ScalarValue::Int32(a), BinaryOperator::Lt, ScalarValue::Int32(b)) => {
143 Some(ScalarValue::Boolean(a < b))
144 }
145 (ScalarValue::Int32(a), BinaryOperator::Gt, ScalarValue::Int32(b)) => {
146 Some(ScalarValue::Boolean(a > b))
147 }
148 (ScalarValue::Boolean(a), BinaryOperator::And, ScalarValue::Boolean(b)) => {
150 Some(ScalarValue::Boolean(*a && *b))
151 }
152 (ScalarValue::Boolean(a), BinaryOperator::Or, ScalarValue::Boolean(b)) => {
153 Some(ScalarValue::Boolean(*a || *b))
154 }
155 (ScalarValue::Float64(a), BinaryOperator::Plus, ScalarValue::Float64(b)) => {
157 Some(ScalarValue::Float64(a + b))
158 }
159 (ScalarValue::Float64(a), BinaryOperator::Minus, ScalarValue::Float64(b)) => {
160 Some(ScalarValue::Float64(a - b))
161 }
162 (ScalarValue::Float64(a), BinaryOperator::Multiply, ScalarValue::Float64(b)) => {
163 Some(ScalarValue::Float64(a * b))
164 }
165 _ => None,
166 }
167 }
168}