Skip to main content

hematite/query/
planner.rs

1//! Query planning and optimization
2
3use crate::catalog::{Schema, Table, Value};
4use crate::error::Result;
5use crate::parser::ast::*;
6use crate::query::optimizer::QueryOptimizer;
7pub use crate::query::plan::*;
8use crate::query::predicate::extract_literal_equalities;
9use crate::query::validation::validate_statement;
10use crate::HematiteError;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
14pub struct QueryPlanner {
15    catalog: Schema,
16    table_row_counts: HashMap<String, usize>,
17}
18
19impl QueryPlanner {
20    pub fn new(catalog: Schema) -> Self {
21        Self {
22            catalog,
23            table_row_counts: HashMap::new(),
24        }
25    }
26
27    pub fn with_table_row_counts(mut self, table_row_counts: HashMap<String, usize>) -> Self {
28        self.table_row_counts = table_row_counts;
29        self
30    }
31
32    pub fn plan(&self, statement: Statement) -> Result<QueryPlan> {
33        // Validate statement against catalog
34        validate_statement(&statement, &self.catalog)?;
35
36        let plan = match statement {
37            Statement::Begin
38            | Statement::Commit
39            | Statement::Rollback
40            | Statement::Savepoint(_)
41            | Statement::RollbackToSavepoint(_)
42            | Statement::ReleaseSavepoint(_) => {
43                return Err(HematiteError::ParseError(
44                    "Transaction control statements are handled at the SQL connection boundary"
45                        .to_string(),
46                ))
47            }
48            Statement::Explain(_)
49            | Statement::Describe(_)
50            | Statement::ShowTables
51            | Statement::ShowViews
52            | Statement::ShowIndexes(_)
53            | Statement::ShowTriggers(_)
54            | Statement::ShowCreateTable(_)
55            | Statement::ShowCreateView(_) => {
56                return Err(HematiteError::ParseError(
57                    "Introspection statements are handled at the SQL connection boundary"
58                        .to_string(),
59                ))
60            }
61            Statement::Select(select) => self.plan_select(select),
62            Statement::SelectInto(_) => {
63                return Err(HematiteError::ParseError(
64                    "SELECT INTO is handled at the SQL connection boundary".to_string(),
65                ))
66            }
67            Statement::Update(update) => self.plan_update(update),
68            Statement::Insert(insert) => self.plan_insert(insert),
69            Statement::Delete(delete) => self.plan_delete(delete),
70            Statement::Create(create) => self.plan_create(create),
71            Statement::CreateView(_) | Statement::CreateTrigger(_) => {
72                return Err(HematiteError::ParseError(
73                    "View and trigger statements are not planned yet".to_string(),
74                ))
75            }
76            Statement::CreateIndex(create_index) => self.plan_create_index(create_index),
77            Statement::Alter(alter) => self.plan_alter(alter),
78            Statement::Drop(drop) => self.plan_drop(drop),
79            Statement::DropView(_) | Statement::DropTrigger(_) => {
80                return Err(HematiteError::ParseError(
81                    "View and trigger statements are not planned yet".to_string(),
82                ))
83            }
84            Statement::DropIndex(drop_index) => self.plan_drop_index(drop_index),
85        }?;
86
87        let optimizer = QueryOptimizer::new(self.catalog.clone());
88        optimizer.optimize(plan)
89    }
90
91    fn plan_select(&self, statement: SelectStatement) -> Result<QueryPlan> {
92        // Analyze the query to determine optimal execution strategy
93        let analysis = self.analyze_select(&statement)?;
94        let node = self.build_select_plan_node(&statement, &analysis);
95        let access_path = node.access_path.clone();
96
97        // Estimate cost (simplified cost model)
98        let estimated_cost = self.estimate_select_cost(&analysis);
99
100        Ok(QueryPlan {
101            node: PlanNode::Select(node),
102            program: ExecutionProgram::Select {
103                statement,
104                access_path,
105            },
106            estimated_cost,
107            select_analysis: Some(analysis),
108            optimizations: None,
109        })
110    }
111
112    fn plan_insert(&self, statement: InsertStatement) -> Result<QueryPlan> {
113        // For INSERT, the planning is straightforward
114        let row_count = match &statement.source {
115            InsertSource::Values(rows) => rows.len(),
116            InsertSource::Select(_) => 1,
117        };
118        let estimated_cost = row_count as f64;
119        let node = PlanNode::Insert(InsertPlanNode {
120            table_name: statement.table.clone(),
121            row_count,
122        });
123
124        Ok(QueryPlan {
125            node,
126            program: ExecutionProgram::Insert { statement },
127            estimated_cost,
128            select_analysis: None,
129            optimizations: None,
130        })
131    }
132
133    fn plan_create(&self, statement: CreateStatement) -> Result<QueryPlan> {
134        let node = PlanNode::Create(CreatePlanNode {
135            table_name: statement.table.clone(),
136            column_count: statement.columns.len(),
137        });
138        Ok(self.simple_plan(node, ExecutionProgram::Create { statement }))
139    }
140
141    fn plan_update(&self, statement: UpdateStatement) -> Result<QueryPlan> {
142        let access_path = if matches!(statement.source.as_ref(), Some(source) if !matches!(source, TableReference::Table(_, _)))
143        {
144            SelectAccessPath::JoinScan
145        } else {
146            let analysis = self.analyze_table_access(&statement.table, &statement.where_clause)?;
147            self.choose_access_path(&analysis)
148        };
149        let assignment_count = statement.assignments.len();
150        let node = PlanNode::Update(UpdatePlanNode {
151            table_name: statement.table.clone(),
152            assignment_count,
153            has_filter: statement.where_clause.is_some(),
154            access_path: access_path.clone(),
155        });
156        let estimated_cost = if matches!(access_path, SelectAccessPath::JoinScan) {
157            self.catalog
158                .get_table_by_name(&statement.table)
159                .map(|table| self.estimate_table_rows(table) as f64 + assignment_count as f64)
160                .unwrap_or(assignment_count as f64)
161        } else {
162            let analysis = self.analyze_table_access(&statement.table, &statement.where_clause)?;
163            self.estimate_update_cost(&analysis, &access_path, assignment_count)
164        };
165
166        Ok(QueryPlan {
167            node,
168            program: ExecutionProgram::Update {
169                statement,
170                access_path,
171            },
172            estimated_cost,
173            select_analysis: None,
174            optimizations: None,
175        })
176    }
177
178    fn plan_delete(&self, statement: DeleteStatement) -> Result<QueryPlan> {
179        let access_path = if matches!(statement.source.as_ref(), Some(source) if !matches!(source, TableReference::Table(_, _)))
180        {
181            SelectAccessPath::JoinScan
182        } else {
183            let analysis = self.analyze_table_access(&statement.table, &statement.where_clause)?;
184            self.choose_access_path(&analysis)
185        };
186        let node = PlanNode::Delete(DeletePlanNode {
187            table_name: statement.table.clone(),
188            has_filter: statement.where_clause.is_some(),
189            access_path: access_path.clone(),
190        });
191        let estimated_cost = if matches!(access_path, SelectAccessPath::JoinScan) {
192            self.catalog
193                .get_table_by_name(&statement.table)
194                .map(|table| self.estimate_table_rows(table) as f64)
195                .unwrap_or(1.0)
196        } else {
197            let analysis = self.analyze_table_access(&statement.table, &statement.where_clause)?;
198            self.estimate_delete_cost(&analysis, &access_path)
199        };
200
201        Ok(QueryPlan {
202            node,
203            program: ExecutionProgram::Delete {
204                statement,
205                access_path,
206            },
207            estimated_cost,
208            select_analysis: None,
209            optimizations: None,
210        })
211    }
212
213    fn plan_drop(&self, statement: DropStatement) -> Result<QueryPlan> {
214        let node = PlanNode::Drop(DropPlanNode {
215            table_name: statement.table.clone(),
216        });
217        Ok(self.simple_plan(node, ExecutionProgram::Drop { statement }))
218    }
219
220    fn plan_alter(&self, statement: AlterStatement) -> Result<QueryPlan> {
221        let node = PlanNode::Alter(AlterPlanNode {
222            table_name: statement.table.clone(),
223        });
224        Ok(self.simple_plan(node, ExecutionProgram::Alter { statement }))
225    }
226
227    fn plan_create_index(&self, statement: CreateIndexStatement) -> Result<QueryPlan> {
228        let node = PlanNode::CreateIndex(CreateIndexPlanNode {
229            table_name: statement.table.clone(),
230            index_name: statement.index_name.clone(),
231            column_count: statement.columns.len(),
232        });
233        Ok(self.simple_plan(node, ExecutionProgram::CreateIndex { statement }))
234    }
235
236    fn plan_drop_index(&self, statement: DropIndexStatement) -> Result<QueryPlan> {
237        let node = PlanNode::DropIndex(DropIndexPlanNode {
238            table_name: statement.table.clone(),
239            index_name: statement.index_name.clone(),
240        });
241        Ok(self.simple_plan(node, ExecutionProgram::DropIndex { statement }))
242    }
243
244    fn simple_plan(&self, node: PlanNode, program: ExecutionProgram) -> QueryPlan {
245        QueryPlan {
246            node,
247            program,
248            estimated_cost: 1.0,
249            select_analysis: None,
250            optimizations: None,
251        }
252    }
253
254    fn build_select_plan_node(
255        &self,
256        statement: &SelectStatement,
257        analysis: &SelectAnalysis,
258    ) -> SelectPlanNode {
259        let access_path = self.choose_access_path(analysis);
260
261        let projection = if statement
262            .columns
263            .iter()
264            .any(|item| matches!(item, SelectItem::Wildcard))
265        {
266            SelectProjection::Wildcard
267        } else if let Some(item) = statement.columns.first() {
268            match item {
269                SelectItem::CountAll => SelectProjection::CountAll,
270                SelectItem::Aggregate { function, column } => SelectProjection::Aggregate {
271                    function: *function,
272                    column: column.clone(),
273                },
274                SelectItem::Expression(_) => SelectProjection::Expressions(statement.columns.len()),
275                _ => SelectProjection::Columns(
276                    statement
277                        .columns
278                        .iter()
279                        .filter_map(|item| match item {
280                            SelectItem::Column(name) => {
281                                Some(SelectStatement::column_reference_name(name).to_string())
282                            }
283                            _ => None,
284                        })
285                        .collect(),
286                ),
287            }
288        } else {
289            SelectProjection::Columns(Vec::new())
290        };
291
292        SelectPlanNode {
293            table_name: analysis.table_name.clone(),
294            source_count: analysis.source_count,
295            access_path,
296            projection,
297            distinct: statement.distinct,
298            has_filter: statement.where_clause.is_some(),
299            order_by_columns: statement
300                .order_by
301                .iter()
302                .map(|item| item.column.clone())
303                .collect(),
304            limit: statement.limit,
305            offset: statement.offset,
306        }
307    }
308
309    fn extract_rowid_lookup(&self, statement: &SelectStatement) -> Option<u64> {
310        let equalities = extract_literal_equalities(statement.where_clause.as_ref()?)?;
311        match equalities.get("rowid") {
312            Some(Value::Integer(v)) if *v >= 0 => Some(*v as u64),
313            _ => None,
314        }
315    }
316
317    fn analyze_select(&self, statement: &SelectStatement) -> Result<SelectAnalysis> {
318        let bindings = SelectStatement::collect_table_bindings(&statement.from);
319        let primary = bindings.first().ok_or_else(|| {
320            HematiteError::ParseError("SELECT requires at least one table source".to_string())
321        })?;
322
323        if bindings.len() == 1 && !statement.has_non_table_source() {
324            return self.analyze_table_access(&primary.table_name, &statement.where_clause);
325        }
326
327        let estimated_rows = if bindings.len() > 1 || statement.has_non_table_source() {
328            self.estimate_complex_source_rows(statement, &statement.from)
329        } else {
330            bindings
331                .iter()
332                .try_fold(1usize, |product, binding| -> Result<usize> {
333                    let table = self
334                        .catalog
335                        .get_table_by_name(&binding.table_name)
336                        .ok_or_else(|| {
337                            HematiteError::ParseError(format!(
338                                "Table '{}' not found",
339                                binding.table_name
340                            ))
341                        })?;
342                    Ok(product.saturating_mul(self.estimate_table_rows(table).max(1)))
343                })?
344        };
345
346        Ok(SelectAnalysis {
347            table_name: primary.table_name.clone(),
348            source_count: bindings.len(),
349            has_complex_source: statement.has_non_table_source(),
350            table_id: self
351                .catalog
352                .get_table_by_name(&primary.table_name)
353                .map(|table| table.id)
354                .unwrap_or_else(|| crate::catalog::TableId::new(0)),
355            rowid_lookup: None,
356            estimated_rows,
357            usable_indexes: Vec::new(),
358            accessed_columns: Vec::new(),
359        })
360    }
361
362    fn analyze_table_access(
363        &self,
364        table_name: &str,
365        where_clause: &Option<WhereClause>,
366    ) -> Result<SelectAnalysis> {
367        let table_name = table_name.to_string();
368        let table = self.catalog.get_table_by_name(&table_name).ok_or_else(|| {
369            HematiteError::ParseError(format!("Table '{}' not found", table_name))
370        })?;
371
372        let synthetic_select = synthetic_table_select(&table_name, where_clause.clone());
373        let rowid_lookup = self.extract_rowid_lookup(&synthetic_select);
374
375        // Analyze WHERE clause for index usage opportunities
376        let usable_indexes = self.analyze_where_clause(where_clause, table)?;
377
378        // Analyze column access patterns
379        let accessed_columns = self.analyze_column_access(&synthetic_select.columns, table)?;
380
381        Ok(SelectAnalysis {
382            table_name,
383            source_count: 1,
384            has_complex_source: false,
385            table_id: table.id,
386            rowid_lookup,
387            estimated_rows: self.estimate_table_rows(table),
388            usable_indexes,
389            accessed_columns,
390        })
391    }
392
393    fn choose_access_path(&self, analysis: &SelectAnalysis) -> SelectAccessPath {
394        if analysis.has_complex_source || analysis.source_count > 1 {
395            return SelectAccessPath::JoinScan;
396        }
397
398        self.access_path_candidates(analysis)
399            .into_iter()
400            .min_by(|left, right| {
401                self.estimate_total_access_cost(analysis, left)
402                    .partial_cmp(&self.estimate_total_access_cost(analysis, right))
403                    .unwrap_or(std::cmp::Ordering::Equal)
404            })
405            .unwrap_or(SelectAccessPath::FullTableScan)
406    }
407
408    fn access_path_candidates(&self, analysis: &SelectAnalysis) -> Vec<SelectAccessPath> {
409        let mut candidates = vec![SelectAccessPath::FullTableScan];
410
411        if analysis.rowid_lookup.is_some() {
412            candidates.push(SelectAccessPath::RowIdLookup);
413        }
414
415        if analysis
416            .usable_indexes
417            .iter()
418            .any(|usage| matches!(usage.index_type, IndexType::PrimaryKey))
419        {
420            candidates.push(SelectAccessPath::PrimaryKeyLookup);
421        }
422
423        candidates.extend(
424            analysis
425                .usable_indexes
426                .iter()
427                .filter(|usage| matches!(usage.index_type, IndexType::Secondary))
428                .map(|usage| {
429                    SelectAccessPath::SecondaryIndexLookup(
430                        usage
431                            .index_name
432                            .clone()
433                            .unwrap_or_else(|| "unnamed_secondary_index".to_string()),
434                    )
435                }),
436        );
437
438        candidates
439    }
440
441    fn analyze_where_clause(
442        &self,
443        where_clause: &Option<WhereClause>,
444        table: &Table,
445    ) -> Result<Vec<IndexUsage>> {
446        let mut usable_indexes = Vec::new();
447        let Some(where_clause) = where_clause.as_ref() else {
448            return Ok(usable_indexes);
449        };
450        let Some(equalities) = extract_literal_equalities(where_clause) else {
451            return Ok(usable_indexes);
452        };
453
454        if table
455            .primary_key_columns
456            .iter()
457            .all(|&index| equalities.contains_key(table.columns[index].name.as_str()))
458        {
459            let first_pk = table
460                .primary_key_columns
461                .first()
462                .and_then(|&index| table.columns.get(index))
463                .ok_or_else(|| {
464                    HematiteError::InternalError(format!(
465                        "Table '{}' lost its primary key metadata during planning",
466                        table.name
467                    ))
468                })?;
469            usable_indexes.push(IndexUsage {
470                column_id: first_pk.id,
471                index_type: IndexType::PrimaryKey,
472                index_name: None,
473                selectivity: (1.0 / self.estimate_table_rows(table).max(1) as f64).max(0.0001),
474            });
475        }
476
477        for index in &table.secondary_indexes {
478            if index.column_indices.iter().all(|&column_index| {
479                equalities.contains_key(table.columns[column_index].name.as_str())
480            }) {
481                let column = table.columns.get(index.column_indices[0]).ok_or_else(|| {
482                    HematiteError::InternalError(format!(
483                        "Index '{}' references an invalid column on table '{}'",
484                        index.name, table.name
485                    ))
486                })?;
487                usable_indexes.push(IndexUsage {
488                    column_id: column.id,
489                    index_type: IndexType::Secondary,
490                    index_name: Some(index.name.clone()),
491                    selectivity: if index.unique {
492                        (1.0 / self.estimate_table_rows(table).max(1) as f64).max(0.0001)
493                    } else if index.column_indices.len() > 1 {
494                        0.02
495                    } else {
496                        0.1
497                    },
498                });
499            }
500        }
501
502        Ok(usable_indexes)
503    }
504
505    fn analyze_column_access(
506        &self,
507        select_items: &[SelectItem],
508        table: &Table,
509    ) -> Result<Vec<ColumnAccess>> {
510        let mut accessed_columns = Vec::new();
511
512        for item in select_items {
513            match item {
514                SelectItem::Wildcard => {
515                    // All columns are accessed
516                    for column in &table.columns {
517                        accessed_columns.push(ColumnAccess {
518                            column_id: column.id,
519                            access_type: ColumnAccessType::Read,
520                        });
521                    }
522                }
523                SelectItem::Column(name) => {
524                    if let Some(column) =
525                        table.get_column_by_name(SelectStatement::column_reference_name(name))
526                    {
527                        accessed_columns.push(ColumnAccess {
528                            column_id: column.id,
529                            access_type: ColumnAccessType::Read,
530                        });
531                    }
532                }
533                SelectItem::Expression(expr) => {
534                    self.collect_expression_columns(expr, table, &mut accessed_columns);
535                }
536                SelectItem::CountAll => {}
537                SelectItem::Aggregate { .. } => {}
538                SelectItem::Window { window, .. } => {
539                    for expr in &window.partition_by {
540                        self.collect_expression_columns(expr, table, &mut accessed_columns);
541                    }
542                    for item in &window.order_by {
543                        if let Some(column) = table.get_column_by_name(
544                            SelectStatement::column_reference_name(&item.column),
545                        ) {
546                            accessed_columns.push(ColumnAccess {
547                                column_id: column.id,
548                                access_type: ColumnAccessType::Read,
549                            });
550                        }
551                    }
552                }
553            }
554        }
555
556        Ok(accessed_columns)
557    }
558
559    fn collect_expression_columns(
560        &self,
561        expr: &Expression,
562        table: &Table,
563        accessed_columns: &mut Vec<ColumnAccess>,
564    ) {
565        match expr {
566            Expression::Column(name) => {
567                if let Some(column) =
568                    table.get_column_by_name(SelectStatement::column_reference_name(name))
569                {
570                    accessed_columns.push(ColumnAccess {
571                        column_id: column.id,
572                        access_type: ColumnAccessType::Read,
573                    });
574                }
575            }
576            Expression::AggregateCall { target, .. } => {
577                if let AggregateTarget::Column(name) = target {
578                    if let Some(column) =
579                        table.get_column_by_name(SelectStatement::column_reference_name(name))
580                    {
581                        accessed_columns.push(ColumnAccess {
582                            column_id: column.id,
583                            access_type: ColumnAccessType::Read,
584                        });
585                    }
586                }
587            }
588            Expression::Case {
589                branches,
590                else_expr,
591            } => {
592                for branch in branches {
593                    self.collect_expression_columns(&branch.condition, table, accessed_columns);
594                    self.collect_expression_columns(&branch.result, table, accessed_columns);
595                }
596                if let Some(else_expr) = else_expr {
597                    self.collect_expression_columns(else_expr, table, accessed_columns);
598                }
599            }
600            Expression::ScalarFunctionCall { args, .. } => {
601                for arg in args {
602                    self.collect_expression_columns(arg, table, accessed_columns);
603                }
604            }
605            Expression::Cast { expr, .. } => {
606                self.collect_expression_columns(expr, table, accessed_columns);
607            }
608            Expression::ScalarSubquery(_) => {}
609            Expression::UnaryMinus(expr) => {
610                self.collect_expression_columns(expr, table, accessed_columns);
611            }
612            Expression::UnaryNot(expr) => {
613                self.collect_expression_columns(expr, table, accessed_columns);
614            }
615            Expression::Binary { left, right, .. } => {
616                self.collect_expression_columns(left, table, accessed_columns);
617                self.collect_expression_columns(right, table, accessed_columns);
618            }
619            Expression::Comparison { left, right, .. } => {
620                self.collect_expression_columns(left, table, accessed_columns);
621                self.collect_expression_columns(right, table, accessed_columns);
622            }
623            Expression::InList { expr, values, .. } => {
624                self.collect_expression_columns(expr, table, accessed_columns);
625                for value in values {
626                    self.collect_expression_columns(value, table, accessed_columns);
627                }
628            }
629            Expression::InSubquery { expr, .. } => {
630                self.collect_expression_columns(expr, table, accessed_columns);
631            }
632            Expression::Between {
633                expr, lower, upper, ..
634            } => {
635                self.collect_expression_columns(expr, table, accessed_columns);
636                self.collect_expression_columns(lower, table, accessed_columns);
637                self.collect_expression_columns(upper, table, accessed_columns);
638            }
639            Expression::Like { expr, pattern, .. } => {
640                self.collect_expression_columns(expr, table, accessed_columns);
641                self.collect_expression_columns(pattern, table, accessed_columns);
642            }
643            Expression::Exists { .. } => {}
644            Expression::NullCheck { expr, .. } => {
645                self.collect_expression_columns(expr, table, accessed_columns);
646            }
647            Expression::Logical { left, right, .. } => {
648                self.collect_expression_columns(left, table, accessed_columns);
649                self.collect_expression_columns(right, table, accessed_columns);
650            }
651            Expression::Literal(_)
652            | Expression::IntervalLiteral { .. }
653            | Expression::Parameter(_) => {}
654        }
655    }
656
657    fn estimate_table_rows(&self, table: &Table) -> usize {
658        self.table_row_counts
659            .get(&table.name)
660            .copied()
661            .unwrap_or(1000)
662    }
663
664    fn estimate_complex_source_rows(
665        &self,
666        statement: &SelectStatement,
667        from: &TableReference,
668    ) -> usize {
669        match from {
670            TableReference::Table(table_name, _) => {
671                if statement.references_cte(table_name) {
672                    1000
673                } else {
674                    self.catalog
675                        .get_table_by_name(table_name)
676                        .map(|table| self.estimate_table_rows(table))
677                        .unwrap_or(1000)
678                }
679            }
680            TableReference::Derived { .. } => 1000,
681            TableReference::CrossJoin(left, right) => self
682                .estimate_complex_source_rows(statement, left)
683                .saturating_mul(self.estimate_complex_source_rows(statement, right).max(1)),
684            TableReference::InnerJoin { left, right, on } => {
685                self.estimate_join_rows(statement, left, right, Some(on), false)
686            }
687            TableReference::LeftJoin { left, right, on } => {
688                self.estimate_join_rows(statement, left, right, Some(on), true)
689            }
690            TableReference::RightJoin { left, right, on } => {
691                self.estimate_join_rows(statement, right, left, Some(on), true)
692            }
693            TableReference::FullOuterJoin { left, right, on } => {
694                let join_rows = self.estimate_join_rows(statement, left, right, Some(on), true);
695                let right_rows = self.estimate_complex_source_rows(statement, right).max(1);
696                join_rows.max(right_rows)
697            }
698        }
699    }
700
701    fn estimate_join_rows(
702        &self,
703        statement: &SelectStatement,
704        left: &TableReference,
705        right: &TableReference,
706        on: Option<&Condition>,
707        preserve_left_rows: bool,
708    ) -> usize {
709        let left_rows = self.estimate_complex_source_rows(statement, left).max(1);
710        let right_rows = self.estimate_complex_source_rows(statement, right).max(1);
711
712        let join_rows = if on.is_some_and(is_equality_join_condition) {
713            left_rows.max(right_rows)
714        } else {
715            left_rows.saturating_mul(right_rows)
716        };
717
718        if preserve_left_rows {
719            join_rows.max(left_rows)
720        } else {
721            join_rows
722        }
723    }
724
725    fn estimate_select_cost(&self, analysis: &SelectAnalysis) -> f64 {
726        let access_path = self.choose_access_path(analysis);
727        let mut cost = self.estimate_total_access_cost(analysis, &access_path);
728        cost += analysis.accessed_columns.len() as f64 * 0.1;
729        cost.max(1.0)
730    }
731
732    fn estimate_update_cost(
733        &self,
734        analysis: &SelectAnalysis,
735        access_path: &SelectAccessPath,
736        assignment_count: usize,
737    ) -> f64 {
738        let rows_touched = self.estimate_rows_touched(analysis, access_path);
739        (self.estimate_locator_cost(analysis, access_path)
740            + rows_touched * 3.0
741            + assignment_count as f64 * 0.2)
742            .max(1.0)
743    }
744
745    fn estimate_delete_cost(
746        &self,
747        analysis: &SelectAnalysis,
748        access_path: &SelectAccessPath,
749    ) -> f64 {
750        let rows_touched = self.estimate_rows_touched(analysis, access_path);
751        (self.estimate_locator_cost(analysis, access_path) + rows_touched * 2.0).max(1.0)
752    }
753
754    fn estimate_rows_touched(
755        &self,
756        analysis: &SelectAnalysis,
757        access_path: &SelectAccessPath,
758    ) -> f64 {
759        match access_path {
760            SelectAccessPath::JoinScan => analysis.estimated_rows as f64,
761            SelectAccessPath::RowIdLookup | SelectAccessPath::PrimaryKeyLookup => 1.0,
762            SelectAccessPath::SecondaryIndexLookup(index_name) => self
763                .secondary_index_selectivity(analysis, index_name)
764                .map(|selectivity| (analysis.estimated_rows as f64 * selectivity).max(1.0))
765                .unwrap_or((analysis.estimated_rows as f64 * 0.1).max(1.0)),
766            SelectAccessPath::FullTableScan => analysis.estimated_rows as f64,
767        }
768    }
769
770    fn estimate_locator_cost(
771        &self,
772        analysis: &SelectAnalysis,
773        access_path: &SelectAccessPath,
774    ) -> f64 {
775        match access_path {
776            SelectAccessPath::JoinScan => analysis.estimated_rows as f64 * 1.5,
777            SelectAccessPath::RowIdLookup => 1.0,
778            SelectAccessPath::PrimaryKeyLookup => 2.0,
779            SelectAccessPath::SecondaryIndexLookup(index_name) => {
780                2.5 + self.estimate_rows_touched(analysis, access_path)
781                    + self
782                        .secondary_index_selectivity(analysis, index_name)
783                        .map(|selectivity| selectivity * 5.0)
784                        .unwrap_or(0.5)
785            }
786            SelectAccessPath::FullTableScan => analysis.estimated_rows as f64,
787        }
788    }
789
790    fn secondary_index_selectivity(
791        &self,
792        analysis: &SelectAnalysis,
793        index_name: &str,
794    ) -> Option<f64> {
795        analysis
796            .usable_indexes
797            .iter()
798            .find(|usage| {
799                matches!(usage.index_type, IndexType::Secondary)
800                    && usage.index_name.as_deref() == Some(index_name)
801            })
802            .map(|usage| usage.selectivity)
803    }
804
805    fn estimate_total_access_cost(
806        &self,
807        analysis: &SelectAnalysis,
808        access_path: &SelectAccessPath,
809    ) -> f64 {
810        self.estimate_locator_cost(analysis, access_path)
811            + self.estimate_rows_touched(analysis, access_path) * 0.5
812    }
813}
814
815fn is_equality_join_condition(condition: &Condition) -> bool {
816    match condition {
817        Condition::Comparison {
818            left: Expression::Column(_),
819            operator: ComparisonOperator::Equal,
820            right: Expression::Column(_),
821        } => true,
822        Condition::Logical {
823            left,
824            operator: LogicalOperator::And,
825            right,
826        } => is_equality_join_condition(left) && is_equality_join_condition(right),
827        _ => false,
828    }
829}
830
831fn synthetic_table_select(table_name: &str, where_clause: Option<WhereClause>) -> SelectStatement {
832    SelectStatement {
833        with_clause: Vec::new(),
834        distinct: false,
835        columns: vec![SelectItem::Wildcard],
836        column_aliases: vec![None],
837        from: TableReference::Table(table_name.to_string(), None),
838        where_clause,
839        group_by: Vec::new(),
840        having_clause: None,
841        order_by: Vec::new(),
842        limit: None,
843        offset: None,
844        set_operation: None,
845    }
846}