use std::sync::Arc;
use turso_ext::{ConstraintInfo, ConstraintUsage, ResultCode};
use turso_parser::ast::{self, SortOrder, TableInternalId};
use crate::translate::expr::{as_binary_components, walk_expr, WalkControl};
use crate::translate::optimizer::constraints::{
convert_to_vtab_constraint, BinaryExprSide, Constraint, RangeConstraintRef,
};
use crate::translate::optimizer::cost::{RowCountEstimate, ESTIMATED_HARDCODED_ROWS_PER_PAGE};
use crate::translate::plan::{HashJoinKey, NonFromClauseSubquery, SubqueryState, WhereTerm};
use crate::util::exprs_are_equivalent;
use crate::vdbe::hash_table::DEFAULT_MEM_BUDGET;
use crate::{
schema::{Index, Table},
translate::plan::{IterationDirection, JoinOrderMember, JoinedTable},
vtab::VirtualTable,
LimboError, Result,
};
use super::{
constraints::{usable_constraints_for_join_order, TableConstraints},
cost::{estimate_cost_for_scan_or_seek, Cost, IndexInfo},
order::OrderTarget,
};
use crate::translate::optimizer::order::ColumnTarget;
#[derive(Debug, Clone)]
pub struct AccessMethod {
pub cost: Cost,
pub params: AccessMethodParams,
}
#[derive(Debug, Clone)]
pub enum AccessMethodParams {
BTreeTable {
iter_dir: IterationDirection,
index: Option<Arc<Index>>,
constraint_refs: Vec<RangeConstraintRef>,
},
VirtualTable {
idx_num: i32,
idx_str: Option<String>,
constraints: Vec<ConstraintInfo>,
constraint_usages: Vec<ConstraintUsage>,
},
Subquery,
HashJoin {
build_table_idx: usize,
probe_table_idx: usize,
join_keys: Vec<HashJoinKey>,
mem_budget: usize,
},
}
pub fn find_best_access_method_for_join_order(
rhs_table: &JoinedTable,
rhs_constraints: &TableConstraints,
join_order: &[JoinOrderMember],
maybe_order_target: Option<&OrderTarget>,
input_cardinality: f64,
base_row_count: RowCountEstimate,
) -> Result<Option<AccessMethod>> {
match &rhs_table.table {
Table::BTree(_) => find_best_access_method_for_btree(
rhs_table,
rhs_constraints,
join_order,
maybe_order_target,
input_cardinality,
base_row_count,
),
Table::Virtual(vtab) => find_best_access_method_for_vtab(
vtab,
&rhs_constraints.constraints,
join_order,
input_cardinality,
base_row_count,
),
Table::FromClauseSubquery(_) => Ok(Some(AccessMethod {
cost: estimate_cost_for_scan_or_seek(None, &[], &[], input_cardinality, base_row_count),
params: AccessMethodParams::Subquery,
})),
}
}
fn find_best_access_method_for_btree(
rhs_table: &JoinedTable,
rhs_constraints: &TableConstraints,
join_order: &[JoinOrderMember],
maybe_order_target: Option<&OrderTarget>,
input_cardinality: f64,
base_row_count: RowCountEstimate,
) -> Result<Option<AccessMethod>> {
let table_no = join_order.last().unwrap().table_id;
let mut best_cost =
estimate_cost_for_scan_or_seek(None, &[], &[], input_cardinality, base_row_count);
let mut best_params = AccessMethodParams::BTreeTable {
iter_dir: IterationDirection::Forwards,
index: None,
constraint_refs: vec![],
};
let rowid_column_idx = rhs_table.columns().iter().position(|c| c.is_rowid_alias());
for candidate in rhs_constraints.candidates.iter() {
let usable_constraint_refs = usable_constraints_for_join_order(
&rhs_constraints.constraints,
&candidate.refs,
join_order,
);
let index_info = match candidate.index.as_ref() {
Some(index) => IndexInfo {
unique: index.unique,
covering: rhs_table.index_is_covering(index),
column_count: index.columns.len(),
},
None => IndexInfo {
unique: true, covering: !usable_constraint_refs.is_empty(),
column_count: 1,
},
};
let cost = estimate_cost_for_scan_or_seek(
Some(index_info),
&rhs_constraints.constraints,
&usable_constraint_refs,
input_cardinality,
base_row_count,
);
let (iter_dir, order_satisfiability_bonus) = if let Some(order_target) = maybe_order_target
{
let mut all_same_direction = true;
let mut all_opposite_direction = true;
for i in 0..order_target.0.len().min(index_info.column_count) {
let target = &order_target.0[i];
let correct_table = target.table_id == table_no;
let correct_column = match (&target.target, &candidate.index) {
(ColumnTarget::Column(col_no), Some(index)) => {
index.columns[i].expr.is_none() && index.columns[i].pos_in_table == *col_no
}
(ColumnTarget::Expr(expr), Some(index)) => index.columns[i]
.expr
.as_ref()
.is_some_and(|e| exprs_are_equivalent(e, unsafe { &**expr })),
(ColumnTarget::Column(col_no), None) => {
rowid_column_idx.is_some_and(|idx| idx == *col_no)
}
_ => false,
};
if !correct_table || !correct_column {
all_same_direction = false;
all_opposite_direction = false;
break;
}
let correct_order = {
match &candidate.index {
Some(index) => target.order == index.columns[i].order,
None => target.order == SortOrder::Asc,
}
};
if correct_order {
all_opposite_direction = false;
} else {
all_same_direction = false;
}
}
if all_same_direction || all_opposite_direction {
(
if all_same_direction {
IterationDirection::Forwards
} else {
IterationDirection::Backwards
},
Cost(1.0),
)
} else {
(IterationDirection::Forwards, Cost(0.0))
}
} else {
(IterationDirection::Forwards, Cost(0.0))
};
if cost < best_cost + order_satisfiability_bonus {
best_cost = cost;
best_params = AccessMethodParams::BTreeTable {
iter_dir,
index: candidate.index.clone(),
constraint_refs: usable_constraint_refs,
};
}
}
Ok(Some(AccessMethod {
cost: best_cost,
params: best_params,
}))
}
fn find_best_access_method_for_vtab(
vtab: &VirtualTable,
constraints: &[Constraint],
join_order: &[JoinOrderMember],
input_cardinality: f64,
base_row_count: RowCountEstimate,
) -> Result<Option<AccessMethod>> {
let vtab_constraints = convert_to_vtab_constraint(constraints, join_order);
let best_index_result = vtab.best_index(&vtab_constraints, &[]);
match best_index_result {
Ok(index_info) => {
Ok(Some(AccessMethod {
cost: estimate_cost_for_scan_or_seek(
None,
&[],
&[],
input_cardinality,
base_row_count,
),
params: AccessMethodParams::VirtualTable {
idx_num: index_info.idx_num,
idx_str: index_info.idx_str,
constraints: vtab_constraints,
constraint_usages: index_info.constraint_usages,
},
}))
}
Err(ResultCode::ConstraintViolation) => Ok(None),
Err(e) => Err(LimboError::from(e)),
}
}
fn collect_table_refs(expr: &ast::Expr) -> Option<Vec<TableInternalId>> {
let mut tables = Vec::new();
let result = walk_expr(expr, &mut |e| {
match e {
ast::Expr::Column { table, .. } | ast::Expr::RowId { table, .. } => {
if !tables.contains(table) {
tables.push(*table);
}
}
_ => {}
}
Ok(WalkControl::Continue)
});
result.ok().map(|_| tables)
}
pub fn find_equijoin_conditions(
build_table_id: TableInternalId,
probe_table_id: TableInternalId,
constraints: &[Constraint],
where_clause: &[WhereTerm],
) -> Vec<HashJoinKey> {
let mut join_keys = Vec::new();
let mut seen_where_indices = Vec::new();
for constraint in constraints {
if !matches!(constraint.operator, ast::Operator::Equals) {
continue;
}
let (where_idx, _side) = constraint.where_clause_pos;
if seen_where_indices.contains(&where_idx) {
continue;
}
let where_term = &where_clause[where_idx];
if where_term.consumed {
continue;
}
let Ok(Some((lhs, ast::Operator::Equals, rhs))) = as_binary_components(&where_term.expr)
else {
continue;
};
let Some(lhs_tables) = collect_table_refs(lhs) else {
continue;
};
let Some(rhs_tables) = collect_table_refs(rhs) else {
continue;
};
if lhs_tables.len() != 1 || rhs_tables.len() != 1 {
continue;
}
let lhs_table = lhs_tables[0];
let rhs_table = rhs_tables[0];
let build_side = if lhs_table == build_table_id && rhs_table == probe_table_id {
Some(BinaryExprSide::Lhs)
} else if lhs_table == probe_table_id && rhs_table == build_table_id {
Some(BinaryExprSide::Rhs)
} else {
None
};
if let Some(build_side) = build_side {
seen_where_indices.push(where_idx);
join_keys.push(HashJoinKey {
where_clause_idx: where_idx,
build_side,
});
}
}
join_keys
}
pub fn estimate_hash_join_cost(
build_cardinality: f64,
probe_cardinality: f64,
mem_budget: usize,
) -> Cost {
const CPU_HASH_COST: f64 = 0.001;
const CPU_INSERT_COST: f64 = 0.002;
const CPU_LOOKUP_COST: f64 = 0.003;
const BYTES_PER_ROW_ESTIMATE: usize = 100;
let estimated_hash_table_size =
(build_cardinality as usize).saturating_mul(BYTES_PER_ROW_ESTIMATE);
let will_spill = estimated_hash_table_size > mem_budget;
let build_cost = build_cardinality * (CPU_HASH_COST + CPU_INSERT_COST);
let probe_cost = probe_cardinality * (CPU_HASH_COST + CPU_LOOKUP_COST);
let spill_cost = if will_spill {
let build_pages = (build_cardinality / ESTIMATED_HARDCODED_ROWS_PER_PAGE as f64).ceil();
let probe_pages = (probe_cardinality / ESTIMATED_HARDCODED_ROWS_PER_PAGE as f64).ceil();
(build_pages + probe_pages) * 2.0
} else {
0.0
};
Cost(build_cost + probe_cost + spill_cost)
}
#[allow(clippy::too_many_arguments)]
pub fn try_hash_join_access_method(
build_table: &JoinedTable,
probe_table: &JoinedTable,
build_table_idx: usize,
probe_table_idx: usize,
build_constraints: &TableConstraints,
probe_constraints: &TableConstraints,
where_clause: &mut [WhereTerm],
build_cardinality: f64,
probe_cardinality: f64,
subqueries: &[NonFromClauseSubquery],
) -> Option<AccessMethod> {
if !matches!(build_table.table, Table::BTree(_))
|| !matches!(probe_table.table, Table::BTree(_))
{
return None;
}
let probe_root_page = probe_table.table.btree().expect("table is BTree").root_page;
let build_root_page = build_table.table.btree().expect("table is BTree").root_page;
if build_root_page == probe_root_page {
return None;
}
if build_table.join_info.as_ref().is_some_and(|ji| ji.outer)
|| probe_table.join_info.as_ref().is_some_and(|ji| ji.outer)
{
return None;
}
if build_table
.join_info
.as_ref()
.is_some_and(|ji| !ji.using.is_empty())
|| probe_table
.join_info
.as_ref()
.is_some_and(|ji| !ji.using.is_empty())
{
return None;
}
for subquery in subqueries {
if !subquery.correlated {
continue;
}
if let SubqueryState::Unevaluated { plan } = &subquery.state {
if let Some(plan) = plan.as_ref() {
let outer_refs = plan.table_references.outer_query_refs();
for outer_ref in outer_refs {
if outer_ref.internal_id == build_table.internal_id
|| outer_ref.internal_id == probe_table.internal_id
{
return None;
}
}
}
}
}
let join_keys = find_equijoin_conditions(
build_table.internal_id,
probe_table.internal_id,
&probe_constraints.constraints,
where_clause,
);
if join_keys.is_empty() {
return None;
}
for join_key in &join_keys {
if let Some(constraint) = probe_constraints
.constraints
.iter()
.find(|c| c.where_clause_pos.0 == join_key.where_clause_idx)
{
if let Some(col_pos) = constraint.table_col_pos {
if let Some(column) = probe_table.columns().get(col_pos) {
if column.is_rowid_alias() {
return None;
}
}
for candidate in &probe_constraints.candidates {
if let Some(index) = &candidate.index {
if index.column_table_pos_to_index_pos(col_pos).is_some() {
return None;
}
}
}
}
}
if let Some(constraint) = build_constraints
.constraints
.iter()
.find(|c| c.where_clause_pos.0 == join_key.where_clause_idx)
{
if let Some(col_pos) = constraint.table_col_pos {
if let Some(column) = build_table.columns().get(col_pos) {
if column.is_rowid_alias() {
return None;
}
}
for candidate in &build_constraints.candidates {
if let Some(index) = &candidate.index {
if index.column_table_pos_to_index_pos(col_pos).is_some() {
return None;
}
}
}
}
}
}
let cost = estimate_hash_join_cost(build_cardinality, probe_cardinality, DEFAULT_MEM_BUDGET);
Some(AccessMethod {
cost,
params: AccessMethodParams::HashJoin {
build_table_idx,
probe_table_idx,
join_keys,
mem_budget: DEFAULT_MEM_BUDGET,
},
})
}