1use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16 CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17 PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24 pub rows: f64,
26 pub selectivity: f64,
28 pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33 pub fn new(rows: f64, selectivity: f64) -> Self {
35 Self {
36 rows,
37 selectivity,
38 confidence: 1.0,
39 }
40 }
41
42 pub fn full_scan(table_size: f64) -> Self {
44 Self {
45 rows: table_size,
46 selectivity: 1.0,
47 confidence: 1.0,
48 }
49 }
50
51 pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53 self.rows *= filter_selectivity;
54 self.selectivity *= filter_selectivity;
55 self.confidence *= 0.9; self
57 }
58}
59
60#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71 pub cpu: f64,
73 pub io: f64,
75 pub network: f64,
77 pub memory: f64,
79 pub startup_cost: f64,
85 pub total: f64,
87}
88
89impl PlanCost {
90 pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92 let total = cpu + io * 10.0 + memory * 0.1; Self {
94 cpu,
95 io,
96 network: 0.0,
97 memory,
98 startup_cost: 0.0,
99 total,
100 }
101 }
102
103 pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107 let total = cpu + io * 10.0 + memory * 0.1;
108 Self {
109 cpu,
110 io,
111 network: 0.0,
112 memory,
113 startup_cost: startup_cost.max(0.0),
114 total: total.max(startup_cost),
115 }
116 }
117
118 pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124 PlanCost {
125 cpu: self.cpu + other.cpu,
126 io: self.io + other.io,
127 network: self.network + other.network,
128 memory: self.memory.max(other.memory),
129 startup_cost: self.startup_cost + other.startup_cost,
130 total: self.total + other.total,
131 }
132 }
133
134 pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141 PlanCost {
142 cpu: self.cpu + blocker.cpu,
143 io: self.io + blocker.io,
144 network: self.network + blocker.network,
145 memory: self.memory.max(blocker.memory),
146 startup_cost: self.total + blocker.startup_cost,
147 total: self.total + blocker.total,
148 }
149 }
150
151 pub fn combine(&self, other: &PlanCost) -> PlanCost {
156 self.combine_pipelined(other)
157 }
158
159 pub fn scale(&self, factor: f64) -> PlanCost {
161 PlanCost {
162 cpu: self.cpu * factor,
163 io: self.io * factor,
164 network: self.network * factor,
165 memory: self.memory, startup_cost: self.startup_cost, total: self.total * factor,
168 }
169 }
170
171 pub fn prefer_over(
183 &self,
184 other: &PlanCost,
185 limit: Option<u64>,
186 cardinality: f64,
187 ) -> std::cmp::Ordering {
188 let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189 let (lhs, rhs) = if use_startup {
190 (self.startup_cost, other.startup_cost)
191 } else {
192 (self.total, other.total)
193 };
194 lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195 }
196}
197
198#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201 pub row_count: u64,
203 pub avg_row_size: u32,
205 pub page_count: u64,
207 pub columns: Vec<ColumnStats>,
209}
210
211#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214 pub name: String,
216 pub distinct_count: u64,
218 pub null_count: u64,
220 pub min_value: Option<String>,
222 pub max_value: Option<String>,
224 pub has_index: bool,
226}
227
228pub struct CostEstimator {
230 default_row_count: f64,
232 row_scan_cost: f64,
234 index_lookup_cost: f64,
236 hash_probe_cost: f64,
238 nested_loop_cost: f64,
240 edge_traversal_cost: f64,
242 stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250 pub fn new() -> Self {
253 Self {
254 default_row_count: 1000.0,
255 row_scan_cost: 1.0,
256 index_lookup_cost: 0.1,
257 hash_probe_cost: 0.5,
258 nested_loop_cost: 2.0,
259 edge_traversal_cost: 1.5,
260 stats: Arc::new(NullProvider),
261 }
262 }
263
264 pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268 Self {
269 stats: provider,
270 ..Self::new()
271 }
272 }
273
274 pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278 self.stats = provider;
279 }
280
281 pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283 match query {
284 QueryExpr::Table(tq) => self.estimate_table(tq),
285 QueryExpr::Graph(gq) => self.estimate_graph(gq),
286 QueryExpr::Join(jq) => self.estimate_join(jq),
287 QueryExpr::Path(pq) => self.estimate_path(pq),
288 QueryExpr::Vector(vq) => self.estimate_vector(vq),
289 QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290 QueryExpr::Insert(_)
292 | QueryExpr::Update(_)
293 | QueryExpr::Delete(_)
294 | QueryExpr::CreateTable(_)
295 | QueryExpr::CreateCollection(_)
296 | QueryExpr::CreateVector(_)
297 | QueryExpr::DropTable(_)
298 | QueryExpr::DropGraph(_)
299 | QueryExpr::DropVector(_)
300 | QueryExpr::DropDocument(_)
301 | QueryExpr::DropKv(_)
302 | QueryExpr::DropCollection(_)
303 | QueryExpr::Truncate(_)
304 | QueryExpr::AlterTable(_)
305 | QueryExpr::GraphCommand(_)
306 | QueryExpr::SearchCommand(_)
307 | QueryExpr::CreateIndex(_)
308 | QueryExpr::DropIndex(_)
309 | QueryExpr::ProbabilisticCommand(_)
310 | QueryExpr::Ask(_)
311 | QueryExpr::SetConfig { .. }
312 | QueryExpr::ShowConfig { .. }
313 | QueryExpr::SetSecret { .. }
314 | QueryExpr::DeleteSecret { .. }
315 | QueryExpr::ShowSecrets { .. }
316 | QueryExpr::SetTenant(_)
317 | QueryExpr::ShowTenant
318 | QueryExpr::CreateTimeSeries(_)
319 | QueryExpr::DropTimeSeries(_)
320 | QueryExpr::CreateQueue(_)
321 | QueryExpr::AlterQueue(_)
322 | QueryExpr::DropQueue(_)
323 | QueryExpr::QueueSelect(_)
324 | QueryExpr::QueueCommand(_)
325 | QueryExpr::KvCommand(_)
326 | QueryExpr::ConfigCommand(_)
327 | QueryExpr::CreateTree(_)
328 | QueryExpr::DropTree(_)
329 | QueryExpr::TreeCommand(_)
330 | QueryExpr::ExplainAlter(_)
331 | QueryExpr::TransactionControl(_)
332 | QueryExpr::MaintenanceCommand(_)
333 | QueryExpr::CreateSchema(_)
334 | QueryExpr::DropSchema(_)
335 | QueryExpr::CreateSequence(_)
336 | QueryExpr::DropSequence(_)
337 | QueryExpr::CopyFrom(_)
338 | QueryExpr::CreateView(_)
339 | QueryExpr::DropView(_)
340 | QueryExpr::RefreshMaterializedView(_)
341 | QueryExpr::CreatePolicy(_)
342 | QueryExpr::DropPolicy(_)
343 | QueryExpr::CreateServer(_)
344 | QueryExpr::DropServer(_)
345 | QueryExpr::CreateForeignTable(_)
346 | QueryExpr::DropForeignTable(_)
347 | QueryExpr::Grant(_)
348 | QueryExpr::Revoke(_)
349 | QueryExpr::AlterUser(_)
350 | QueryExpr::CreateIamPolicy { .. }
351 | QueryExpr::DropIamPolicy { .. }
352 | QueryExpr::AttachPolicy { .. }
353 | QueryExpr::DetachPolicy { .. }
354 | QueryExpr::ShowPolicies { .. }
355 | QueryExpr::ShowEffectivePermissions { .. }
356 | QueryExpr::SimulatePolicy { .. }
357 | QueryExpr::CreateMigration(_)
358 | QueryExpr::ApplyMigration(_)
359 | QueryExpr::RollbackMigration(_)
360 | QueryExpr::ExplainMigration(_)
361 | QueryExpr::EventsBackfill(_)
362 | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
363 }
364 }
365
366 pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
368 match query {
369 QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
370 QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
371 QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
372 QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
373 QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
374 QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
375 QueryExpr::Insert(_)
377 | QueryExpr::Update(_)
378 | QueryExpr::Delete(_)
379 | QueryExpr::CreateTable(_)
380 | QueryExpr::CreateCollection(_)
381 | QueryExpr::CreateVector(_)
382 | QueryExpr::DropTable(_)
383 | QueryExpr::DropGraph(_)
384 | QueryExpr::DropVector(_)
385 | QueryExpr::DropDocument(_)
386 | QueryExpr::DropKv(_)
387 | QueryExpr::DropCollection(_)
388 | QueryExpr::Truncate(_)
389 | QueryExpr::AlterTable(_)
390 | QueryExpr::GraphCommand(_)
391 | QueryExpr::SearchCommand(_)
392 | QueryExpr::CreateIndex(_)
393 | QueryExpr::DropIndex(_)
394 | QueryExpr::ProbabilisticCommand(_)
395 | QueryExpr::Ask(_)
396 | QueryExpr::SetConfig { .. }
397 | QueryExpr::ShowConfig { .. }
398 | QueryExpr::SetSecret { .. }
399 | QueryExpr::DeleteSecret { .. }
400 | QueryExpr::ShowSecrets { .. }
401 | QueryExpr::SetTenant(_)
402 | QueryExpr::ShowTenant
403 | QueryExpr::CreateTimeSeries(_)
404 | QueryExpr::DropTimeSeries(_)
405 | QueryExpr::CreateQueue(_)
406 | QueryExpr::AlterQueue(_)
407 | QueryExpr::DropQueue(_)
408 | QueryExpr::QueueSelect(_)
409 | QueryExpr::QueueCommand(_)
410 | QueryExpr::KvCommand(_)
411 | QueryExpr::ConfigCommand(_)
412 | QueryExpr::CreateTree(_)
413 | QueryExpr::DropTree(_)
414 | QueryExpr::TreeCommand(_)
415 | QueryExpr::ExplainAlter(_)
416 | QueryExpr::TransactionControl(_)
417 | QueryExpr::MaintenanceCommand(_)
418 | QueryExpr::CreateSchema(_)
419 | QueryExpr::DropSchema(_)
420 | QueryExpr::CreateSequence(_)
421 | QueryExpr::DropSequence(_)
422 | QueryExpr::CopyFrom(_)
423 | QueryExpr::CreateView(_)
424 | QueryExpr::DropView(_)
425 | QueryExpr::RefreshMaterializedView(_)
426 | QueryExpr::CreatePolicy(_)
427 | QueryExpr::DropPolicy(_)
428 | QueryExpr::CreateServer(_)
429 | QueryExpr::DropServer(_)
430 | QueryExpr::CreateForeignTable(_)
431 | QueryExpr::DropForeignTable(_)
432 | QueryExpr::Grant(_)
433 | QueryExpr::Revoke(_)
434 | QueryExpr::AlterUser(_)
435 | QueryExpr::CreateIamPolicy { .. }
436 | QueryExpr::DropIamPolicy { .. }
437 | QueryExpr::AttachPolicy { .. }
438 | QueryExpr::DetachPolicy { .. }
439 | QueryExpr::ShowPolicies { .. }
440 | QueryExpr::ShowEffectivePermissions { .. }
441 | QueryExpr::SimulatePolicy { .. }
442 | QueryExpr::CreateMigration(_)
443 | QueryExpr::ApplyMigration(_)
444 | QueryExpr::RollbackMigration(_)
445 | QueryExpr::ExplainMigration(_)
446 | QueryExpr::EventsBackfill(_)
447 | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
448 }
449 }
450
451 fn estimate_table(&self, query: &TableQuery) -> PlanCost {
456 let cardinality = self.estimate_table_cardinality(query);
457
458 let cpu = cardinality.rows * self.row_scan_cost;
459
460 let io = self.estimate_table_io(query, cardinality.rows);
463
464 let memory = cardinality.rows * 100.0; PlanCost::new(cpu, io, memory)
467 }
468
469 fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
476 const ROWS_PER_PAGE: f64 = 100.0;
477
478 let table_stats = self.stats.table_stats(&query.table);
480 let heap_pages = table_stats
481 .map(|s| s.page_count as f64)
482 .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
483
484 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
487 if let Some(col) = first_filter_column(&filter, &query.table) {
488 if let Some(idx) = self.stats.index_stats(&query.table, col) {
489 return idx.correlated_io_cost(result_rows, heap_pages);
490 }
491 }
492 }
493
494 (result_rows / ROWS_PER_PAGE).ceil()
496 }
497
498 fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
499 let base_rows = self
502 .stats
503 .table_stats(&query.table)
504 .map(|s| s.row_count as f64)
505 .unwrap_or(self.default_row_count);
506
507 let mut estimate = CardinalityEstimate::full_scan(base_rows);
508
509 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
512 let selectivity = self.filter_selectivity(&filter, &query.table);
513 estimate = estimate.with_filter(selectivity);
514 }
515
516 if let Some(limit) = query.limit {
518 estimate.rows = estimate.rows.min(limit as f64);
519 }
520
521 estimate
522 }
523
524 fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
537 match filter {
538 AstFilter::Compare { field, op, value } => {
539 let column = column_name_for_table(field, table);
540 match op {
541 CompareOp::Eq => self.eq_selectivity(table, column, value),
542 CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
543 CompareOp::Lt | CompareOp::Le => {
544 self.range_selectivity(table, column, None, Some(value))
545 }
546 CompareOp::Gt | CompareOp::Ge => {
547 self.range_selectivity(table, column, Some(value), None)
548 }
549 }
550 }
551 AstFilter::Between {
552 field, low, high, ..
553 } => {
554 let column = column_name_for_table(field, table);
555 self.range_selectivity(table, column, Some(low), Some(high))
556 }
557 AstFilter::In { field, values, .. } => {
558 let column = column_name_for_table(field, table);
559 if let Some(c) = column {
563 if let Some(mcv) = self.stats.column_mcv(table, c) {
564 let mut hits: f64 = 0.0;
565 let mut residual_count = 0usize;
566 for v in values {
567 if let Some(cv) = column_value_from(v) {
568 if let Some(freq) = mcv.frequency_of(&cv) {
569 hits += freq;
570 } else {
571 residual_count += 1;
572 }
573 } else {
574 residual_count += 1;
575 }
576 }
577 let total = mcv.total_frequency();
578 let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
579 let non_mcv_distinct =
580 distinct.saturating_sub(mcv.values.len() as u64).max(1);
581 let per_residual = (1.0 - total) / non_mcv_distinct as f64;
582 let estimate = hits + (residual_count as f64) * per_residual;
583 return estimate.clamp(0.0, 1.0).min(0.5);
584 }
585 if let Some(s) = self.stats.index_stats(table, c) {
586 return (s.point_selectivity() * values.len() as f64).min(0.5);
587 }
588 }
589 (values.len() as f64 * 0.01).min(0.5)
590 }
591 AstFilter::Like { .. } => 0.1,
592 AstFilter::StartsWith { .. } => 0.15,
593 AstFilter::EndsWith { .. } => 0.15,
594 AstFilter::Contains { .. } => 0.1,
595 AstFilter::IsNull { .. } => 0.01,
596 AstFilter::IsNotNull { .. } => 0.99,
597 AstFilter::And(left, right) => {
598 self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
599 }
600 AstFilter::Or(left, right) => {
601 let s1 = self.filter_selectivity(left, table);
602 let s2 = self.filter_selectivity(right, table);
603 s1 + s2 - (s1 * s2)
604 }
605 AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
606 AstFilter::CompareFields { .. } => {
607 0.1
611 }
612 AstFilter::CompareExpr { .. } => {
613 0.1
617 }
618 }
619 }
620
621 fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
626 let cardinality = self.estimate_graph_cardinality(query);
627
628 let nodes = query.pattern.nodes.len() as f64;
630 let edges = query.pattern.edges.len() as f64;
631
632 let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
633 let io = cardinality.rows * 0.1; let memory = cardinality.rows * 200.0; PlanCost::new(cpu, io, memory)
637 }
638
639 fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
640 let nodes = query.pattern.nodes.len() as f64;
641 let edges = query.pattern.edges.len() as f64;
642
643 let base_rows = self.default_row_count;
645 let edge_factor = 0.1_f64.powf(edges); let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
648 estimate.confidence = 0.5; if let Some(ref filter) = query.filter {
652 let selectivity = Self::estimate_filter_selectivity(filter);
653 estimate = estimate.with_filter(selectivity);
654 }
655
656 estimate
657 }
658
659 fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
664 let left_cost = self.estimate(&query.left);
665 let right_cost = self.estimate(&query.right);
666
667 let left_card = self.estimate_cardinality(&query.left);
668 let right_card = self.estimate_cardinality(&query.right);
669
670 let build_cpu = left_card.rows * self.hash_probe_cost;
677 let probe_cpu = right_card.rows * self.hash_probe_cost;
678 let join_memory = left_card.rows * 100.0; let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
682 let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
684
685 let after_build = left_cost.combine_blocking(&build_op);
687 after_build
688 .combine_pipelined(&right_cost)
689 .combine_pipelined(&probe_op)
690 }
691
692 fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
693 let left = self.estimate_cardinality(&query.left);
694 let right = self.estimate_cardinality(&query.right);
695
696 let selectivity = match query.join_type {
698 JoinType::Inner => 0.1, JoinType::LeftOuter => 1.0, JoinType::RightOuter => 1.0, JoinType::FullOuter => 1.0, JoinType::Cross => 1.0, };
704
705 CardinalityEstimate::new(
706 left.rows * right.rows * selectivity,
707 left.selectivity * right.selectivity * selectivity,
708 )
709 }
710
711 fn estimate_path(&self, query: &PathQuery) -> PlanCost {
716 let cardinality = self.estimate_path_cardinality(query);
717
718 let max_hops = query.max_length;
720 let branching_factor: f64 = 5.0; let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
723 let cpu = nodes_visited * self.edge_traversal_cost;
724 let io = nodes_visited * 0.1;
725 let memory = nodes_visited * 50.0; PlanCost::new(cpu, io, memory)
728 }
729
730 fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
731 let max_paths = 10.0;
733 CardinalityEstimate::new(max_paths, 0.001)
734 }
735
736 fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
741 let k = query.k as f64;
744
745 let hnsw_cost = 100.0 * (1.0 + k.ln()); let filter_cost =
752 if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
753 50.0
754 } else {
755 0.0
756 };
757
758 let cpu = hnsw_cost + filter_cost;
759 let io = 20.0; let memory = k * 32.0 + 1000.0; PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
767 }
768
769 fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
770 let k = query.k as f64;
772 CardinalityEstimate::new(k, 0.1)
773 }
774
775 fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
780 let structured_cost = self.estimate(&query.structured);
782 let vector_cost = self.estimate_vector(&query.vector);
783
784 let fusion_overhead = match &query.fusion {
786 crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
787 crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
788 crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
789 crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
790 crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
791 crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
792 };
793
794 PlanCost::new(
795 structured_cost.cpu + vector_cost.cpu + fusion_overhead,
796 structured_cost.io + vector_cost.io,
797 structured_cost.memory + vector_cost.memory,
798 )
799 }
800
801 fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
802 let structured_card = self.estimate_cardinality(&query.structured);
803 let vector_card = self.estimate_vector_cardinality(&query.vector);
804
805 let rows = match &query.fusion {
807 crate::storage::query::ast::FusionStrategy::Intersection => {
808 structured_card.rows.min(vector_card.rows)
809 }
810 crate::storage::query::ast::FusionStrategy::Union { .. } => {
811 structured_card.rows + vector_card.rows
812 }
813 _ => vector_card.rows, };
815
816 CardinalityEstimate::new(rows, 0.2)
817 }
818
819 fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
824 match filter {
825 AstFilter::Compare { op, .. } => {
826 match op {
827 CompareOp::Eq => 0.01, CompareOp::Ne => 0.99, CompareOp::Lt | CompareOp::Le => 0.3,
830 CompareOp::Gt | CompareOp::Ge => 0.3,
831 }
832 }
833 AstFilter::Between { .. } => 0.25,
834 AstFilter::In { values, .. } => {
835 (values.len() as f64 * 0.01).min(0.5)
837 }
838 AstFilter::Like { .. } => 0.1,
839 AstFilter::StartsWith { .. } => 0.15,
840 AstFilter::EndsWith { .. } => 0.15,
841 AstFilter::Contains { .. } => 0.1,
842 AstFilter::IsNull { .. } => 0.01,
843 AstFilter::IsNotNull { .. } => 0.99,
844 AstFilter::And(left, right) => {
845 Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
846 }
847 AstFilter::Or(left, right) => {
848 let s1 = Self::estimate_filter_selectivity(left);
849 let s2 = Self::estimate_filter_selectivity(right);
850 s1 + s2 - (s1 * s2) }
852 AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
853 AstFilter::CompareFields { .. } => 0.1,
854 AstFilter::CompareExpr { .. } => 0.1,
855 }
856 }
857}
858
859impl CostEstimator {
860 fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
868 if let Some(col) = column {
869 if let Some(mcv) = self.stats.column_mcv(table, col) {
871 if let Some(cv) = column_value_from(value) {
872 if let Some(freq) = mcv.frequency_of(&cv) {
873 return freq;
874 }
875 let total = mcv.total_frequency();
877 let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
878 let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
879 return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
880 }
881 }
882 if let Some(s) = self.stats.index_stats(table, col) {
884 return s.point_selectivity();
885 }
886 }
887 0.01
889 }
890
891 fn range_selectivity(
902 &self,
903 table: &str,
904 column: Option<&str>,
905 lo: Option<&Value>,
906 hi: Option<&Value>,
907 ) -> f64 {
908 if let Some(col) = column {
909 if let Some(h) = self.stats.column_histogram(table, col) {
911 let lo_cv = lo.and_then(column_value_from);
912 let hi_cv = hi.and_then(column_value_from);
913 return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
914 }
915 if let Some(s) = self.stats.index_stats(table, col) {
917 let cap = if lo.is_some() && hi.is_some() {
918 0.25
919 } else {
920 0.3
921 };
922 return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
923 }
924 }
925 if lo.is_some() && hi.is_some() {
927 0.25
928 } else {
929 0.3
930 }
931 }
932}
933
934impl Default for CostEstimator {
935 fn default() -> Self {
936 Self::new()
937 }
938}
939
940fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
945 use super::histogram::ColumnValue;
946 use crate::storage::schema::Value;
947 match v {
948 Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
949 Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
950 Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
951 Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
952 Value::Email(s)
953 | Value::Url(s)
954 | Value::NodeRef(s)
955 | Value::EdgeRef(s)
956 | Value::TableRef(s)
957 | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
958 Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
959 Value::Duration(d) => Some(ColumnValue::Int(*d)),
960 Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
961 Value::Decimal(d) => Some(ColumnValue::Int(*d)),
962 Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
963 Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
964 Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
965 Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
966 Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
967 Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
968 Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
969 Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
970 Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
971 _ => None,
976 }
977}
978
979fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
984 match filter {
985 AstFilter::Compare { field, .. } => column_name_for_table(field, table),
986 AstFilter::Between { field, .. } => column_name_for_table(field, table),
987 AstFilter::And(l, r) => {
988 first_filter_column(l, table).or_else(|| first_filter_column(r, table))
989 }
990 _ => None,
991 }
992}
993
994fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
996 match field {
997 FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
998 Some(column.as_str())
999 }
1000 _ => None,
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::super::stats_provider::StaticProvider;
1008 use super::*;
1009 use crate::storage::index::{IndexKind, IndexStats};
1010 use crate::storage::query::ast::{FieldRef, Projection};
1011 use crate::storage::schema::Value;
1012
1013 fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1014 AstFilter::Compare {
1015 field: FieldRef::column(table, column),
1016 op: CompareOp::Eq,
1017 value: Value::Integer(value),
1018 }
1019 }
1020
1021 fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1022 TableQuery {
1023 table: name.to_string(),
1024 source: None,
1025 alias: None,
1026 select_items: Vec::new(),
1027 columns: vec![Projection::All],
1028 where_expr: None,
1029 filter,
1030 group_by_exprs: Vec::new(),
1031 group_by: Vec::new(),
1032 having_expr: None,
1033 having: None,
1034 order_by: vec![],
1035 limit: None,
1036 limit_param: None,
1037 offset: None,
1038 offset_param: None,
1039 expand: None,
1040 as_of: None,
1041 sessionize: None,
1042 }
1043 }
1044
1045 #[test]
1046 fn injected_row_count_overrides_default() {
1047 let provider = Arc::new(StaticProvider::new().with_table(
1048 "users",
1049 TableStats {
1050 row_count: 50_000,
1051 avg_row_size: 256,
1052 page_count: 500,
1053 columns: vec![],
1054 },
1055 ));
1056 let estimator = CostEstimator::with_stats(provider);
1057 let q = table_query("users", None);
1058 let card = estimator.estimate_table_cardinality(&q);
1059 assert_eq!(card.rows, 50_000.0);
1061 }
1062
1063 #[test]
1064 fn stats_aware_eq_selectivity_beats_default() {
1065 let provider = Arc::new(
1066 StaticProvider::new()
1067 .with_table(
1068 "users",
1069 TableStats {
1070 row_count: 1_000_000,
1071 avg_row_size: 256,
1072 page_count: 10_000,
1073 columns: vec![],
1074 },
1075 )
1076 .with_index(
1077 "users",
1078 "email",
1079 IndexStats {
1080 entries: 1_000_000,
1081 distinct_keys: 1_000_000,
1082 approx_bytes: 0,
1083 kind: IndexKind::Hash,
1084 has_bloom: true,
1085 index_correlation: 0.0,
1086 },
1087 ),
1088 );
1089 let estimator = CostEstimator::with_stats(provider);
1090 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1091 let card = estimator.estimate_table_cardinality(&q);
1092 assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1094 }
1095
1096 #[test]
1097 fn fallback_when_no_index_stats() {
1098 let provider = Arc::new(StaticProvider::new().with_table(
1099 "users",
1100 TableStats {
1101 row_count: 1_000_000,
1102 avg_row_size: 256,
1103 page_count: 10_000,
1104 columns: vec![],
1105 },
1106 ));
1107 let estimator = CostEstimator::with_stats(provider);
1108 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1109 let card = estimator.estimate_table_cardinality(&q);
1110 assert!((card.rows - 10_000.0).abs() < 1.0);
1112 }
1113
1114 #[test]
1115 fn null_provider_keeps_legacy_behaviour() {
1116 let estimator = CostEstimator::new();
1117 let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1118 let card = estimator.estimate_table_cardinality(&q);
1119 assert!((card.rows - 10.0).abs() < 1.0);
1121 }
1122
1123 #[test]
1124 fn and_combines_stats_selectivities() {
1125 let provider = Arc::new(
1126 StaticProvider::new()
1127 .with_table(
1128 "t",
1129 TableStats {
1130 row_count: 100_000,
1131 avg_row_size: 64,
1132 page_count: 100,
1133 columns: vec![],
1134 },
1135 )
1136 .with_index(
1137 "t",
1138 "a",
1139 IndexStats {
1140 entries: 100_000,
1141 distinct_keys: 10,
1142 approx_bytes: 0,
1143 kind: IndexKind::BTree,
1144 has_bloom: false,
1145 index_correlation: 0.0,
1146 },
1147 )
1148 .with_index(
1149 "t",
1150 "b",
1151 IndexStats {
1152 entries: 100_000,
1153 distinct_keys: 1000,
1154 approx_bytes: 0,
1155 kind: IndexKind::BTree,
1156 has_bloom: false,
1157 index_correlation: 0.0,
1158 },
1159 ),
1160 );
1161 let estimator = CostEstimator::with_stats(provider);
1162 let filter = AstFilter::And(
1163 Box::new(eq_filter("t", "a", 1)),
1164 Box::new(eq_filter("t", "b", 1)),
1165 );
1166 let q = table_query("t", Some(filter));
1167 let card = estimator.estimate_table_cardinality(&q);
1168 assert!(card.rows < 15.0, "got {}", card.rows);
1170 }
1171
1172 #[test]
1173 fn test_table_cost_estimation() {
1174 let estimator = CostEstimator::new();
1175
1176 let query = QueryExpr::Table(TableQuery {
1177 table: "hosts".to_string(),
1178 source: None,
1179 alias: None,
1180 select_items: Vec::new(),
1181 columns: vec![Projection::All],
1182 where_expr: None,
1183 filter: None,
1184 group_by_exprs: Vec::new(),
1185 group_by: Vec::new(),
1186 having_expr: None,
1187 having: None,
1188 order_by: vec![],
1189 limit: None,
1190 limit_param: None,
1191 offset: None,
1192 offset_param: None,
1193 expand: None,
1194 as_of: None,
1195 sessionize: None,
1196 });
1197
1198 let cost = estimator.estimate(&query);
1199 assert!(cost.cpu > 0.0);
1200 assert!(cost.total > 0.0);
1201 }
1202
1203 #[test]
1204 fn test_filter_selectivity() {
1205 let estimator = CostEstimator::new();
1206
1207 let eq_filter = AstFilter::Compare {
1208 field: FieldRef::column("hosts", "id"),
1209 op: CompareOp::Eq,
1210 value: Value::Integer(1),
1211 };
1212 assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1213
1214 let ne_filter = AstFilter::Compare {
1215 field: FieldRef::column("hosts", "id"),
1216 op: CompareOp::Ne,
1217 value: Value::Integer(1),
1218 };
1219 assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1220 }
1221
1222 #[test]
1223 fn test_and_selectivity() {
1224 let estimator = CostEstimator::new();
1225
1226 let and_filter = AstFilter::And(
1227 Box::new(AstFilter::Compare {
1228 field: FieldRef::column("hosts", "a"),
1229 op: CompareOp::Eq,
1230 value: Value::Integer(1),
1231 }),
1232 Box::new(AstFilter::Compare {
1233 field: FieldRef::column("hosts", "b"),
1234 op: CompareOp::Eq,
1235 value: Value::Integer(2),
1236 }),
1237 );
1238
1239 let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1240 assert!(selectivity < 0.01); }
1242
1243 #[test]
1244 fn test_cardinality_with_limit() {
1245 let estimator = CostEstimator::new();
1246
1247 let query = TableQuery {
1248 table: "hosts".to_string(),
1249 source: None,
1250 alias: None,
1251 select_items: Vec::new(),
1252 columns: vec![Projection::All],
1253 where_expr: None,
1254 filter: None,
1255 group_by_exprs: Vec::new(),
1256 group_by: Vec::new(),
1257 having_expr: None,
1258 having: None,
1259 order_by: vec![],
1260 limit: Some(10),
1261 limit_param: None,
1262 offset: None,
1263 offset_param: None,
1264 expand: None,
1265 as_of: None,
1266 sessionize: None,
1267 };
1268
1269 let card = estimator.estimate_table_cardinality(&query);
1270 assert!(card.rows <= 10.0);
1271 }
1272
1273 #[test]
1278 fn startup_zero_for_full_scan() {
1279 let estimator = CostEstimator::new();
1283 let q = table_query("any_table", None);
1284 let cost = estimator.estimate(&QueryExpr::Table(q));
1285 assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1286 assert!(cost.total > 0.0);
1287 }
1288
1289 #[test]
1290 fn startup_nonzero_for_blocking_combine() {
1291 let input = PlanCost::new(100.0, 10.0, 50.0); let blocker = PlanCost::new(20.0, 0.0, 10.0); let composed = input.combine_blocking(&blocker);
1296 assert_eq!(composed.startup_cost, input.total);
1298 assert_eq!(composed.total, input.total + blocker.total);
1300 assert!(composed.startup_cost > 0.0);
1301 }
1302
1303 #[test]
1304 fn pipelined_combine_adds_startup_directly() {
1305 let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1306 let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1307 let composed = upstream.combine_pipelined(&downstream);
1308 assert_eq!(composed.startup_cost, 30.0 + 5.0);
1309 assert_eq!(composed.total, upstream.total + downstream.total);
1310 }
1311
1312 #[test]
1313 fn cost_prefers_low_startup_when_limit_small() {
1314 let fast_first = PlanCost {
1317 cpu: 100.0,
1318 io: 10.0,
1319 network: 0.0,
1320 memory: 50.0,
1321 startup_cost: 5.0,
1322 total: 200.0,
1323 };
1324 let slow_first = PlanCost {
1325 cpu: 100.0,
1326 io: 10.0,
1327 network: 0.0,
1328 memory: 50.0,
1329 startup_cost: 150.0,
1330 total: 200.0,
1331 };
1332 assert_eq!(
1334 fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1335 std::cmp::Ordering::Less
1336 );
1337 }
1338
1339 #[test]
1340 fn cost_prefers_low_total_when_no_limit() {
1341 let low_total = PlanCost {
1343 cpu: 50.0,
1344 io: 5.0,
1345 network: 0.0,
1346 memory: 0.0,
1347 startup_cost: 30.0,
1348 total: 100.0,
1349 };
1350 let high_total = PlanCost {
1351 cpu: 100.0,
1352 io: 10.0,
1353 network: 0.0,
1354 memory: 0.0,
1355 startup_cost: 5.0,
1356 total: 200.0,
1357 };
1358 assert_eq!(
1359 low_total.prefer_over(&high_total, None, 10_000.0),
1360 std::cmp::Ordering::Less
1361 );
1362 }
1363
1364 #[test]
1365 fn limit_threshold_falls_back_to_total_when_limit_large() {
1366 let low_total = PlanCost {
1368 cpu: 50.0,
1369 io: 5.0,
1370 network: 0.0,
1371 memory: 0.0,
1372 startup_cost: 30.0,
1373 total: 100.0,
1374 };
1375 let low_startup = PlanCost {
1376 cpu: 100.0,
1377 io: 10.0,
1378 network: 0.0,
1379 memory: 0.0,
1380 startup_cost: 5.0,
1381 total: 200.0,
1382 };
1383 assert_eq!(
1384 low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1385 std::cmp::Ordering::Less
1386 );
1387 }
1388
1389 #[test]
1390 fn hash_join_startup_includes_build_cost() {
1391 let left = PlanCost::new(80.0, 8.0, 100.0); let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); let after_build = left.combine_blocking(&build);
1397 assert!(
1398 after_build.startup_cost >= left.total,
1399 "after-build startup ({}) must absorb left.total ({})",
1400 after_build.startup_cost,
1401 left.total
1402 );
1403 assert!(after_build.total >= after_build.startup_cost);
1404 }
1405
1406 #[test]
1407 fn vector_search_reports_nonzero_startup() {
1408 let estimator = CostEstimator::new();
1412 let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1415 assert!(v.startup_cost > 0.0);
1416 assert!(v.startup_cost < v.total);
1417 let _ = estimator; }
1419
1420 #[test]
1421 fn with_startup_clamps_total_below_startup() {
1422 let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1424 assert!(cost.total >= cost.startup_cost);
1425 }
1426
1427 #[test]
1428 fn default_plancost_has_zero_startup() {
1429 let c = PlanCost::default();
1430 assert_eq!(c.startup_cost, 0.0);
1431 assert_eq!(c.total, 0.0);
1432 }
1433
1434 use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1439
1440 fn provider_with_skew() -> Arc<StaticProvider> {
1441 let mut sample: Vec<ColumnValue> = Vec::new();
1445 for i in 0..80 {
1446 sample.push(ColumnValue::Int(i % 10));
1447 }
1448 for i in 0..20 {
1449 sample.push(ColumnValue::Int(10 + i * 50));
1450 }
1451 let h = Histogram::equi_depth_from_sample(sample, 10);
1452
1453 let mcv = MostCommonValues::new(vec![
1454 (ColumnValue::Text("boss".to_string()), 0.5),
1455 (ColumnValue::Text("intern".to_string()), 0.05),
1456 ]);
1457
1458 Arc::new(
1459 StaticProvider::new()
1460 .with_table(
1461 "people",
1462 TableStats {
1463 row_count: 100_000,
1464 avg_row_size: 64,
1465 page_count: 100,
1466 columns: vec![],
1467 },
1468 )
1469 .with_histogram("people", "score", h)
1470 .with_mcv("people", "role", mcv),
1471 )
1472 }
1473
1474 #[test]
1475 fn eq_uses_mcv_when_value_is_tracked() {
1476 let provider = provider_with_skew();
1477 let estimator = CostEstimator::with_stats(provider);
1478 let filter = AstFilter::Compare {
1479 field: FieldRef::column("people", "role"),
1480 op: CompareOp::Eq,
1481 value: Value::text("boss".to_string()),
1482 };
1483 let s = estimator.filter_selectivity(&filter, "people");
1486 assert!(
1487 (s - 0.5).abs() < 1e-9,
1488 "MCV-tracked equality should report exact frequency, got {s}"
1489 );
1490 }
1491
1492 #[test]
1493 fn eq_uses_residual_for_non_mcv_value() {
1494 let provider = provider_with_skew();
1495 let estimator = CostEstimator::with_stats(provider);
1496 let filter = AstFilter::Compare {
1497 field: FieldRef::column("people", "role"),
1498 op: CompareOp::Eq,
1499 value: Value::text("staff".to_string()),
1500 };
1501 let s = estimator.filter_selectivity(&filter, "people");
1505 assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1506 }
1507
1508 #[test]
1509 fn ne_is_one_minus_eq_under_mcv() {
1510 let provider = provider_with_skew();
1511 let estimator = CostEstimator::with_stats(provider);
1512 let filter = AstFilter::Compare {
1513 field: FieldRef::column("people", "role"),
1514 op: CompareOp::Ne,
1515 value: Value::text("boss".to_string()),
1516 };
1517 let s = estimator.filter_selectivity(&filter, "people");
1518 assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1520 }
1521
1522 #[test]
1523 fn range_uses_histogram_when_present() {
1524 let provider = provider_with_skew();
1525 let estimator = CostEstimator::with_stats(provider);
1526 let filter = AstFilter::Compare {
1527 field: FieldRef::column("people", "score"),
1528 op: CompareOp::Le,
1529 value: Value::Integer(9),
1530 };
1531 let s = estimator.filter_selectivity(&filter, "people");
1534 assert!(
1535 s > 0.5,
1536 "histogram-based range selectivity should beat 0.3, got {s}"
1537 );
1538 }
1539
1540 #[test]
1541 fn between_uses_histogram() {
1542 let provider = provider_with_skew();
1543 let estimator = CostEstimator::with_stats(provider);
1544 let filter = AstFilter::Between {
1545 field: FieldRef::column("people", "score"),
1546 low: Value::Integer(0),
1547 high: Value::Integer(9),
1548 };
1549 let s = estimator.filter_selectivity(&filter, "people");
1550 assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1551 }
1552
1553 #[test]
1554 fn graceful_fallback_when_histogram_absent() {
1555 let provider = Arc::new(StaticProvider::new().with_table(
1558 "people",
1559 TableStats {
1560 row_count: 1000,
1561 avg_row_size: 64,
1562 page_count: 10,
1563 columns: vec![],
1564 },
1565 ));
1566 let estimator = CostEstimator::with_stats(provider);
1567 let filter = AstFilter::Compare {
1568 field: FieldRef::column("people", "unknown_col"),
1569 op: CompareOp::Lt,
1570 value: Value::Integer(50),
1571 };
1572 let s = estimator.filter_selectivity(&filter, "people");
1573 assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1574 }
1575}