use crate::plan::physical::PhysicalNode;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostEstimate {
pub cpu_cost: f64,
pub io_cost: f64,
pub memory_cost: f64,
pub network_cost: f64,
pub total_time: f64,
}
#[derive(Debug, Clone)]
pub struct CostModel {
pub cpu_cost_per_row: f64,
pub io_cost_per_page: f64,
pub memory_cost_per_byte: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Statistics {
pub total_nodes: usize,
pub total_edges: usize,
pub node_counts: HashMap<String, usize>,
pub edge_counts: HashMap<String, usize>,
pub average_degree: f64,
pub max_degree: usize,
pub property_selectivity: HashMap<String, f64>,
pub available_indices: Vec<IndexInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexInfo {
pub name: String,
pub index_type: IndexType,
pub entity_type: EntityType,
pub properties: Vec<String>,
pub cardinality: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IndexType {
Hash,
BTree,
Label,
Property,
Composite,
TextInverted, TextBM25, TextNGram, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EntityType {
Node,
Edge,
}
impl CostEstimate {
pub fn new() -> Self {
Self {
cpu_cost: 0.0,
io_cost: 0.0,
memory_cost: 0.0,
network_cost: 0.0,
total_time: 0.0,
}
}
pub fn add(&mut self, other: &CostEstimate) {
self.cpu_cost += other.cpu_cost;
self.io_cost += other.io_cost;
self.memory_cost += other.memory_cost;
self.network_cost += other.network_cost;
self.total_time += other.total_time;
}
pub fn total_cost(&self) -> f64 {
let cpu_weight = 1.0;
let io_weight = 10.0; let memory_weight = 0.1;
let network_weight = 5.0;
self.cpu_cost * cpu_weight
+ self.io_cost * io_weight
+ self.memory_cost * memory_weight
+ self.network_cost * network_weight
}
}
impl Default for CostEstimate {
fn default() -> Self {
Self::new()
}
}
impl CostModel {
pub fn new() -> Self {
Self {
cpu_cost_per_row: 0.001, io_cost_per_page: 0.01, memory_cost_per_byte: 0.000001, }
}
pub fn estimate_node_cost(&self, node: &PhysicalNode, stats: &Statistics) -> CostEstimate {
match node {
PhysicalNode::NodeSeqScan {
labels,
estimated_rows,
..
} => self.estimate_scan_cost(*estimated_rows, labels, stats, true),
PhysicalNode::NodeIndexScan {
labels,
estimated_rows,
..
} => self.estimate_scan_cost(*estimated_rows, labels, stats, false),
PhysicalNode::EdgeSeqScan {
labels,
estimated_rows,
..
} => self.estimate_scan_cost(*estimated_rows, labels, stats, true),
PhysicalNode::IndexedExpand {
input,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(input, stats);
let expand_cost = self.estimate_expand_cost(*estimated_rows, false);
cost.add(&expand_cost);
cost
}
PhysicalNode::HashExpand {
input,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(input, stats);
let expand_cost = self.estimate_expand_cost(*estimated_rows, true);
cost.add(&expand_cost);
cost
}
PhysicalNode::Filter {
input, selectivity, ..
} => {
let mut cost = self.estimate_node_cost(input, stats);
let filter_cost = self.estimate_filter_cost(input.get_row_count(), *selectivity);
cost.add(&filter_cost);
cost
}
PhysicalNode::Project {
input,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(input, stats);
let project_cost = self.estimate_project_cost(*estimated_rows);
cost.add(&project_cost);
cost
}
PhysicalNode::HashJoin {
build,
probe,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(build, stats);
cost.add(&self.estimate_node_cost(probe, stats));
let join_cost = self.estimate_join_cost(
build.get_row_count(),
probe.get_row_count(),
*estimated_rows,
JoinAlgorithm::Hash,
);
cost.add(&join_cost);
cost
}
PhysicalNode::NestedLoopJoin {
left,
right,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(left, stats);
cost.add(&self.estimate_node_cost(right, stats));
let join_cost = self.estimate_join_cost(
left.get_row_count(),
right.get_row_count(),
*estimated_rows,
JoinAlgorithm::NestedLoop,
);
cost.add(&join_cost);
cost
}
PhysicalNode::SortMergeJoin {
left,
right,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(left, stats);
cost.add(&self.estimate_node_cost(right, stats));
let join_cost = self.estimate_join_cost(
left.get_row_count(),
right.get_row_count(),
*estimated_rows,
JoinAlgorithm::SortMerge,
);
cost.add(&join_cost);
cost
}
PhysicalNode::ExternalSort {
input,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(input, stats);
let sort_cost = self.estimate_sort_cost(*estimated_rows, true);
cost.add(&sort_cost);
cost
}
PhysicalNode::InMemorySort {
input,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(input, stats);
let sort_cost = self.estimate_sort_cost(*estimated_rows, false);
cost.add(&sort_cost);
cost
}
PhysicalNode::Limit { input, count, .. } => {
let input_cost = self.estimate_node_cost(input, stats);
let input_rows = input.get_row_count();
let ratio = (*count as f64) / (input_rows as f64).max(1.0);
CostEstimate {
cpu_cost: input_cost.cpu_cost * ratio,
io_cost: input_cost.io_cost * ratio,
memory_cost: input_cost.memory_cost,
network_cost: input_cost.network_cost * ratio,
total_time: input_cost.total_time * ratio,
}
}
PhysicalNode::GraphIndexScan { estimated_rows, .. } => {
let base_cost = *estimated_rows as f64 * self.cpu_cost_per_row * 0.05; let io_cost = (*estimated_rows / 10000) as f64 * self.io_cost_per_page * 0.1;
CostEstimate {
cpu_cost: base_cost,
io_cost,
memory_cost: *estimated_rows as f64 * 100.0 * self.memory_cost_per_byte, network_cost: 0.0,
total_time: base_cost + io_cost,
}
}
PhysicalNode::IndexJoin {
left,
right,
estimated_rows,
..
} => {
let mut cost = self.estimate_node_cost(left, stats);
cost.add(&self.estimate_node_cost(right, stats));
let join_cost = self.estimate_join_cost(
left.get_row_count(),
right.get_row_count(),
*estimated_rows,
JoinAlgorithm::IndexNL, );
cost.add(&join_cost);
cost
}
PhysicalNode::SingleRow { .. } => {
CostEstimate {
cpu_cost: 0.0001, io_cost: 0.0, memory_cost: 0.0001, network_cost: 0.0, total_time: 0.0001, }
}
_ => CostEstimate::new(), }
}
fn estimate_scan_cost(
&self,
rows: usize,
_labels: &[String],
_stats: &Statistics,
is_sequential: bool,
) -> CostEstimate {
let base_cpu_cost = rows as f64 * self.cpu_cost_per_row;
let cpu_multiplier = if is_sequential { 1.0 } else { 0.3 };
let io_cost = if is_sequential {
(rows / 1000) as f64 * self.io_cost_per_page } else {
(rows / 10000) as f64 * self.io_cost_per_page
};
CostEstimate {
cpu_cost: base_cpu_cost * cpu_multiplier,
io_cost,
memory_cost: (rows * 100) as f64 * self.memory_cost_per_byte, network_cost: 0.0,
total_time: base_cpu_cost * cpu_multiplier + io_cost,
}
}
fn estimate_expand_cost(&self, rows: usize, use_hash: bool) -> CostEstimate {
let base_cost = rows as f64 * self.cpu_cost_per_row * 2.0; let memory_multiplier = if use_hash { 2.0 } else { 1.0 };
CostEstimate {
cpu_cost: base_cost,
io_cost: (rows / 5000) as f64 * self.io_cost_per_page, memory_cost: (rows * 50) as f64 * self.memory_cost_per_byte * memory_multiplier,
network_cost: 0.0,
total_time: base_cost,
}
}
fn estimate_filter_cost(&self, input_rows: usize, _selectivity: f64) -> CostEstimate {
let cpu_cost = input_rows as f64 * self.cpu_cost_per_row * 0.5;
CostEstimate {
cpu_cost,
io_cost: 0.0, memory_cost: 0.0, network_cost: 0.0,
total_time: cpu_cost,
}
}
fn estimate_project_cost(&self, rows: usize) -> CostEstimate {
let cpu_cost = rows as f64 * self.cpu_cost_per_row * 0.2;
CostEstimate {
cpu_cost,
io_cost: 0.0,
memory_cost: 0.0,
network_cost: 0.0,
total_time: cpu_cost,
}
}
fn estimate_join_cost(
&self,
left_rows: usize,
right_rows: usize,
_output_rows: usize,
algorithm: JoinAlgorithm,
) -> CostEstimate {
let (cpu_multiplier, memory_multiplier) = match algorithm {
JoinAlgorithm::Hash => (1.5, 2.0), JoinAlgorithm::NestedLoop => (left_rows as f64, 0.1), JoinAlgorithm::SortMerge => {
((left_rows as f64).log2() + (right_rows as f64).log2(), 1.0)
} JoinAlgorithm::IndexNL => (0.8, 0.5), };
let base_cost = (left_rows + right_rows) as f64 * self.cpu_cost_per_row;
CostEstimate {
cpu_cost: base_cost * cpu_multiplier,
io_cost: ((left_rows + right_rows) / 1000) as f64 * self.io_cost_per_page,
memory_cost: (left_rows.max(right_rows) * 100) as f64
* self.memory_cost_per_byte
* memory_multiplier,
network_cost: 0.0,
total_time: base_cost * cpu_multiplier,
}
}
fn estimate_sort_cost(&self, rows: usize, external: bool) -> CostEstimate {
let n_log_n = rows as f64 * (rows as f64).log2();
let cpu_cost = n_log_n * self.cpu_cost_per_row * 0.01;
let (io_multiplier, memory_multiplier) = if external {
(3.0, 0.5) } else {
(0.0, 2.0) };
CostEstimate {
cpu_cost,
io_cost: (rows / 1000) as f64 * self.io_cost_per_page * io_multiplier,
memory_cost: (rows * 100) as f64 * self.memory_cost_per_byte * memory_multiplier,
network_cost: 0.0,
total_time: cpu_cost,
}
}
}
#[derive(Debug, Clone)]
enum JoinAlgorithm {
Hash,
NestedLoop,
SortMerge,
IndexNL, }
impl Default for CostModel {
fn default() -> Self {
Self::new()
}
}
impl Statistics {
pub fn new() -> Self {
Self {
total_nodes: 0,
total_edges: 0,
node_counts: HashMap::new(),
edge_counts: HashMap::new(),
average_degree: 0.0,
max_degree: 0,
property_selectivity: HashMap::new(),
available_indices: Vec::new(),
}
}
#[allow(dead_code)] pub fn update_from_graph(&mut self, graph: &crate::storage::GraphCache) {
let stats = graph.stats();
self.total_nodes = stats.node_count;
self.total_edges = stats.edge_count;
self.average_degree = if self.total_nodes > 0 {
(2 * self.total_edges) as f64 / self.total_nodes as f64
} else {
0.0
};
self.available_indices.push(IndexInfo {
name: "node_labels".to_string(),
index_type: IndexType::Label,
entity_type: EntityType::Node,
properties: vec![],
cardinality: self.total_nodes,
});
self.available_indices.push(IndexInfo {
name: "edge_labels".to_string(),
index_type: IndexType::Label,
entity_type: EntityType::Edge,
properties: vec![],
cardinality: self.total_edges,
});
self.property_selectivity.insert("id".to_string(), 1.0); self.property_selectivity.insert("label".to_string(), 0.1); self.property_selectivity
.insert("risk_score".to_string(), 0.5); self.property_selectivity.insert("amount".to_string(), 0.3); }
#[allow(dead_code)] pub fn get_property_selectivity(&self, property: &str) -> f64 {
self.property_selectivity
.get(property)
.copied()
.unwrap_or(0.5)
}
#[allow(dead_code)] pub fn has_index(&self, entity_type: &EntityType, properties: &[String]) -> bool {
self.available_indices.iter().any(|index| {
matches!(&index.entity_type, et if std::mem::discriminant(et) == std::mem::discriminant(entity_type)) &&
properties.iter().all(|prop| index.properties.contains(prop))
})
}
}
impl Default for Statistics {
fn default() -> Self {
Self::new()
}
}