Skip to main content

featherdb_query/optimizer/cost/
estimator.rs

1//! Cost estimator for query plans
2
3use super::constants as cost_constants;
4use crate::expr::{BinaryOp, Expr};
5use crate::planner::{IndexBound, IndexRange, JoinType, LogicalPlan};
6use featherdb_catalog::{Catalog, Table};
7use std::sync::Arc;
8
9/// Cost estimator for query plans
10///
11/// Provides cardinality and cost estimates for logical plan nodes.
12/// Used by the query optimizer to choose optimal join orders and
13/// access methods.
14pub struct CostEstimator<'a> {
15    catalog: &'a Catalog,
16}
17
18impl<'a> CostEstimator<'a> {
19    /// Create a new cost estimator with catalog access
20    pub fn new(catalog: &'a Catalog) -> Self {
21        CostEstimator { catalog }
22    }
23
24    /// Estimate the number of rows output by a plan node
25    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
26        match plan {
27            LogicalPlan::Scan { table, filter, .. } => {
28                let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
29                match filter {
30                    Some(pred) => base_rows * self.estimate_selectivity(pred, table),
31                    None => base_rows,
32                }
33            }
34
35            LogicalPlan::IndexScan {
36                table,
37                range,
38                residual_filter,
39                index_column,
40                ..
41            } => {
42                let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
43
44                // Estimate selectivity based on range
45                let range_selectivity = if range.is_point_lookup() {
46                    self.catalog.selectivity_eq(&table.name, *index_column)
47                } else {
48                    self.estimate_range_selectivity(range, &table.name, *index_column)
49                };
50
51                let rows_after_index = base_rows * range_selectivity;
52
53                // Apply residual filter selectivity
54                match residual_filter {
55                    Some(pred) => rows_after_index * self.estimate_selectivity(pred, table),
56                    None => rows_after_index,
57                }
58            }
59
60            LogicalPlan::PkSeek {
61                table,
62                residual_filter,
63                ..
64            } => {
65                // PK seek returns at most 1 row (point lookup)
66                let base_cardinality = 1.0;
67
68                // Apply residual filter selectivity if present
69                match residual_filter {
70                    Some(pred) => base_cardinality * self.estimate_selectivity(pred, table),
71                    None => base_cardinality,
72                }
73            }
74
75            LogicalPlan::PkRangeScan {
76                table,
77                range,
78                residual_filter,
79                ..
80            } => {
81                let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
82                // Use range selectivity estimate - for PK ranges, treat like the first PK column
83                let pk_col = if !table.primary_key.is_empty() {
84                    table.primary_key[0]
85                } else {
86                    0
87                };
88                let range_selectivity = self.estimate_range_selectivity(range, &table.name, pk_col);
89                let rows_after_range = base_rows * range_selectivity;
90                match residual_filter {
91                    Some(pred) => rows_after_range * self.estimate_selectivity(pred, table),
92                    None => rows_after_range,
93                }
94            }
95
96            LogicalPlan::Filter { input, predicate } => {
97                let input_rows = self.estimate_cardinality(input);
98                let table = self.extract_table(input);
99                let selectivity = table
100                    .map(|t| self.estimate_selectivity(predicate, &t))
101                    .unwrap_or(cost_constants::DEFAULT_SELECTIVITY);
102                input_rows * selectivity
103            }
104
105            LogicalPlan::Project { input, .. } => self.estimate_cardinality(input),
106
107            LogicalPlan::Join {
108                left,
109                right,
110                condition,
111                join_type,
112            } => {
113                let left_rows = self.estimate_cardinality(left);
114                let right_rows = self.estimate_cardinality(right);
115
116                let join_selectivity = condition
117                    .as_ref()
118                    .map(|c| self.estimate_join_selectivity(c, left, right))
119                    .unwrap_or(1.0); // Cross join
120
121                let base_result = left_rows * right_rows * join_selectivity;
122
123                // Adjust for join type
124                match join_type {
125                    JoinType::Inner => base_result,
126                    JoinType::Left => base_result.max(left_rows), // At least left rows
127                    JoinType::Right => base_result.max(right_rows), // At least right rows
128                    JoinType::Full => base_result.max(left_rows).max(right_rows),
129                }
130            }
131
132            LogicalPlan::Aggregate {
133                input, group_by, ..
134            } => {
135                if group_by.is_empty() {
136                    1.0 // Single aggregate result
137                } else {
138                    // Estimate distinct groups
139                    let input_rows = self.estimate_cardinality(input);
140                    // Assume group by reduces rows significantly
141                    (input_rows * 0.1).max(1.0)
142                }
143            }
144
145            LogicalPlan::Sort { input, .. } => self.estimate_cardinality(input),
146
147            LogicalPlan::Limit {
148                input,
149                limit,
150                offset,
151            } => {
152                let input_rows = self.estimate_cardinality(input);
153                let remaining = input_rows - *offset as f64;
154                match limit {
155                    Some(l) => remaining.min(*l as f64).max(0.0),
156                    None => remaining.max(0.0),
157                }
158            }
159
160            LogicalPlan::Distinct { input } => {
161                let input_rows = self.estimate_cardinality(input);
162                // Assume distinct reduces by some factor
163                (input_rows * 0.5).max(1.0)
164            }
165
166            LogicalPlan::EmptyRelation => 1.0,
167
168            _ => 1000.0, // Default for other plan types
169        }
170    }
171
172    /// Estimate the total cost of executing a plan
173    pub fn estimate_cost(&self, plan: &LogicalPlan) -> f64 {
174        match plan {
175            LogicalPlan::Scan { table, .. } => {
176                let rows = self.catalog.estimated_row_count(&table.name) as f64;
177                rows * cost_constants::SEQ_SCAN_COST_PER_ROW
178            }
179
180            LogicalPlan::IndexScan {
181                table,
182                range,
183                index_column,
184                ..
185            } => {
186                let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
187                let selectivity = if range.is_point_lookup() {
188                    self.catalog.selectivity_eq(&table.name, *index_column)
189                } else {
190                    self.estimate_range_selectivity(range, &table.name, *index_column)
191                };
192                let rows_scanned = base_rows * selectivity;
193                rows_scanned * cost_constants::INDEX_SCAN_COST_PER_ROW
194            }
195
196            LogicalPlan::PkSeek { .. } => {
197                // PK seek: single B-tree lookup (O(log N))
198                // Very low cost - similar to index point lookup
199                cost_constants::INDEX_SCAN_COST_PER_ROW
200            }
201
202            LogicalPlan::PkRangeScan { table, range, .. } => {
203                let base_rows = self.catalog.estimated_row_count(&table.name) as f64;
204                let pk_col = if !table.primary_key.is_empty() {
205                    table.primary_key[0]
206                } else {
207                    0
208                };
209                let selectivity = self.estimate_range_selectivity(range, &table.name, pk_col);
210                let rows_scanned = base_rows * selectivity;
211                // Similar cost to index scan - direct B-tree range access
212                rows_scanned * cost_constants::INDEX_SCAN_COST_PER_ROW
213            }
214
215            LogicalPlan::Filter { input, .. } => {
216                let input_cost = self.estimate_cost(input);
217                let rows = self.estimate_cardinality(input);
218                input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
219            }
220
221            LogicalPlan::Project { input, .. } => {
222                let input_cost = self.estimate_cost(input);
223                let rows = self.estimate_cardinality(input);
224                input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
225            }
226
227            LogicalPlan::Join { left, right, .. } => {
228                let left_cost = self.estimate_cost(left);
229                let right_cost = self.estimate_cost(right);
230                let left_rows = self.estimate_cardinality(left);
231                let right_rows = self.estimate_cardinality(right);
232
233                // Hash join cost: build hash table + probe
234                // Smaller table goes on build side
235                let (build_rows, probe_rows) = if left_rows <= right_rows {
236                    (left_rows, right_rows)
237                } else {
238                    (right_rows, left_rows)
239                };
240
241                let build_cost = build_rows * cost_constants::HASH_BUILD_COST_PER_ROW;
242                let probe_cost = probe_rows * cost_constants::HASH_JOIN_COST_PER_ROW;
243
244                left_cost + right_cost + build_cost + probe_cost
245            }
246
247            LogicalPlan::Aggregate { input, .. } => {
248                let input_cost = self.estimate_cost(input);
249                let rows = self.estimate_cardinality(input);
250                input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
251            }
252
253            LogicalPlan::Sort { input, .. } => {
254                let input_cost = self.estimate_cost(input);
255                let rows = self.estimate_cardinality(input);
256                // n log n cost
257                let sort_cost = if rows > 1.0 {
258                    rows * rows.log2() * cost_constants::SORT_COST_PER_ROW
259                } else {
260                    0.0
261                };
262                input_cost + sort_cost
263            }
264
265            LogicalPlan::Limit { input, .. } => self.estimate_cost(input),
266
267            LogicalPlan::Distinct { input } => {
268                let input_cost = self.estimate_cost(input);
269                let rows = self.estimate_cardinality(input);
270                input_cost + rows * cost_constants::CPU_COST_MULTIPLIER
271            }
272
273            LogicalPlan::EmptyRelation => 0.0,
274
275            _ => 1000.0, // Default cost for other plan types
276        }
277    }
278
279    /// Estimate selectivity of a predicate
280    pub fn estimate_selectivity(&self, predicate: &Expr, table: &Table) -> f64 {
281        match predicate {
282            Expr::BinaryOp { left, op, right } => {
283                match op {
284                    BinaryOp::And => {
285                        // AND: multiply selectivities (independence assumption)
286                        let left_sel = self.estimate_selectivity(left, table);
287                        let right_sel = self.estimate_selectivity(right, table);
288                        left_sel * right_sel
289                    }
290                    BinaryOp::Or => {
291                        // OR: p(A) + p(B) - p(A)*p(B)
292                        let left_sel = self.estimate_selectivity(left, table);
293                        let right_sel = self.estimate_selectivity(right, table);
294                        left_sel + right_sel - (left_sel * right_sel)
295                    }
296                    BinaryOp::Eq => self.estimate_equality_selectivity(left, right, table),
297                    BinaryOp::Ne => 1.0 - self.estimate_equality_selectivity(left, right, table),
298                    BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
299                        cost_constants::DEFAULT_RANGE_SELECTIVITY
300                    }
301                    _ => cost_constants::DEFAULT_SELECTIVITY,
302                }
303            }
304            Expr::UnaryOp { op, .. } => {
305                match op {
306                    crate::expr::UnaryOp::Not => 1.0 - cost_constants::DEFAULT_SELECTIVITY,
307                    crate::expr::UnaryOp::IsNull => 0.01, // Assume 1% nulls
308                    crate::expr::UnaryOp::IsNotNull => 0.99,
309                    _ => cost_constants::DEFAULT_SELECTIVITY,
310                }
311            }
312            Expr::Between { .. } => cost_constants::DEFAULT_RANGE_SELECTIVITY,
313            Expr::InList { list, .. } => {
314                // IN list: each value has equality selectivity
315                let eq_sel = cost_constants::DEFAULT_SELECTIVITY;
316                (eq_sel * list.len() as f64).min(1.0)
317            }
318            Expr::Like { .. } => 0.2, // LIKE usually quite selective
319            _ => cost_constants::DEFAULT_SELECTIVITY,
320        }
321    }
322
323    /// Estimate selectivity for equality predicate
324    fn estimate_equality_selectivity(&self, left: &Expr, right: &Expr, table: &Table) -> f64 {
325        // Check if this is a column = literal pattern
326        if let Expr::Column { name, .. } = left {
327            if matches!(right, Expr::Literal(_)) {
328                if let Some(col_idx) = table.get_column_index(name) {
329                    return self.catalog.selectivity_eq(&table.name, col_idx);
330                }
331            }
332        }
333        if let Expr::Column { name, .. } = right {
334            if matches!(left, Expr::Literal(_)) {
335                if let Some(col_idx) = table.get_column_index(name) {
336                    return self.catalog.selectivity_eq(&table.name, col_idx);
337                }
338            }
339        }
340        cost_constants::DEFAULT_SELECTIVITY
341    }
342
343    /// Estimate selectivity for join conditions
344    fn estimate_join_selectivity(
345        &self,
346        condition: &Expr,
347        _left: &LogicalPlan,
348        _right: &LogicalPlan,
349    ) -> f64 {
350        match condition {
351            Expr::BinaryOp {
352                op: BinaryOp::Eq, ..
353            } => cost_constants::DEFAULT_JOIN_SELECTIVITY,
354            Expr::BinaryOp {
355                op: BinaryOp::And,
356                left,
357                right,
358            } => {
359                let left_sel = self.estimate_join_selectivity(left, _left, _right);
360                let right_sel = self.estimate_join_selectivity(right, _left, _right);
361                left_sel * right_sel
362            }
363            _ => cost_constants::DEFAULT_JOIN_SELECTIVITY,
364        }
365    }
366
367    /// Estimate selectivity for index range
368    fn estimate_range_selectivity(
369        &self,
370        range: &IndexRange,
371        table_name: &str,
372        col_index: usize,
373    ) -> f64 {
374        let low = match &range.start {
375            IndexBound::Inclusive(v) | IndexBound::Exclusive(v) => Some(v),
376            IndexBound::Unbounded => None,
377        };
378        let high = match &range.end {
379            IndexBound::Inclusive(v) | IndexBound::Exclusive(v) => Some(v),
380            IndexBound::Unbounded => None,
381        };
382        self.catalog
383            .selectivity_range(table_name, col_index, low, high)
384    }
385
386    /// Extract the primary table from a plan (for selectivity estimation)
387    fn extract_table(&self, plan: &LogicalPlan) -> Option<Arc<Table>> {
388        match plan {
389            LogicalPlan::Scan { table, .. } => Some(table.clone()),
390            LogicalPlan::IndexScan { table, .. } => Some(table.clone()),
391            LogicalPlan::PkSeek { table, .. } => Some(table.clone()),
392            LogicalPlan::PkRangeScan { table, .. } => Some(table.clone()),
393            LogicalPlan::Filter { input, .. } => self.extract_table(input),
394            LogicalPlan::Project { input, .. } => self.extract_table(input),
395            _ => None,
396        }
397    }
398
399    /// Format a plan with cost annotations for EXPLAIN output
400    pub fn format_plan_with_costs(&self, plan: &LogicalPlan, indent: usize) -> String {
401        let prefix = "  ".repeat(indent);
402        let cost = self.estimate_cost(plan);
403        let rows = self.estimate_cardinality(plan);
404
405        match plan {
406            LogicalPlan::Scan {
407                table,
408                alias,
409                filter,
410                ..
411            } => {
412                let alias_str = alias
413                    .as_ref()
414                    .map(|a| format!(" AS {}", a))
415                    .unwrap_or_default();
416                let filter_str = filter
417                    .as_ref()
418                    .map(|f| format!(" (filter: {:?})", f))
419                    .unwrap_or_default();
420                format!(
421                    "{}Scan: {}{}{} (cost={:.2}, rows={:.0})",
422                    prefix, table.name, alias_str, filter_str, cost, rows
423                )
424            }
425
426            LogicalPlan::IndexScan {
427                table,
428                index,
429                range,
430                alias,
431                ..
432            } => {
433                let alias_str = alias
434                    .as_ref()
435                    .map(|a| format!(" AS {}", a))
436                    .unwrap_or_default();
437                let range_str = if range.is_point_lookup() {
438                    "point lookup"
439                } else {
440                    "range scan"
441                };
442                format!(
443                    "{}IndexScan: {}{} using {} ({}) (cost={:.2}, rows={:.0})",
444                    prefix, table.name, alias_str, index.name, range_str, cost, rows
445                )
446            }
447
448            LogicalPlan::PkSeek {
449                table,
450                alias,
451                key_values,
452                residual_filter,
453                ..
454            } => {
455                let alias_str = alias
456                    .as_ref()
457                    .map(|a| format!(" AS {}", a))
458                    .unwrap_or_default();
459                let keys_str = key_values
460                    .iter()
461                    .map(|e| format!("{:?}", e))
462                    .collect::<Vec<_>>()
463                    .join(", ");
464                let filter_str = residual_filter
465                    .as_ref()
466                    .map(|f| format!(" (residual: {:?})", f))
467                    .unwrap_or_default();
468                format!(
469                    "{}PkSeek: {}{} [{}]{} (cost={:.2}, rows={:.0})",
470                    prefix, table.name, alias_str, keys_str, filter_str, cost, rows
471                )
472            }
473
474            LogicalPlan::PkRangeScan {
475                table,
476                alias,
477                range,
478                residual_filter,
479                ..
480            } => {
481                let alias_str = alias
482                    .as_ref()
483                    .map(|a| format!(" AS {}", a))
484                    .unwrap_or_default();
485                let range_str = format!("{:?}..{:?}", range.start, range.end);
486                let filter_str = residual_filter
487                    .as_ref()
488                    .map(|f| format!(" (residual: {:?})", f))
489                    .unwrap_or_default();
490                format!(
491                    "{}PkRangeScan: {}{} [{}]{} (cost={:.2}, rows={:.0})",
492                    prefix, table.name, alias_str, range_str, filter_str, cost, rows
493                )
494            }
495
496            LogicalPlan::Filter { input, predicate } => {
497                let child = self.format_plan_with_costs(input, indent + 1);
498                format!(
499                    "{}Filter: {:?} (cost={:.2}, rows={:.0})\n{}",
500                    prefix, predicate, cost, rows, child
501                )
502            }
503
504            LogicalPlan::Project { input, exprs } => {
505                let child = self.format_plan_with_costs(input, indent + 1);
506                let cols: Vec<_> = exprs.iter().map(|(_, name)| name.as_str()).collect();
507                format!(
508                    "{}Project: [{}] (cost={:.2}, rows={:.0})\n{}",
509                    prefix,
510                    cols.join(", "),
511                    cost,
512                    rows,
513                    child
514                )
515            }
516
517            LogicalPlan::Join {
518                left,
519                right,
520                join_type,
521                condition,
522            } => {
523                let left_child = self.format_plan_with_costs(left, indent + 1);
524                let right_child = self.format_plan_with_costs(right, indent + 1);
525                let join_str = match join_type {
526                    JoinType::Inner => "Inner",
527                    JoinType::Left => "Left",
528                    JoinType::Right => "Right",
529                    JoinType::Full => "Full",
530                };
531                let cond_str = condition
532                    .as_ref()
533                    .map(|c| format!(" ON {:?}", c))
534                    .unwrap_or_default();
535                format!(
536                    "{}{} Join{} (cost={:.2}, rows={:.0})\n{}\n{}",
537                    prefix, join_str, cond_str, cost, rows, left_child, right_child
538                )
539            }
540
541            LogicalPlan::Aggregate {
542                input,
543                group_by,
544                aggregates,
545            } => {
546                let child = self.format_plan_with_costs(input, indent + 1);
547                let aggs: Vec<_> = aggregates.iter().map(|(_, name)| name.as_str()).collect();
548                format!(
549                    "{}Aggregate: group_by={}, aggs=[{}] (cost={:.2}, rows={:.0})\n{}",
550                    prefix,
551                    group_by.len(),
552                    aggs.join(", "),
553                    cost,
554                    rows,
555                    child
556                )
557            }
558
559            LogicalPlan::Sort { input, order_by } => {
560                let child = self.format_plan_with_costs(input, indent + 1);
561                format!(
562                    "{}Sort: {} columns (cost={:.2}, rows={:.0})\n{}",
563                    prefix,
564                    order_by.len(),
565                    cost,
566                    rows,
567                    child
568                )
569            }
570
571            LogicalPlan::Limit {
572                input,
573                limit,
574                offset,
575            } => {
576                let child = self.format_plan_with_costs(input, indent + 1);
577                format!(
578                    "{}Limit: {:?} offset {} (cost={:.2}, rows={:.0})\n{}",
579                    prefix, limit, offset, cost, rows, child
580                )
581            }
582
583            LogicalPlan::Distinct { input } => {
584                let child = self.format_plan_with_costs(input, indent + 1);
585                format!(
586                    "{}Distinct (cost={:.2}, rows={:.0})\n{}",
587                    prefix, cost, rows, child
588                )
589            }
590
591            _ => format!("{}Plan node (cost={:.2}, rows={:.0})", prefix, cost, rows),
592        }
593    }
594}