1use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16 CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17 PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24 pub rows: f64,
26 pub selectivity: f64,
28 pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33 pub fn new(rows: f64, selectivity: f64) -> Self {
35 Self {
36 rows,
37 selectivity,
38 confidence: 1.0,
39 }
40 }
41
42 pub fn full_scan(table_size: f64) -> Self {
44 Self {
45 rows: table_size,
46 selectivity: 1.0,
47 confidence: 1.0,
48 }
49 }
50
51 pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53 self.rows *= filter_selectivity;
54 self.selectivity *= filter_selectivity;
55 self.confidence *= 0.9; self
57 }
58}
59
60#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71 pub cpu: f64,
73 pub io: f64,
75 pub network: f64,
77 pub memory: f64,
79 pub startup_cost: f64,
85 pub total: f64,
87}
88
89impl PlanCost {
90 pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92 let total = cpu + io * 10.0 + memory * 0.1; Self {
94 cpu,
95 io,
96 network: 0.0,
97 memory,
98 startup_cost: 0.0,
99 total,
100 }
101 }
102
103 pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107 let total = cpu + io * 10.0 + memory * 0.1;
108 Self {
109 cpu,
110 io,
111 network: 0.0,
112 memory,
113 startup_cost: startup_cost.max(0.0),
114 total: total.max(startup_cost),
115 }
116 }
117
118 pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124 PlanCost {
125 cpu: self.cpu + other.cpu,
126 io: self.io + other.io,
127 network: self.network + other.network,
128 memory: self.memory.max(other.memory),
129 startup_cost: self.startup_cost + other.startup_cost,
130 total: self.total + other.total,
131 }
132 }
133
134 pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141 PlanCost {
142 cpu: self.cpu + blocker.cpu,
143 io: self.io + blocker.io,
144 network: self.network + blocker.network,
145 memory: self.memory.max(blocker.memory),
146 startup_cost: self.total + blocker.startup_cost,
147 total: self.total + blocker.total,
148 }
149 }
150
151 pub fn combine(&self, other: &PlanCost) -> PlanCost {
156 self.combine_pipelined(other)
157 }
158
159 pub fn scale(&self, factor: f64) -> PlanCost {
161 PlanCost {
162 cpu: self.cpu * factor,
163 io: self.io * factor,
164 network: self.network * factor,
165 memory: self.memory, startup_cost: self.startup_cost, total: self.total * factor,
168 }
169 }
170
171 pub fn prefer_over(
183 &self,
184 other: &PlanCost,
185 limit: Option<u64>,
186 cardinality: f64,
187 ) -> std::cmp::Ordering {
188 let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189 let (lhs, rhs) = if use_startup {
190 (self.startup_cost, other.startup_cost)
191 } else {
192 (self.total, other.total)
193 };
194 lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195 }
196}
197
198#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201 pub row_count: u64,
203 pub avg_row_size: u32,
205 pub page_count: u64,
207 pub columns: Vec<ColumnStats>,
209}
210
211#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214 pub name: String,
216 pub distinct_count: u64,
218 pub null_count: u64,
220 pub min_value: Option<String>,
222 pub max_value: Option<String>,
224 pub has_index: bool,
226}
227
228pub struct CostEstimator {
230 default_row_count: f64,
232 row_scan_cost: f64,
234 index_lookup_cost: f64,
236 hash_probe_cost: f64,
238 nested_loop_cost: f64,
240 edge_traversal_cost: f64,
242 stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250 pub fn new() -> Self {
253 Self {
254 default_row_count: 1000.0,
255 row_scan_cost: 1.0,
256 index_lookup_cost: 0.1,
257 hash_probe_cost: 0.5,
258 nested_loop_cost: 2.0,
259 edge_traversal_cost: 1.5,
260 stats: Arc::new(NullProvider),
261 }
262 }
263
264 pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268 Self {
269 stats: provider,
270 ..Self::new()
271 }
272 }
273
274 pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278 self.stats = provider;
279 }
280
281 pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283 match query {
284 QueryExpr::Table(tq) => self.estimate_table(tq),
285 QueryExpr::Graph(gq) => self.estimate_graph(gq),
286 QueryExpr::Join(jq) => self.estimate_join(jq),
287 QueryExpr::Path(pq) => self.estimate_path(pq),
288 QueryExpr::Vector(vq) => self.estimate_vector(vq),
289 QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290 QueryExpr::Insert(_)
292 | QueryExpr::Update(_)
293 | QueryExpr::Delete(_)
294 | QueryExpr::CreateTable(_)
295 | QueryExpr::DropTable(_)
296 | QueryExpr::DropGraph(_)
297 | QueryExpr::DropVector(_)
298 | QueryExpr::DropDocument(_)
299 | QueryExpr::DropKv(_)
300 | QueryExpr::DropCollection(_)
301 | QueryExpr::Truncate(_)
302 | QueryExpr::AlterTable(_)
303 | QueryExpr::GraphCommand(_)
304 | QueryExpr::SearchCommand(_)
305 | QueryExpr::CreateIndex(_)
306 | QueryExpr::DropIndex(_)
307 | QueryExpr::ProbabilisticCommand(_)
308 | QueryExpr::Ask(_)
309 | QueryExpr::SetConfig { .. }
310 | QueryExpr::ShowConfig { .. }
311 | QueryExpr::SetSecret { .. }
312 | QueryExpr::DeleteSecret { .. }
313 | QueryExpr::ShowSecrets { .. }
314 | QueryExpr::SetTenant(_)
315 | QueryExpr::ShowTenant
316 | QueryExpr::CreateTimeSeries(_)
317 | QueryExpr::DropTimeSeries(_)
318 | QueryExpr::CreateQueue(_)
319 | QueryExpr::AlterQueue(_)
320 | QueryExpr::DropQueue(_)
321 | QueryExpr::QueueSelect(_)
322 | QueryExpr::QueueCommand(_)
323 | QueryExpr::KvCommand(_)
324 | QueryExpr::ConfigCommand(_)
325 | QueryExpr::CreateTree(_)
326 | QueryExpr::DropTree(_)
327 | QueryExpr::TreeCommand(_)
328 | QueryExpr::ExplainAlter(_)
329 | QueryExpr::TransactionControl(_)
330 | QueryExpr::MaintenanceCommand(_)
331 | QueryExpr::CreateSchema(_)
332 | QueryExpr::DropSchema(_)
333 | QueryExpr::CreateSequence(_)
334 | QueryExpr::DropSequence(_)
335 | QueryExpr::CopyFrom(_)
336 | QueryExpr::CreateView(_)
337 | QueryExpr::DropView(_)
338 | QueryExpr::RefreshMaterializedView(_)
339 | QueryExpr::CreatePolicy(_)
340 | QueryExpr::DropPolicy(_)
341 | QueryExpr::CreateServer(_)
342 | QueryExpr::DropServer(_)
343 | QueryExpr::CreateForeignTable(_)
344 | QueryExpr::DropForeignTable(_)
345 | QueryExpr::Grant(_)
346 | QueryExpr::Revoke(_)
347 | QueryExpr::AlterUser(_)
348 | QueryExpr::CreateIamPolicy { .. }
349 | QueryExpr::DropIamPolicy { .. }
350 | QueryExpr::AttachPolicy { .. }
351 | QueryExpr::DetachPolicy { .. }
352 | QueryExpr::ShowPolicies { .. }
353 | QueryExpr::ShowEffectivePermissions { .. }
354 | QueryExpr::SimulatePolicy { .. }
355 | QueryExpr::CreateMigration(_)
356 | QueryExpr::ApplyMigration(_)
357 | QueryExpr::RollbackMigration(_)
358 | QueryExpr::ExplainMigration(_)
359 | QueryExpr::EventsBackfill(_)
360 | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
361 }
362 }
363
364 pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
366 match query {
367 QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
368 QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
369 QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
370 QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
371 QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
372 QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
373 QueryExpr::Insert(_)
375 | QueryExpr::Update(_)
376 | QueryExpr::Delete(_)
377 | QueryExpr::CreateTable(_)
378 | QueryExpr::DropTable(_)
379 | QueryExpr::DropGraph(_)
380 | QueryExpr::DropVector(_)
381 | QueryExpr::DropDocument(_)
382 | QueryExpr::DropKv(_)
383 | QueryExpr::DropCollection(_)
384 | QueryExpr::Truncate(_)
385 | QueryExpr::AlterTable(_)
386 | QueryExpr::GraphCommand(_)
387 | QueryExpr::SearchCommand(_)
388 | QueryExpr::CreateIndex(_)
389 | QueryExpr::DropIndex(_)
390 | QueryExpr::ProbabilisticCommand(_)
391 | QueryExpr::Ask(_)
392 | QueryExpr::SetConfig { .. }
393 | QueryExpr::ShowConfig { .. }
394 | QueryExpr::SetSecret { .. }
395 | QueryExpr::DeleteSecret { .. }
396 | QueryExpr::ShowSecrets { .. }
397 | QueryExpr::SetTenant(_)
398 | QueryExpr::ShowTenant
399 | QueryExpr::CreateTimeSeries(_)
400 | QueryExpr::DropTimeSeries(_)
401 | QueryExpr::CreateQueue(_)
402 | QueryExpr::AlterQueue(_)
403 | QueryExpr::DropQueue(_)
404 | QueryExpr::QueueSelect(_)
405 | QueryExpr::QueueCommand(_)
406 | QueryExpr::KvCommand(_)
407 | QueryExpr::ConfigCommand(_)
408 | QueryExpr::CreateTree(_)
409 | QueryExpr::DropTree(_)
410 | QueryExpr::TreeCommand(_)
411 | QueryExpr::ExplainAlter(_)
412 | QueryExpr::TransactionControl(_)
413 | QueryExpr::MaintenanceCommand(_)
414 | QueryExpr::CreateSchema(_)
415 | QueryExpr::DropSchema(_)
416 | QueryExpr::CreateSequence(_)
417 | QueryExpr::DropSequence(_)
418 | QueryExpr::CopyFrom(_)
419 | QueryExpr::CreateView(_)
420 | QueryExpr::DropView(_)
421 | QueryExpr::RefreshMaterializedView(_)
422 | QueryExpr::CreatePolicy(_)
423 | QueryExpr::DropPolicy(_)
424 | QueryExpr::CreateServer(_)
425 | QueryExpr::DropServer(_)
426 | QueryExpr::CreateForeignTable(_)
427 | QueryExpr::DropForeignTable(_)
428 | QueryExpr::Grant(_)
429 | QueryExpr::Revoke(_)
430 | QueryExpr::AlterUser(_)
431 | QueryExpr::CreateIamPolicy { .. }
432 | QueryExpr::DropIamPolicy { .. }
433 | QueryExpr::AttachPolicy { .. }
434 | QueryExpr::DetachPolicy { .. }
435 | QueryExpr::ShowPolicies { .. }
436 | QueryExpr::ShowEffectivePermissions { .. }
437 | QueryExpr::SimulatePolicy { .. }
438 | QueryExpr::CreateMigration(_)
439 | QueryExpr::ApplyMigration(_)
440 | QueryExpr::RollbackMigration(_)
441 | QueryExpr::ExplainMigration(_)
442 | QueryExpr::EventsBackfill(_)
443 | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
444 }
445 }
446
447 fn estimate_table(&self, query: &TableQuery) -> PlanCost {
452 let cardinality = self.estimate_table_cardinality(query);
453
454 let cpu = cardinality.rows * self.row_scan_cost;
455
456 let io = self.estimate_table_io(query, cardinality.rows);
459
460 let memory = cardinality.rows * 100.0; PlanCost::new(cpu, io, memory)
463 }
464
465 fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
472 const ROWS_PER_PAGE: f64 = 100.0;
473
474 let table_stats = self.stats.table_stats(&query.table);
476 let heap_pages = table_stats
477 .map(|s| s.page_count as f64)
478 .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
479
480 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
483 if let Some(col) = first_filter_column(&filter, &query.table) {
484 if let Some(idx) = self.stats.index_stats(&query.table, col) {
485 return idx.correlated_io_cost(result_rows, heap_pages);
486 }
487 }
488 }
489
490 (result_rows / ROWS_PER_PAGE).ceil()
492 }
493
494 fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
495 let base_rows = self
498 .stats
499 .table_stats(&query.table)
500 .map(|s| s.row_count as f64)
501 .unwrap_or(self.default_row_count);
502
503 let mut estimate = CardinalityEstimate::full_scan(base_rows);
504
505 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
508 let selectivity = self.filter_selectivity(&filter, &query.table);
509 estimate = estimate.with_filter(selectivity);
510 }
511
512 if let Some(limit) = query.limit {
514 estimate.rows = estimate.rows.min(limit as f64);
515 }
516
517 estimate
518 }
519
520 fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
533 match filter {
534 AstFilter::Compare { field, op, value } => {
535 let column = column_name_for_table(field, table);
536 match op {
537 CompareOp::Eq => self.eq_selectivity(table, column, value),
538 CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
539 CompareOp::Lt | CompareOp::Le => {
540 self.range_selectivity(table, column, None, Some(value))
541 }
542 CompareOp::Gt | CompareOp::Ge => {
543 self.range_selectivity(table, column, Some(value), None)
544 }
545 }
546 }
547 AstFilter::Between {
548 field, low, high, ..
549 } => {
550 let column = column_name_for_table(field, table);
551 self.range_selectivity(table, column, Some(low), Some(high))
552 }
553 AstFilter::In { field, values, .. } => {
554 let column = column_name_for_table(field, table);
555 if let Some(c) = column {
559 if let Some(mcv) = self.stats.column_mcv(table, c) {
560 let mut hits: f64 = 0.0;
561 let mut residual_count = 0usize;
562 for v in values {
563 if let Some(cv) = column_value_from(v) {
564 if let Some(freq) = mcv.frequency_of(&cv) {
565 hits += freq;
566 } else {
567 residual_count += 1;
568 }
569 } else {
570 residual_count += 1;
571 }
572 }
573 let total = mcv.total_frequency();
574 let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
575 let non_mcv_distinct =
576 distinct.saturating_sub(mcv.values.len() as u64).max(1);
577 let per_residual = (1.0 - total) / non_mcv_distinct as f64;
578 let estimate = hits + (residual_count as f64) * per_residual;
579 return estimate.clamp(0.0, 1.0).min(0.5);
580 }
581 if let Some(s) = self.stats.index_stats(table, c) {
582 return (s.point_selectivity() * values.len() as f64).min(0.5);
583 }
584 }
585 (values.len() as f64 * 0.01).min(0.5)
586 }
587 AstFilter::Like { .. } => 0.1,
588 AstFilter::StartsWith { .. } => 0.15,
589 AstFilter::EndsWith { .. } => 0.15,
590 AstFilter::Contains { .. } => 0.1,
591 AstFilter::IsNull { .. } => 0.01,
592 AstFilter::IsNotNull { .. } => 0.99,
593 AstFilter::And(left, right) => {
594 self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
595 }
596 AstFilter::Or(left, right) => {
597 let s1 = self.filter_selectivity(left, table);
598 let s2 = self.filter_selectivity(right, table);
599 s1 + s2 - (s1 * s2)
600 }
601 AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
602 AstFilter::CompareFields { .. } => {
603 0.1
607 }
608 AstFilter::CompareExpr { .. } => {
609 0.1
613 }
614 }
615 }
616
617 fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
622 let cardinality = self.estimate_graph_cardinality(query);
623
624 let nodes = query.pattern.nodes.len() as f64;
626 let edges = query.pattern.edges.len() as f64;
627
628 let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
629 let io = cardinality.rows * 0.1; let memory = cardinality.rows * 200.0; PlanCost::new(cpu, io, memory)
633 }
634
635 fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
636 let nodes = query.pattern.nodes.len() as f64;
637 let edges = query.pattern.edges.len() as f64;
638
639 let base_rows = self.default_row_count;
641 let edge_factor = 0.1_f64.powf(edges); let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
644 estimate.confidence = 0.5; if let Some(ref filter) = query.filter {
648 let selectivity = Self::estimate_filter_selectivity(filter);
649 estimate = estimate.with_filter(selectivity);
650 }
651
652 estimate
653 }
654
655 fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
660 let left_cost = self.estimate(&query.left);
661 let right_cost = self.estimate(&query.right);
662
663 let left_card = self.estimate_cardinality(&query.left);
664 let right_card = self.estimate_cardinality(&query.right);
665
666 let build_cpu = left_card.rows * self.hash_probe_cost;
673 let probe_cpu = right_card.rows * self.hash_probe_cost;
674 let join_memory = left_card.rows * 100.0; let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
678 let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
680
681 let after_build = left_cost.combine_blocking(&build_op);
683 after_build
684 .combine_pipelined(&right_cost)
685 .combine_pipelined(&probe_op)
686 }
687
688 fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
689 let left = self.estimate_cardinality(&query.left);
690 let right = self.estimate_cardinality(&query.right);
691
692 let selectivity = match query.join_type {
694 JoinType::Inner => 0.1, JoinType::LeftOuter => 1.0, JoinType::RightOuter => 1.0, JoinType::FullOuter => 1.0, JoinType::Cross => 1.0, };
700
701 CardinalityEstimate::new(
702 left.rows * right.rows * selectivity,
703 left.selectivity * right.selectivity * selectivity,
704 )
705 }
706
707 fn estimate_path(&self, query: &PathQuery) -> PlanCost {
712 let cardinality = self.estimate_path_cardinality(query);
713
714 let max_hops = query.max_length;
716 let branching_factor: f64 = 5.0; let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
719 let cpu = nodes_visited * self.edge_traversal_cost;
720 let io = nodes_visited * 0.1;
721 let memory = nodes_visited * 50.0; PlanCost::new(cpu, io, memory)
724 }
725
726 fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
727 let max_paths = 10.0;
729 CardinalityEstimate::new(max_paths, 0.001)
730 }
731
732 fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
737 let k = query.k as f64;
740
741 let hnsw_cost = 100.0 * (1.0 + k.ln()); let filter_cost =
748 if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
749 50.0
750 } else {
751 0.0
752 };
753
754 let cpu = hnsw_cost + filter_cost;
755 let io = 20.0; let memory = k * 32.0 + 1000.0; PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
763 }
764
765 fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
766 let k = query.k as f64;
768 CardinalityEstimate::new(k, 0.1)
769 }
770
771 fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
776 let structured_cost = self.estimate(&query.structured);
778 let vector_cost = self.estimate_vector(&query.vector);
779
780 let fusion_overhead = match &query.fusion {
782 crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
783 crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
784 crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
785 crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
786 crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
787 crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
788 };
789
790 PlanCost::new(
791 structured_cost.cpu + vector_cost.cpu + fusion_overhead,
792 structured_cost.io + vector_cost.io,
793 structured_cost.memory + vector_cost.memory,
794 )
795 }
796
797 fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
798 let structured_card = self.estimate_cardinality(&query.structured);
799 let vector_card = self.estimate_vector_cardinality(&query.vector);
800
801 let rows = match &query.fusion {
803 crate::storage::query::ast::FusionStrategy::Intersection => {
804 structured_card.rows.min(vector_card.rows)
805 }
806 crate::storage::query::ast::FusionStrategy::Union { .. } => {
807 structured_card.rows + vector_card.rows
808 }
809 _ => vector_card.rows, };
811
812 CardinalityEstimate::new(rows, 0.2)
813 }
814
815 fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
820 match filter {
821 AstFilter::Compare { op, .. } => {
822 match op {
823 CompareOp::Eq => 0.01, CompareOp::Ne => 0.99, CompareOp::Lt | CompareOp::Le => 0.3,
826 CompareOp::Gt | CompareOp::Ge => 0.3,
827 }
828 }
829 AstFilter::Between { .. } => 0.25,
830 AstFilter::In { values, .. } => {
831 (values.len() as f64 * 0.01).min(0.5)
833 }
834 AstFilter::Like { .. } => 0.1,
835 AstFilter::StartsWith { .. } => 0.15,
836 AstFilter::EndsWith { .. } => 0.15,
837 AstFilter::Contains { .. } => 0.1,
838 AstFilter::IsNull { .. } => 0.01,
839 AstFilter::IsNotNull { .. } => 0.99,
840 AstFilter::And(left, right) => {
841 Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
842 }
843 AstFilter::Or(left, right) => {
844 let s1 = Self::estimate_filter_selectivity(left);
845 let s2 = Self::estimate_filter_selectivity(right);
846 s1 + s2 - (s1 * s2) }
848 AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
849 AstFilter::CompareFields { .. } => 0.1,
850 AstFilter::CompareExpr { .. } => 0.1,
851 }
852 }
853}
854
855impl CostEstimator {
856 fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
864 if let Some(col) = column {
865 if let Some(mcv) = self.stats.column_mcv(table, col) {
867 if let Some(cv) = column_value_from(value) {
868 if let Some(freq) = mcv.frequency_of(&cv) {
869 return freq;
870 }
871 let total = mcv.total_frequency();
873 let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
874 let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
875 return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
876 }
877 }
878 if let Some(s) = self.stats.index_stats(table, col) {
880 return s.point_selectivity();
881 }
882 }
883 0.01
885 }
886
887 fn range_selectivity(
898 &self,
899 table: &str,
900 column: Option<&str>,
901 lo: Option<&Value>,
902 hi: Option<&Value>,
903 ) -> f64 {
904 if let Some(col) = column {
905 if let Some(h) = self.stats.column_histogram(table, col) {
907 let lo_cv = lo.and_then(column_value_from);
908 let hi_cv = hi.and_then(column_value_from);
909 return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
910 }
911 if let Some(s) = self.stats.index_stats(table, col) {
913 let cap = if lo.is_some() && hi.is_some() {
914 0.25
915 } else {
916 0.3
917 };
918 return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
919 }
920 }
921 if lo.is_some() && hi.is_some() {
923 0.25
924 } else {
925 0.3
926 }
927 }
928}
929
930impl Default for CostEstimator {
931 fn default() -> Self {
932 Self::new()
933 }
934}
935
936fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
941 use super::histogram::ColumnValue;
942 use crate::storage::schema::Value;
943 match v {
944 Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
945 Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
946 Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
947 Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
948 Value::Email(s)
949 | Value::Url(s)
950 | Value::NodeRef(s)
951 | Value::EdgeRef(s)
952 | Value::TableRef(s)
953 | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
954 Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
955 Value::Duration(d) => Some(ColumnValue::Int(*d)),
956 Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
957 Value::Decimal(d) => Some(ColumnValue::Int(*d)),
958 Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
959 Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
960 Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
961 Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
962 Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
963 Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
964 Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
965 Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
966 Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
967 _ => None,
972 }
973}
974
975fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
980 match filter {
981 AstFilter::Compare { field, .. } => column_name_for_table(field, table),
982 AstFilter::Between { field, .. } => column_name_for_table(field, table),
983 AstFilter::And(l, r) => {
984 first_filter_column(l, table).or_else(|| first_filter_column(r, table))
985 }
986 _ => None,
987 }
988}
989
990fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
992 match field {
993 FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
994 Some(column.as_str())
995 }
996 _ => None,
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::super::stats_provider::StaticProvider;
1004 use super::*;
1005 use crate::storage::index::{IndexKind, IndexStats};
1006 use crate::storage::query::ast::{FieldRef, Projection};
1007 use crate::storage::schema::Value;
1008
1009 fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1010 AstFilter::Compare {
1011 field: FieldRef::column(table, column),
1012 op: CompareOp::Eq,
1013 value: Value::Integer(value),
1014 }
1015 }
1016
1017 fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1018 TableQuery {
1019 table: name.to_string(),
1020 source: None,
1021 alias: None,
1022 select_items: Vec::new(),
1023 columns: vec![Projection::All],
1024 where_expr: None,
1025 filter,
1026 group_by_exprs: Vec::new(),
1027 group_by: Vec::new(),
1028 having_expr: None,
1029 having: None,
1030 order_by: vec![],
1031 limit: None,
1032 limit_param: None,
1033 offset: None,
1034 offset_param: None,
1035 expand: None,
1036 as_of: None,
1037 }
1038 }
1039
1040 #[test]
1041 fn injected_row_count_overrides_default() {
1042 let provider = Arc::new(StaticProvider::new().with_table(
1043 "users",
1044 TableStats {
1045 row_count: 50_000,
1046 avg_row_size: 256,
1047 page_count: 500,
1048 columns: vec![],
1049 },
1050 ));
1051 let estimator = CostEstimator::with_stats(provider);
1052 let q = table_query("users", None);
1053 let card = estimator.estimate_table_cardinality(&q);
1054 assert_eq!(card.rows, 50_000.0);
1056 }
1057
1058 #[test]
1059 fn stats_aware_eq_selectivity_beats_default() {
1060 let provider = Arc::new(
1061 StaticProvider::new()
1062 .with_table(
1063 "users",
1064 TableStats {
1065 row_count: 1_000_000,
1066 avg_row_size: 256,
1067 page_count: 10_000,
1068 columns: vec![],
1069 },
1070 )
1071 .with_index(
1072 "users",
1073 "email",
1074 IndexStats {
1075 entries: 1_000_000,
1076 distinct_keys: 1_000_000,
1077 approx_bytes: 0,
1078 kind: IndexKind::Hash,
1079 has_bloom: true,
1080 index_correlation: 0.0,
1081 },
1082 ),
1083 );
1084 let estimator = CostEstimator::with_stats(provider);
1085 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1086 let card = estimator.estimate_table_cardinality(&q);
1087 assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1089 }
1090
1091 #[test]
1092 fn fallback_when_no_index_stats() {
1093 let provider = Arc::new(StaticProvider::new().with_table(
1094 "users",
1095 TableStats {
1096 row_count: 1_000_000,
1097 avg_row_size: 256,
1098 page_count: 10_000,
1099 columns: vec![],
1100 },
1101 ));
1102 let estimator = CostEstimator::with_stats(provider);
1103 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1104 let card = estimator.estimate_table_cardinality(&q);
1105 assert!((card.rows - 10_000.0).abs() < 1.0);
1107 }
1108
1109 #[test]
1110 fn null_provider_keeps_legacy_behaviour() {
1111 let estimator = CostEstimator::new();
1112 let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1113 let card = estimator.estimate_table_cardinality(&q);
1114 assert!((card.rows - 10.0).abs() < 1.0);
1116 }
1117
1118 #[test]
1119 fn and_combines_stats_selectivities() {
1120 let provider = Arc::new(
1121 StaticProvider::new()
1122 .with_table(
1123 "t",
1124 TableStats {
1125 row_count: 100_000,
1126 avg_row_size: 64,
1127 page_count: 100,
1128 columns: vec![],
1129 },
1130 )
1131 .with_index(
1132 "t",
1133 "a",
1134 IndexStats {
1135 entries: 100_000,
1136 distinct_keys: 10,
1137 approx_bytes: 0,
1138 kind: IndexKind::BTree,
1139 has_bloom: false,
1140 index_correlation: 0.0,
1141 },
1142 )
1143 .with_index(
1144 "t",
1145 "b",
1146 IndexStats {
1147 entries: 100_000,
1148 distinct_keys: 1000,
1149 approx_bytes: 0,
1150 kind: IndexKind::BTree,
1151 has_bloom: false,
1152 index_correlation: 0.0,
1153 },
1154 ),
1155 );
1156 let estimator = CostEstimator::with_stats(provider);
1157 let filter = AstFilter::And(
1158 Box::new(eq_filter("t", "a", 1)),
1159 Box::new(eq_filter("t", "b", 1)),
1160 );
1161 let q = table_query("t", Some(filter));
1162 let card = estimator.estimate_table_cardinality(&q);
1163 assert!(card.rows < 15.0, "got {}", card.rows);
1165 }
1166
1167 #[test]
1168 fn test_table_cost_estimation() {
1169 let estimator = CostEstimator::new();
1170
1171 let query = QueryExpr::Table(TableQuery {
1172 table: "hosts".to_string(),
1173 source: None,
1174 alias: None,
1175 select_items: Vec::new(),
1176 columns: vec![Projection::All],
1177 where_expr: None,
1178 filter: None,
1179 group_by_exprs: Vec::new(),
1180 group_by: Vec::new(),
1181 having_expr: None,
1182 having: None,
1183 order_by: vec![],
1184 limit: None,
1185 limit_param: None,
1186 offset: None,
1187 offset_param: None,
1188 expand: None,
1189 as_of: None,
1190 });
1191
1192 let cost = estimator.estimate(&query);
1193 assert!(cost.cpu > 0.0);
1194 assert!(cost.total > 0.0);
1195 }
1196
1197 #[test]
1198 fn test_filter_selectivity() {
1199 let estimator = CostEstimator::new();
1200
1201 let eq_filter = AstFilter::Compare {
1202 field: FieldRef::column("hosts", "id"),
1203 op: CompareOp::Eq,
1204 value: Value::Integer(1),
1205 };
1206 assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1207
1208 let ne_filter = AstFilter::Compare {
1209 field: FieldRef::column("hosts", "id"),
1210 op: CompareOp::Ne,
1211 value: Value::Integer(1),
1212 };
1213 assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1214 }
1215
1216 #[test]
1217 fn test_and_selectivity() {
1218 let estimator = CostEstimator::new();
1219
1220 let and_filter = AstFilter::And(
1221 Box::new(AstFilter::Compare {
1222 field: FieldRef::column("hosts", "a"),
1223 op: CompareOp::Eq,
1224 value: Value::Integer(1),
1225 }),
1226 Box::new(AstFilter::Compare {
1227 field: FieldRef::column("hosts", "b"),
1228 op: CompareOp::Eq,
1229 value: Value::Integer(2),
1230 }),
1231 );
1232
1233 let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1234 assert!(selectivity < 0.01); }
1236
1237 #[test]
1238 fn test_cardinality_with_limit() {
1239 let estimator = CostEstimator::new();
1240
1241 let query = TableQuery {
1242 table: "hosts".to_string(),
1243 source: None,
1244 alias: None,
1245 select_items: Vec::new(),
1246 columns: vec![Projection::All],
1247 where_expr: None,
1248 filter: None,
1249 group_by_exprs: Vec::new(),
1250 group_by: Vec::new(),
1251 having_expr: None,
1252 having: None,
1253 order_by: vec![],
1254 limit: Some(10),
1255 limit_param: None,
1256 offset: None,
1257 offset_param: None,
1258 expand: None,
1259 as_of: None,
1260 };
1261
1262 let card = estimator.estimate_table_cardinality(&query);
1263 assert!(card.rows <= 10.0);
1264 }
1265
1266 #[test]
1271 fn startup_zero_for_full_scan() {
1272 let estimator = CostEstimator::new();
1276 let q = table_query("any_table", None);
1277 let cost = estimator.estimate(&QueryExpr::Table(q));
1278 assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1279 assert!(cost.total > 0.0);
1280 }
1281
1282 #[test]
1283 fn startup_nonzero_for_blocking_combine() {
1284 let input = PlanCost::new(100.0, 10.0, 50.0); let blocker = PlanCost::new(20.0, 0.0, 10.0); let composed = input.combine_blocking(&blocker);
1289 assert_eq!(composed.startup_cost, input.total);
1291 assert_eq!(composed.total, input.total + blocker.total);
1293 assert!(composed.startup_cost > 0.0);
1294 }
1295
1296 #[test]
1297 fn pipelined_combine_adds_startup_directly() {
1298 let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1299 let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1300 let composed = upstream.combine_pipelined(&downstream);
1301 assert_eq!(composed.startup_cost, 30.0 + 5.0);
1302 assert_eq!(composed.total, upstream.total + downstream.total);
1303 }
1304
1305 #[test]
1306 fn cost_prefers_low_startup_when_limit_small() {
1307 let fast_first = PlanCost {
1310 cpu: 100.0,
1311 io: 10.0,
1312 network: 0.0,
1313 memory: 50.0,
1314 startup_cost: 5.0,
1315 total: 200.0,
1316 };
1317 let slow_first = PlanCost {
1318 cpu: 100.0,
1319 io: 10.0,
1320 network: 0.0,
1321 memory: 50.0,
1322 startup_cost: 150.0,
1323 total: 200.0,
1324 };
1325 assert_eq!(
1327 fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1328 std::cmp::Ordering::Less
1329 );
1330 }
1331
1332 #[test]
1333 fn cost_prefers_low_total_when_no_limit() {
1334 let low_total = PlanCost {
1336 cpu: 50.0,
1337 io: 5.0,
1338 network: 0.0,
1339 memory: 0.0,
1340 startup_cost: 30.0,
1341 total: 100.0,
1342 };
1343 let high_total = PlanCost {
1344 cpu: 100.0,
1345 io: 10.0,
1346 network: 0.0,
1347 memory: 0.0,
1348 startup_cost: 5.0,
1349 total: 200.0,
1350 };
1351 assert_eq!(
1352 low_total.prefer_over(&high_total, None, 10_000.0),
1353 std::cmp::Ordering::Less
1354 );
1355 }
1356
1357 #[test]
1358 fn limit_threshold_falls_back_to_total_when_limit_large() {
1359 let low_total = PlanCost {
1361 cpu: 50.0,
1362 io: 5.0,
1363 network: 0.0,
1364 memory: 0.0,
1365 startup_cost: 30.0,
1366 total: 100.0,
1367 };
1368 let low_startup = PlanCost {
1369 cpu: 100.0,
1370 io: 10.0,
1371 network: 0.0,
1372 memory: 0.0,
1373 startup_cost: 5.0,
1374 total: 200.0,
1375 };
1376 assert_eq!(
1377 low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1378 std::cmp::Ordering::Less
1379 );
1380 }
1381
1382 #[test]
1383 fn hash_join_startup_includes_build_cost() {
1384 let left = PlanCost::new(80.0, 8.0, 100.0); let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); let after_build = left.combine_blocking(&build);
1390 assert!(
1391 after_build.startup_cost >= left.total,
1392 "after-build startup ({}) must absorb left.total ({})",
1393 after_build.startup_cost,
1394 left.total
1395 );
1396 assert!(after_build.total >= after_build.startup_cost);
1397 }
1398
1399 #[test]
1400 fn vector_search_reports_nonzero_startup() {
1401 let estimator = CostEstimator::new();
1405 let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1408 assert!(v.startup_cost > 0.0);
1409 assert!(v.startup_cost < v.total);
1410 let _ = estimator; }
1412
1413 #[test]
1414 fn with_startup_clamps_total_below_startup() {
1415 let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1417 assert!(cost.total >= cost.startup_cost);
1418 }
1419
1420 #[test]
1421 fn default_plancost_has_zero_startup() {
1422 let c = PlanCost::default();
1423 assert_eq!(c.startup_cost, 0.0);
1424 assert_eq!(c.total, 0.0);
1425 }
1426
1427 use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1432
1433 fn provider_with_skew() -> Arc<StaticProvider> {
1434 let mut sample: Vec<ColumnValue> = Vec::new();
1438 for i in 0..80 {
1439 sample.push(ColumnValue::Int(i % 10));
1440 }
1441 for i in 0..20 {
1442 sample.push(ColumnValue::Int(10 + i * 50));
1443 }
1444 let h = Histogram::equi_depth_from_sample(sample, 10);
1445
1446 let mcv = MostCommonValues::new(vec![
1447 (ColumnValue::Text("boss".to_string()), 0.5),
1448 (ColumnValue::Text("intern".to_string()), 0.05),
1449 ]);
1450
1451 Arc::new(
1452 StaticProvider::new()
1453 .with_table(
1454 "people",
1455 TableStats {
1456 row_count: 100_000,
1457 avg_row_size: 64,
1458 page_count: 100,
1459 columns: vec![],
1460 },
1461 )
1462 .with_histogram("people", "score", h)
1463 .with_mcv("people", "role", mcv),
1464 )
1465 }
1466
1467 #[test]
1468 fn eq_uses_mcv_when_value_is_tracked() {
1469 let provider = provider_with_skew();
1470 let estimator = CostEstimator::with_stats(provider);
1471 let filter = AstFilter::Compare {
1472 field: FieldRef::column("people", "role"),
1473 op: CompareOp::Eq,
1474 value: Value::text("boss".to_string()),
1475 };
1476 let s = estimator.filter_selectivity(&filter, "people");
1479 assert!(
1480 (s - 0.5).abs() < 1e-9,
1481 "MCV-tracked equality should report exact frequency, got {s}"
1482 );
1483 }
1484
1485 #[test]
1486 fn eq_uses_residual_for_non_mcv_value() {
1487 let provider = provider_with_skew();
1488 let estimator = CostEstimator::with_stats(provider);
1489 let filter = AstFilter::Compare {
1490 field: FieldRef::column("people", "role"),
1491 op: CompareOp::Eq,
1492 value: Value::text("staff".to_string()),
1493 };
1494 let s = estimator.filter_selectivity(&filter, "people");
1498 assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1499 }
1500
1501 #[test]
1502 fn ne_is_one_minus_eq_under_mcv() {
1503 let provider = provider_with_skew();
1504 let estimator = CostEstimator::with_stats(provider);
1505 let filter = AstFilter::Compare {
1506 field: FieldRef::column("people", "role"),
1507 op: CompareOp::Ne,
1508 value: Value::text("boss".to_string()),
1509 };
1510 let s = estimator.filter_selectivity(&filter, "people");
1511 assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1513 }
1514
1515 #[test]
1516 fn range_uses_histogram_when_present() {
1517 let provider = provider_with_skew();
1518 let estimator = CostEstimator::with_stats(provider);
1519 let filter = AstFilter::Compare {
1520 field: FieldRef::column("people", "score"),
1521 op: CompareOp::Le,
1522 value: Value::Integer(9),
1523 };
1524 let s = estimator.filter_selectivity(&filter, "people");
1527 assert!(
1528 s > 0.5,
1529 "histogram-based range selectivity should beat 0.3, got {s}"
1530 );
1531 }
1532
1533 #[test]
1534 fn between_uses_histogram() {
1535 let provider = provider_with_skew();
1536 let estimator = CostEstimator::with_stats(provider);
1537 let filter = AstFilter::Between {
1538 field: FieldRef::column("people", "score"),
1539 low: Value::Integer(0),
1540 high: Value::Integer(9),
1541 };
1542 let s = estimator.filter_selectivity(&filter, "people");
1543 assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1544 }
1545
1546 #[test]
1547 fn graceful_fallback_when_histogram_absent() {
1548 let provider = Arc::new(StaticProvider::new().with_table(
1551 "people",
1552 TableStats {
1553 row_count: 1000,
1554 avg_row_size: 64,
1555 page_count: 10,
1556 columns: vec![],
1557 },
1558 ));
1559 let estimator = CostEstimator::with_stats(provider);
1560 let filter = AstFilter::Compare {
1561 field: FieldRef::column("people", "unknown_col"),
1562 op: CompareOp::Lt,
1563 value: Value::Integer(50),
1564 };
1565 let s = estimator.filter_selectivity(&filter, "people");
1566 assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1567 }
1568}