1use super::constants as cost_constants;
4use crate::expr::{BinaryOp, Expr};
5use crate::planner::{IndexBound, IndexRange, JoinType, LogicalPlan};
6use featherdb_catalog::{Catalog, Table};
7use std::sync::Arc;
8
9pub struct CostEstimator<'a> {
15 catalog: &'a Catalog,
16}
17
18impl<'a> CostEstimator<'a> {
19 pub fn new(catalog: &'a Catalog) -> Self {
21 CostEstimator { catalog }
22 }
23
24 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
26 match plan {
27 LogicalPlan::Scan { table, filter, .. } => {
28 let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
29 match filter {
30 Some(pred) => base_rows * self.estimate_selectivity(pred, table),
31 None => base_rows,
32 }
33 }
34
35 LogicalPlan::IndexScan {
36 table,
37 range,
38 residual_filter,
39 index_column,
40 ..
41 } => {
42 let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
43
44 let range_selectivity = if range.is_point_lookup() {
46 self.catalog.selectivity_eq(&table.name, *index_column)
47 } else {
48 self.estimate_range_selectivity(range, &table.name, *index_column)
49 };
50
51 let rows_after_index = base_rows * range_selectivity;
52
53 match residual_filter {
55 Some(pred) => rows_after_index * self.estimate_selectivity(pred, table),
56 None => rows_after_index,
57 }
58 }
59
60 LogicalPlan::PkSeek {
61 table,
62 residual_filter,
63 ..
64 } => {
65 let base_cardinality = 1.0;
67
68 match residual_filter {
70 Some(pred) => base_cardinality * self.estimate_selectivity(pred, table),
71 None => base_cardinality,
72 }
73 }
74
75 LogicalPlan::PkRangeScan {
76 table,
77 range,
78 residual_filter,
79 ..
80 } => {
81 let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
82 let pk_col = if !table.primary_key.is_empty() {
84 table.primary_key[0]
85 } else {
86 0
87 };
88 let range_selectivity = self.estimate_range_selectivity(range, &table.name, pk_col);
89 let rows_after_range = base_rows * range_selectivity;
90 match residual_filter {
91 Some(pred) => rows_after_range * self.estimate_selectivity(pred, table),
92 None => rows_after_range,
93 }
94 }
95
96 LogicalPlan::Filter { input, predicate } => {
97 let input_rows = self.estimate_cardinality(input);
98 let table = self.extract_table(input);
99 let selectivity = table
100 .map(|t| self.estimate_selectivity(predicate, &t))
101 .unwrap_or(cost_constants::DEFAULT_SELECTIVITY);
102 input_rows * selectivity
103 }
104
105 LogicalPlan::Project { input, .. } => self.estimate_cardinality(input),
106
107 LogicalPlan::Join {
108 left,
109 right,
110 condition,
111 join_type,
112 } => {
113 let left_rows = self.estimate_cardinality(left);
114 let right_rows = self.estimate_cardinality(right);
115
116 let join_selectivity = condition
117 .as_ref()
118 .map(|c| self.estimate_join_selectivity(c, left, right))
119 .unwrap_or(1.0); let base_result = left_rows * right_rows * join_selectivity;
122
123 match join_type {
125 JoinType::Inner => base_result,
126 JoinType::Left => base_result.max(left_rows), JoinType::Right => base_result.max(right_rows), JoinType::Full => base_result.max(left_rows).max(right_rows),
129 }
130 }
131
132 LogicalPlan::Aggregate {
133 input, group_by, ..
134 } => {
135 if group_by.is_empty() {
136 1.0 } else {
138 let input_rows = self.estimate_cardinality(input);
140 (input_rows * 0.1).max(1.0)
142 }
143 }
144
145 LogicalPlan::Sort { input, .. } => self.estimate_cardinality(input),
146
147 LogicalPlan::Limit {
148 input,
149 limit,
150 offset,
151 } => {
152 let input_rows = self.estimate_cardinality(input);
153 let remaining = input_rows - *offset as f64;
154 match limit {
155 Some(l) => remaining.min(*l as f64).max(0.0),
156 None => remaining.max(0.0),
157 }
158 }
159
160 LogicalPlan::Distinct { input } => {
161 let input_rows = self.estimate_cardinality(input);
162 (input_rows * 0.5).max(1.0)
164 }
165
166 LogicalPlan::EmptyRelation => 1.0,
167
168 _ => 1000.0, }
170 }
171
172 pub fn estimate_cost(&self, plan: &LogicalPlan) -> f64 {
174 match plan {
175 LogicalPlan::Scan { table, .. } => {
176 let rows = self.catalog.estimated_row_count(&table.name) as f64;
177 rows * cost_constants::SEQ_SCAN_COST_PER_ROW
178 }
179
180 LogicalPlan::IndexScan {
181 table,
182 range,
183 index_column,
184 ..
185 } => {
186 let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
187 let selectivity = if range.is_point_lookup() {
188 self.catalog.selectivity_eq(&table.name, *index_column)
189 } else {
190 self.estimate_range_selectivity(range, &table.name, *index_column)
191 };
192 let rows_scanned = base_rows * selectivity;
193 rows_scanned * cost_constants::INDEX_SCAN_COST_PER_ROW
194 }
195
196 LogicalPlan::PkSeek { .. } => {
197 cost_constants::INDEX_SCAN_COST_PER_ROW
200 }
201
202 LogicalPlan::PkRangeScan { table, range, .. } => {
203 let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
204 let pk_col = if !table.primary_key.is_empty() {
205 table.primary_key[0]
206 } else {
207 0
208 };
209 let selectivity = self.estimate_range_selectivity(range, &table.name, pk_col);
210 let rows_scanned = base_rows * selectivity;
211 rows_scanned * cost_constants::INDEX_SCAN_COST_PER_ROW
213 }
214
215 LogicalPlan::Filter { input, .. } => {
216 let input_cost = self.estimate_cost(input);
217 let rows = self.estimate_cardinality(input);
218 input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
219 }
220
221 LogicalPlan::Project { input, .. } => {
222 let input_cost = self.estimate_cost(input);
223 let rows = self.estimate_cardinality(input);
224 input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
225 }
226
227 LogicalPlan::Join { left, right, .. } => {
228 let left_cost = self.estimate_cost(left);
229 let right_cost = self.estimate_cost(right);
230 let left_rows = self.estimate_cardinality(left);
231 let right_rows = self.estimate_cardinality(right);
232
233 let (build_rows, probe_rows) = if left_rows <= right_rows {
236 (left_rows, right_rows)
237 } else {
238 (right_rows, left_rows)
239 };
240
241 let build_cost = build_rows * cost_constants::HASH_BUILD_COST_PER_ROW;
242 let probe_cost = probe_rows * cost_constants::HASH_JOIN_COST_PER_ROW;
243
244 left_cost + right_cost + build_cost + probe_cost
245 }
246
247 LogicalPlan::Aggregate { input, .. } => {
248 let input_cost = self.estimate_cost(input);
249 let rows = self.estimate_cardinality(input);
250 input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
251 }
252
253 LogicalPlan::Sort { input, .. } => {
254 let input_cost = self.estimate_cost(input);
255 let rows = self.estimate_cardinality(input);
256 let sort_cost = if rows > 1.0 {
258 rows * rows.log2() * cost_constants::SORT_COST_PER_ROW
259 } else {
260 0.0
261 };
262 input_cost + sort_cost
263 }
264
265 LogicalPlan::Limit { input, .. } => self.estimate_cost(input),
266
267 LogicalPlan::Distinct { input } => {
268 let input_cost = self.estimate_cost(input);
269 let rows = self.estimate_cardinality(input);
270 input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
271 }
272
273 LogicalPlan::EmptyRelation => 0.0,
274
275 _ => 1000.0, }
277 }
278
279 pub fn estimate_selectivity(&self, predicate: &Expr, table: &Table) -> f64 {
281 match predicate {
282 Expr::BinaryOp { left, op, right } => {
283 match op {
284 BinaryOp::And => {
285 let left_sel = self.estimate_selectivity(left, table);
287 let right_sel = self.estimate_selectivity(right, table);
288 left_sel * right_sel
289 }
290 BinaryOp::Or => {
291 let left_sel = self.estimate_selectivity(left, table);
293 let right_sel = self.estimate_selectivity(right, table);
294 left_sel + right_sel - (left_sel * right_sel)
295 }
296 BinaryOp::Eq => self.estimate_equality_selectivity(left, right, table),
297 BinaryOp::Ne => 1.0 - self.estimate_equality_selectivity(left, right, table),
298 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
299 cost_constants::DEFAULT_RANGE_SELECTIVITY
300 }
301 _ => cost_constants::DEFAULT_SELECTIVITY,
302 }
303 }
304 Expr::UnaryOp { op, .. } => {
305 match op {
306 crate::expr::UnaryOp::Not => 1.0 - cost_constants::DEFAULT_SELECTIVITY,
307 crate::expr::UnaryOp::IsNull => 0.01, crate::expr::UnaryOp::IsNotNull => 0.99,
309 _ => cost_constants::DEFAULT_SELECTIVITY,
310 }
311 }
312 Expr::Between { .. } => cost_constants::DEFAULT_RANGE_SELECTIVITY,
313 Expr::InList { list, .. } => {
314 let eq_sel = cost_constants::DEFAULT_SELECTIVITY;
316 (eq_sel * list.len() as f64).min(1.0)
317 }
318 Expr::Like { .. } => 0.2, _ => cost_constants::DEFAULT_SELECTIVITY,
320 }
321 }
322
323 fn estimate_equality_selectivity(&self, left: &Expr, right: &Expr, table: &Table) -> f64 {
325 if let Expr::Column { name, .. } = left {
327 if matches!(right, Expr::Literal(_)) {
328 if let Some(col_idx) = table.get_column_index(name) {
329 return self.catalog.selectivity_eq(&table.name, col_idx);
330 }
331 }
332 }
333 if let Expr::Column { name, .. } = right {
334 if matches!(left, Expr::Literal(_)) {
335 if let Some(col_idx) = table.get_column_index(name) {
336 return self.catalog.selectivity_eq(&table.name, col_idx);
337 }
338 }
339 }
340 cost_constants::DEFAULT_SELECTIVITY
341 }
342
343 fn estimate_join_selectivity(
345 &self,
346 condition: &Expr,
347 _left: &LogicalPlan,
348 _right: &LogicalPlan,
349 ) -> f64 {
350 match condition {
351 Expr::BinaryOp {
352 op: BinaryOp::Eq, ..
353 } => cost_constants::DEFAULT_JOIN_SELECTIVITY,
354 Expr::BinaryOp {
355 op: BinaryOp::And,
356 left,
357 right,
358 } => {
359 let left_sel = self.estimate_join_selectivity(left, _left, _right);
360 let right_sel = self.estimate_join_selectivity(right, _left, _right);
361 left_sel * right_sel
362 }
363 _ => cost_constants::DEFAULT_JOIN_SELECTIVITY,
364 }
365 }
366
367 fn estimate_range_selectivity(
369 &self,
370 range: &IndexRange,
371 table_name: &str,
372 col_index: usize,
373 ) -> f64 {
374 let low = match &range.start {
375 IndexBound::Inclusive(v) | IndexBound::Exclusive(v) => Some(v),
376 IndexBound::Unbounded => None,
377 };
378 let high = match &range.end {
379 IndexBound::Inclusive(v) | IndexBound::Exclusive(v) => Some(v),
380 IndexBound::Unbounded => None,
381 };
382 self.catalog
383 .selectivity_range(table_name, col_index, low, high)
384 }
385
386 fn extract_table(&self, plan: &LogicalPlan) -> Option<Arc<Table>> {
388 match plan {
389 LogicalPlan::Scan { table, .. } => Some(table.clone()),
390 LogicalPlan::IndexScan { table, .. } => Some(table.clone()),
391 LogicalPlan::PkSeek { table, .. } => Some(table.clone()),
392 LogicalPlan::PkRangeScan { table, .. } => Some(table.clone()),
393 LogicalPlan::Filter { input, .. } => self.extract_table(input),
394 LogicalPlan::Project { input, .. } => self.extract_table(input),
395 _ => None,
396 }
397 }
398
399 pub fn format_plan_with_costs(&self, plan: &LogicalPlan, indent: usize) -> String {
401 let prefix = " ".repeat(indent);
402 let cost = self.estimate_cost(plan);
403 let rows = self.estimate_cardinality(plan);
404
405 match plan {
406 LogicalPlan::Scan {
407 table,
408 alias,
409 filter,
410 ..
411 } => {
412 let alias_str = alias
413 .as_ref()
414 .map(|a| format!(" AS {}", a))
415 .unwrap_or_default();
416 let filter_str = filter
417 .as_ref()
418 .map(|f| format!(" (filter: {:?})", f))
419 .unwrap_or_default();
420 format!(
421 "{}Scan: {}{}{} (cost={:.2}, rows={:.0})",
422 prefix, table.name, alias_str, filter_str, cost, rows
423 )
424 }
425
426 LogicalPlan::IndexScan {
427 table,
428 index,
429 range,
430 alias,
431 ..
432 } => {
433 let alias_str = alias
434 .as_ref()
435 .map(|a| format!(" AS {}", a))
436 .unwrap_or_default();
437 let range_str = if range.is_point_lookup() {
438 "point lookup"
439 } else {
440 "range scan"
441 };
442 format!(
443 "{}IndexScan: {}{} using {} ({}) (cost={:.2}, rows={:.0})",
444 prefix, table.name, alias_str, index.name, range_str, cost, rows
445 )
446 }
447
448 LogicalPlan::PkSeek {
449 table,
450 alias,
451 key_values,
452 residual_filter,
453 ..
454 } => {
455 let alias_str = alias
456 .as_ref()
457 .map(|a| format!(" AS {}", a))
458 .unwrap_or_default();
459 let keys_str = key_values
460 .iter()
461 .map(|e| format!("{:?}", e))
462 .collect::<Vec<_>>()
463 .join(", ");
464 let filter_str = residual_filter
465 .as_ref()
466 .map(|f| format!(" (residual: {:?})", f))
467 .unwrap_or_default();
468 format!(
469 "{}PkSeek: {}{} [{}]{} (cost={:.2}, rows={:.0})",
470 prefix, table.name, alias_str, keys_str, filter_str, cost, rows
471 )
472 }
473
474 LogicalPlan::PkRangeScan {
475 table,
476 alias,
477 range,
478 residual_filter,
479 ..
480 } => {
481 let alias_str = alias
482 .as_ref()
483 .map(|a| format!(" AS {}", a))
484 .unwrap_or_default();
485 let range_str = format!("{:?}..{:?}", range.start, range.end);
486 let filter_str = residual_filter
487 .as_ref()
488 .map(|f| format!(" (residual: {:?})", f))
489 .unwrap_or_default();
490 format!(
491 "{}PkRangeScan: {}{} [{}]{} (cost={:.2}, rows={:.0})",
492 prefix, table.name, alias_str, range_str, filter_str, cost, rows
493 )
494 }
495
496 LogicalPlan::Filter { input, predicate } => {
497 let child = self.format_plan_with_costs(input, indent + 1);
498 format!(
499 "{}Filter: {:?} (cost={:.2}, rows={:.0})\n{}",
500 prefix, predicate, cost, rows, child
501 )
502 }
503
504 LogicalPlan::Project { input, exprs } => {
505 let child = self.format_plan_with_costs(input, indent + 1);
506 let cols: Vec<_> = exprs.iter().map(|(_, name)| name.as_str()).collect();
507 format!(
508 "{}Project: [{}] (cost={:.2}, rows={:.0})\n{}",
509 prefix,
510 cols.join(", "),
511 cost,
512 rows,
513 child
514 )
515 }
516
517 LogicalPlan::Join {
518 left,
519 right,
520 join_type,
521 condition,
522 } => {
523 let left_child = self.format_plan_with_costs(left, indent + 1);
524 let right_child = self.format_plan_with_costs(right, indent + 1);
525 let join_str = match join_type {
526 JoinType::Inner => "Inner",
527 JoinType::Left => "Left",
528 JoinType::Right => "Right",
529 JoinType::Full => "Full",
530 };
531 let cond_str = condition
532 .as_ref()
533 .map(|c| format!(" ON {:?}", c))
534 .unwrap_or_default();
535 format!(
536 "{}{} Join{} (cost={:.2}, rows={:.0})\n{}\n{}",
537 prefix, join_str, cond_str, cost, rows, left_child, right_child
538 )
539 }
540
541 LogicalPlan::Aggregate {
542 input,
543 group_by,
544 aggregates,
545 } => {
546 let child = self.format_plan_with_costs(input, indent + 1);
547 let aggs: Vec<_> = aggregates.iter().map(|(_, name)| name.as_str()).collect();
548 format!(
549 "{}Aggregate: group_by={}, aggs=[{}] (cost={:.2}, rows={:.0})\n{}",
550 prefix,
551 group_by.len(),
552 aggs.join(", "),
553 cost,
554 rows,
555 child
556 )
557 }
558
559 LogicalPlan::Sort { input, order_by } => {
560 let child = self.format_plan_with_costs(input, indent + 1);
561 format!(
562 "{}Sort: {} columns (cost={:.2}, rows={:.0})\n{}",
563 prefix,
564 order_by.len(),
565 cost,
566 rows,
567 child
568 )
569 }
570
571 LogicalPlan::Limit {
572 input,
573 limit,
574 offset,
575 } => {
576 let child = self.format_plan_with_costs(input, indent + 1);
577 format!(
578 "{}Limit: {:?} offset {} (cost={:.2}, rows={:.0})\n{}",
579 prefix, limit, offset, cost, rows, child
580 )
581 }
582
583 LogicalPlan::Distinct { input } => {
584 let child = self.format_plan_with_costs(input, indent + 1);
585 format!(
586 "{}Distinct (cost={:.2}, rows={:.0})\n{}",
587 prefix, cost, rows, child
588 )
589 }
590
591 _ => format!("{}Plan node (cost={:.2}, rows={:.0})", prefix, cost, rows),
592 }
593 }
594}