use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct CostModelConfig {
pub c_seq: f64,
pub c_random: f64,
pub c_filter: f64,
pub c_compare: f64,
pub block_size: usize,
pub btree_fanout: usize,
pub memory_bandwidth: f64,
}
impl Default for CostModelConfig {
fn default() -> Self {
Self {
c_seq: 0.1, c_random: 5.0, c_filter: 0.001, c_compare: 0.0001, block_size: 4096, btree_fanout: 100, memory_bandwidth: 10000.0, }
}
}
#[derive(Debug, Clone)]
pub struct TableStats {
pub name: String,
pub row_count: u64,
pub size_bytes: u64,
pub column_stats: HashMap<String, ColumnStats>,
pub indices: Vec<IndexStats>,
pub last_updated: u64,
}
#[derive(Debug, Clone)]
pub struct ColumnStats {
pub name: String,
pub distinct_count: u64,
pub null_count: u64,
pub min_value: Option<String>,
pub max_value: Option<String>,
pub avg_length: f64,
pub mcv: Vec<(String, f64)>,
pub histogram: Option<Histogram>,
}
#[derive(Debug, Clone)]
pub struct Histogram {
pub boundaries: Vec<f64>,
pub counts: Vec<u64>,
pub total_rows: u64,
}
impl Histogram {
pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
if self.total_rows == 0 {
return 0.5; }
let mut selected_rows = 0u64;
for (i, &count) in self.counts.iter().enumerate() {
let bucket_min = if i == 0 {
f64::NEG_INFINITY
} else {
self.boundaries[i - 1]
};
let bucket_max = if i == self.boundaries.len() {
f64::INFINITY
} else {
self.boundaries[i]
};
let overlaps = match (min, max) {
(Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
(Some(min_val), None) => bucket_max >= min_val,
(None, Some(max_val)) => bucket_min <= max_val,
(None, None) => true,
};
if overlaps {
selected_rows += count;
}
}
selected_rows as f64 / self.total_rows as f64
}
}
#[derive(Debug, Clone)]
pub struct IndexStats {
pub name: String,
pub columns: Vec<String>,
pub is_primary: bool,
pub is_unique: bool,
pub index_type: IndexType,
pub leaf_pages: u64,
pub height: u32,
pub avg_leaf_density: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexType {
BTree,
Hash,
LSM,
Learned,
Vector,
Bloom,
}
#[derive(Debug, Clone)]
pub enum Predicate {
Eq { column: String, value: String },
Ne { column: String, value: String },
Lt { column: String, value: String },
Le { column: String, value: String },
Gt { column: String, value: String },
Ge { column: String, value: String },
Between {
column: String,
min: String,
max: String,
},
In { column: String, values: Vec<String> },
Like { column: String, pattern: String },
IsNull { column: String },
IsNotNull { column: String },
And(Box<Predicate>, Box<Predicate>),
Or(Box<Predicate>, Box<Predicate>),
Not(Box<Predicate>),
}
impl Predicate {
pub fn referenced_columns(&self) -> HashSet<String> {
let mut cols = HashSet::new();
self.collect_columns(&mut cols);
cols
}
fn collect_columns(&self, cols: &mut HashSet<String>) {
match self {
Self::Eq { column, .. }
| Self::Ne { column, .. }
| Self::Lt { column, .. }
| Self::Le { column, .. }
| Self::Gt { column, .. }
| Self::Ge { column, .. }
| Self::Between { column, .. }
| Self::In { column, .. }
| Self::Like { column, .. }
| Self::IsNull { column }
| Self::IsNotNull { column } => {
cols.insert(column.clone());
}
Self::And(left, right) | Self::Or(left, right) => {
left.collect_columns(cols);
right.collect_columns(cols);
}
Self::Not(inner) => inner.collect_columns(cols),
}
}
}
#[derive(Debug, Clone)]
pub enum PhysicalPlan {
TableScan {
table: String,
columns: Vec<String>,
predicate: Option<Box<Predicate>>,
estimated_rows: u64,
estimated_cost: f64,
},
IndexSeek {
table: String,
index: String,
columns: Vec<String>,
key_range: KeyRange,
predicate: Option<Box<Predicate>>,
estimated_rows: u64,
estimated_cost: f64,
},
Filter {
input: Box<PhysicalPlan>,
predicate: Predicate,
estimated_rows: u64,
estimated_cost: f64,
},
Project {
input: Box<PhysicalPlan>,
columns: Vec<String>,
estimated_cost: f64,
},
Sort {
input: Box<PhysicalPlan>,
order_by: Vec<(String, SortDirection)>,
estimated_cost: f64,
},
Limit {
input: Box<PhysicalPlan>,
limit: u64,
offset: u64,
estimated_cost: f64,
},
NestedLoopJoin {
outer: Box<PhysicalPlan>,
inner: Box<PhysicalPlan>,
condition: Predicate,
join_type: JoinType,
estimated_rows: u64,
estimated_cost: f64,
},
HashJoin {
build: Box<PhysicalPlan>,
probe: Box<PhysicalPlan>,
build_keys: Vec<String>,
probe_keys: Vec<String>,
join_type: JoinType,
estimated_rows: u64,
estimated_cost: f64,
},
MergeJoin {
left: Box<PhysicalPlan>,
right: Box<PhysicalPlan>,
left_keys: Vec<String>,
right_keys: Vec<String>,
join_type: JoinType,
estimated_rows: u64,
estimated_cost: f64,
},
Aggregate {
input: Box<PhysicalPlan>,
group_by: Vec<String>,
aggregates: Vec<AggregateExpr>,
estimated_rows: u64,
estimated_cost: f64,
},
}
#[derive(Debug, Clone)]
pub struct KeyRange {
pub start: Option<Vec<u8>>,
pub end: Option<Vec<u8>>,
pub start_inclusive: bool,
pub end_inclusive: bool,
}
impl KeyRange {
pub fn all() -> Self {
Self {
start: None,
end: None,
start_inclusive: true,
end_inclusive: true,
}
}
pub fn point(key: Vec<u8>) -> Self {
Self {
start: Some(key.clone()),
end: Some(key),
start_inclusive: true,
end_inclusive: true,
}
}
pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
Self {
start,
end,
start_inclusive: inclusive,
end_inclusive: inclusive,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SortDirection {
Ascending,
Descending,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
Cross,
}
#[derive(Debug, Clone)]
pub struct AggregateExpr {
pub function: AggregateFunction,
pub column: Option<String>,
pub alias: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregateFunction {
Count,
Sum,
Avg,
Min,
Max,
CountDistinct,
}
pub struct CostBasedOptimizer {
config: CostModelConfig,
stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
token_budget: Option<u64>,
tokens_per_row: f64,
plan_cache: Arc<RwLock<HashMap<u64, (PhysicalPlan, u64)>>>,
plan_cache_ttl_us: u64,
}
impl CostBasedOptimizer {
pub fn new(config: CostModelConfig) -> Self {
Self {
config,
stats_cache: Arc::new(RwLock::new(HashMap::new())),
token_budget: None,
tokens_per_row: 25.0, plan_cache: Arc::new(RwLock::new(HashMap::new())),
plan_cache_ttl_us: 5_000_000, }
}
pub fn with_plan_cache_ttl_ms(mut self, ttl_ms: u64) -> Self {
self.plan_cache_ttl_us = ttl_ms * 1000;
self
}
pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
self.token_budget = Some(budget);
self.tokens_per_row = tokens_per_row;
self
}
pub fn update_stats(&self, stats: TableStats) {
self.stats_cache.write().insert(stats.name.clone(), stats);
}
pub fn get_stats(&self, table: &str) -> Option<TableStats> {
self.stats_cache.read().get(table).cloned()
}
pub fn optimize(
&self,
table: &str,
columns: Vec<String>,
predicate: Option<Predicate>,
order_by: Vec<(String, SortDirection)>,
limit: Option<u64>,
) -> PhysicalPlan {
let stats = self.get_stats(table);
let effective_limit = self.calculate_token_limit(limit);
let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
plan = self.apply_projection_pushdown(plan, columns.clone());
if !order_by.is_empty() {
plan = self.add_sort(plan, order_by, &stats);
}
if let Some(lim) = effective_limit {
plan = PhysicalPlan::Limit {
estimated_cost: 0.0,
input: Box::new(plan),
limit: lim,
offset: 0,
};
}
plan
}
fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
match (self.token_budget, user_limit) {
(Some(budget), Some(limit)) => {
let header_tokens = 50u64;
let usable = budget.saturating_sub(header_tokens);
let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
Some(limit.min(max_rows))
}
(Some(budget), None) => {
let header_tokens = 50u64;
let usable = budget.saturating_sub(header_tokens);
let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
Some(max_rows)
}
(None, limit) => limit,
}
}
fn choose_access_path(
&self,
table: &str,
columns: &[String],
predicate: Option<&Predicate>,
stats: &Option<TableStats>,
) -> PhysicalPlan {
let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
let size_bytes = stats
.as_ref()
.map(|s| s.size_bytes)
.unwrap_or(row_count * 100);
let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
let mut best_index_cost = f64::MAX;
let mut best_index: Option<&IndexStats> = None;
if let Some(table_stats) = stats.as_ref()
&& let Some(pred) = predicate
{
let pred_columns = pred.referenced_columns();
for index in &table_stats.indices {
if self.index_covers_predicate(index, &pred_columns) {
let selectivity = self.estimate_selectivity(pred, table_stats);
let index_cost = self.estimate_index_cost(index, row_count, selectivity);
if index_cost < best_index_cost {
best_index_cost = index_cost;
best_index = Some(index);
}
}
}
}
if best_index_cost < scan_cost {
let index = best_index.unwrap();
let selectivity = predicate
.map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
.unwrap_or(1.0);
PhysicalPlan::IndexSeek {
table: table.to_string(),
index: index.name.clone(),
columns: columns.to_vec(),
key_range: predicate
.map(|p| Self::derive_key_range(p))
.unwrap_or_else(KeyRange::all),
predicate: predicate.map(|p| Box::new(p.clone())),
estimated_rows: (row_count as f64 * selectivity).max(1.0) as u64,
estimated_cost: best_index_cost,
}
} else {
PhysicalPlan::TableScan {
table: table.to_string(),
columns: columns.to_vec(),
predicate: predicate.map(|p| Box::new(p.clone())),
estimated_rows: row_count,
estimated_cost: scan_cost,
}
}
}
fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
if let Some(first_col) = index.columns.first() {
pred_columns.contains(first_col)
} else {
false
}
}
fn estimate_scan_cost(
&self,
row_count: u64,
size_bytes: u64,
_predicate: Option<&Predicate>,
) -> f64 {
let blocks = (size_bytes as f64 / self.config.block_size as f64)
.ceil()
.max(1.0) as u64;
let io_cost = blocks as f64 * self.config.c_seq;
let cpu_cost = row_count as f64 * self.config.c_filter;
io_cost + cpu_cost
}
fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
let tree_cost = index.height as f64 * self.config.c_random;
let matching_rows = (total_rows as f64 * selectivity) as u64;
let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
let fetch_cost = if index.is_primary {
0.0 } else {
matching_rows.min(1000) as f64 * self.config.c_random * 0.1 };
tree_cost + leaf_cost + fetch_cost
}
#[allow(clippy::only_used_in_recursion)]
fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
match predicate {
Predicate::Eq { column, value } => {
if let Some(col_stats) = stats.column_stats.get(column) {
for (mcv_val, freq) in &col_stats.mcv {
if mcv_val == value {
return *freq;
}
}
1.0 / col_stats.distinct_count.max(1) as f64
} else {
0.1 }
}
Predicate::Ne { .. } => 0.9, Predicate::Lt { column, value }
| Predicate::Le { column, value }
| Predicate::Gt { column, value }
| Predicate::Ge { column, value } => {
if let Some(col_stats) = stats.column_stats.get(column) {
if let Some(ref hist) = col_stats.histogram {
let val: f64 = value.parse().unwrap_or(0.0);
match predicate {
Predicate::Lt { .. } | Predicate::Le { .. } => {
hist.estimate_range_selectivity(None, Some(val))
}
_ => hist.estimate_range_selectivity(Some(val), None),
}
} else {
0.25 }
} else {
0.25
}
}
Predicate::Between { column, min, max } => {
if let Some(col_stats) = stats.column_stats.get(column) {
if let Some(ref hist) = col_stats.histogram {
let min_val: f64 = min.parse().unwrap_or(0.0);
let max_val: f64 = max.parse().unwrap_or(f64::MAX);
hist.estimate_range_selectivity(Some(min_val), Some(max_val))
} else {
0.2
}
} else {
0.2
}
}
Predicate::In { column, values } => {
if let Some(col_stats) = stats.column_stats.get(column) {
(values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
} else {
(values.len() as f64 * 0.1).min(0.5)
}
}
Predicate::Like { .. } => 0.15, Predicate::IsNull { column } => {
if let Some(col_stats) = stats.column_stats.get(column) {
col_stats.null_count as f64 / stats.row_count.max(1) as f64
} else {
0.01
}
}
Predicate::IsNotNull { column } => {
if let Some(col_stats) = stats.column_stats.get(column) {
1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
} else {
0.99
}
}
Predicate::And(left, right) => {
self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
}
Predicate::Or(left, right) => {
let s1 = self.estimate_selectivity(left, stats);
let s2 = self.estimate_selectivity(right, stats);
(s1 + s2 - s1 * s2).min(1.0)
}
Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
}
}
fn derive_key_range(predicate: &Predicate) -> KeyRange {
match predicate {
Predicate::Eq { value, .. } => KeyRange::point(value.as_bytes().to_vec()),
Predicate::Lt { value, .. } | Predicate::Le { value, .. } => KeyRange::range(
None,
Some(value.as_bytes().to_vec()),
matches!(predicate, Predicate::Le { .. }),
),
Predicate::Gt { value, .. } | Predicate::Ge { value, .. } => KeyRange::range(
Some(value.as_bytes().to_vec()),
None,
matches!(predicate, Predicate::Ge { .. }),
),
Predicate::Between { min, max, .. } => KeyRange {
start: Some(min.as_bytes().to_vec()),
end: Some(max.as_bytes().to_vec()),
start_inclusive: true,
end_inclusive: true,
},
Predicate::And(left, _) => Self::derive_key_range(left),
_ => KeyRange::all(),
}
}
fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
match plan {
PhysicalPlan::TableScan {
ref table,
predicate,
estimated_rows,
estimated_cost,
columns: ref all_columns,
..
} => {
let col_ratio = if all_columns.is_empty() || columns.is_empty() {
1.0
} else {
(columns.len() as f64 / all_columns.len().max(1) as f64).clamp(0.1, 1.0)
};
PhysicalPlan::TableScan {
table: table.clone(),
columns,
predicate,
estimated_rows,
estimated_cost: estimated_cost * col_ratio,
}
}
PhysicalPlan::IndexSeek {
table,
index,
key_range,
predicate,
estimated_rows,
estimated_cost,
..
} => {
PhysicalPlan::IndexSeek {
table,
index,
columns, key_range,
predicate,
estimated_rows,
estimated_cost,
}
}
other => PhysicalPlan::Project {
input: Box::new(other),
columns,
estimated_cost: 0.0,
},
}
}
fn add_sort(
&self,
plan: PhysicalPlan,
order_by: Vec<(String, SortDirection)>,
_stats: &Option<TableStats>,
) -> PhysicalPlan {
let estimated_rows = self.get_plan_rows(&plan);
let sort_cost = if estimated_rows > 0 {
estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
} else {
0.0
};
PhysicalPlan::Sort {
input: Box::new(plan),
order_by,
estimated_cost: sort_cost,
}
}
#[allow(clippy::only_used_in_recursion)]
fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
match plan {
PhysicalPlan::TableScan { estimated_rows, .. }
| PhysicalPlan::IndexSeek { estimated_rows, .. }
| PhysicalPlan::Filter { estimated_rows, .. }
| PhysicalPlan::Aggregate { estimated_rows, .. }
| PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
| PhysicalPlan::HashJoin { estimated_rows, .. }
| PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
self.get_plan_rows(input)
}
PhysicalPlan::Limit { limit, .. } => *limit,
}
}
#[allow(clippy::only_used_in_recursion)]
pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
match plan {
PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
PhysicalPlan::Filter {
estimated_cost,
input,
..
} => *estimated_cost + self.get_plan_cost(input),
PhysicalPlan::Project {
estimated_cost,
input,
..
} => *estimated_cost + self.get_plan_cost(input),
PhysicalPlan::Sort {
estimated_cost,
input,
..
} => *estimated_cost + self.get_plan_cost(input),
PhysicalPlan::Limit {
estimated_cost,
input,
..
} => *estimated_cost + self.get_plan_cost(input),
PhysicalPlan::NestedLoopJoin {
estimated_cost,
outer,
inner,
..
} => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
PhysicalPlan::HashJoin {
estimated_cost,
build,
probe,
..
} => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
PhysicalPlan::MergeJoin {
estimated_cost,
left,
right,
..
} => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
PhysicalPlan::Aggregate {
estimated_cost,
input,
..
} => *estimated_cost + self.get_plan_cost(input),
}
}
pub fn explain(&self, plan: &PhysicalPlan) -> String {
self.explain_impl(plan, 0)
}
fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
let prefix = " ".repeat(indent);
let cost = self.get_plan_cost(plan);
match plan {
PhysicalPlan::TableScan {
table,
columns,
estimated_rows,
..
} => {
format!(
"{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
prefix, table, columns, estimated_rows, cost
)
}
PhysicalPlan::IndexSeek {
table,
index,
columns,
estimated_rows,
..
} => {
format!(
"{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
prefix, table, index, columns, estimated_rows, cost
)
}
PhysicalPlan::Filter {
input,
estimated_rows,
..
} => {
format!(
"{}Filter [rows={}, cost={:.2}ms]\n{}",
prefix,
estimated_rows,
cost,
self.explain_impl(input, indent + 1)
)
}
PhysicalPlan::Project { input, columns, .. } => {
format!(
"{}Project [columns={:?}, cost={:.2}ms]\n{}",
prefix,
columns,
cost,
self.explain_impl(input, indent + 1)
)
}
PhysicalPlan::Sort {
input, order_by, ..
} => {
let order: Vec<_> = order_by
.iter()
.map(|(c, d)| format!("{} {:?}", c, d))
.collect();
format!(
"{}Sort [order={:?}, cost={:.2}ms]\n{}",
prefix,
order,
cost,
self.explain_impl(input, indent + 1)
)
}
PhysicalPlan::Limit {
input,
limit,
offset,
..
} => {
format!(
"{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
prefix,
limit,
offset,
cost,
self.explain_impl(input, indent + 1)
)
}
PhysicalPlan::HashJoin {
build,
probe,
join_type,
estimated_rows,
..
} => {
format!(
"{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
prefix,
join_type,
estimated_rows,
cost,
self.explain_impl(build, indent + 1),
self.explain_impl(probe, indent + 1)
)
}
PhysicalPlan::MergeJoin {
left,
right,
join_type,
estimated_rows,
..
} => {
format!(
"{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
prefix,
join_type,
estimated_rows,
cost,
self.explain_impl(left, indent + 1),
self.explain_impl(right, indent + 1)
)
}
PhysicalPlan::NestedLoopJoin {
outer,
inner,
join_type,
estimated_rows,
..
} => {
format!(
"{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
prefix,
join_type,
estimated_rows,
cost,
self.explain_impl(outer, indent + 1),
self.explain_impl(inner, indent + 1)
)
}
PhysicalPlan::Aggregate {
input,
group_by,
aggregates,
estimated_rows,
..
} => {
let aggs: Vec<_> = aggregates
.iter()
.map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
.collect();
format!(
"{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
prefix,
group_by,
aggs,
estimated_rows,
cost,
self.explain_impl(input, indent + 1)
)
}
}
}
}
impl CostBasedOptimizer {
pub fn evict_stale_plans(&self) {
let now = Self::now_us();
self.plan_cache
.write()
.retain(|_, (_, ts)| now.saturating_sub(*ts) < self.plan_cache_ttl_us);
}
pub fn invalidate_plan_cache(&self) {
self.plan_cache.write().clear();
}
pub fn collect_stats(
&self,
table_name: &str,
row_count: u64,
size_bytes: u64,
column_values: HashMap<String, Vec<String>>,
indices: Vec<IndexStats>,
) {
let mut column_stats = HashMap::new();
for (col_name, values) in &column_values {
let distinct: HashSet<&String> = values.iter().collect();
let null_count = values.iter().filter(|v| v.is_empty()).count() as u64;
let avg_length = if values.is_empty() {
0.0
} else {
values.iter().map(|v| v.len()).sum::<usize>() as f64 / values.len() as f64
};
let is_numeric = values.iter().take(10).all(|v| v.parse::<f64>().is_ok());
let histogram = if is_numeric && values.len() >= 10 {
let mut nums: Vec<f64> = values.iter().filter_map(|v| v.parse().ok()).collect();
nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let bucket_count = 10.min(nums.len());
let bucket_size = nums.len() / bucket_count;
let mut boundaries = Vec::new();
let mut counts = Vec::new();
for i in 0..bucket_count {
let end = if i == bucket_count - 1 {
nums.len()
} else {
(i + 1) * bucket_size
};
let start = i * bucket_size;
boundaries.push(nums[end - 1]);
counts.push((end - start) as u64);
}
Some(Histogram {
boundaries,
counts,
total_rows: nums.len() as u64,
})
} else {
None
};
let mut freq_map: HashMap<&String, usize> = HashMap::new();
for v in values {
*freq_map.entry(v).or_insert(0) += 1;
}
let total = values.len() as f64;
let mut mcv: Vec<(String, f64)> = freq_map
.iter()
.map(|(k, &v)| ((*k).clone(), v as f64 / total))
.collect();
mcv.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
mcv.truncate(5);
column_stats.insert(
col_name.clone(),
ColumnStats {
name: col_name.clone(),
distinct_count: distinct.len() as u64,
null_count,
min_value: values.iter().min().cloned(),
max_value: values.iter().max().cloned(),
avg_length,
mcv,
histogram,
},
);
}
self.update_stats(TableStats {
name: table_name.to_string(),
row_count,
size_bytes,
column_stats,
indices,
last_updated: Self::now_us(),
});
self.invalidate_plan_cache();
}
pub fn stats_age_us(&self, table: &str) -> Option<u64> {
self.stats_cache
.read()
.get(table)
.map(|s| Self::now_us().saturating_sub(s.last_updated))
}
fn now_us() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64
}
}
pub struct JoinOrderOptimizer {
stats: HashMap<String, TableStats>,
config: CostModelConfig,
}
impl JoinOrderOptimizer {
pub fn new(config: CostModelConfig) -> Self {
Self {
stats: HashMap::new(),
config,
}
}
pub fn add_stats(&mut self, stats: TableStats) {
self.stats.insert(stats.name.clone(), stats);
}
pub fn find_optimal_order(
&self,
tables: &[String],
join_conditions: &[(String, String, String, String)], ) -> Vec<(String, String)> {
let n = tables.len();
if n <= 1 {
return vec![];
}
let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
for (i, _table) in tables.iter().enumerate() {
let mask = 1u32 << i;
dp.insert(mask, (0.0, vec![]));
}
for size in 2..=n {
for mask in 0..(1u32 << n) {
if mask.count_ones() != size as u32 {
continue;
}
let mut best_cost = f64::MAX;
let mut best_order = vec![];
for sub in 1..mask {
if sub & mask != sub || sub == 0 {
continue;
}
let other = mask ^ sub;
if other == 0 {
continue;
}
if !self.has_join_condition(tables, sub, other, join_conditions) {
continue;
}
if let (Some((cost1, order1)), Some((cost2, order2))) =
(dp.get(&sub), dp.get(&other))
{
let join_cost = self.estimate_join_cost(tables, sub, other);
let total_cost = cost1 + cost2 + join_cost;
if total_cost < best_cost {
best_cost = total_cost;
best_order = order1.clone();
best_order.extend(order2.clone());
let (t1, t2) =
self.get_join_tables(tables, sub, other, join_conditions);
if let Some((t1, t2)) = Some((t1, t2)) {
best_order.push((t1, t2));
}
}
}
}
if best_cost < f64::MAX {
dp.insert(mask, (best_cost, best_order));
}
}
}
let full_mask = (1u32 << n) - 1;
dp.get(&full_mask)
.map(|(_, order)| order.clone())
.unwrap_or_default()
}
fn has_join_condition(
&self,
tables: &[String],
mask1: u32,
mask2: u32,
conditions: &[(String, String, String, String)],
) -> bool {
for (t1, _, t2, _) in conditions {
let idx1 = tables.iter().position(|t| t == t1);
let idx2 = tables.iter().position(|t| t == t2);
if let (Some(i1), Some(i2)) = (idx1, idx2) {
let in_mask1 = (mask1 >> i1) & 1 == 1;
let in_mask2 = (mask2 >> i2) & 1 == 1;
if in_mask1 && in_mask2 {
return true;
}
}
}
false
}
fn get_join_tables(
&self,
tables: &[String],
mask1: u32,
mask2: u32,
conditions: &[(String, String, String, String)],
) -> (String, String) {
for (t1, _, t2, _) in conditions {
let idx1 = tables.iter().position(|t| t == t1);
let idx2 = tables.iter().position(|t| t == t2);
if let (Some(i1), Some(i2)) = (idx1, idx2) {
let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
if t1_in_mask1 && t2_in_mask2 {
return (t1.clone(), t2.clone());
}
}
}
(String::new(), String::new())
}
fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
let rows1 = self.estimate_rows_for_mask(tables, mask1);
let rows2 = self.estimate_rows_for_mask(tables, mask2);
let build_cost = rows1 as f64 * self.config.c_filter;
let probe_cost = rows2 as f64 * self.config.c_filter;
build_cost + probe_cost
}
fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
let mut total = 1u64;
for (i, table) in tables.iter().enumerate() {
if (mask >> i) & 1 == 1 {
let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
total = total.saturating_mul(rows);
}
}
let num_tables = mask.count_ones();
if num_tables > 1 {
total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
}
total.max(1)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_stats() -> TableStats {
let mut column_stats = HashMap::new();
column_stats.insert(
"id".to_string(),
ColumnStats {
name: "id".to_string(),
distinct_count: 100000,
null_count: 0,
min_value: Some("1".to_string()),
max_value: Some("100000".to_string()),
avg_length: 8.0,
mcv: vec![],
histogram: None,
},
);
column_stats.insert(
"score".to_string(),
ColumnStats {
name: "score".to_string(),
distinct_count: 100,
null_count: 1000,
min_value: Some("0".to_string()),
max_value: Some("100".to_string()),
avg_length: 8.0,
mcv: vec![("50".to_string(), 0.05)],
histogram: Some(Histogram {
boundaries: vec![25.0, 50.0, 75.0, 100.0],
counts: vec![25000, 25000, 25000, 25000],
total_rows: 100000,
}),
},
);
TableStats {
name: "users".to_string(),
row_count: 100000,
size_bytes: 10_000_000, column_stats,
indices: vec![
IndexStats {
name: "pk_users".to_string(),
columns: vec!["id".to_string()],
is_primary: true,
is_unique: true,
index_type: IndexType::BTree,
leaf_pages: 1000,
height: 3,
avg_leaf_density: 100.0,
},
IndexStats {
name: "idx_score".to_string(),
columns: vec!["score".to_string()],
is_primary: false,
is_unique: false,
index_type: IndexType::BTree,
leaf_pages: 500,
height: 2,
avg_leaf_density: 200.0,
},
],
last_updated: 0,
}
}
#[test]
fn test_selectivity_estimation() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let stats = create_test_stats();
optimizer.update_stats(stats.clone());
let pred = Predicate::Eq {
column: "id".to_string(),
value: "12345".to_string(),
};
let sel = optimizer.estimate_selectivity(&pred, &stats);
assert!(sel < 0.001);
let pred = Predicate::Gt {
column: "score".to_string(),
value: "75".to_string(),
};
let sel = optimizer.estimate_selectivity(&pred, &stats);
assert!(sel > 0.4 && sel < 0.6); }
#[test]
fn test_access_path_selection() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let stats = create_test_stats();
optimizer.update_stats(stats);
let pred = Predicate::Eq {
column: "id".to_string(),
value: "12345".to_string(),
};
let plan = optimizer.optimize(
"users",
vec!["id".to_string(), "score".to_string()],
Some(pred),
vec![],
None,
);
match plan {
PhysicalPlan::IndexSeek { index, .. } => {
assert_eq!(index, "pk_users");
}
_ => panic!("Expected IndexSeek for equality on primary key"),
}
}
#[test]
fn test_token_budget_limit() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
match plan {
PhysicalPlan::Limit { limit, .. } => {
assert!(limit <= 80);
}
_ => panic!("Expected Limit to be injected"),
}
}
#[test]
fn test_explain_output() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let stats = create_test_stats();
optimizer.update_stats(stats);
let plan = optimizer.optimize(
"users",
vec!["id".to_string(), "score".to_string()],
Some(Predicate::Gt {
column: "score".to_string(),
value: "80".to_string(),
}),
vec![("score".to_string(), SortDirection::Descending)],
Some(10),
);
let explain = optimizer.explain(&plan);
assert!(explain.contains("Limit"));
assert!(explain.contains("Sort"));
}
#[test]
fn test_token_budget_underflow_safety() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config).with_token_budget(10, 25.0);
let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
match plan {
PhysicalPlan::Limit { limit, .. } => {
assert!(limit >= 1, "Must return at least 1 row");
}
_ => panic!("Expected Limit"),
}
}
#[test]
fn test_index_seek_derives_key_range() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
optimizer.update_stats(create_test_stats());
let plan = optimizer.optimize(
"users",
vec!["id".to_string()],
Some(Predicate::Eq {
column: "id".to_string(),
value: "42".to_string(),
}),
vec![],
None,
);
match plan {
PhysicalPlan::IndexSeek { key_range, .. } => {
assert!(
key_range.start.is_some(),
"KeyRange must derive from Eq predicate"
);
assert_eq!(
key_range.start, key_range.end,
"Eq predicate → point key range"
);
}
_ => panic!("Expected IndexSeek"),
}
}
#[test]
fn test_range_predicate_key_range() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
optimizer.update_stats(create_test_stats());
let plan = optimizer.optimize(
"users",
vec!["score".to_string()],
Some(Predicate::Between {
column: "score".to_string(),
min: "10".to_string(),
max: "90".to_string(),
}),
vec![],
None,
);
match plan {
PhysicalPlan::IndexSeek { key_range, .. } => {
assert!(key_range.start.is_some());
assert!(key_range.end.is_some());
assert!(key_range.start_inclusive);
assert!(key_range.end_inclusive);
}
_ => {} }
}
#[test]
fn test_projection_pushdown_proportional_reduction() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
optimizer.update_stats(create_test_stats());
let plan_all = optimizer.optimize(
"users",
vec!["id".to_string(), "score".to_string()],
None,
vec![],
Some(100),
);
let plan_single =
optimizer.optimize("users", vec!["id".to_string()], None, vec![], Some(100));
let cost_all = optimizer.get_plan_cost(&plan_all);
let cost_single = optimizer.get_plan_cost(&plan_single);
assert!(
cost_single <= cost_all,
"Projection should reduce cost: {} vs {}",
cost_single,
cost_all
);
}
#[test]
fn test_collect_stats_builds_histogram() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let mut column_values = HashMap::new();
let scores: Vec<String> = (0..100).map(|i| i.to_string()).collect();
column_values.insert("score".to_string(), scores);
optimizer.collect_stats("test_table", 100, 10000, column_values, vec![]);
let stats = optimizer.get_stats("test_table").unwrap();
assert_eq!(stats.row_count, 100);
let score_stats = stats.column_stats.get("score").unwrap();
assert_eq!(score_stats.distinct_count, 100);
assert!(
score_stats.histogram.is_some(),
"Numeric column should get histogram"
);
assert!(!score_stats.mcv.is_empty(), "Should build MCV list");
}
#[test]
fn test_plan_cache_invalidation() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let mut col = HashMap::new();
col.insert("x".to_string(), vec!["1".to_string()]);
optimizer.collect_stats("t", 1, 100, col.clone(), vec![]);
assert!(optimizer.plan_cache.read().is_empty());
}
#[test]
fn test_stats_age_tracking() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
assert!(optimizer.stats_age_us("unknown").is_none());
let mut col = HashMap::new();
col.insert("x".to_string(), vec!["1".to_string()]);
optimizer.collect_stats("t", 1, 100, col, vec![]);
let age = optimizer.stats_age_us("t").unwrap();
assert!(age < 1_000_000, "Stats should be fresh (< 1 second old)");
}
#[test]
fn test_scan_cost_reads_all_blocks() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config.clone());
let no_pred = optimizer.estimate_scan_cost(1000, 4096 * 10, None);
let with_pred = optimizer.estimate_scan_cost(
1000,
4096 * 10,
Some(&Predicate::Eq {
column: "x".to_string(),
value: "1".to_string(),
}),
);
assert!(
(no_pred - with_pred).abs() < 0.001,
"Scan cost should not depend on predicate: {} vs {}",
no_pred,
with_pred
);
}
#[test]
fn test_index_wins_over_scan_for_point_lookup() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
optimizer.update_stats(create_test_stats());
let scan_cost = optimizer.estimate_scan_cost(100000, 10_000_000, None);
let pk_index = &create_test_stats().indices[0]; let index_cost = optimizer.estimate_index_cost(pk_index, 100000, 0.00001);
assert!(
index_cost < scan_cost * 0.1,
"Index point lookup ({:.2}) should be <10% of scan cost ({:.2})",
index_cost,
scan_cost
);
}
#[test]
fn test_no_stats_defaults_to_scan() {
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let plan = optimizer.optimize(
"unknown_table",
vec!["col1".to_string()],
Some(Predicate::Eq {
column: "col1".to_string(),
value: "x".to_string(),
}),
vec![],
None,
);
match plan {
PhysicalPlan::TableScan { estimated_rows, .. } => {
assert!(estimated_rows > 0, "Default row estimate must be positive");
}
PhysicalPlan::IndexSeek { .. } => {} _ => panic!("Expected TableScan or IndexSeek for unknown table"),
}
}
#[test]
fn test_compound_predicate_selectivity() {
let stats = create_test_stats();
let config = CostModelConfig::default();
let optimizer = CostBasedOptimizer::new(config);
let and_pred = Predicate::And(
Box::new(Predicate::Eq {
column: "id".to_string(),
value: "1".to_string(),
}),
Box::new(Predicate::IsNotNull {
column: "score".to_string(),
}),
);
let sel = optimizer.estimate_selectivity(&and_pred, &stats);
let eq_sel = optimizer.estimate_selectivity(
&Predicate::Eq {
column: "id".to_string(),
value: "1".to_string(),
},
&stats,
);
assert!(sel < eq_sel, "AND must be more selective than either child");
let or_pred = Predicate::Or(
Box::new(Predicate::Eq {
column: "id".to_string(),
value: "1".to_string(),
}),
Box::new(Predicate::Eq {
column: "id".to_string(),
value: "2".to_string(),
}),
);
let sel = optimizer.estimate_selectivity(&or_pred, &stats);
assert!(sel > eq_sel, "OR must be less selective than either child");
assert!(sel <= 1.0, "Selectivity must be <= 1.0");
}
#[test]
fn test_join_order_optimizer() {
let mut join_opt = JoinOrderOptimizer::new(CostModelConfig::default());
join_opt.add_stats(TableStats {
name: "orders".to_string(),
row_count: 1000000,
size_bytes: 100_000_000,
column_stats: HashMap::new(),
indices: vec![],
last_updated: 0,
});
join_opt.add_stats(TableStats {
name: "users".to_string(),
row_count: 10000,
size_bytes: 1_000_000,
column_stats: HashMap::new(),
indices: vec![],
last_updated: 0,
});
let order = join_opt.find_optimal_order(
&["orders".to_string(), "users".to_string()],
&[(
"orders".to_string(),
"user_id".to_string(),
"users".to_string(),
"id".to_string(),
)],
);
assert!(!order.is_empty(), "Should find a join order");
}
}