1use kotoba_core::ir::*;
4use kotoba_core::types::*;
5use std::collections::HashMap;
6
7#[derive(Debug)]
9pub enum OptimizationRule {
10 PushDownPredicates,
12
13 JoinOrderOptimization,
15
16 EliminateUnnecessaryProjections,
18
19 ConstantFolding,
21
22 IndexSelection,
24}
25
26#[derive(Debug, Clone)]
28pub struct CostEstimate {
29 pub cardinality: f64,
31 pub cost: f64,
33 pub selectivity: f64,
35}
36
37#[derive(Debug)]
39struct JoinPlan {
40 relations: Vec<LogicalOp>,
41 cost: f64,
42 cardinality: f64,
43}
44
45#[derive(Debug)]
47pub struct QueryOptimizer {
48 rules: Vec<OptimizationRule>,
49 statistics: CostStatistics,
51}
52
53#[derive(Debug)]
55pub struct CostStatistics {
56 default_selectivity: f64,
58 index_selectivity: f64,
60 join_cost_factor: f64,
62}
63
64impl QueryOptimizer {
65 pub fn new() -> Self {
66 Self {
67 rules: vec![
68 OptimizationRule::PushDownPredicates,
69 OptimizationRule::JoinOrderOptimization,
70 OptimizationRule::EliminateUnnecessaryProjections,
71 OptimizationRule::ConstantFolding,
72 OptimizationRule::IndexSelection,
73 ],
74 statistics: CostStatistics {
75 default_selectivity: 0.1, index_selectivity: 0.01, join_cost_factor: 0.5, },
79 }
80 }
81
82 pub fn optimize(&self, plan: &PlanIR, catalog: &Catalog) -> PlanIR {
84 let mut optimized = plan.clone();
85
86 for rule in &self.rules {
87 optimized = self.apply_rule(optimized, rule, catalog);
88 }
89
90 optimized
91 }
92
93 fn apply_rule(&self, plan: PlanIR, rule: &OptimizationRule, catalog: &Catalog) -> PlanIR {
95 match rule {
96 OptimizationRule::PushDownPredicates => {
97 self.push_down_predicates(plan)
98 }
99 OptimizationRule::JoinOrderOptimization => {
100 self.optimize_join_order(plan, catalog)
101 }
102 OptimizationRule::EliminateUnnecessaryProjections => {
103 self.eliminate_unnecessary_projections(plan)
104 }
105 OptimizationRule::ConstantFolding => {
106 self.constant_folding(plan)
107 }
108 OptimizationRule::IndexSelection => {
109 self.select_indexes(plan, catalog)
110 }
111 }
112 }
113
114 fn push_down_predicates(&self, plan: PlanIR) -> PlanIR {
116 let optimized_plan = self.push_down_predicates_op(&plan.plan);
117 PlanIR {
118 plan: optimized_plan,
119 limit: plan.limit,
120 }
121 }
122
123 fn push_down_predicates_op(&self, op: &LogicalOp) -> LogicalOp {
124 match op {
125 LogicalOp::Filter { pred, input } => {
126 match input.as_ref() {
128 LogicalOp::Join { left, right, on } => {
129 let (left_pred, right_pred, remaining_pred) =
131 self.split_predicate_for_join(pred, on);
132
133 let new_left = if let Some(lp) = left_pred {
134 Box::new(LogicalOp::Filter {
135 pred: lp,
136 input: left.clone(),
137 })
138 } else {
139 left.clone()
140 };
141
142 let new_right = if let Some(rp) = right_pred {
143 Box::new(LogicalOp::Filter {
144 pred: rp,
145 input: right.clone(),
146 })
147 } else {
148 right.clone()
149 };
150
151 if let Some(rem_pred) = remaining_pred {
152 LogicalOp::Filter {
153 pred: rem_pred,
154 input: Box::new(LogicalOp::Join {
155 left: new_left,
156 right: new_right,
157 on: on.clone(),
158 }),
159 }
160 } else {
161 LogicalOp::Join {
162 left: new_left,
163 right: new_right,
164 on: on.clone(),
165 }
166 }
167 }
168 _ => {
169 LogicalOp::Filter {
171 pred: pred.clone(),
172 input: Box::new(self.push_down_predicates_op(input)),
173 }
174 }
175 }
176 }
177 LogicalOp::Join { left, right, on } => {
178 LogicalOp::Join {
179 left: Box::new(self.push_down_predicates_op(left)),
180 right: Box::new(self.push_down_predicates_op(right)),
181 on: on.clone(),
182 }
183 }
184 LogicalOp::Project { cols, input } => {
185 LogicalOp::Project {
186 cols: cols.clone(),
187 input: Box::new(self.push_down_predicates_op(input)),
188 }
189 }
190 _ => op.clone(),
192 }
193 }
194
195 fn split_predicate_for_join(&self, pred: &Predicate, join_keys: &[String])
197 -> (Option<Predicate>, Option<Predicate>, Option<Predicate>) {
198 match pred {
199 Predicate::And { and } => {
200 let mut left_preds = Vec::new();
201 let mut right_preds = Vec::new();
202 let mut remaining = Vec::new();
203
204 for p in and {
205 let (l, r, rem) = self.split_predicate_for_join(p, join_keys);
206 if let Some(lp) = l { left_preds.push(lp); }
207 if let Some(rp) = r { right_preds.push(rp); }
208 if let Some(rem_p) = rem { remaining.push(rem_p); }
209 }
210
211 let left = if left_preds.is_empty() {
212 None
213 } else if left_preds.len() == 1 {
214 Some(left_preds.into_iter().next().unwrap())
215 } else {
216 Some(Predicate::And { and: left_preds })
217 };
218
219 let right = if right_preds.is_empty() {
220 None
221 } else if right_preds.len() == 1 {
222 Some(right_preds.into_iter().next().unwrap())
223 } else {
224 Some(Predicate::And { and: right_preds })
225 };
226
227 let rem = if remaining.is_empty() {
228 None
229 } else if remaining.len() == 1 {
230 Some(remaining.into_iter().next().unwrap())
231 } else {
232 Some(Predicate::And { and: remaining })
233 };
234
235 (left, right, rem)
236 }
237 Predicate::Eq { eq } if eq.len() == 2 => {
238 let left_vars = self.extract_variables(&eq[0]);
240 let right_vars = self.extract_variables(&eq[1]);
241
242 if self.contains_join_key(&left_vars, join_keys) &&
243 self.contains_join_key(&right_vars, join_keys) {
244 (None, None, Some(pred.clone()))
246 } else if self.contains_join_key(&left_vars, join_keys) {
247 (Some(pred.clone()), None, None)
248 } else if self.contains_join_key(&right_vars, join_keys) {
249 (None, Some(pred.clone()), None)
250 } else {
251 (None, None, Some(pred.clone()))
252 }
253 }
254 _ => (None, None, Some(pred.clone())),
255 }
256 }
257
258 fn extract_variables(&self, expr: &Expr) -> Vec<String> {
260 match expr {
261 Expr::Var(v) => vec![v.clone()],
262 Expr::Fn { args, .. } => args.iter()
263 .flat_map(|arg| self.extract_variables(arg))
264 .collect(),
265 _ => Vec::new(),
266 }
267 }
268
269 fn contains_join_key(&self, vars: &[String], join_keys: &[String]) -> bool {
271 vars.iter().any(|v| join_keys.contains(v))
272 }
273
274 fn optimize_join_order(&self, plan: PlanIR, catalog: &Catalog) -> PlanIR {
276 let optimized_plan = self.optimize_join_order_op(&plan.plan, catalog);
277 PlanIR {
278 plan: optimized_plan,
279 limit: plan.limit,
280 }
281 }
282
283 fn optimize_join_order_op(&self, op: &LogicalOp, catalog: &Catalog) -> LogicalOp {
284 match op {
285 LogicalOp::Join { left, right, on } => {
286 let mut relations = Vec::new();
288 self.collect_relations(left, &mut relations);
289 self.collect_relations(right, &mut relations);
290
291 if relations.len() > 2 {
292 self.optimize_join_order_dp(&relations, on, catalog)
294 } else {
295 let left_cost = self.estimate_cost_detailed(left, catalog);
297 let right_cost = self.estimate_cost_detailed(right, catalog);
298
299 if left_cost.cost > right_cost.cost {
300 LogicalOp::Join {
301 left: Box::new(self.optimize_join_order_op(right, catalog)),
302 right: Box::new(self.optimize_join_order_op(left, catalog)),
303 on: on.clone(),
304 }
305 } else {
306 LogicalOp::Join {
307 left: Box::new(self.optimize_join_order_op(left, catalog)),
308 right: Box::new(self.optimize_join_order_op(right, catalog)),
309 on: on.clone(),
310 }
311 }
312 }
313 }
314 _ => op.clone(),
315 }
316 }
317
318 fn collect_relations(&self, op: &LogicalOp, relations: &mut Vec<LogicalOp>) {
320 match op {
321 LogicalOp::Join { left, right, .. } => {
322 self.collect_relations(left, relations);
323 self.collect_relations(right, relations);
324 }
325 _ => relations.push(op.clone()),
326 }
327 }
328
329 fn optimize_join_order_dp(&self, relations: &[LogicalOp], join_keys: &[String], catalog: &Catalog) -> LogicalOp {
331 let n = relations.len();
332 let mut dp = vec![vec![None; n]; n];
333 let mut costs = vec![vec![f64::INFINITY; n]; n];
334 let mut cardinalities = vec![vec![0.0; n]; n];
335
336 for i in 0..n {
338 let cost_est = self.estimate_cost_detailed(&relations[i], catalog);
339 costs[i][i] = cost_est.cost;
340 cardinalities[i][i] = cost_est.cardinality;
341 dp[i][i] = Some(relations[i].clone());
342 }
343
344 for len in 2..=n {
346 for i in 0..=n-len {
347 let j = i + len - 1;
348 costs[i][j] = f64::INFINITY;
349
350 for k in i..j {
352 let left_cost = costs[i][k];
353 let right_cost = costs[k+1][j];
354 let left_card = cardinalities[i][k];
355 let right_card = cardinalities[k+1][j];
356
357 let join_cost = self.calculate_join_cost(left_card, right_card, join_keys);
359 let total_cost = left_cost + right_cost + join_cost;
360
361 if total_cost < costs[i][j] {
362 costs[i][j] = total_cost;
363 cardinalities[i][j] = self.estimate_join_cardinality(left_card, right_card, join_keys);
364
365 if let (Some(left_plan), Some(right_plan)) = (&dp[i][k], &dp[k+1][j]) {
367 dp[i][j] = Some(LogicalOp::Join {
368 left: Box::new(left_plan.clone()),
369 right: Box::new(right_plan.clone()),
370 on: join_keys.to_vec(),
371 });
372 }
373 }
374 }
375 }
376 }
377
378 dp[0][n-1].clone().unwrap_or_else(|| relations[0].clone())
379 }
380
381 fn estimate_cost_detailed(&self, op: &LogicalOp, catalog: &Catalog) -> CostEstimate {
383 match op {
384 LogicalOp::NodeScan { label, props, .. } => {
385 let base_cardinality = catalog.get_label(label)
386 .map(|_| 1000.0) .unwrap_or(100.0);
388
389 let selectivity = if props.is_some() {
390 self.statistics.index_selectivity
391 } else {
392 1.0
393 };
394
395 CostEstimate {
396 cardinality: base_cardinality * selectivity,
397 cost: base_cardinality * selectivity * 10.0, selectivity,
399 }
400 }
401 LogicalOp::IndexScan { .. } => {
402 CostEstimate {
403 cardinality: 10.0, cost: 5.0, selectivity: self.statistics.index_selectivity,
406 }
407 }
408 LogicalOp::Filter { input, pred } => {
409 let input_cost = self.estimate_cost_detailed(input, catalog);
410 let filter_selectivity = self.estimate_filter_selectivity(pred);
411
412 CostEstimate {
413 cardinality: input_cost.cardinality * filter_selectivity,
414 cost: input_cost.cost + (input_cost.cardinality * filter_selectivity * 2.0), selectivity: input_cost.selectivity * filter_selectivity,
416 }
417 }
418 LogicalOp::Join { left, right, on } => {
419 let left_cost = self.estimate_cost_detailed(left, catalog);
420 let right_cost = self.estimate_cost_detailed(right, catalog);
421 let join_card = self.estimate_join_cardinality(left_cost.cardinality, right_cost.cardinality, on);
422 let join_cost = self.calculate_join_cost(left_cost.cardinality, right_cost.cardinality, on);
423
424 CostEstimate {
425 cardinality: join_card,
426 cost: left_cost.cost + right_cost.cost + join_cost,
427 selectivity: (left_cost.selectivity + right_cost.selectivity) / 2.0,
428 }
429 }
430 _ => CostEstimate {
431 cardinality: 100.0,
432 cost: 10.0,
433 selectivity: self.statistics.default_selectivity,
434 },
435 }
436 }
437
438 fn estimate_filter_selectivity(&self, pred: &Predicate) -> f64 {
440 match pred {
441 Predicate::Eq { .. } => self.statistics.index_selectivity, Predicate::Gt { .. } | Predicate::Lt { .. } | Predicate::Ge { .. } | Predicate::Le { .. } => 0.3, Predicate::And { and } => {
444 and.iter().map(|p| self.estimate_filter_selectivity(p)).product()
446 }
447 Predicate::Or { or } => {
448 let sum: f64 = or.iter().map(|p| self.estimate_filter_selectivity(p)).sum();
450 sum.min(1.0)
451 }
452 _ => self.statistics.default_selectivity,
453 }
454 }
455
456 fn estimate_join_cardinality(&self, left_card: f64, right_card: f64, join_keys: &[String]) -> f64 {
458 if join_keys.is_empty() {
459 left_card * right_card
461 } else {
462 (left_card * right_card * self.statistics.default_selectivity).max(left_card.max(right_card))
464 }
465 }
466
467 fn calculate_join_cost(&self, left_card: f64, right_card: f64, _join_keys: &[String]) -> f64 {
469 left_card * right_card * self.statistics.join_cost_factor
471 }
472
473 fn estimate_cost(&self, op: &LogicalOp, catalog: &Catalog) -> f64 {
475 self.estimate_cost_detailed(op, catalog).cost
476 }
477
478 fn eliminate_unnecessary_projections(&self, plan: PlanIR) -> PlanIR {
480 let optimized_plan = self.eliminate_unnecessary_projections_op(&plan.plan);
481 PlanIR {
482 plan: optimized_plan,
483 limit: plan.limit,
484 }
485 }
486
487 fn eliminate_unnecessary_projections_op(&self, op: &LogicalOp) -> LogicalOp {
488 match op {
489 LogicalOp::Project { cols, input } => {
490 match input.as_ref() {
491 LogicalOp::Project { cols: inner_cols, input: inner_input } => {
492 let merged_cols = cols.iter()
494 .filter(|col| inner_cols.contains(col))
495 .cloned()
496 .collect();
497
498 LogicalOp::Project {
499 cols: merged_cols,
500 input: inner_input.clone(),
501 }
502 }
503 _ => LogicalOp::Project {
504 cols: cols.clone(),
505 input: Box::new(self.eliminate_unnecessary_projections_op(input)),
506 }
507 }
508 }
509 _ => op.clone(),
510 }
511 }
512
513 fn constant_folding(&self, plan: PlanIR) -> PlanIR {
515 plan
517 }
518
519 fn select_indexes(&self, plan: PlanIR, catalog: &Catalog) -> PlanIR {
521 let optimized_plan = self.select_indexes_op(&plan.plan, catalog);
522 PlanIR {
523 plan: optimized_plan,
524 limit: plan.limit,
525 }
526 }
527
528 fn select_indexes_op(&self, op: &LogicalOp, catalog: &Catalog) -> LogicalOp {
529 match op {
530 LogicalOp::Filter { pred, input } => {
531 match input.as_ref() {
532 LogicalOp::NodeScan { label, as_, props: _ } => {
533 if let Some(index) = self.find_best_index(catalog, label, pred) {
535 LogicalOp::Filter {
536 pred: pred.clone(),
537 input: Box::new(LogicalOp::IndexScan {
538 label: label.clone(),
539 as_: as_.clone(),
540 index: index.name,
541 value: self.extract_index_value(pred, &index.properties[0]),
542 }),
543 }
544 } else {
545 LogicalOp::Filter {
546 pred: pred.clone(),
547 input: Box::new(self.select_indexes_op(input, catalog)),
548 }
549 }
550 }
551 _ => LogicalOp::Filter {
552 pred: pred.clone(),
553 input: Box::new(self.select_indexes_op(input, catalog)),
554 }
555 }
556 }
557 _ => op.clone(),
558 }
559 }
560
561 fn find_best_index(&self, catalog: &Catalog, label: &Label, pred: &Predicate) -> Option<IndexDef> {
563 catalog.indexes.iter()
564 .filter(|idx| &idx.label == label)
565 .find(|idx| self.can_use_index(pred, &idx.properties[0]))
566 .cloned()
567 }
568
569 fn can_use_index(&self, pred: &Predicate, prop: &PropertyKey) -> bool {
571 match pred {
572 Predicate::Eq { eq } if eq.len() == 2 => {
573 let left_vars = self.extract_variables(&eq[0]);
574 let right_vars = self.extract_variables(&eq[1]);
575
576 left_vars.contains(prop) || right_vars.contains(prop)
578 }
579 _ => false,
580 }
581 }
582
583 fn extract_index_value(&self, pred: &Predicate, prop: &PropertyKey) -> Value {
585 match pred {
586 Predicate::Eq { eq } if eq.len() == 2 => {
587 if let Expr::Var(var) = &eq[0] {
588 if var == prop {
589 if let Expr::Const(val) = &eq[1] {
590 return val.clone();
591 }
592 }
593 }
594 if let Expr::Var(var) = &eq[1] {
595 if var == prop {
596 if let Expr::Const(val) = &eq[0] {
597 return val.clone();
598 }
599 }
600 }
601 }
602 _ => {}
603 }
604 Value::Null
605 }
606}