use super::{ComparisonOperator, Condition, ParsedQuery, QueryType, WhereClause};
use crate::{schema::SchemaManager, Config, Error, Result, TableId};
use std::sync::Arc;
const DEFAULT_PARALLELISM: usize = 4;
const PARALLELIZATION_ROW_THRESHOLD: u64 = 10_000;
const FILTER_COST_FACTOR: f64 = 0.1;
const UPDATE_WRITE_COST_FACTOR: f64 = 0.5;
const PROJECT_COST_FACTOR: f64 = 0.001;
const PRIMARY_INDEX_COST_FACTOR: f64 = 0.1;
const BLOOM_INDEX_COST_FACTOR: f64 = 0.01;
const COMPOSITE_INDEX_COST_FACTOR: f64 = 0.5;
const BLOOM_INDEX_SELECTIVITY: f64 = 0.1;
const SELECTIVITY_EQUAL: f64 = 0.1;
const SELECTIVITY_NOT_EQUAL: f64 = 0.9;
const SELECTIVITY_RANGE: f64 = 0.3;
const SELECTIVITY_IN: f64 = 0.2;
const SELECTIVITY_NOT_IN: f64 = 0.8;
const SELECTIVITY_LIKE: f64 = 0.5;
const DDL_FIXED_COST: f64 = 1.0;
const METADATA_FIXED_COST: f64 = 0.1;
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub plan_type: PlanType,
pub table: Option<TableId>,
pub estimated_cost: f64,
pub estimated_rows: u64,
pub selected_indexes: Vec<IndexSelection>,
pub steps: Vec<ExecutionStep>,
pub hints: QueryHints,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PlanType {
TableScan,
IndexScan,
PointLookup,
RangeScan,
Join,
Aggregation,
Subquery,
}
#[derive(Debug, Clone)]
pub struct IndexSelection {
pub index_name: String,
pub columns: Vec<String>,
pub selectivity: f64,
pub index_type: IndexType,
}
#[derive(Debug, Clone, PartialEq)]
pub enum IndexType {
Primary,
Secondary,
BloomFilter,
Composite,
}
#[derive(Debug, Clone)]
pub struct ExecutionStep {
pub step_type: StepType,
pub columns: Vec<String>,
pub conditions: Vec<Condition>,
pub cost: f64,
pub parallelization: ParallelizationInfo,
}
#[derive(Debug, Clone, PartialEq)]
pub enum StepType {
Scan,
Filter,
Insert,
Sort,
Limit,
Project,
Join,
Aggregate,
}
#[derive(Debug, Clone)]
pub struct ParallelizationInfo {
pub can_parallelize: bool,
pub suggested_threads: usize,
pub partition_key: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct QueryHints {
pub force_index: Option<String>,
pub disable_bloom_filter: bool,
pub preferred_parallelization: Option<usize>,
pub timeout_ms: Option<u64>,
}
#[derive(Debug)]
pub struct QueryPlanner {
_schema: Arc<SchemaManager>,
config: Config,
cost_model: CostModel,
}
#[derive(Debug, Clone)]
pub struct CostModel {
pub row_scan_cost: f64,
pub index_lookup_cost: f64,
pub sort_cost_per_row: f64,
pub join_cost_per_row: f64,
pub memory_cost_factor: f64,
}
impl Default for CostModel {
fn default() -> Self {
Self {
row_scan_cost: 1.0,
index_lookup_cost: 0.1,
sort_cost_per_row: 0.01,
join_cost_per_row: 0.05,
memory_cost_factor: 0.001,
}
}
}
fn require_table<'a>(query: &'a ParsedQuery, op: &str) -> Result<&'a TableId> {
query
.table
.as_ref()
.ok_or_else(|| Error::query_execution(format!("Missing table in {op}")))
}
fn clone_conditions(where_clause: &Option<WhereClause>) -> Vec<Condition> {
where_clause
.as_ref()
.map(|w| w.conditions.clone())
.unwrap_or_default()
}
fn default_insert_columns(table_name: &str, value_count: usize) -> Vec<String> {
fn s(items: &[&str]) -> Vec<String> {
items.iter().map(|s| s.to_string()).collect()
}
match table_name {
"sales" => s(&["id", "region", "amount"]),
"orders" => s(&["id", "status", "amount"]),
"products" => s(&["id", "name", "price", "category"]),
"employees" => s(&["department", "id", "name", "salary"]),
"inventory" => s(&["id", "product", "quantity", "price", "active"]),
"customers" => s(&["id", "name", "email"]),
"user_data" => s(&["id", "tags", "preferences"]),
"performance_test" => s(&["id", "value", "category"]),
_ => (0..value_count).map(|i| format!("col_{i}")).collect(),
}
}
impl QueryPlanner {
pub fn new(schema: Arc<SchemaManager>, config: &Config) -> Self {
Self {
_schema: schema,
config: config.clone(),
cost_model: CostModel::default(),
}
}
pub async fn plan(&self, query: &ParsedQuery) -> Result<QueryPlan> {
match query.query_type {
QueryType::Select => self.plan_select(query).await,
QueryType::Insert => self.plan_insert(query).await,
QueryType::Update => self.plan_update(query).await,
QueryType::Delete => self.plan_delete(query).await,
QueryType::CreateTable => Ok(self.plan_ddl(query, PlanType::TableScan, DDL_FIXED_COST)),
QueryType::DropTable => Ok(self.plan_ddl(query, PlanType::TableScan, DDL_FIXED_COST)),
QueryType::CreateIndex => Ok(self.plan_ddl(query, PlanType::IndexScan, DDL_FIXED_COST)),
QueryType::DropIndex => Ok(self.plan_ddl(query, PlanType::IndexScan, DDL_FIXED_COST)),
QueryType::Describe => {
Ok(self.plan_metadata(query, PlanType::PointLookup, METADATA_FIXED_COST, 1))
}
QueryType::Use => {
Ok(self.plan_metadata(query, PlanType::PointLookup, METADATA_FIXED_COST, 0))
}
}
}
fn query_parallelism(&self) -> usize {
self.config
.query
.query_parallelism
.unwrap_or(DEFAULT_PARALLELISM)
}
fn parallel_info(&self) -> ParallelizationInfo {
ParallelizationInfo {
can_parallelize: true,
suggested_threads: self.query_parallelism(),
partition_key: None,
}
}
fn serial_info() -> ParallelizationInfo {
ParallelizationInfo {
can_parallelize: false,
suggested_threads: 1,
partition_key: None,
}
}
async fn plan_select(&self, query: &ParsedQuery) -> Result<QueryPlan> {
let table = require_table(query, "SELECT")?;
let table_stats = self.get_table_statistics(table).await?;
let index_selection = self.select_indexes(table, &query.where_clause).await?;
let plan_type = self.determine_plan_type(&index_selection, &query.where_clause);
let mut steps = Vec::new();
steps.push(ExecutionStep {
step_type: StepType::Scan,
columns: query.columns.clone(),
conditions: clone_conditions(&query.where_clause),
cost: self.calculate_scan_cost(&index_selection, &table_stats),
parallelization: self.determine_parallelization(&index_selection, &table_stats),
});
if let Some(where_clause) = &query.where_clause {
if plan_type != PlanType::PointLookup {
steps.push(ExecutionStep {
step_type: StepType::Filter,
columns: vec![],
conditions: where_clause.conditions.clone(),
cost: table_stats.row_count as f64
* self.cost_model.row_scan_cost
* FILTER_COST_FACTOR,
parallelization: self.parallel_info(),
});
}
}
if !query.order_by.is_empty() {
steps.push(ExecutionStep {
step_type: StepType::Sort,
columns: query.order_by.iter().map(|o| o.column.clone()).collect(),
conditions: vec![],
cost: table_stats.row_count as f64 * self.cost_model.sort_cost_per_row,
parallelization: self.parallel_info(),
});
}
if query.limit.is_some() {
steps.push(ExecutionStep {
step_type: StepType::Limit,
columns: vec![],
conditions: vec![],
cost: 0.0,
parallelization: Self::serial_info(),
});
}
if !query.columns.is_empty() && query.columns != vec!["*"] {
steps.push(ExecutionStep {
step_type: StepType::Project,
columns: query.columns.clone(),
conditions: vec![],
cost: table_stats.row_count as f64 * PROJECT_COST_FACTOR,
parallelization: self.parallel_info(),
});
}
let total_cost = steps.iter().map(|s| s.cost).sum();
let estimated_rows = self.estimate_result_rows(&table_stats, &query.where_clause);
Ok(QueryPlan {
plan_type,
table: Some(table.clone()),
estimated_cost: total_cost,
estimated_rows,
selected_indexes: index_selection,
steps,
hints: QueryHints::default(),
})
}
async fn plan_insert(&self, query: &ParsedQuery) -> Result<QueryPlan> {
let table = require_table(query, "INSERT")?;
let _table_stats = self.get_table_statistics(table).await?;
let owned_default;
let columns: &[String] = if query.columns.is_empty() {
owned_default = default_insert_columns(table.name(), query.values.len());
&owned_default
} else {
&query.columns
};
let conditions: Vec<Condition> = columns
.iter()
.zip(query.values.iter())
.map(|(column, value)| Condition {
column: column.clone(),
operator: ComparisonOperator::Equal,
value: value.clone(),
})
.collect();
let steps = vec![ExecutionStep {
step_type: StepType::Insert,
columns: query.columns.clone(),
conditions,
cost: self.cost_model.row_scan_cost,
parallelization: Self::serial_info(),
}];
Ok(QueryPlan {
plan_type: PlanType::TableScan,
table: Some(table.clone()),
estimated_cost: self.cost_model.row_scan_cost,
estimated_rows: 1,
selected_indexes: vec![],
steps,
hints: QueryHints::default(),
})
}
async fn plan_update(&self, query: &ParsedQuery) -> Result<QueryPlan> {
let table = require_table(query, "UPDATE")?;
let table_stats = self.get_table_statistics(table).await?;
let index_selection = self.select_indexes(table, &query.where_clause).await?;
let steps = vec![
ExecutionStep {
step_type: StepType::Scan,
columns: vec![],
conditions: clone_conditions(&query.where_clause),
cost: self.calculate_scan_cost(&index_selection, &table_stats),
parallelization: self.determine_parallelization(&index_selection, &table_stats),
},
ExecutionStep {
step_type: StepType::Filter,
columns: query.set_clause.keys().cloned().collect(),
conditions: vec![],
cost: table_stats.row_count as f64
* self.cost_model.row_scan_cost
* UPDATE_WRITE_COST_FACTOR,
parallelization: self.parallel_info(),
},
];
let total_cost = steps.iter().map(|s| s.cost).sum();
let estimated_rows = self.estimate_result_rows(&table_stats, &query.where_clause);
Ok(QueryPlan {
plan_type: PlanType::TableScan,
table: Some(table.clone()),
estimated_cost: total_cost,
estimated_rows,
selected_indexes: index_selection,
steps,
hints: QueryHints::default(),
})
}
async fn plan_delete(&self, query: &ParsedQuery) -> Result<QueryPlan> {
let table = require_table(query, "DELETE")?;
let table_stats = self.get_table_statistics(table).await?;
let index_selection = self.select_indexes(table, &query.where_clause).await?;
let steps = vec![ExecutionStep {
step_type: StepType::Scan,
columns: vec![],
conditions: clone_conditions(&query.where_clause),
cost: self.calculate_scan_cost(&index_selection, &table_stats),
parallelization: self.determine_parallelization(&index_selection, &table_stats),
}];
let total_cost = steps.iter().map(|s| s.cost).sum();
let estimated_rows = self.estimate_result_rows(&table_stats, &query.where_clause);
Ok(QueryPlan {
plan_type: PlanType::TableScan,
table: Some(table.clone()),
estimated_cost: total_cost,
estimated_rows,
selected_indexes: index_selection,
steps,
hints: QueryHints::default(),
})
}
fn plan_ddl(&self, query: &ParsedQuery, plan_type: PlanType, cost: f64) -> QueryPlan {
QueryPlan {
plan_type,
table: query.table.clone(),
estimated_cost: cost,
estimated_rows: 0,
selected_indexes: vec![],
steps: vec![],
hints: QueryHints::default(),
}
}
fn plan_metadata(
&self,
query: &ParsedQuery,
plan_type: PlanType,
cost: f64,
estimated_rows: u64,
) -> QueryPlan {
QueryPlan {
plan_type,
table: query.table.clone(),
estimated_cost: cost,
estimated_rows,
selected_indexes: vec![],
steps: vec![],
hints: QueryHints::default(),
}
}
async fn select_indexes(
&self,
_table: &TableId,
where_clause: &Option<WhereClause>,
) -> Result<Vec<IndexSelection>> {
let mut selections = Vec::new();
selections.push(IndexSelection {
index_name: "PRIMARY".to_string(),
columns: vec!["id".to_string()], selectivity: 1.0,
index_type: IndexType::Primary,
});
if let Some(where_clause) = where_clause {
for condition in &where_clause.conditions {
selections.push(IndexSelection {
index_name: format!("idx_{}", condition.column),
columns: vec![condition.column.clone()],
selectivity: self.estimate_selectivity(condition),
index_type: IndexType::Secondary,
});
}
for condition in &where_clause.conditions {
if condition.operator == ComparisonOperator::Equal {
selections.push(IndexSelection {
index_name: format!("bloom_{}", condition.column),
columns: vec![condition.column.clone()],
selectivity: BLOOM_INDEX_SELECTIVITY,
index_type: IndexType::BloomFilter,
});
}
}
}
Ok(selections)
}
fn determine_plan_type(
&self,
index_selection: &[IndexSelection],
where_clause: &Option<WhereClause>,
) -> PlanType {
let Some(where_clause) = where_clause else {
return PlanType::TableScan;
};
let primary_columns: Vec<&str> = index_selection
.iter()
.filter(|idx| idx.index_type == IndexType::Primary)
.flat_map(|idx| idx.columns.iter().map(String::as_str))
.collect();
let mut has_range = false;
for condition in &where_clause.conditions {
match condition.operator {
ComparisonOperator::Equal => {
if primary_columns.iter().any(|c| *c == condition.column) {
return PlanType::PointLookup;
}
}
ComparisonOperator::LessThan
| ComparisonOperator::LessThanOrEqual
| ComparisonOperator::GreaterThan
| ComparisonOperator::GreaterThanOrEqual => {
has_range = true;
}
_ => {}
}
}
if has_range {
return PlanType::RangeScan;
}
if index_selection
.iter()
.any(|idx| idx.index_type == IndexType::Secondary)
{
return PlanType::IndexScan;
}
PlanType::TableScan
}
fn calculate_scan_cost(
&self,
index_selection: &[IndexSelection],
table_stats: &TableStatistics,
) -> f64 {
let rows = table_stats.row_count as f64;
let base_lookup = rows * self.cost_model.index_lookup_cost;
let mut min_cost = rows * self.cost_model.row_scan_cost;
for index in index_selection {
let index_cost = match index.index_type {
IndexType::Primary => base_lookup * PRIMARY_INDEX_COST_FACTOR,
IndexType::Secondary => base_lookup * index.selectivity,
IndexType::BloomFilter => base_lookup * BLOOM_INDEX_COST_FACTOR,
IndexType::Composite => {
base_lookup * index.selectivity * COMPOSITE_INDEX_COST_FACTOR
}
};
min_cost = min_cost.min(index_cost);
}
min_cost
}
fn determine_parallelization(
&self,
index_selection: &[IndexSelection],
table_stats: &TableStatistics,
) -> ParallelizationInfo {
let can_parallelize = table_stats.row_count > PARALLELIZATION_ROW_THRESHOLD;
let suggested_threads = if can_parallelize {
self.query_parallelism()
} else {
1
};
let partition_key = index_selection
.iter()
.find(|idx| idx.index_type == IndexType::Primary)
.and_then(|idx| idx.columns.first())
.cloned();
ParallelizationInfo {
can_parallelize,
suggested_threads,
partition_key,
}
}
fn estimate_selectivity(&self, condition: &Condition) -> f64 {
match condition.operator {
ComparisonOperator::Equal => SELECTIVITY_EQUAL,
ComparisonOperator::NotEqual => SELECTIVITY_NOT_EQUAL,
ComparisonOperator::LessThan
| ComparisonOperator::LessThanOrEqual
| ComparisonOperator::GreaterThan
| ComparisonOperator::GreaterThanOrEqual => SELECTIVITY_RANGE,
ComparisonOperator::In => SELECTIVITY_IN,
ComparisonOperator::NotIn => SELECTIVITY_NOT_IN,
ComparisonOperator::Like | ComparisonOperator::NotLike => SELECTIVITY_LIKE,
}
}
fn estimate_result_rows(
&self,
table_stats: &TableStatistics,
where_clause: &Option<WhereClause>,
) -> u64 {
let selectivity = where_clause
.as_ref()
.map(|w| {
w.conditions
.iter()
.map(|c| self.estimate_selectivity(c))
.product::<f64>()
})
.unwrap_or(1.0);
(table_stats.row_count as f64 * selectivity) as u64
}
async fn get_table_statistics(&self, _table: &TableId) -> Result<TableStatistics> {
Ok(TableStatistics {
row_count: 100_000,
avg_row_size: 256,
table_size: 25_600_000,
index_count: 3,
})
}
}
#[derive(Debug, Clone)]
pub struct TableStatistics {
pub row_count: u64,
pub avg_row_size: u32,
pub table_size: u64,
pub index_count: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Config;
use std::sync::Arc;
use tempfile::TempDir;
async fn make_planner() -> (TempDir, QueryPlanner) {
let temp_dir = TempDir::new().unwrap();
let config = Config::default();
let schema = Arc::new(
crate::schema::SchemaManager::new(temp_dir.path())
.await
.unwrap(),
);
let planner = QueryPlanner::new(schema, &config);
(temp_dir, planner)
}
#[tokio::test]
async fn test_query_planner_creation() {
let (_tmp, planner) = make_planner().await;
assert_eq!(planner.cost_model.row_scan_cost, 1.0);
}
#[tokio::test]
async fn test_plan_type_determination() {
let (_tmp, planner) = make_planner().await;
let index_selection = vec![IndexSelection {
index_name: "PRIMARY".to_string(),
columns: vec!["id".to_string()],
selectivity: 1.0,
index_type: IndexType::Primary,
}];
let where_clause = Some(WhereClause {
conditions: vec![Condition {
column: "id".to_string(),
operator: ComparisonOperator::Equal,
value: crate::Value::Integer(1),
}],
});
let plan_type = planner.determine_plan_type(&index_selection, &where_clause);
assert_eq!(plan_type, PlanType::PointLookup);
}
#[tokio::test]
async fn test_selectivity_estimation() {
let (_tmp, planner) = make_planner().await;
let condition = Condition {
column: "name".to_string(),
operator: ComparisonOperator::Equal,
value: crate::Value::Text("test".to_string()),
};
let selectivity = planner.estimate_selectivity(&condition);
assert_eq!(selectivity, SELECTIVITY_EQUAL);
}
}