use std::sync::{OnceLock, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::features::query::TypedAst;
use crate::features::storage::api as storage_api;
use super::types::{
ExecutableBitmapPredicate, ExecutablePredicate, ExecutableScalarPredicate, ExecuteParams,
ExplainError, ExplainPlan, PhysicalOp, PlannerStatsSnapshot, Result,
};
const PLANNER_STATS_SCHEMA_VERSION: u32 = 1;
static PLANNER_STATS: OnceLock<RwLock<Option<PlannerStatsSnapshot>>> = OnceLock::new();
fn planner_stats_cell() -> &'static RwLock<Option<PlannerStatsSnapshot>> {
PLANNER_STATS.get_or_init(|| RwLock::new(None))
}
pub fn collect_planner_stats(
handle: &storage_api::StorageHandle,
params: &ExecuteParams,
) -> PlannerStatsSnapshot {
let scan_range = params.scan_end_exclusive.saturating_sub(params.scan_start);
let sstable_count = handle.l0_runs.len() as u64 + handle.sstable_cache.len() as u64;
let node_scan_base_cost = (scan_range / 1_000 + 1)
.saturating_mul(sstable_count + 1)
.saturating_mul(10);
let vector_scan_base_cost = if handle.hnsw_total_vectors > 0 {
(handle.hnsw_total_vectors as f64).log2() as u64 * 15 + 20
} else {
node_scan_base_cost.saturating_mul(3)
};
let vector_selectivity_ppm = (handle
.hnsw_total_vectors
.min(scan_range)
.saturating_mul(1_000_000)
/ scan_range.max(1))
.min(1_000_000) as u32;
let graph_selectivity_ppm = (sstable_count.saturating_mul(100_000)).min(800_000) as u32;
let skew_penalty_cost = if handle.pending_deltas_per_node.is_empty() {
0
} else {
let max_pending = handle
.pending_deltas_per_node
.values()
.copied()
.max()
.unwrap_or(0) as u64;
let avg_pending = handle
.pending_deltas_per_node
.values()
.map(|&v| v as u64)
.sum::<u64>()
/ handle.pending_deltas_per_node.len() as u64;
max_pending.saturating_sub(avg_pending).saturating_mul(2)
};
let ts = now_millis();
PlannerStatsSnapshot {
schema_version: PLANNER_STATS_SCHEMA_VERSION,
stats_version: ts,
collected_at_millis: ts,
ttl_millis: 60_000,
node_scan_base_cost,
vector_scan_base_cost,
filter_base_cost: 10,
vector_selectivity_ppm,
graph_selectivity_ppm,
skew_penalty_cost,
}
}
pub fn set_planner_stats(stats: PlannerStatsSnapshot) -> Result<()> {
if stats.schema_version != PLANNER_STATS_SCHEMA_VERSION {
return Err(ExplainError::InvalidPlan(format!(
"unsupported planner stats schema version {} (expected {})",
stats.schema_version, PLANNER_STATS_SCHEMA_VERSION
)));
}
if stats.vector_selectivity_ppm > 1_000_000 || stats.graph_selectivity_ppm > 1_000_000 {
return Err(ExplainError::InvalidPlan(
"planner selectivity ppm must be <= 1000000".to_string(),
));
}
if stats.ttl_millis == 0 {
return Err(ExplainError::InvalidPlan(
"planner stats ttl_millis must be > 0".to_string(),
));
}
let mut guard = planner_stats_cell()
.write()
.map_err(|_| ExplainError::InvalidPlan("planner stats lock poisoned".to_string()))?;
*guard = Some(stats);
Ok(())
}
pub fn clear_planner_stats() {
if let Ok(mut guard) = planner_stats_cell().write() {
*guard = None;
}
}
pub fn planner_stats_snapshot() -> Option<PlannerStatsSnapshot> {
planner_stats_cell()
.read()
.ok()
.and_then(|guard| guard.clone())
}
pub fn explain(typed: &TypedAst) -> Result<ExplainPlan> {
let ast = &typed.ast;
if ast.match_alias.is_empty() {
return Err(ExplainError::InvalidPlan(
"MATCH alias must not be empty".to_string(),
));
}
let has_scalar_predicate = typed.scalar_predicate.is_some();
let has_vector_predicate = ast.where_predicate.is_some();
let has_bitmap_predicate = ast.where_bitmap_predicate.is_some();
let stats = planner_stats_snapshot().filter(is_stats_fresh);
let (bitmap_first, vector_first, planner_mode, stats_version) =
choose_strategy(has_vector_predicate, has_bitmap_predicate, stats.as_ref());
let strategy = if bitmap_first {
"bitmap-first"
} else if vector_first {
"vector-first"
} else {
"graph-first"
}
.to_string();
let match_projection = ast.match_aliases.join(", ");
let mut logical_ops = vec![format!("Match({})", match_projection)];
if let Some(pred) = &ast.where_bitmap_predicate {
logical_ops.push(format!(
"BitmapFilter({} = {})",
pred.index_name, pred.value_key
));
}
if let Some(pred) = &ast.where_predicate {
logical_ops.push(format!(
"Filter({} {} {})",
pred.function, pred.operator, pred.threshold
));
}
if let Some(pred) = &typed.scalar_predicate {
logical_ops.push(format!(
"ScalarFilter({} {} {})",
pred.field, pred.operator, pred.value
));
}
if !ast.with_items.is_empty() {
let with_items = ast
.with_items
.iter()
.map(format_projection)
.collect::<Vec<_>>()
.join(", ");
logical_ops.push(format!("With({})", with_items));
}
let return_items = ast
.return_items
.iter()
.map(format_projection)
.collect::<Vec<_>>()
.join(", ");
logical_ops.push(format!("Return({})", return_items));
if let Some(limit) = ast.limit {
logical_ops.push(format!("Limit({})", limit));
}
let mut physical_ops = vec![if bitmap_first {
PhysicalOp::BitmapScan
} else if vector_first {
PhysicalOp::VectorScan
} else {
PhysicalOp::NodeScan
}];
if ast.where_predicate.is_some() || ast.where_bitmap_predicate.is_some() || has_scalar_predicate
{
physical_ops.push(PhysicalOp::Filter);
}
physical_ops.push(PhysicalOp::Project);
if ast.limit.is_some() {
physical_ops.push(PhysicalOp::Limit);
}
let mut estimated_cost: u64 = if let Some(snapshot) = stats.as_ref() {
estimate_cost(
snapshot,
bitmap_first,
vector_first,
has_vector_predicate,
has_bitmap_predicate,
)
} else {
let mut base: u64 = if bitmap_first {
70
} else if vector_first {
90
} else {
140
};
if has_vector_predicate {
base += 20;
}
if has_bitmap_predicate {
base = base.saturating_sub(20);
}
base
};
if let Some(limit) = ast.limit {
estimated_cost = estimated_cost.saturating_sub(limit.min(100) / 5);
}
let predicate = if let Some(pred) = &ast.where_predicate {
let threshold = pred
.threshold
.parse::<f64>()
.map_err(|_| ExplainError::InvalidPlan("invalid predicate threshold".to_string()))?;
let metric = storage_api::VectorMetric::from_function(&pred.function).ok_or_else(|| {
ExplainError::InvalidPlan(format!(
"unsupported vector predicate function '{}'",
pred.function
))
})?;
Some(ExecutablePredicate {
metric,
param: pred.param.clone(),
inline_vector: storage_api::parse_inline_vector_param(&pred.param),
operator: pred.operator.clone(),
threshold,
})
} else {
None
};
let bitmap_predicate =
ast.where_bitmap_predicate
.as_ref()
.map(|pred| ExecutableBitmapPredicate {
index_name: pred.index_name.clone(),
value_key: pred.value_key.clone(),
});
let scalar_predicate = typed
.scalar_predicate
.as_ref()
.map(|pred| ExecutableScalarPredicate {
field: pred.field.clone(),
operator: pred.operator.clone(),
value: pred.value,
});
Ok(ExplainPlan {
strategy,
planner_mode,
stats_version,
logical_ops,
physical_ops,
estimated_cost,
limit: ast.limit,
predicate,
bitmap_predicate,
scalar_predicate,
})
}
fn choose_strategy(
has_vector_predicate: bool,
has_bitmap_predicate: bool,
stats: Option<&PlannerStatsSnapshot>,
) -> (bool, bool, String, Option<u64>) {
if has_bitmap_predicate {
return (
true,
false,
"heuristic".to_string(),
stats.map(|s| s.stats_version),
);
}
if !has_vector_predicate {
return (
false,
false,
"heuristic".to_string(),
stats.map(|s| s.stats_version),
);
}
if let Some(snapshot) = stats {
let vector_cost = weighted_cost(
snapshot.vector_scan_base_cost,
snapshot.vector_selectivity_ppm,
)
.saturating_add(snapshot.filter_base_cost);
let graph_cost =
weighted_cost(snapshot.node_scan_base_cost, snapshot.graph_selectivity_ppm)
.saturating_add(snapshot.filter_base_cost)
.saturating_add(snapshot.skew_penalty_cost);
return (
false,
vector_cost <= graph_cost,
"cbo".to_string(),
Some(snapshot.stats_version),
);
}
(false, true, "heuristic".to_string(), None)
}
fn estimate_cost(
stats: &PlannerStatsSnapshot,
bitmap_first: bool,
vector_first: bool,
has_vector_predicate: bool,
has_bitmap_predicate: bool,
) -> u64 {
let mut base = if bitmap_first {
weighted_cost(stats.node_scan_base_cost, 100_000)
} else if vector_first {
weighted_cost(stats.vector_scan_base_cost, stats.vector_selectivity_ppm)
} else {
weighted_cost(stats.node_scan_base_cost, stats.graph_selectivity_ppm)
.saturating_add(stats.skew_penalty_cost)
};
if has_vector_predicate {
base = base.saturating_add(stats.filter_base_cost);
}
if has_bitmap_predicate {
base = base.saturating_sub(stats.filter_base_cost / 2);
}
base
}
fn weighted_cost(base: u64, selectivity_ppm: u32) -> u64 {
let scaled = (base as u128)
.saturating_mul(selectivity_ppm.max(1) as u128)
.saturating_div(1_000_000_u128);
scaled.min(u64::MAX as u128) as u64
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
.min(u64::MAX as u128) as u64
}
fn is_stats_fresh(stats: &PlannerStatsSnapshot) -> bool {
let age = now_millis().saturating_sub(stats.collected_at_millis);
age <= stats.ttl_millis
}
fn format_projection(item: &crate::features::query::ProjectionItem) -> String {
match item {
crate::features::query::ProjectionItem::Identifier(value) => value.clone(),
crate::features::query::ProjectionItem::Function {
name,
argument,
alias,
} => {
if let Some(alias) = alias {
format!("{}({}) AS {}", name, argument, alias)
} else {
format!("{}({})", name, argument)
}
}
}
}