1use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16 CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17 PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24 pub rows: f64,
26 pub selectivity: f64,
28 pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33 pub fn new(rows: f64, selectivity: f64) -> Self {
35 Self {
36 rows,
37 selectivity,
38 confidence: 1.0,
39 }
40 }
41
42 pub fn full_scan(table_size: f64) -> Self {
44 Self {
45 rows: table_size,
46 selectivity: 1.0,
47 confidence: 1.0,
48 }
49 }
50
51 pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53 self.rows *= filter_selectivity;
54 self.selectivity *= filter_selectivity;
55 self.confidence *= 0.9; self
57 }
58}
59
60#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71 pub cpu: f64,
73 pub io: f64,
75 pub network: f64,
77 pub memory: f64,
79 pub startup_cost: f64,
85 pub total: f64,
87}
88
89impl PlanCost {
90 pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92 let total = cpu + io * 10.0 + memory * 0.1; Self {
94 cpu,
95 io,
96 network: 0.0,
97 memory,
98 startup_cost: 0.0,
99 total,
100 }
101 }
102
103 pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107 let total = cpu + io * 10.0 + memory * 0.1;
108 Self {
109 cpu,
110 io,
111 network: 0.0,
112 memory,
113 startup_cost: startup_cost.max(0.0),
114 total: total.max(startup_cost),
115 }
116 }
117
118 pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124 PlanCost {
125 cpu: self.cpu + other.cpu,
126 io: self.io + other.io,
127 network: self.network + other.network,
128 memory: self.memory.max(other.memory),
129 startup_cost: self.startup_cost + other.startup_cost,
130 total: self.total + other.total,
131 }
132 }
133
134 pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141 PlanCost {
142 cpu: self.cpu + blocker.cpu,
143 io: self.io + blocker.io,
144 network: self.network + blocker.network,
145 memory: self.memory.max(blocker.memory),
146 startup_cost: self.total + blocker.startup_cost,
147 total: self.total + blocker.total,
148 }
149 }
150
151 pub fn combine(&self, other: &PlanCost) -> PlanCost {
156 self.combine_pipelined(other)
157 }
158
159 pub fn scale(&self, factor: f64) -> PlanCost {
161 PlanCost {
162 cpu: self.cpu * factor,
163 io: self.io * factor,
164 network: self.network * factor,
165 memory: self.memory, startup_cost: self.startup_cost, total: self.total * factor,
168 }
169 }
170
171 pub fn prefer_over(
183 &self,
184 other: &PlanCost,
185 limit: Option<u64>,
186 cardinality: f64,
187 ) -> std::cmp::Ordering {
188 let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189 let (lhs, rhs) = if use_startup {
190 (self.startup_cost, other.startup_cost)
191 } else {
192 (self.total, other.total)
193 };
194 lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195 }
196}
197
198#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201 pub row_count: u64,
203 pub avg_row_size: u32,
205 pub page_count: u64,
207 pub columns: Vec<ColumnStats>,
209}
210
211#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214 pub name: String,
216 pub distinct_count: u64,
218 pub null_count: u64,
220 pub min_value: Option<String>,
222 pub max_value: Option<String>,
224 pub has_index: bool,
226}
227
228pub struct CostEstimator {
230 default_row_count: f64,
232 row_scan_cost: f64,
234 index_lookup_cost: f64,
236 hash_probe_cost: f64,
238 nested_loop_cost: f64,
240 edge_traversal_cost: f64,
242 stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250 pub fn new() -> Self {
253 Self {
254 default_row_count: 1000.0,
255 row_scan_cost: 1.0,
256 index_lookup_cost: 0.1,
257 hash_probe_cost: 0.5,
258 nested_loop_cost: 2.0,
259 edge_traversal_cost: 1.5,
260 stats: Arc::new(NullProvider),
261 }
262 }
263
264 pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268 Self {
269 stats: provider,
270 ..Self::new()
271 }
272 }
273
274 pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278 self.stats = provider;
279 }
280
281 pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283 match query {
284 QueryExpr::Table(tq) => self.estimate_table(tq),
285 QueryExpr::Graph(gq) => self.estimate_graph(gq),
286 QueryExpr::Join(jq) => self.estimate_join(jq),
287 QueryExpr::Path(pq) => self.estimate_path(pq),
288 QueryExpr::Vector(vq) => self.estimate_vector(vq),
289 QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290 QueryExpr::Insert(_)
292 | QueryExpr::Update(_)
293 | QueryExpr::Delete(_)
294 | QueryExpr::CreateTable(_)
295 | QueryExpr::DropTable(_)
296 | QueryExpr::DropGraph(_)
297 | QueryExpr::DropVector(_)
298 | QueryExpr::DropDocument(_)
299 | QueryExpr::DropKv(_)
300 | QueryExpr::DropCollection(_)
301 | QueryExpr::Truncate(_)
302 | QueryExpr::AlterTable(_)
303 | QueryExpr::GraphCommand(_)
304 | QueryExpr::SearchCommand(_)
305 | QueryExpr::CreateIndex(_)
306 | QueryExpr::DropIndex(_)
307 | QueryExpr::ProbabilisticCommand(_)
308 | QueryExpr::Ask(_)
309 | QueryExpr::SetConfig { .. }
310 | QueryExpr::ShowConfig { .. }
311 | QueryExpr::SetSecret { .. }
312 | QueryExpr::DeleteSecret { .. }
313 | QueryExpr::ShowSecrets { .. }
314 | QueryExpr::SetTenant(_)
315 | QueryExpr::ShowTenant
316 | QueryExpr::CreateTimeSeries(_)
317 | QueryExpr::DropTimeSeries(_)
318 | QueryExpr::CreateQueue(_)
319 | QueryExpr::AlterQueue(_)
320 | QueryExpr::DropQueue(_)
321 | QueryExpr::QueueSelect(_)
322 | QueryExpr::QueueCommand(_)
323 | QueryExpr::KvCommand(_)
324 | QueryExpr::ConfigCommand(_)
325 | QueryExpr::CreateTree(_)
326 | QueryExpr::DropTree(_)
327 | QueryExpr::TreeCommand(_)
328 | QueryExpr::ExplainAlter(_)
329 | QueryExpr::TransactionControl(_)
330 | QueryExpr::MaintenanceCommand(_)
331 | QueryExpr::CreateSchema(_)
332 | QueryExpr::DropSchema(_)
333 | QueryExpr::CreateSequence(_)
334 | QueryExpr::DropSequence(_)
335 | QueryExpr::CopyFrom(_)
336 | QueryExpr::CreateView(_)
337 | QueryExpr::DropView(_)
338 | QueryExpr::RefreshMaterializedView(_)
339 | QueryExpr::CreatePolicy(_)
340 | QueryExpr::DropPolicy(_)
341 | QueryExpr::CreateServer(_)
342 | QueryExpr::DropServer(_)
343 | QueryExpr::CreateForeignTable(_)
344 | QueryExpr::DropForeignTable(_)
345 | QueryExpr::Grant(_)
346 | QueryExpr::Revoke(_)
347 | QueryExpr::AlterUser(_)
348 | QueryExpr::CreateIamPolicy { .. }
349 | QueryExpr::DropIamPolicy { .. }
350 | QueryExpr::AttachPolicy { .. }
351 | QueryExpr::DetachPolicy { .. }
352 | QueryExpr::ShowPolicies { .. }
353 | QueryExpr::ShowEffectivePermissions { .. }
354 | QueryExpr::SimulatePolicy { .. }
355 | QueryExpr::CreateMigration(_)
356 | QueryExpr::ApplyMigration(_)
357 | QueryExpr::RollbackMigration(_)
358 | QueryExpr::ExplainMigration(_)
359 | QueryExpr::EventsBackfill(_)
360 | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
361 }
362 }
363
364 pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
366 match query {
367 QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
368 QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
369 QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
370 QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
371 QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
372 QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
373 QueryExpr::Insert(_)
375 | QueryExpr::Update(_)
376 | QueryExpr::Delete(_)
377 | QueryExpr::CreateTable(_)
378 | QueryExpr::DropTable(_)
379 | QueryExpr::DropGraph(_)
380 | QueryExpr::DropVector(_)
381 | QueryExpr::DropDocument(_)
382 | QueryExpr::DropKv(_)
383 | QueryExpr::DropCollection(_)
384 | QueryExpr::Truncate(_)
385 | QueryExpr::AlterTable(_)
386 | QueryExpr::GraphCommand(_)
387 | QueryExpr::SearchCommand(_)
388 | QueryExpr::CreateIndex(_)
389 | QueryExpr::DropIndex(_)
390 | QueryExpr::ProbabilisticCommand(_)
391 | QueryExpr::Ask(_)
392 | QueryExpr::SetConfig { .. }
393 | QueryExpr::ShowConfig { .. }
394 | QueryExpr::SetSecret { .. }
395 | QueryExpr::DeleteSecret { .. }
396 | QueryExpr::ShowSecrets { .. }
397 | QueryExpr::SetTenant(_)
398 | QueryExpr::ShowTenant
399 | QueryExpr::CreateTimeSeries(_)
400 | QueryExpr::DropTimeSeries(_)
401 | QueryExpr::CreateQueue(_)
402 | QueryExpr::AlterQueue(_)
403 | QueryExpr::DropQueue(_)
404 | QueryExpr::QueueSelect(_)
405 | QueryExpr::QueueCommand(_)
406 | QueryExpr::KvCommand(_)
407 | QueryExpr::ConfigCommand(_)
408 | QueryExpr::CreateTree(_)
409 | QueryExpr::DropTree(_)
410 | QueryExpr::TreeCommand(_)
411 | QueryExpr::ExplainAlter(_)
412 | QueryExpr::TransactionControl(_)
413 | QueryExpr::MaintenanceCommand(_)
414 | QueryExpr::CreateSchema(_)
415 | QueryExpr::DropSchema(_)
416 | QueryExpr::CreateSequence(_)
417 | QueryExpr::DropSequence(_)
418 | QueryExpr::CopyFrom(_)
419 | QueryExpr::CreateView(_)
420 | QueryExpr::DropView(_)
421 | QueryExpr::RefreshMaterializedView(_)
422 | QueryExpr::CreatePolicy(_)
423 | QueryExpr::DropPolicy(_)
424 | QueryExpr::CreateServer(_)
425 | QueryExpr::DropServer(_)
426 | QueryExpr::CreateForeignTable(_)
427 | QueryExpr::DropForeignTable(_)
428 | QueryExpr::Grant(_)
429 | QueryExpr::Revoke(_)
430 | QueryExpr::AlterUser(_)
431 | QueryExpr::CreateIamPolicy { .. }
432 | QueryExpr::DropIamPolicy { .. }
433 | QueryExpr::AttachPolicy { .. }
434 | QueryExpr::DetachPolicy { .. }
435 | QueryExpr::ShowPolicies { .. }
436 | QueryExpr::ShowEffectivePermissions { .. }
437 | QueryExpr::SimulatePolicy { .. }
438 | QueryExpr::CreateMigration(_)
439 | QueryExpr::ApplyMigration(_)
440 | QueryExpr::RollbackMigration(_)
441 | QueryExpr::ExplainMigration(_)
442 | QueryExpr::EventsBackfill(_)
443 | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
444 }
445 }
446
447 fn estimate_table(&self, query: &TableQuery) -> PlanCost {
452 let cardinality = self.estimate_table_cardinality(query);
453
454 let cpu = cardinality.rows * self.row_scan_cost;
455
456 let io = self.estimate_table_io(query, cardinality.rows);
459
460 let memory = cardinality.rows * 100.0; PlanCost::new(cpu, io, memory)
463 }
464
465 fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
472 const ROWS_PER_PAGE: f64 = 100.0;
473
474 let table_stats = self.stats.table_stats(&query.table);
476 let heap_pages = table_stats
477 .map(|s| s.page_count as f64)
478 .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
479
480 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
483 if let Some(col) = first_filter_column(&filter, &query.table) {
484 if let Some(idx) = self.stats.index_stats(&query.table, col) {
485 return idx.correlated_io_cost(result_rows, heap_pages);
486 }
487 }
488 }
489
490 (result_rows / ROWS_PER_PAGE).ceil()
492 }
493
494 fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
495 let base_rows = self
498 .stats
499 .table_stats(&query.table)
500 .map(|s| s.row_count as f64)
501 .unwrap_or(self.default_row_count);
502
503 let mut estimate = CardinalityEstimate::full_scan(base_rows);
504
505 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
508 let selectivity = self.filter_selectivity(&filter, &query.table);
509 estimate = estimate.with_filter(selectivity);
510 }
511
512 if let Some(limit) = query.limit {
514 estimate.rows = estimate.rows.min(limit as f64);
515 }
516
517 estimate
518 }
519
520 fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
533 match filter {
534 AstFilter::Compare { field, op, value } => {
535 let column = column_name_for_table(field, table);
536 match op {
537 CompareOp::Eq => self.eq_selectivity(table, column, value),
538 CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
539 CompareOp::Lt | CompareOp::Le => {
540 self.range_selectivity(table, column, None, Some(value))
541 }
542 CompareOp::Gt | CompareOp::Ge => {
543 self.range_selectivity(table, column, Some(value), None)
544 }
545 }
546 }
547 AstFilter::Between {
548 field, low, high, ..
549 } => {
550 let column = column_name_for_table(field, table);
551 self.range_selectivity(table, column, Some(low), Some(high))
552 }
553 AstFilter::In { field, values, .. } => {
554 let column = column_name_for_table(field, table);
555 if let Some(c) = column {
559 if let Some(mcv) = self.stats.column_mcv(table, c) {
560 let mut hits: f64 = 0.0;
561 let mut residual_count = 0usize;
562 for v in values {
563 if let Some(cv) = column_value_from(v) {
564 if let Some(freq) = mcv.frequency_of(&cv) {
565 hits += freq;
566 } else {
567 residual_count += 1;
568 }
569 } else {
570 residual_count += 1;
571 }
572 }
573 let total = mcv.total_frequency();
574 let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
575 let non_mcv_distinct =
576 distinct.saturating_sub(mcv.values.len() as u64).max(1);
577 let per_residual = (1.0 - total) / non_mcv_distinct as f64;
578 let estimate = hits + (residual_count as f64) * per_residual;
579 return estimate.clamp(0.0, 1.0).min(0.5);
580 }
581 if let Some(s) = self.stats.index_stats(table, c) {
582 return (s.point_selectivity() * values.len() as f64).min(0.5);
583 }
584 }
585 (values.len() as f64 * 0.01).min(0.5)
586 }
587 AstFilter::Like { .. } => 0.1,
588 AstFilter::StartsWith { .. } => 0.15,
589 AstFilter::EndsWith { .. } => 0.15,
590 AstFilter::Contains { .. } => 0.1,
591 AstFilter::IsNull { .. } => 0.01,
592 AstFilter::IsNotNull { .. } => 0.99,
593 AstFilter::And(left, right) => {
594 self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
595 }
596 AstFilter::Or(left, right) => {
597 let s1 = self.filter_selectivity(left, table);
598 let s2 = self.filter_selectivity(right, table);
599 s1 + s2 - (s1 * s2)
600 }
601 AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
602 AstFilter::CompareFields { .. } => {
603 0.1
607 }
608 AstFilter::CompareExpr { .. } => {
609 0.1
613 }
614 }
615 }
616
617 fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
622 let cardinality = self.estimate_graph_cardinality(query);
623
624 let nodes = query.pattern.nodes.len() as f64;
626 let edges = query.pattern.edges.len() as f64;
627
628 let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
629 let io = cardinality.rows * 0.1; let memory = cardinality.rows * 200.0; PlanCost::new(cpu, io, memory)
633 }
634
635 fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
636 let nodes = query.pattern.nodes.len() as f64;
637 let edges = query.pattern.edges.len() as f64;
638
639 let base_rows = self.default_row_count;
641 let edge_factor = 0.1_f64.powf(edges); let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
644 estimate.confidence = 0.5; if let Some(ref filter) = query.filter {
648 let selectivity = Self::estimate_filter_selectivity(filter);
649 estimate = estimate.with_filter(selectivity);
650 }
651
652 estimate
653 }
654
655 fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
660 let left_cost = self.estimate(&query.left);
661 let right_cost = self.estimate(&query.right);
662
663 let left_card = self.estimate_cardinality(&query.left);
664 let right_card = self.estimate_cardinality(&query.right);
665
666 let build_cpu = left_card.rows * self.hash_probe_cost;
673 let probe_cpu = right_card.rows * self.hash_probe_cost;
674 let join_memory = left_card.rows * 100.0; let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
678 let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
680
681 let after_build = left_cost.combine_blocking(&build_op);
683 after_build
684 .combine_pipelined(&right_cost)
685 .combine_pipelined(&probe_op)
686 }
687
688 fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
689 let left = self.estimate_cardinality(&query.left);
690 let right = self.estimate_cardinality(&query.right);
691
692 let selectivity = match query.join_type {
694 JoinType::Inner => 0.1, JoinType::LeftOuter => 1.0, JoinType::RightOuter => 1.0, JoinType::FullOuter => 1.0, JoinType::Cross => 1.0, };
700
701 CardinalityEstimate::new(
702 left.rows * right.rows * selectivity,
703 left.selectivity * right.selectivity * selectivity,
704 )
705 }
706
707 fn estimate_path(&self, query: &PathQuery) -> PlanCost {
712 let cardinality = self.estimate_path_cardinality(query);
713
714 let max_hops = query.max_length;
716 let branching_factor: f64 = 5.0; let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
719 let cpu = nodes_visited * self.edge_traversal_cost;
720 let io = nodes_visited * 0.1;
721 let memory = nodes_visited * 50.0; PlanCost::new(cpu, io, memory)
724 }
725
726 fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
727 let max_paths = 10.0;
729 CardinalityEstimate::new(max_paths, 0.001)
730 }
731
732 fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
737 let k = query.k as f64;
740
741 let hnsw_cost = 100.0 * (1.0 + k.ln()); let filter_cost =
748 if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
749 50.0
750 } else {
751 0.0
752 };
753
754 let cpu = hnsw_cost + filter_cost;
755 let io = 20.0; let memory = k * 32.0 + 1000.0; PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
763 }
764
765 fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
766 let k = query.k as f64;
768 CardinalityEstimate::new(k, 0.1)
769 }
770
771 fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
776 let structured_cost = self.estimate(&query.structured);
778 let vector_cost = self.estimate_vector(&query.vector);
779
780 let fusion_overhead = match &query.fusion {
782 crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
783 crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
784 crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
785 crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
786 crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
787 crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
788 };
789
790 PlanCost::new(
791 structured_cost.cpu + vector_cost.cpu + fusion_overhead,
792 structured_cost.io + vector_cost.io,
793 structured_cost.memory + vector_cost.memory,
794 )
795 }
796
797 fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
798 let structured_card = self.estimate_cardinality(&query.structured);
799 let vector_card = self.estimate_vector_cardinality(&query.vector);
800
801 let rows = match &query.fusion {
803 crate::storage::query::ast::FusionStrategy::Intersection => {
804 structured_card.rows.min(vector_card.rows)
805 }
806 crate::storage::query::ast::FusionStrategy::Union { .. } => {
807 structured_card.rows + vector_card.rows
808 }
809 _ => vector_card.rows, };
811
812 CardinalityEstimate::new(rows, 0.2)
813 }
814
815 fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
820 match filter {
821 AstFilter::Compare { op, .. } => {
822 match op {
823 CompareOp::Eq => 0.01, CompareOp::Ne => 0.99, CompareOp::Lt | CompareOp::Le => 0.3,
826 CompareOp::Gt | CompareOp::Ge => 0.3,
827 }
828 }
829 AstFilter::Between { .. } => 0.25,
830 AstFilter::In { values, .. } => {
831 (values.len() as f64 * 0.01).min(0.5)
833 }
834 AstFilter::Like { .. } => 0.1,
835 AstFilter::StartsWith { .. } => 0.15,
836 AstFilter::EndsWith { .. } => 0.15,
837 AstFilter::Contains { .. } => 0.1,
838 AstFilter::IsNull { .. } => 0.01,
839 AstFilter::IsNotNull { .. } => 0.99,
840 AstFilter::And(left, right) => {
841 Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
842 }
843 AstFilter::Or(left, right) => {
844 let s1 = Self::estimate_filter_selectivity(left);
845 let s2 = Self::estimate_filter_selectivity(right);
846 s1 + s2 - (s1 * s2) }
848 AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
849 AstFilter::CompareFields { .. } => 0.1,
850 AstFilter::CompareExpr { .. } => 0.1,
851 }
852 }
853}
854
855impl CostEstimator {
856 fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
864 if let Some(col) = column {
865 if let Some(mcv) = self.stats.column_mcv(table, col) {
867 if let Some(cv) = column_value_from(value) {
868 if let Some(freq) = mcv.frequency_of(&cv) {
869 return freq;
870 }
871 let total = mcv.total_frequency();
873 let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
874 let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
875 return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
876 }
877 }
878 if let Some(s) = self.stats.index_stats(table, col) {
880 return s.point_selectivity();
881 }
882 }
883 0.01
885 }
886
887 fn range_selectivity(
898 &self,
899 table: &str,
900 column: Option<&str>,
901 lo: Option<&Value>,
902 hi: Option<&Value>,
903 ) -> f64 {
904 if let Some(col) = column {
905 if let Some(h) = self.stats.column_histogram(table, col) {
907 let lo_cv = lo.and_then(column_value_from);
908 let hi_cv = hi.and_then(column_value_from);
909 return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
910 }
911 if let Some(s) = self.stats.index_stats(table, col) {
913 let cap = if lo.is_some() && hi.is_some() {
914 0.25
915 } else {
916 0.3
917 };
918 return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
919 }
920 }
921 if lo.is_some() && hi.is_some() {
923 0.25
924 } else {
925 0.3
926 }
927 }
928}
929
930impl Default for CostEstimator {
931 fn default() -> Self {
932 Self::new()
933 }
934}
935
936fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
941 use super::histogram::ColumnValue;
942 use crate::storage::schema::Value;
943 match v {
944 Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
945 Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
946 Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
947 Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
948 Value::Email(s)
949 | Value::Url(s)
950 | Value::NodeRef(s)
951 | Value::EdgeRef(s)
952 | Value::TableRef(s)
953 | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
954 Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
955 Value::Duration(d) => Some(ColumnValue::Int(*d)),
956 Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
957 Value::Decimal(d) => Some(ColumnValue::Int(*d)),
958 Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
959 Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
960 Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
961 Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
962 Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
963 Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
964 Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
965 Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
966 Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
967 _ => None,
972 }
973}
974
975fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
980 match filter {
981 AstFilter::Compare { field, .. } => column_name_for_table(field, table),
982 AstFilter::Between { field, .. } => column_name_for_table(field, table),
983 AstFilter::And(l, r) => {
984 first_filter_column(l, table).or_else(|| first_filter_column(r, table))
985 }
986 _ => None,
987 }
988}
989
990fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
992 match field {
993 FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
994 Some(column.as_str())
995 }
996 _ => None,
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::super::stats_provider::StaticProvider;
1004 use super::*;
1005 use crate::storage::index::{IndexKind, IndexStats};
1006 use crate::storage::query::ast::{FieldRef, Projection};
1007 use crate::storage::schema::Value;
1008
1009 fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1010 AstFilter::Compare {
1011 field: FieldRef::column(table, column),
1012 op: CompareOp::Eq,
1013 value: Value::Integer(value),
1014 }
1015 }
1016
1017 fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1018 TableQuery {
1019 table: name.to_string(),
1020 source: None,
1021 alias: None,
1022 select_items: Vec::new(),
1023 columns: vec![Projection::All],
1024 where_expr: None,
1025 filter,
1026 group_by_exprs: Vec::new(),
1027 group_by: Vec::new(),
1028 having_expr: None,
1029 having: None,
1030 order_by: vec![],
1031 limit: None,
1032 offset: None,
1033 expand: None,
1034 as_of: None,
1035 }
1036 }
1037
1038 #[test]
1039 fn injected_row_count_overrides_default() {
1040 let provider = Arc::new(StaticProvider::new().with_table(
1041 "users",
1042 TableStats {
1043 row_count: 50_000,
1044 avg_row_size: 256,
1045 page_count: 500,
1046 columns: vec![],
1047 },
1048 ));
1049 let estimator = CostEstimator::with_stats(provider);
1050 let q = table_query("users", None);
1051 let card = estimator.estimate_table_cardinality(&q);
1052 assert_eq!(card.rows, 50_000.0);
1054 }
1055
1056 #[test]
1057 fn stats_aware_eq_selectivity_beats_default() {
1058 let provider = Arc::new(
1059 StaticProvider::new()
1060 .with_table(
1061 "users",
1062 TableStats {
1063 row_count: 1_000_000,
1064 avg_row_size: 256,
1065 page_count: 10_000,
1066 columns: vec![],
1067 },
1068 )
1069 .with_index(
1070 "users",
1071 "email",
1072 IndexStats {
1073 entries: 1_000_000,
1074 distinct_keys: 1_000_000,
1075 approx_bytes: 0,
1076 kind: IndexKind::Hash,
1077 has_bloom: true,
1078 index_correlation: 0.0,
1079 },
1080 ),
1081 );
1082 let estimator = CostEstimator::with_stats(provider);
1083 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1084 let card = estimator.estimate_table_cardinality(&q);
1085 assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1087 }
1088
1089 #[test]
1090 fn fallback_when_no_index_stats() {
1091 let provider = Arc::new(StaticProvider::new().with_table(
1092 "users",
1093 TableStats {
1094 row_count: 1_000_000,
1095 avg_row_size: 256,
1096 page_count: 10_000,
1097 columns: vec![],
1098 },
1099 ));
1100 let estimator = CostEstimator::with_stats(provider);
1101 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1102 let card = estimator.estimate_table_cardinality(&q);
1103 assert!((card.rows - 10_000.0).abs() < 1.0);
1105 }
1106
1107 #[test]
1108 fn null_provider_keeps_legacy_behaviour() {
1109 let estimator = CostEstimator::new();
1110 let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1111 let card = estimator.estimate_table_cardinality(&q);
1112 assert!((card.rows - 10.0).abs() < 1.0);
1114 }
1115
1116 #[test]
1117 fn and_combines_stats_selectivities() {
1118 let provider = Arc::new(
1119 StaticProvider::new()
1120 .with_table(
1121 "t",
1122 TableStats {
1123 row_count: 100_000,
1124 avg_row_size: 64,
1125 page_count: 100,
1126 columns: vec![],
1127 },
1128 )
1129 .with_index(
1130 "t",
1131 "a",
1132 IndexStats {
1133 entries: 100_000,
1134 distinct_keys: 10,
1135 approx_bytes: 0,
1136 kind: IndexKind::BTree,
1137 has_bloom: false,
1138 index_correlation: 0.0,
1139 },
1140 )
1141 .with_index(
1142 "t",
1143 "b",
1144 IndexStats {
1145 entries: 100_000,
1146 distinct_keys: 1000,
1147 approx_bytes: 0,
1148 kind: IndexKind::BTree,
1149 has_bloom: false,
1150 index_correlation: 0.0,
1151 },
1152 ),
1153 );
1154 let estimator = CostEstimator::with_stats(provider);
1155 let filter = AstFilter::And(
1156 Box::new(eq_filter("t", "a", 1)),
1157 Box::new(eq_filter("t", "b", 1)),
1158 );
1159 let q = table_query("t", Some(filter));
1160 let card = estimator.estimate_table_cardinality(&q);
1161 assert!(card.rows < 15.0, "got {}", card.rows);
1163 }
1164
1165 #[test]
1166 fn test_table_cost_estimation() {
1167 let estimator = CostEstimator::new();
1168
1169 let query = QueryExpr::Table(TableQuery {
1170 table: "hosts".to_string(),
1171 source: None,
1172 alias: None,
1173 select_items: Vec::new(),
1174 columns: vec![Projection::All],
1175 where_expr: None,
1176 filter: None,
1177 group_by_exprs: Vec::new(),
1178 group_by: Vec::new(),
1179 having_expr: None,
1180 having: None,
1181 order_by: vec![],
1182 limit: None,
1183 offset: None,
1184 expand: None,
1185 as_of: None,
1186 });
1187
1188 let cost = estimator.estimate(&query);
1189 assert!(cost.cpu > 0.0);
1190 assert!(cost.total > 0.0);
1191 }
1192
1193 #[test]
1194 fn test_filter_selectivity() {
1195 let estimator = CostEstimator::new();
1196
1197 let eq_filter = AstFilter::Compare {
1198 field: FieldRef::column("hosts", "id"),
1199 op: CompareOp::Eq,
1200 value: Value::Integer(1),
1201 };
1202 assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1203
1204 let ne_filter = AstFilter::Compare {
1205 field: FieldRef::column("hosts", "id"),
1206 op: CompareOp::Ne,
1207 value: Value::Integer(1),
1208 };
1209 assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1210 }
1211
1212 #[test]
1213 fn test_and_selectivity() {
1214 let estimator = CostEstimator::new();
1215
1216 let and_filter = AstFilter::And(
1217 Box::new(AstFilter::Compare {
1218 field: FieldRef::column("hosts", "a"),
1219 op: CompareOp::Eq,
1220 value: Value::Integer(1),
1221 }),
1222 Box::new(AstFilter::Compare {
1223 field: FieldRef::column("hosts", "b"),
1224 op: CompareOp::Eq,
1225 value: Value::Integer(2),
1226 }),
1227 );
1228
1229 let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1230 assert!(selectivity < 0.01); }
1232
1233 #[test]
1234 fn test_cardinality_with_limit() {
1235 let estimator = CostEstimator::new();
1236
1237 let query = TableQuery {
1238 table: "hosts".to_string(),
1239 source: None,
1240 alias: None,
1241 select_items: Vec::new(),
1242 columns: vec![Projection::All],
1243 where_expr: None,
1244 filter: None,
1245 group_by_exprs: Vec::new(),
1246 group_by: Vec::new(),
1247 having_expr: None,
1248 having: None,
1249 order_by: vec![],
1250 limit: Some(10),
1251 offset: None,
1252 expand: None,
1253 as_of: None,
1254 };
1255
1256 let card = estimator.estimate_table_cardinality(&query);
1257 assert!(card.rows <= 10.0);
1258 }
1259
1260 #[test]
1265 fn startup_zero_for_full_scan() {
1266 let estimator = CostEstimator::new();
1270 let q = table_query("any_table", None);
1271 let cost = estimator.estimate(&QueryExpr::Table(q));
1272 assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1273 assert!(cost.total > 0.0);
1274 }
1275
1276 #[test]
1277 fn startup_nonzero_for_blocking_combine() {
1278 let input = PlanCost::new(100.0, 10.0, 50.0); let blocker = PlanCost::new(20.0, 0.0, 10.0); let composed = input.combine_blocking(&blocker);
1283 assert_eq!(composed.startup_cost, input.total);
1285 assert_eq!(composed.total, input.total + blocker.total);
1287 assert!(composed.startup_cost > 0.0);
1288 }
1289
1290 #[test]
1291 fn pipelined_combine_adds_startup_directly() {
1292 let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1293 let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1294 let composed = upstream.combine_pipelined(&downstream);
1295 assert_eq!(composed.startup_cost, 30.0 + 5.0);
1296 assert_eq!(composed.total, upstream.total + downstream.total);
1297 }
1298
1299 #[test]
1300 fn cost_prefers_low_startup_when_limit_small() {
1301 let fast_first = PlanCost {
1304 cpu: 100.0,
1305 io: 10.0,
1306 network: 0.0,
1307 memory: 50.0,
1308 startup_cost: 5.0,
1309 total: 200.0,
1310 };
1311 let slow_first = PlanCost {
1312 cpu: 100.0,
1313 io: 10.0,
1314 network: 0.0,
1315 memory: 50.0,
1316 startup_cost: 150.0,
1317 total: 200.0,
1318 };
1319 assert_eq!(
1321 fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1322 std::cmp::Ordering::Less
1323 );
1324 }
1325
1326 #[test]
1327 fn cost_prefers_low_total_when_no_limit() {
1328 let low_total = PlanCost {
1330 cpu: 50.0,
1331 io: 5.0,
1332 network: 0.0,
1333 memory: 0.0,
1334 startup_cost: 30.0,
1335 total: 100.0,
1336 };
1337 let high_total = PlanCost {
1338 cpu: 100.0,
1339 io: 10.0,
1340 network: 0.0,
1341 memory: 0.0,
1342 startup_cost: 5.0,
1343 total: 200.0,
1344 };
1345 assert_eq!(
1346 low_total.prefer_over(&high_total, None, 10_000.0),
1347 std::cmp::Ordering::Less
1348 );
1349 }
1350
1351 #[test]
1352 fn limit_threshold_falls_back_to_total_when_limit_large() {
1353 let low_total = PlanCost {
1355 cpu: 50.0,
1356 io: 5.0,
1357 network: 0.0,
1358 memory: 0.0,
1359 startup_cost: 30.0,
1360 total: 100.0,
1361 };
1362 let low_startup = PlanCost {
1363 cpu: 100.0,
1364 io: 10.0,
1365 network: 0.0,
1366 memory: 0.0,
1367 startup_cost: 5.0,
1368 total: 200.0,
1369 };
1370 assert_eq!(
1371 low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1372 std::cmp::Ordering::Less
1373 );
1374 }
1375
1376 #[test]
1377 fn hash_join_startup_includes_build_cost() {
1378 let left = PlanCost::new(80.0, 8.0, 100.0); let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); let after_build = left.combine_blocking(&build);
1384 assert!(
1385 after_build.startup_cost >= left.total,
1386 "after-build startup ({}) must absorb left.total ({})",
1387 after_build.startup_cost,
1388 left.total
1389 );
1390 assert!(after_build.total >= after_build.startup_cost);
1391 }
1392
1393 #[test]
1394 fn vector_search_reports_nonzero_startup() {
1395 let estimator = CostEstimator::new();
1399 let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1402 assert!(v.startup_cost > 0.0);
1403 assert!(v.startup_cost < v.total);
1404 let _ = estimator; }
1406
1407 #[test]
1408 fn with_startup_clamps_total_below_startup() {
1409 let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1411 assert!(cost.total >= cost.startup_cost);
1412 }
1413
1414 #[test]
1415 fn default_plancost_has_zero_startup() {
1416 let c = PlanCost::default();
1417 assert_eq!(c.startup_cost, 0.0);
1418 assert_eq!(c.total, 0.0);
1419 }
1420
1421 use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1426
1427 fn provider_with_skew() -> Arc<StaticProvider> {
1428 let mut sample: Vec<ColumnValue> = Vec::new();
1432 for i in 0..80 {
1433 sample.push(ColumnValue::Int(i % 10));
1434 }
1435 for i in 0..20 {
1436 sample.push(ColumnValue::Int(10 + i * 50));
1437 }
1438 let h = Histogram::equi_depth_from_sample(sample, 10);
1439
1440 let mcv = MostCommonValues::new(vec![
1441 (ColumnValue::Text("boss".to_string()), 0.5),
1442 (ColumnValue::Text("intern".to_string()), 0.05),
1443 ]);
1444
1445 Arc::new(
1446 StaticProvider::new()
1447 .with_table(
1448 "people",
1449 TableStats {
1450 row_count: 100_000,
1451 avg_row_size: 64,
1452 page_count: 100,
1453 columns: vec![],
1454 },
1455 )
1456 .with_histogram("people", "score", h)
1457 .with_mcv("people", "role", mcv),
1458 )
1459 }
1460
1461 #[test]
1462 fn eq_uses_mcv_when_value_is_tracked() {
1463 let provider = provider_with_skew();
1464 let estimator = CostEstimator::with_stats(provider);
1465 let filter = AstFilter::Compare {
1466 field: FieldRef::column("people", "role"),
1467 op: CompareOp::Eq,
1468 value: Value::text("boss".to_string()),
1469 };
1470 let s = estimator.filter_selectivity(&filter, "people");
1473 assert!(
1474 (s - 0.5).abs() < 1e-9,
1475 "MCV-tracked equality should report exact frequency, got {s}"
1476 );
1477 }
1478
1479 #[test]
1480 fn eq_uses_residual_for_non_mcv_value() {
1481 let provider = provider_with_skew();
1482 let estimator = CostEstimator::with_stats(provider);
1483 let filter = AstFilter::Compare {
1484 field: FieldRef::column("people", "role"),
1485 op: CompareOp::Eq,
1486 value: Value::text("staff".to_string()),
1487 };
1488 let s = estimator.filter_selectivity(&filter, "people");
1492 assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1493 }
1494
1495 #[test]
1496 fn ne_is_one_minus_eq_under_mcv() {
1497 let provider = provider_with_skew();
1498 let estimator = CostEstimator::with_stats(provider);
1499 let filter = AstFilter::Compare {
1500 field: FieldRef::column("people", "role"),
1501 op: CompareOp::Ne,
1502 value: Value::text("boss".to_string()),
1503 };
1504 let s = estimator.filter_selectivity(&filter, "people");
1505 assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1507 }
1508
1509 #[test]
1510 fn range_uses_histogram_when_present() {
1511 let provider = provider_with_skew();
1512 let estimator = CostEstimator::with_stats(provider);
1513 let filter = AstFilter::Compare {
1514 field: FieldRef::column("people", "score"),
1515 op: CompareOp::Le,
1516 value: Value::Integer(9),
1517 };
1518 let s = estimator.filter_selectivity(&filter, "people");
1521 assert!(
1522 s > 0.5,
1523 "histogram-based range selectivity should beat 0.3, got {s}"
1524 );
1525 }
1526
1527 #[test]
1528 fn between_uses_histogram() {
1529 let provider = provider_with_skew();
1530 let estimator = CostEstimator::with_stats(provider);
1531 let filter = AstFilter::Between {
1532 field: FieldRef::column("people", "score"),
1533 low: Value::Integer(0),
1534 high: Value::Integer(9),
1535 };
1536 let s = estimator.filter_selectivity(&filter, "people");
1537 assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1538 }
1539
1540 #[test]
1541 fn graceful_fallback_when_histogram_absent() {
1542 let provider = Arc::new(StaticProvider::new().with_table(
1545 "people",
1546 TableStats {
1547 row_count: 1000,
1548 avg_row_size: 64,
1549 page_count: 10,
1550 columns: vec![],
1551 },
1552 ));
1553 let estimator = CostEstimator::with_stats(provider);
1554 let filter = AstFilter::Compare {
1555 field: FieldRef::column("people", "unknown_col"),
1556 op: CompareOp::Lt,
1557 value: Value::Integer(50),
1558 };
1559 let s = estimator.filter_selectivity(&filter, "people");
1560 assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1561 }
1562}