use super::physical::PhysicalOp;
use super::stats::Statistics;
const DEFAULT_FILTER_SELECTIVITY: f64 = 0.1;
const DEFAULT_DISTINCT_RATIO: f64 = 0.5;
const DEFAULT_JOIN_SELECTIVITY: f64 = 0.1;
const ESTIMATED_FIELDS_PER_NODE: usize = 10;
const ESTIMATED_FIELDS_PER_TEMPORAL_NODE: usize = 20;
const BATCH_IO_ROWS: f64 = 1000.0;
const FILTERED_SEARCH_OVERHEAD: f64 = 2.0;
const TEMPORAL_SNAPSHOT_OVERHEAD: f64 = 1.5;
#[derive(Debug, Clone, Copy, Default)]
pub struct Cost {
pub cpu: f64,
pub io: f64,
pub memory: usize,
pub network: f64,
}
impl Cost {
#[must_use]
pub fn zero() -> Self {
Cost::default()
}
#[must_use]
pub fn total(&self, weights: &CostWeights) -> f64 {
self.cpu * weights.cpu_weight
+ self.io * weights.io_weight
+ (self.memory as f64) * weights.memory_weight
+ self.network * weights.network_weight
}
#[must_use]
pub fn add(&self, other: &Cost) -> Cost {
Cost {
cpu: self.cpu + other.cpu,
io: self.io + other.io,
memory: self.memory.saturating_add(other.memory),
network: self.network + other.network,
}
}
#[must_use]
pub fn scale(&self, factor: f64) -> Cost {
Cost {
cpu: self.cpu * factor,
io: self.io * factor,
memory: (self.memory as f64 * factor) as usize,
network: self.network * factor,
}
}
}
impl std::ops::Add for Cost {
type Output = Cost;
fn add(self, rhs: Self) -> Self::Output {
Cost::add(&self, &rhs)
}
}
#[derive(Debug, Clone)]
pub struct CostWeights {
pub cpu_weight: f64,
pub io_weight: f64,
pub memory_weight: f64,
pub network_weight: f64,
}
impl Default for CostWeights {
fn default() -> Self {
CostWeights {
cpu_weight: 1.0,
io_weight: 10.0, memory_weight: 0.001, network_weight: 100.0, }
}
}
#[derive(Debug, Clone)]
pub struct OperationCosts {
pub node_lookup: f64,
pub single_hop_traversal: f64,
pub hnsw_search_per_k: f64,
pub hnsw_log_factor: f64,
pub temporal_delta: f64,
pub filter_eval: f64,
pub vector_similarity: f64,
pub sort_per_element: f64,
pub hash_build: f64,
pub hash_probe: f64,
pub lock_acquisition: f64,
pub batch_threshold: usize,
}
impl Default for OperationCosts {
fn default() -> Self {
OperationCosts {
node_lookup: 0.5,
single_hop_traversal: 1.0,
hnsw_search_per_k: 0.3,
hnsw_log_factor: 1.0,
temporal_delta: 10.0,
filter_eval: 0.1,
vector_similarity: 0.5,
sort_per_element: 0.01,
hash_build: 0.1,
hash_probe: 0.05,
lock_acquisition: 2.0,
batch_threshold: 100,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CostModel {
pub weights: CostWeights,
pub operation_costs: OperationCosts,
}
impl CostModel {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn low_latency() -> Self {
CostModel {
weights: CostWeights {
cpu_weight: 1.0,
io_weight: 20.0, memory_weight: 0.0001,
network_weight: 1000.0,
},
operation_costs: OperationCosts::default(),
}
}
#[must_use]
pub fn high_throughput() -> Self {
CostModel {
weights: CostWeights {
cpu_weight: 0.5,
io_weight: 5.0,
memory_weight: 0.01, network_weight: 50.0,
},
operation_costs: OperationCosts::default(),
}
}
#[must_use]
pub fn should_use_batch_temporal_lookup(&self, node_count: usize) -> bool {
node_count >= self.operation_costs.batch_threshold
}
#[must_use]
pub fn estimate(&self, op: &PhysicalOp, stats: &Statistics) -> Cost {
match op {
PhysicalOp::NodeLookup { node_ids } => self.estimate_node_lookup(node_ids.len()),
PhysicalOp::NodeScan { estimated_rows, .. } => self.estimate_node_scan(*estimated_rows),
PhysicalOp::EdgeScan { estimated_rows, .. } => self.estimate_edge_scan(*estimated_rows),
PhysicalOp::PropertyScan { estimated_rows, .. } => {
self.estimate_node_lookup(*estimated_rows)
}
PhysicalOp::HnswSearch {
k, label_filter, ..
} => self.estimate_hnsw_search(*k, label_filter.is_some(), stats),
PhysicalOp::TemporalNodeLookup {
node_ids,
use_batch,
..
} => self.estimate_temporal_lookup(node_ids.len(), *use_batch, stats),
PhysicalOp::TemporalVectorSearch { k, .. } => {
self.estimate_temporal_vector_search(*k, stats)
}
PhysicalOp::IndexedTraversal { input, depth, .. } => {
let input_cost = self.estimate(input, stats);
let input_card = self.estimate_cardinality(input, stats);
self.estimate_traversal(input_cost, input_card, *depth, stats)
}
PhysicalOp::Filter { input, .. } => {
let input_cost = self.estimate(input, stats);
let input_card = self.estimate_cardinality(input, stats);
self.estimate_filter(input_cost, input_card)
}
PhysicalOp::VectorRerank { input, k, .. } => {
let input_cost = self.estimate(input, stats);
let input_card = self.estimate_cardinality(input, stats);
self.estimate_vector_rerank(input_cost, input_card, *k)
}
PhysicalOp::Sort { input, .. } => {
let input_cost = self.estimate(input, stats);
let input_card = self.estimate_cardinality(input, stats);
self.estimate_sort(input_cost, input_card)
}
PhysicalOp::Limit { input, .. } => {
self.estimate(input, stats)
}
PhysicalOp::HashJoin { left, right, .. } => {
let left_cost = self.estimate(left, stats);
let right_cost = self.estimate(right, stats);
let left_card = self.estimate_cardinality(left, stats);
let right_card = self.estimate_cardinality(right, stats);
self.estimate_hash_join(left_cost, right_cost, left_card, right_card)
}
PhysicalOp::Union { left, right }
| PhysicalOp::Intersect { left, right }
| PhysicalOp::Except { left, right } => {
let left_cost = self.estimate(left, stats);
let right_cost = self.estimate(right, stats);
left_cost + right_cost
}
PhysicalOp::Project { input, .. }
| PhysicalOp::Distinct { input }
| PhysicalOp::Count { input }
| PhysicalOp::Materialize { input }
| PhysicalOp::TemporalTrack { input, .. } => self.estimate(input, stats),
PhysicalOp::SimilarToNode {
k, label_filter, ..
} => {
let lookup_cost = self.estimate_node_lookup(1);
let search_cost = self.estimate_hnsw_search(*k, label_filter.is_some(), stats);
lookup_cost + search_cost
}
PhysicalOp::Empty => Cost::zero(),
}
}
#[must_use]
pub fn estimate_cardinality(&self, op: &PhysicalOp, stats: &Statistics) -> usize {
match op {
PhysicalOp::NodeLookup { node_ids } => node_ids.len(),
PhysicalOp::NodeScan { estimated_rows, .. } => *estimated_rows,
PhysicalOp::EdgeScan { estimated_rows, .. } => *estimated_rows,
PhysicalOp::PropertyScan { estimated_rows, .. } => *estimated_rows,
PhysicalOp::HnswSearch { k, .. } => *k,
PhysicalOp::TemporalNodeLookup { node_ids, .. } => node_ids.len(),
PhysicalOp::TemporalVectorSearch { k, .. } => *k,
PhysicalOp::IndexedTraversal { input, depth, .. } => {
let input_card = self.estimate_cardinality(input, stats);
let avg_degree = stats.average_out_degree();
(input_card as f64 * avg_degree.powi(*depth as i32)) as usize
}
PhysicalOp::Filter { input, .. } => {
(self.estimate_cardinality(input, stats) as f64 * DEFAULT_FILTER_SELECTIVITY)
as usize
}
PhysicalOp::VectorRerank { k, .. } => *k,
PhysicalOp::Limit { count, input, .. } => {
(*count).min(self.estimate_cardinality(input, stats))
}
PhysicalOp::Sort { input, .. } | PhysicalOp::Project { input, .. } => {
self.estimate_cardinality(input, stats)
}
PhysicalOp::Distinct { input } => {
(self.estimate_cardinality(input, stats) as f64 * DEFAULT_DISTINCT_RATIO) as usize
}
PhysicalOp::Count { .. } => 1,
PhysicalOp::HashJoin { left, right, .. } => {
let left_card = self.estimate_cardinality(left, stats);
let right_card = self.estimate_cardinality(right, stats);
(left_card as f64 * right_card as f64 * DEFAULT_JOIN_SELECTIVITY) as usize
}
PhysicalOp::Union { left, right } => {
self.estimate_cardinality(left, stats) + self.estimate_cardinality(right, stats)
}
PhysicalOp::Intersect { left, right } => self
.estimate_cardinality(left, stats)
.min(self.estimate_cardinality(right, stats)),
PhysicalOp::Except { left, .. } => self.estimate_cardinality(left, stats),
PhysicalOp::Materialize { input } | PhysicalOp::TemporalTrack { input, .. } => {
self.estimate_cardinality(input, stats)
}
PhysicalOp::SimilarToNode { k, .. } => *k,
PhysicalOp::Empty => 0,
}
}
fn estimate_node_lookup(&self, count: usize) -> Cost {
Cost {
cpu: self.operation_costs.node_lookup * count as f64,
io: 0.0, memory: count * std::mem::size_of::<u64>() * ESTIMATED_FIELDS_PER_NODE, network: 0.0,
}
}
fn estimate_node_scan(&self, estimated_rows: usize) -> Cost {
Cost {
cpu: self.operation_costs.node_lookup * estimated_rows as f64,
io: (estimated_rows as f64 / BATCH_IO_ROWS).ceil(), memory: estimated_rows * std::mem::size_of::<u64>() * ESTIMATED_FIELDS_PER_NODE,
network: 0.0,
}
}
fn estimate_edge_scan(&self, estimated_rows: usize) -> Cost {
Cost {
cpu: self.operation_costs.node_lookup * estimated_rows as f64,
io: (estimated_rows as f64 / BATCH_IO_ROWS).ceil(),
memory: estimated_rows * std::mem::size_of::<u64>() * ESTIMATED_FIELDS_PER_NODE,
network: 0.0,
}
}
fn estimate_hnsw_search(&self, k: usize, has_filter: bool, stats: &Statistics) -> Cost {
let vector_count = stats.vector_count().max(1) as f64;
let log_factor = vector_count.log2().max(1.0);
let filter_overhead = if has_filter {
FILTERED_SEARCH_OVERHEAD
} else {
1.0
};
Cost {
cpu: self.operation_costs.hnsw_search_per_k
* k as f64
* log_factor
* self.operation_costs.hnsw_log_factor
* filter_overhead,
io: 0.0,
memory: k * std::mem::size_of::<(u64, f32)>(),
network: 0.0,
}
}
fn estimate_temporal_lookup(&self, count: usize, use_batch: bool, stats: &Statistics) -> Cost {
let avg_delta_chain = stats.average_delta_chain_length().max(1.0);
let lock_overhead = if use_batch {
self.operation_costs.lock_acquisition
} else {
self.operation_costs.lock_acquisition * count as f64
};
Cost {
cpu: lock_overhead
+ self.operation_costs.temporal_delta * avg_delta_chain * count as f64,
io: avg_delta_chain * count as f64,
memory: count * std::mem::size_of::<u64>() * ESTIMATED_FIELDS_PER_TEMPORAL_NODE, network: 0.0,
}
}
fn estimate_temporal_vector_search(&self, k: usize, stats: &Statistics) -> Cost {
let base = self.estimate_hnsw_search(k, false, stats);
Cost {
cpu: base.cpu * TEMPORAL_SNAPSHOT_OVERHEAD, io: base.io + 1.0, memory: base.memory,
network: 0.0,
}
}
fn estimate_traversal(
&self,
input_cost: Cost,
input_card: usize,
depth: usize,
stats: &Statistics,
) -> Cost {
let avg_degree = stats.average_out_degree();
let traversal_factor = avg_degree.powi(depth as i32);
Cost {
cpu: input_cost.cpu
+ self.operation_costs.single_hop_traversal * input_card as f64 * traversal_factor,
io: input_cost.io,
memory: input_cost.memory
+ (input_card as f64 * traversal_factor) as usize * std::mem::size_of::<u64>(),
network: 0.0,
}
}
fn estimate_filter(&self, input_cost: Cost, input_card: usize) -> Cost {
Cost {
cpu: input_cost.cpu + self.operation_costs.filter_eval * input_card as f64,
io: input_cost.io,
memory: input_cost.memory,
network: 0.0,
}
}
fn estimate_vector_rerank(&self, input_cost: Cost, input_card: usize, k: usize) -> Cost {
Cost {
cpu: input_cost.cpu + self.operation_costs.vector_similarity * input_card as f64,
io: input_cost.io,
memory: input_cost.memory + k * std::mem::size_of::<(u64, f32)>(),
network: 0.0,
}
}
fn estimate_sort(&self, input_cost: Cost, input_card: usize) -> Cost {
let n = input_card.max(1) as f64;
let nlogn = n * n.log2();
Cost {
cpu: input_cost.cpu + self.operation_costs.sort_per_element * nlogn,
io: input_cost.io,
memory: input_cost.memory + input_card * std::mem::size_of::<u64>(),
network: 0.0,
}
}
fn estimate_hash_join(
&self,
left_cost: Cost,
right_cost: Cost,
left_card: usize,
right_card: usize,
) -> Cost {
let (build_card, probe_card) = if left_card < right_card {
(left_card, right_card)
} else {
(right_card, left_card)
};
Cost {
cpu: left_cost.cpu
+ right_cost.cpu
+ self.operation_costs.hash_build * build_card as f64
+ self.operation_costs.hash_probe * probe_card as f64,
io: left_cost.io + right_cost.io,
memory: left_cost.memory
+ right_cost.memory
+ build_card * std::mem::size_of::<u64>() * 2,
network: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::NodeId;
use std::sync::Arc;
fn test_stats() -> Statistics {
Statistics::default()
}
#[test]
fn test_cost_arithmetic() {
let a = Cost {
cpu: 1.0,
io: 2.0,
memory: 100,
network: 0.0,
};
let b = Cost {
cpu: 0.5,
io: 1.0,
memory: 50,
network: 0.0,
};
let sum = a.add(&b);
assert_eq!(sum.cpu, 1.5);
assert_eq!(sum.io, 3.0);
assert_eq!(sum.memory, 150);
let scaled = a.scale(2.0);
assert_eq!(scaled.cpu, 2.0);
assert_eq!(scaled.io, 4.0);
assert_eq!(scaled.memory, 200);
}
#[test]
fn test_cost_total() {
let cost = Cost {
cpu: 10.0,
io: 5.0,
memory: 1000,
network: 0.0,
};
let weights = CostWeights::default();
let total = cost.total(&weights);
assert!((total - 61.0).abs() < 0.01);
}
#[test]
fn test_node_lookup_cost() {
let model = CostModel::default();
let stats = test_stats();
let op = PhysicalOp::NodeLookup {
node_ids: vec![NodeId::new(1).unwrap(), NodeId::new(2).unwrap()],
};
let cost = model.estimate(&op, &stats);
assert!(cost.cpu > 0.0);
assert_eq!(cost.io, 0.0); }
#[test]
fn test_hnsw_search_cost() {
let model = CostModel::default();
let stats = test_stats();
let op = PhysicalOp::HnswSearch {
embedding: Arc::from([0.1f32; 4].as_slice()),
k: 10,
label_filter: None,
property_key: None,
};
let cost = model.estimate(&op, &stats);
assert!(cost.cpu > 0.0);
}
#[test]
fn test_traversal_cost() {
let model = CostModel::default();
let stats = test_stats();
let op = PhysicalOp::IndexedTraversal {
input: Box::new(PhysicalOp::NodeLookup {
node_ids: vec![NodeId::new(1).unwrap()],
}),
direction: crate::query::ir::Direction::Outgoing,
label: None,
depth: 2,
temporal_context: None,
};
let cost = model.estimate(&op, &stats);
assert!(cost.cpu > 0.0);
}
#[test]
fn test_cardinality_estimation() {
let model = CostModel::default();
let stats = test_stats();
let lookup = PhysicalOp::NodeLookup {
node_ids: vec![NodeId::new(1).unwrap(), NodeId::new(2).unwrap()],
};
assert_eq!(model.estimate_cardinality(&lookup, &stats), 2);
let search = PhysicalOp::HnswSearch {
embedding: Arc::from([0.1f32; 4].as_slice()),
k: 10,
label_filter: None,
property_key: None,
};
assert_eq!(model.estimate_cardinality(&search, &stats), 10);
let limit = PhysicalOp::Limit {
input: Box::new(search),
count: 5,
offset: 0,
};
assert_eq!(model.estimate_cardinality(&limit, &stats), 5);
}
#[test]
fn test_edge_scan_cost() {
let model = CostModel::default();
let stats = test_stats();
let op = PhysicalOp::EdgeScan {
edge_type: Some("KNOWS".to_string()),
estimated_rows: 500,
};
let cost = model.estimate(&op, &stats);
assert!(cost.cpu > 0.0, "EdgeScan should have positive CPU cost");
assert!(cost.io > 0.0, "EdgeScan should have positive I/O cost");
assert!(cost.memory > 0, "EdgeScan should have positive memory cost");
}
#[test]
fn test_edge_scan_cardinality() {
let model = CostModel::default();
let stats = test_stats();
let op = PhysicalOp::EdgeScan {
edge_type: None,
estimated_rows: 750,
};
assert_eq!(model.estimate_cardinality(&op, &stats), 750);
}
#[test]
fn test_edge_scan_cost_matches_node_scan_for_now() {
let model = CostModel::default();
let stats = test_stats();
let rows = 200;
let node_cost = model.estimate(
&PhysicalOp::NodeScan {
label: None,
estimated_rows: rows,
},
&stats,
);
let edge_cost = model.estimate(
&PhysicalOp::EdgeScan {
edge_type: None,
estimated_rows: rows,
},
&stats,
);
assert_eq!(node_cost.cpu, edge_cost.cpu);
assert_eq!(node_cost.io, edge_cost.io);
assert_eq!(node_cost.memory, edge_cost.memory);
}
}