kotoba_execution/planner/
optimizer.rs1use kotoba_core::{types::*, ir::*};
4
5#[derive(Debug)]
7pub enum OptimizationRule {
8 PushDownPredicates,
10
11 JoinOrderOptimization,
13
14 EliminateUnnecessaryProjections,
16
17 ConstantFolding,
19
20 IndexSelection,
22}
23
24#[derive(Debug)]
26pub struct QueryOptimizer {
27 rules: Vec<OptimizationRule>,
28}
29
30impl Default for QueryOptimizer {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl QueryOptimizer {
37 pub fn new() -> Self {
38 Self {
39 rules: vec![
40 OptimizationRule::PushDownPredicates,
41 OptimizationRule::JoinOrderOptimization,
42 OptimizationRule::EliminateUnnecessaryProjections,
43 OptimizationRule::ConstantFolding,
44 OptimizationRule::IndexSelection,
45 ],
46 }
47 }
48
49 pub fn optimize(&self, plan: &PlanIR, catalog: &Catalog) -> PlanIR {
51 let mut optimized = plan.clone();
52
53 for rule in &self.rules {
54 optimized = self.apply_rule(optimized, rule, catalog);
55 }
56
57 optimized
58 }
59
60 fn apply_rule(&self, plan: PlanIR, rule: &OptimizationRule, catalog: &Catalog) -> PlanIR {
62 match rule {
63 OptimizationRule::PushDownPredicates => {
64 self.push_down_predicates(plan)
65 }
66 OptimizationRule::JoinOrderOptimization => {
67 self.optimize_join_order(plan, catalog)
68 }
69 OptimizationRule::EliminateUnnecessaryProjections => {
70 self.eliminate_unnecessary_projections(plan)
71 }
72 OptimizationRule::ConstantFolding => {
73 self.constant_folding(plan)
74 }
75 OptimizationRule::IndexSelection => {
76 self.select_indexes(plan, catalog)
77 }
78 }
79 }
80
81 fn push_down_predicates(&self, plan: PlanIR) -> PlanIR {
83 let optimized_plan = self.push_down_predicates_op(&plan.plan);
84 PlanIR {
85 plan: optimized_plan,
86 limit: plan.limit,
87 }
88 }
89
90 fn push_down_predicates_op(&self, op: &LogicalOp) -> LogicalOp {
91 match op {
92 LogicalOp::Filter { pred, input } => {
93 match input.as_ref() {
95 LogicalOp::Join { left, right, on } => {
96 let (left_pred, right_pred, remaining_pred) =
98 self.split_predicate_for_join(pred, on);
99
100 let new_left = if let Some(lp) = left_pred {
101 Box::new(LogicalOp::Filter {
102 pred: lp,
103 input: left.clone(),
104 })
105 } else {
106 left.clone()
107 };
108
109 let new_right = if let Some(rp) = right_pred {
110 Box::new(LogicalOp::Filter {
111 pred: rp,
112 input: right.clone(),
113 })
114 } else {
115 right.clone()
116 };
117
118 if let Some(rem_pred) = remaining_pred {
119 LogicalOp::Filter {
120 pred: rem_pred,
121 input: Box::new(LogicalOp::Join {
122 left: new_left,
123 right: new_right,
124 on: on.clone(),
125 }),
126 }
127 } else {
128 LogicalOp::Join {
129 left: new_left,
130 right: new_right,
131 on: on.clone(),
132 }
133 }
134 }
135 _ => {
136 LogicalOp::Filter {
138 pred: pred.clone(),
139 input: Box::new(self.push_down_predicates_op(input)),
140 }
141 }
142 }
143 }
144 LogicalOp::Join { left, right, on } => {
145 LogicalOp::Join {
146 left: Box::new(self.push_down_predicates_op(left)),
147 right: Box::new(self.push_down_predicates_op(right)),
148 on: on.clone(),
149 }
150 }
151 LogicalOp::Project { cols, input } => {
152 LogicalOp::Project {
153 cols: cols.clone(),
154 input: Box::new(self.push_down_predicates_op(input)),
155 }
156 }
157 _ => op.clone(),
159 }
160 }
161
162 fn split_predicate_for_join(&self, pred: &Predicate, join_keys: &[String])
164 -> (Option<Predicate>, Option<Predicate>, Option<Predicate>) {
165 match pred {
166 Predicate::And { and } => {
167 let mut left_preds = Vec::new();
168 let mut right_preds = Vec::new();
169 let mut remaining = Vec::new();
170
171 for p in and {
172 let (l, r, rem) = self.split_predicate_for_join(p, join_keys);
173 if let Some(lp) = l { left_preds.push(lp); }
174 if let Some(rp) = r { right_preds.push(rp); }
175 if let Some(rem_p) = rem { remaining.push(rem_p); }
176 }
177
178 let left = if left_preds.is_empty() {
179 None
180 } else if left_preds.len() == 1 {
181 Some(left_preds.into_iter().next().unwrap())
182 } else {
183 Some(Predicate::And { and: left_preds })
184 };
185
186 let right = if right_preds.is_empty() {
187 None
188 } else if right_preds.len() == 1 {
189 Some(right_preds.into_iter().next().unwrap())
190 } else {
191 Some(Predicate::And { and: right_preds })
192 };
193
194 let rem = if remaining.is_empty() {
195 None
196 } else if remaining.len() == 1 {
197 Some(remaining.into_iter().next().unwrap())
198 } else {
199 Some(Predicate::And { and: remaining })
200 };
201
202 (left, right, rem)
203 }
204 Predicate::Eq { eq } if eq.len() == 2 => {
205 let left_vars = self.extract_variables(&eq[0]);
207 let right_vars = self.extract_variables(&eq[1]);
208
209 if self.contains_join_key(&left_vars, join_keys) &&
210 self.contains_join_key(&right_vars, join_keys) {
211 (None, None, Some(pred.clone()))
213 } else if self.contains_join_key(&left_vars, join_keys) {
214 (Some(pred.clone()), None, None)
215 } else if self.contains_join_key(&right_vars, join_keys) {
216 (None, Some(pred.clone()), None)
217 } else {
218 (None, None, Some(pred.clone()))
219 }
220 }
221 _ => (None, None, Some(pred.clone())),
222 }
223 }
224
225 fn extract_variables(&self, expr: &Expr) -> Vec<String> {
227 match expr {
228 Expr::Var(v) => vec![v.clone()],
229 Expr::Fn { args, .. } => args.iter()
230 .flat_map(|arg| self.extract_variables(arg))
231 .collect(),
232 _ => Vec::new(),
233 }
234 }
235
236 fn contains_join_key(&self, vars: &[String], join_keys: &[String]) -> bool {
238 vars.iter().any(|v| join_keys.contains(v))
239 }
240
241 fn optimize_join_order(&self, plan: PlanIR, catalog: &Catalog) -> PlanIR {
243 let optimized_plan = self.optimize_join_order_op(&plan.plan, catalog);
244 PlanIR {
245 plan: optimized_plan,
246 limit: plan.limit,
247 }
248 }
249
250 fn optimize_join_order_op(&self, op: &LogicalOp, catalog: &Catalog) -> LogicalOp {
251 match op {
252 LogicalOp::Join { left, right, on } => {
253 let left_cost = self.estimate_cost(left, catalog);
255 let right_cost = self.estimate_cost(right, catalog);
256
257 if left_cost > right_cost {
258 LogicalOp::Join {
260 left: Box::new(self.optimize_join_order_op(right, catalog)),
261 right: Box::new(self.optimize_join_order_op(left, catalog)),
262 on: on.clone(),
263 }
264 } else {
265 LogicalOp::Join {
266 left: Box::new(self.optimize_join_order_op(left, catalog)),
267 right: Box::new(self.optimize_join_order_op(right, catalog)),
268 on: on.clone(),
269 }
270 }
271 }
272 _ => op.clone(),
273 }
274 }
275
276 fn estimate_cost(&self, op: &LogicalOp, catalog: &Catalog) -> f64 {
278 match op {
279 LogicalOp::NodeScan { label, .. } => {
280 catalog.get_label(label)
281 .map(|_| 100.0)
282 .unwrap_or(1000.0)
283 }
284 LogicalOp::Join { left, right, .. } => {
285 let left_cost = self.estimate_cost(left, catalog);
286 let right_cost = self.estimate_cost(right, catalog);
287 left_cost * right_cost
288 }
289 _ => 10.0,
290 }
291 }
292
293 fn eliminate_unnecessary_projections(&self, plan: PlanIR) -> PlanIR {
295 let optimized_plan = self.eliminate_unnecessary_projections_op(&plan.plan);
296 PlanIR {
297 plan: optimized_plan,
298 limit: plan.limit,
299 }
300 }
301
302 fn eliminate_unnecessary_projections_op(&self, op: &LogicalOp) -> LogicalOp {
303 match op {
304 LogicalOp::Project { cols, input } => {
305 match input.as_ref() {
306 LogicalOp::Project { cols: inner_cols, input: inner_input } => {
307 let merged_cols = cols.iter()
309 .filter(|col| inner_cols.contains(col))
310 .cloned()
311 .collect();
312
313 LogicalOp::Project {
314 cols: merged_cols,
315 input: inner_input.clone(),
316 }
317 }
318 _ => LogicalOp::Project {
319 cols: cols.clone(),
320 input: Box::new(self.eliminate_unnecessary_projections_op(input)),
321 }
322 }
323 }
324 _ => op.clone(),
325 }
326 }
327
328 fn constant_folding(&self, plan: PlanIR) -> PlanIR {
330 plan
332 }
333
334 fn select_indexes(&self, plan: PlanIR, catalog: &Catalog) -> PlanIR {
336 let optimized_plan = self.select_indexes_op(&plan.plan, catalog);
337 PlanIR {
338 plan: optimized_plan,
339 limit: plan.limit,
340 }
341 }
342
343 fn select_indexes_op(&self, op: &LogicalOp, catalog: &Catalog) -> LogicalOp {
344 match op {
345 LogicalOp::Filter { pred, input } => {
346 match input.as_ref() {
347 LogicalOp::NodeScan { label, as_, props: _ } => {
348 if let Some(index) = self.find_best_index(catalog, label, pred) {
350 LogicalOp::Filter {
351 pred: pred.clone(),
352 input: Box::new(LogicalOp::IndexScan {
353 label: label.clone(),
354 as_: as_.clone(),
355 index: index.name,
356 value: self.extract_index_value(pred, &index.properties[0]),
357 }),
358 }
359 } else {
360 LogicalOp::Filter {
361 pred: pred.clone(),
362 input: Box::new(self.select_indexes_op(input, catalog)),
363 }
364 }
365 }
366 _ => LogicalOp::Filter {
367 pred: pred.clone(),
368 input: Box::new(self.select_indexes_op(input, catalog)),
369 }
370 }
371 }
372 _ => op.clone(),
373 }
374 }
375
376 fn find_best_index(&self, catalog: &Catalog, label: &Label, pred: &Predicate) -> Option<IndexDef> {
378 catalog.indexes.iter()
379 .filter(|idx| &idx.label == label)
380 .find(|idx| self.can_use_index(pred, &idx.properties[0]))
381 .cloned()
382 }
383
384 fn can_use_index(&self, pred: &Predicate, prop: &PropertyKey) -> bool {
386 match pred {
387 Predicate::Eq { eq } if eq.len() == 2 => {
388 let left_vars = self.extract_variables(&eq[0]);
389 let right_vars = self.extract_variables(&eq[1]);
390
391 left_vars.contains(prop) || right_vars.contains(prop)
393 }
394 _ => false,
395 }
396 }
397
398 fn extract_index_value(&self, pred: &Predicate, prop: &PropertyKey) -> Value {
400 match pred {
401 Predicate::Eq { eq } if eq.len() == 2 => {
402 if let Expr::Var(var) = &eq[0] {
403 if var == prop {
404 if let Expr::Const(val) = &eq[1] {
405 return val.clone();
406 }
407 }
408 }
409 if let Expr::Var(var) = &eq[1] {
410 if var == prop {
411 if let Expr::Const(val) = &eq[0] {
412 return val.clone();
413 }
414 }
415 }
416 }
417 _ => {}
418 }
419 Value::Null
420 }
421}