1use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16 CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17 PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24 pub rows: f64,
26 pub selectivity: f64,
28 pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33 pub fn new(rows: f64, selectivity: f64) -> Self {
35 Self {
36 rows,
37 selectivity,
38 confidence: 1.0,
39 }
40 }
41
42 pub fn full_scan(table_size: f64) -> Self {
44 Self {
45 rows: table_size,
46 selectivity: 1.0,
47 confidence: 1.0,
48 }
49 }
50
51 pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53 self.rows *= filter_selectivity;
54 self.selectivity *= filter_selectivity;
55 self.confidence *= 0.9; self
57 }
58}
59
60#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71 pub cpu: f64,
73 pub io: f64,
75 pub network: f64,
77 pub memory: f64,
79 pub startup_cost: f64,
85 pub total: f64,
87}
88
89impl PlanCost {
90 pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92 let total = cpu + io * 10.0 + memory * 0.1; Self {
94 cpu,
95 io,
96 network: 0.0,
97 memory,
98 startup_cost: 0.0,
99 total,
100 }
101 }
102
103 pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107 let total = cpu + io * 10.0 + memory * 0.1;
108 Self {
109 cpu,
110 io,
111 network: 0.0,
112 memory,
113 startup_cost: startup_cost.max(0.0),
114 total: total.max(startup_cost),
115 }
116 }
117
118 pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124 PlanCost {
125 cpu: self.cpu + other.cpu,
126 io: self.io + other.io,
127 network: self.network + other.network,
128 memory: self.memory.max(other.memory),
129 startup_cost: self.startup_cost + other.startup_cost,
130 total: self.total + other.total,
131 }
132 }
133
134 pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141 PlanCost {
142 cpu: self.cpu + blocker.cpu,
143 io: self.io + blocker.io,
144 network: self.network + blocker.network,
145 memory: self.memory.max(blocker.memory),
146 startup_cost: self.total + blocker.startup_cost,
147 total: self.total + blocker.total,
148 }
149 }
150
151 pub fn combine(&self, other: &PlanCost) -> PlanCost {
156 self.combine_pipelined(other)
157 }
158
159 pub fn scale(&self, factor: f64) -> PlanCost {
161 PlanCost {
162 cpu: self.cpu * factor,
163 io: self.io * factor,
164 network: self.network * factor,
165 memory: self.memory, startup_cost: self.startup_cost, total: self.total * factor,
168 }
169 }
170
171 pub fn prefer_over(
183 &self,
184 other: &PlanCost,
185 limit: Option<u64>,
186 cardinality: f64,
187 ) -> std::cmp::Ordering {
188 let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189 let (lhs, rhs) = if use_startup {
190 (self.startup_cost, other.startup_cost)
191 } else {
192 (self.total, other.total)
193 };
194 lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195 }
196}
197
198#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201 pub row_count: u64,
203 pub avg_row_size: u32,
205 pub page_count: u64,
207 pub columns: Vec<ColumnStats>,
209}
210
211#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214 pub name: String,
216 pub distinct_count: u64,
218 pub null_count: u64,
220 pub min_value: Option<String>,
222 pub max_value: Option<String>,
224 pub has_index: bool,
226}
227
228pub struct CostEstimator {
230 default_row_count: f64,
232 row_scan_cost: f64,
234 index_lookup_cost: f64,
236 hash_probe_cost: f64,
238 nested_loop_cost: f64,
240 edge_traversal_cost: f64,
242 stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250 pub fn new() -> Self {
253 Self {
254 default_row_count: 1000.0,
255 row_scan_cost: 1.0,
256 index_lookup_cost: 0.1,
257 hash_probe_cost: 0.5,
258 nested_loop_cost: 2.0,
259 edge_traversal_cost: 1.5,
260 stats: Arc::new(NullProvider),
261 }
262 }
263
264 pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268 Self {
269 stats: provider,
270 ..Self::new()
271 }
272 }
273
274 pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278 self.stats = provider;
279 }
280
281 pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283 match query {
284 QueryExpr::Table(tq) => self.estimate_table(tq),
285 QueryExpr::Graph(gq) => self.estimate_graph(gq),
286 QueryExpr::Join(jq) => self.estimate_join(jq),
287 QueryExpr::Path(pq) => self.estimate_path(pq),
288 QueryExpr::Vector(vq) => self.estimate_vector(vq),
289 QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290 QueryExpr::Insert(_)
292 | QueryExpr::Update(_)
293 | QueryExpr::Delete(_)
294 | QueryExpr::CreateTable(_)
295 | QueryExpr::CreateCollection(_)
296 | QueryExpr::CreateVector(_)
297 | QueryExpr::DropTable(_)
298 | QueryExpr::DropGraph(_)
299 | QueryExpr::DropVector(_)
300 | QueryExpr::DropDocument(_)
301 | QueryExpr::DropKv(_)
302 | QueryExpr::DropCollection(_)
303 | QueryExpr::Truncate(_)
304 | QueryExpr::AlterTable(_)
305 | QueryExpr::GraphCommand(_)
306 | QueryExpr::SearchCommand(_)
307 | QueryExpr::CreateIndex(_)
308 | QueryExpr::DropIndex(_)
309 | QueryExpr::ProbabilisticCommand(_)
310 | QueryExpr::Ask(_)
311 | QueryExpr::SetConfig { .. }
312 | QueryExpr::ShowConfig { .. }
313 | QueryExpr::SetSecret { .. }
314 | QueryExpr::DeleteSecret { .. }
315 | QueryExpr::ShowSecrets { .. }
316 | QueryExpr::SetTenant(_)
317 | QueryExpr::ShowTenant
318 | QueryExpr::CreateTimeSeries(_)
319 | QueryExpr::DropTimeSeries(_)
320 | QueryExpr::CreateQueue(_)
321 | QueryExpr::AlterQueue(_)
322 | QueryExpr::DropQueue(_)
323 | QueryExpr::QueueSelect(_)
324 | QueryExpr::QueueCommand(_)
325 | QueryExpr::KvCommand(_)
326 | QueryExpr::ConfigCommand(_)
327 | QueryExpr::CreateTree(_)
328 | QueryExpr::DropTree(_)
329 | QueryExpr::TreeCommand(_)
330 | QueryExpr::ExplainAlter(_)
331 | QueryExpr::TransactionControl(_)
332 | QueryExpr::MaintenanceCommand(_)
333 | QueryExpr::CreateSchema(_)
334 | QueryExpr::DropSchema(_)
335 | QueryExpr::CreateSequence(_)
336 | QueryExpr::DropSequence(_)
337 | QueryExpr::CopyFrom(_)
338 | QueryExpr::CreateView(_)
339 | QueryExpr::DropView(_)
340 | QueryExpr::RefreshMaterializedView(_)
341 | QueryExpr::CreatePolicy(_)
342 | QueryExpr::DropPolicy(_)
343 | QueryExpr::CreateServer(_)
344 | QueryExpr::DropServer(_)
345 | QueryExpr::CreateForeignTable(_)
346 | QueryExpr::DropForeignTable(_)
347 | QueryExpr::Grant(_)
348 | QueryExpr::Revoke(_)
349 | QueryExpr::AlterUser(_)
350 | QueryExpr::CreateIamPolicy { .. }
351 | QueryExpr::DropIamPolicy { .. }
352 | QueryExpr::AttachPolicy { .. }
353 | QueryExpr::DetachPolicy { .. }
354 | QueryExpr::ShowPolicies { .. }
355 | QueryExpr::ShowEffectivePermissions { .. }
356 | QueryExpr::SimulatePolicy { .. }
357 | QueryExpr::CreateMigration(_)
358 | QueryExpr::ApplyMigration(_)
359 | QueryExpr::RollbackMigration(_)
360 | QueryExpr::ExplainMigration(_)
361 | QueryExpr::EventsBackfill(_)
362 | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
363 }
364 }
365
366 pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
368 match query {
369 QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
370 QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
371 QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
372 QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
373 QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
374 QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
375 QueryExpr::Insert(_)
377 | QueryExpr::Update(_)
378 | QueryExpr::Delete(_)
379 | QueryExpr::CreateTable(_)
380 | QueryExpr::CreateCollection(_)
381 | QueryExpr::CreateVector(_)
382 | QueryExpr::DropTable(_)
383 | QueryExpr::DropGraph(_)
384 | QueryExpr::DropVector(_)
385 | QueryExpr::DropDocument(_)
386 | QueryExpr::DropKv(_)
387 | QueryExpr::DropCollection(_)
388 | QueryExpr::Truncate(_)
389 | QueryExpr::AlterTable(_)
390 | QueryExpr::GraphCommand(_)
391 | QueryExpr::SearchCommand(_)
392 | QueryExpr::CreateIndex(_)
393 | QueryExpr::DropIndex(_)
394 | QueryExpr::ProbabilisticCommand(_)
395 | QueryExpr::Ask(_)
396 | QueryExpr::SetConfig { .. }
397 | QueryExpr::ShowConfig { .. }
398 | QueryExpr::SetSecret { .. }
399 | QueryExpr::DeleteSecret { .. }
400 | QueryExpr::ShowSecrets { .. }
401 | QueryExpr::SetTenant(_)
402 | QueryExpr::ShowTenant
403 | QueryExpr::CreateTimeSeries(_)
404 | QueryExpr::DropTimeSeries(_)
405 | QueryExpr::CreateQueue(_)
406 | QueryExpr::AlterQueue(_)
407 | QueryExpr::DropQueue(_)
408 | QueryExpr::QueueSelect(_)
409 | QueryExpr::QueueCommand(_)
410 | QueryExpr::KvCommand(_)
411 | QueryExpr::ConfigCommand(_)
412 | QueryExpr::CreateTree(_)
413 | QueryExpr::DropTree(_)
414 | QueryExpr::TreeCommand(_)
415 | QueryExpr::ExplainAlter(_)
416 | QueryExpr::TransactionControl(_)
417 | QueryExpr::MaintenanceCommand(_)
418 | QueryExpr::CreateSchema(_)
419 | QueryExpr::DropSchema(_)
420 | QueryExpr::CreateSequence(_)
421 | QueryExpr::DropSequence(_)
422 | QueryExpr::CopyFrom(_)
423 | QueryExpr::CreateView(_)
424 | QueryExpr::DropView(_)
425 | QueryExpr::RefreshMaterializedView(_)
426 | QueryExpr::CreatePolicy(_)
427 | QueryExpr::DropPolicy(_)
428 | QueryExpr::CreateServer(_)
429 | QueryExpr::DropServer(_)
430 | QueryExpr::CreateForeignTable(_)
431 | QueryExpr::DropForeignTable(_)
432 | QueryExpr::Grant(_)
433 | QueryExpr::Revoke(_)
434 | QueryExpr::AlterUser(_)
435 | QueryExpr::CreateIamPolicy { .. }
436 | QueryExpr::DropIamPolicy { .. }
437 | QueryExpr::AttachPolicy { .. }
438 | QueryExpr::DetachPolicy { .. }
439 | QueryExpr::ShowPolicies { .. }
440 | QueryExpr::ShowEffectivePermissions { .. }
441 | QueryExpr::SimulatePolicy { .. }
442 | QueryExpr::CreateMigration(_)
443 | QueryExpr::ApplyMigration(_)
444 | QueryExpr::RollbackMigration(_)
445 | QueryExpr::ExplainMigration(_)
446 | QueryExpr::EventsBackfill(_)
447 | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
448 }
449 }
450
451 fn estimate_table(&self, query: &TableQuery) -> PlanCost {
456 let cardinality = self.estimate_table_cardinality(query);
457
458 let cpu = cardinality.rows * self.row_scan_cost;
459
460 let io = self.estimate_table_io(query, cardinality.rows);
463
464 let memory = cardinality.rows * 100.0; PlanCost::new(cpu, io, memory)
467 }
468
469 fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
476 const ROWS_PER_PAGE: f64 = 100.0;
477
478 let table_stats = self.stats.table_stats(&query.table);
480 let heap_pages = table_stats
481 .map(|s| s.page_count as f64)
482 .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
483
484 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
487 if let Some(col) = first_filter_column(&filter, &query.table) {
488 if let Some(idx) = self.stats.index_stats(&query.table, col) {
489 return idx.correlated_io_cost(result_rows, heap_pages);
490 }
491 }
492 }
493
494 (result_rows / ROWS_PER_PAGE).ceil()
496 }
497
498 fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
499 let base_rows = self
502 .stats
503 .table_stats(&query.table)
504 .map(|s| s.row_count as f64)
505 .unwrap_or(self.default_row_count);
506
507 let mut estimate = CardinalityEstimate::full_scan(base_rows);
508
509 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
512 let selectivity = self.filter_selectivity(&filter, &query.table);
513 estimate = estimate.with_filter(selectivity);
514 }
515
516 if let Some(limit) = query.limit {
518 estimate.rows = estimate.rows.min(limit as f64);
519 }
520
521 estimate
522 }
523
524 fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
537 match filter {
538 AstFilter::Compare { field, op, value } => {
539 let column = column_name_for_table(field, table);
540 match op {
541 CompareOp::Eq => self.eq_selectivity(table, column, value),
542 CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
543 CompareOp::Lt | CompareOp::Le => {
544 self.range_selectivity(table, column, None, Some(value))
545 }
546 CompareOp::Gt | CompareOp::Ge => {
547 self.range_selectivity(table, column, Some(value), None)
548 }
549 }
550 }
551 AstFilter::Between {
552 field, low, high, ..
553 } => {
554 let column = column_name_for_table(field, table);
555 self.range_selectivity(table, column, Some(low), Some(high))
556 }
557 AstFilter::In { field, values, .. } => {
558 let column = column_name_for_table(field, table);
559 if let Some(c) = column {
563 if let Some(mcv) = self.stats.column_mcv(table, c) {
564 let mut hits: f64 = 0.0;
565 let mut residual_count = 0usize;
566 for v in values {
567 if let Some(cv) = column_value_from(v) {
568 if let Some(freq) = mcv.frequency_of(&cv) {
569 hits += freq;
570 } else {
571 residual_count += 1;
572 }
573 } else {
574 residual_count += 1;
575 }
576 }
577 let total = mcv.total_frequency();
578 let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
579 let non_mcv_distinct =
580 distinct.saturating_sub(mcv.values.len() as u64).max(1);
581 let per_residual = (1.0 - total) / non_mcv_distinct as f64;
582 let estimate = hits + (residual_count as f64) * per_residual;
583 return estimate.clamp(0.0, 1.0).min(0.5);
584 }
585 if let Some(s) = self.stats.index_stats(table, c) {
586 return (s.point_selectivity() * values.len() as f64).min(0.5);
587 }
588 }
589 (values.len() as f64 * 0.01).min(0.5)
590 }
591 AstFilter::Like { .. } => 0.1,
592 AstFilter::StartsWith { .. } => 0.15,
593 AstFilter::EndsWith { .. } => 0.15,
594 AstFilter::Contains { .. } => 0.1,
595 AstFilter::IsNull { .. } => 0.01,
596 AstFilter::IsNotNull { .. } => 0.99,
597 AstFilter::And(left, right) => {
598 self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
599 }
600 AstFilter::Or(left, right) => {
601 let s1 = self.filter_selectivity(left, table);
602 let s2 = self.filter_selectivity(right, table);
603 s1 + s2 - (s1 * s2)
604 }
605 AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
606 AstFilter::CompareFields { .. } => {
607 0.1
611 }
612 AstFilter::CompareExpr { .. } => {
613 0.1
617 }
618 }
619 }
620
621 fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
626 let cardinality = self.estimate_graph_cardinality(query);
627
628 let nodes = query.pattern.nodes.len() as f64;
630 let edges = query.pattern.edges.len() as f64;
631
632 let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
633 let io = cardinality.rows * 0.1; let memory = cardinality.rows * 200.0; PlanCost::new(cpu, io, memory)
637 }
638
639 fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
640 let nodes = query.pattern.nodes.len() as f64;
641 let edges = query.pattern.edges.len() as f64;
642
643 let base_rows = self.default_row_count;
645 let edge_factor = 0.1_f64.powf(edges); let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
648 estimate.confidence = 0.5; if let Some(ref filter) = query.filter {
652 let selectivity = Self::estimate_filter_selectivity(filter);
653 estimate = estimate.with_filter(selectivity);
654 }
655
656 estimate
657 }
658
659 fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
664 let left_cost = self.estimate(&query.left);
665 let right_cost = self.estimate(&query.right);
666
667 let left_card = self.estimate_cardinality(&query.left);
668 let right_card = self.estimate_cardinality(&query.right);
669
670 let build_cpu = left_card.rows * self.hash_probe_cost;
677 let probe_cpu = right_card.rows * self.hash_probe_cost;
678 let join_memory = left_card.rows * 100.0; let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
682 let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
684
685 let after_build = left_cost.combine_blocking(&build_op);
687 after_build
688 .combine_pipelined(&right_cost)
689 .combine_pipelined(&probe_op)
690 }
691
692 fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
693 let left = self.estimate_cardinality(&query.left);
694 let right = self.estimate_cardinality(&query.right);
695
696 let selectivity = match query.join_type {
698 JoinType::Inner => 0.1, JoinType::LeftOuter => 1.0, JoinType::RightOuter => 1.0, JoinType::FullOuter => 1.0, JoinType::Cross => 1.0, };
704
705 CardinalityEstimate::new(
706 left.rows * right.rows * selectivity,
707 left.selectivity * right.selectivity * selectivity,
708 )
709 }
710
711 fn estimate_path(&self, query: &PathQuery) -> PlanCost {
716 let cardinality = self.estimate_path_cardinality(query);
717
718 let max_hops = query.max_length;
720 let branching_factor: f64 = 5.0; let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
723 let cpu = nodes_visited * self.edge_traversal_cost;
724 let io = nodes_visited * 0.1;
725 let memory = nodes_visited * 50.0; PlanCost::new(cpu, io, memory)
728 }
729
730 fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
731 let max_paths = 10.0;
733 CardinalityEstimate::new(max_paths, 0.001)
734 }
735
736 fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
741 let k = query.k as f64;
744
745 let hnsw_cost = 100.0 * (1.0 + k.ln()); let filter_cost =
752 if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
753 50.0
754 } else {
755 0.0
756 };
757
758 let cpu = hnsw_cost + filter_cost;
759 let io = 20.0; let memory = k * 32.0 + 1000.0; PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
767 }
768
769 fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
770 let k = query.k as f64;
772 CardinalityEstimate::new(k, 0.1)
773 }
774
775 fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
780 let structured_cost = self.estimate(&query.structured);
782 let vector_cost = self.estimate_vector(&query.vector);
783
784 let fusion_overhead = match &query.fusion {
786 crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
787 crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
788 crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
789 crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
790 crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
791 crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
792 };
793
794 PlanCost::new(
795 structured_cost.cpu + vector_cost.cpu + fusion_overhead,
796 structured_cost.io + vector_cost.io,
797 structured_cost.memory + vector_cost.memory,
798 )
799 }
800
801 fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
802 let structured_card = self.estimate_cardinality(&query.structured);
803 let vector_card = self.estimate_vector_cardinality(&query.vector);
804
805 let rows = match &query.fusion {
807 crate::storage::query::ast::FusionStrategy::Intersection => {
808 structured_card.rows.min(vector_card.rows)
809 }
810 crate::storage::query::ast::FusionStrategy::Union { .. } => {
811 structured_card.rows + vector_card.rows
812 }
813 _ => vector_card.rows, };
815
816 CardinalityEstimate::new(rows, 0.2)
817 }
818
819 fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
824 match filter {
825 AstFilter::Compare { op, .. } => {
826 match op {
827 CompareOp::Eq => 0.01, CompareOp::Ne => 0.99, CompareOp::Lt | CompareOp::Le => 0.3,
830 CompareOp::Gt | CompareOp::Ge => 0.3,
831 }
832 }
833 AstFilter::Between { .. } => 0.25,
834 AstFilter::In { values, .. } => {
835 (values.len() as f64 * 0.01).min(0.5)
837 }
838 AstFilter::Like { .. } => 0.1,
839 AstFilter::StartsWith { .. } => 0.15,
840 AstFilter::EndsWith { .. } => 0.15,
841 AstFilter::Contains { .. } => 0.1,
842 AstFilter::IsNull { .. } => 0.01,
843 AstFilter::IsNotNull { .. } => 0.99,
844 AstFilter::And(left, right) => {
845 Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
846 }
847 AstFilter::Or(left, right) => {
848 let s1 = Self::estimate_filter_selectivity(left);
849 let s2 = Self::estimate_filter_selectivity(right);
850 s1 + s2 - (s1 * s2) }
852 AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
853 AstFilter::CompareFields { .. } => 0.1,
854 AstFilter::CompareExpr { .. } => 0.1,
855 }
856 }
857}
858
859impl CostEstimator {
860 fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
868 if let Some(col) = column {
869 if let Some(mcv) = self.stats.column_mcv(table, col) {
871 if let Some(cv) = column_value_from(value) {
872 if let Some(freq) = mcv.frequency_of(&cv) {
873 return freq;
874 }
875 let total = mcv.total_frequency();
877 let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
878 let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
879 return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
880 }
881 }
882 if let Some(s) = self.stats.index_stats(table, col) {
884 return s.point_selectivity();
885 }
886 }
887 0.01
889 }
890
891 fn range_selectivity(
902 &self,
903 table: &str,
904 column: Option<&str>,
905 lo: Option<&Value>,
906 hi: Option<&Value>,
907 ) -> f64 {
908 if let Some(col) = column {
909 if let Some(h) = self.stats.column_histogram(table, col) {
911 let lo_cv = lo.and_then(column_value_from);
912 let hi_cv = hi.and_then(column_value_from);
913 return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
914 }
915 if let Some(s) = self.stats.index_stats(table, col) {
917 let cap = if lo.is_some() && hi.is_some() {
918 0.25
919 } else {
920 0.3
921 };
922 return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
923 }
924 }
925 if lo.is_some() && hi.is_some() {
927 0.25
928 } else {
929 0.3
930 }
931 }
932}
933
934impl Default for CostEstimator {
935 fn default() -> Self {
936 Self::new()
937 }
938}
939
940fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
945 use super::histogram::ColumnValue;
946 use crate::storage::schema::Value;
947 match v {
948 Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
949 Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
950 Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
951 Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
952 Value::Email(s)
953 | Value::Url(s)
954 | Value::NodeRef(s)
955 | Value::EdgeRef(s)
956 | Value::TableRef(s)
957 | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
958 Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
959 Value::Duration(d) => Some(ColumnValue::Int(*d)),
960 Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
961 Value::Decimal(d) => Some(ColumnValue::Int(*d)),
962 Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
963 Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
964 Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
965 Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
966 Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
967 Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
968 Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
969 Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
970 Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
971 _ => None,
976 }
977}
978
979fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
984 match filter {
985 AstFilter::Compare { field, .. } => column_name_for_table(field, table),
986 AstFilter::Between { field, .. } => column_name_for_table(field, table),
987 AstFilter::And(l, r) => {
988 first_filter_column(l, table).or_else(|| first_filter_column(r, table))
989 }
990 _ => None,
991 }
992}
993
994fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
996 match field {
997 FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
998 Some(column.as_str())
999 }
1000 _ => None,
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::super::stats_provider::StaticProvider;
1008 use super::*;
1009 use crate::storage::index::{IndexKind, IndexStats};
1010 use crate::storage::query::ast::{FieldRef, Projection};
1011 use crate::storage::schema::Value;
1012
1013 fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1014 AstFilter::Compare {
1015 field: FieldRef::column(table, column),
1016 op: CompareOp::Eq,
1017 value: Value::Integer(value),
1018 }
1019 }
1020
1021 fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1022 TableQuery {
1023 table: name.to_string(),
1024 source: None,
1025 alias: None,
1026 select_items: Vec::new(),
1027 columns: vec![Projection::All],
1028 where_expr: None,
1029 filter,
1030 group_by_exprs: Vec::new(),
1031 group_by: Vec::new(),
1032 having_expr: None,
1033 having: None,
1034 order_by: vec![],
1035 limit: None,
1036 limit_param: None,
1037 offset: None,
1038 offset_param: None,
1039 expand: None,
1040 as_of: None,
1041 }
1042 }
1043
1044 #[test]
1045 fn injected_row_count_overrides_default() {
1046 let provider = Arc::new(StaticProvider::new().with_table(
1047 "users",
1048 TableStats {
1049 row_count: 50_000,
1050 avg_row_size: 256,
1051 page_count: 500,
1052 columns: vec![],
1053 },
1054 ));
1055 let estimator = CostEstimator::with_stats(provider);
1056 let q = table_query("users", None);
1057 let card = estimator.estimate_table_cardinality(&q);
1058 assert_eq!(card.rows, 50_000.0);
1060 }
1061
1062 #[test]
1063 fn stats_aware_eq_selectivity_beats_default() {
1064 let provider = Arc::new(
1065 StaticProvider::new()
1066 .with_table(
1067 "users",
1068 TableStats {
1069 row_count: 1_000_000,
1070 avg_row_size: 256,
1071 page_count: 10_000,
1072 columns: vec![],
1073 },
1074 )
1075 .with_index(
1076 "users",
1077 "email",
1078 IndexStats {
1079 entries: 1_000_000,
1080 distinct_keys: 1_000_000,
1081 approx_bytes: 0,
1082 kind: IndexKind::Hash,
1083 has_bloom: true,
1084 index_correlation: 0.0,
1085 },
1086 ),
1087 );
1088 let estimator = CostEstimator::with_stats(provider);
1089 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1090 let card = estimator.estimate_table_cardinality(&q);
1091 assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1093 }
1094
1095 #[test]
1096 fn fallback_when_no_index_stats() {
1097 let provider = Arc::new(StaticProvider::new().with_table(
1098 "users",
1099 TableStats {
1100 row_count: 1_000_000,
1101 avg_row_size: 256,
1102 page_count: 10_000,
1103 columns: vec![],
1104 },
1105 ));
1106 let estimator = CostEstimator::with_stats(provider);
1107 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1108 let card = estimator.estimate_table_cardinality(&q);
1109 assert!((card.rows - 10_000.0).abs() < 1.0);
1111 }
1112
1113 #[test]
1114 fn null_provider_keeps_legacy_behaviour() {
1115 let estimator = CostEstimator::new();
1116 let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1117 let card = estimator.estimate_table_cardinality(&q);
1118 assert!((card.rows - 10.0).abs() < 1.0);
1120 }
1121
1122 #[test]
1123 fn and_combines_stats_selectivities() {
1124 let provider = Arc::new(
1125 StaticProvider::new()
1126 .with_table(
1127 "t",
1128 TableStats {
1129 row_count: 100_000,
1130 avg_row_size: 64,
1131 page_count: 100,
1132 columns: vec![],
1133 },
1134 )
1135 .with_index(
1136 "t",
1137 "a",
1138 IndexStats {
1139 entries: 100_000,
1140 distinct_keys: 10,
1141 approx_bytes: 0,
1142 kind: IndexKind::BTree,
1143 has_bloom: false,
1144 index_correlation: 0.0,
1145 },
1146 )
1147 .with_index(
1148 "t",
1149 "b",
1150 IndexStats {
1151 entries: 100_000,
1152 distinct_keys: 1000,
1153 approx_bytes: 0,
1154 kind: IndexKind::BTree,
1155 has_bloom: false,
1156 index_correlation: 0.0,
1157 },
1158 ),
1159 );
1160 let estimator = CostEstimator::with_stats(provider);
1161 let filter = AstFilter::And(
1162 Box::new(eq_filter("t", "a", 1)),
1163 Box::new(eq_filter("t", "b", 1)),
1164 );
1165 let q = table_query("t", Some(filter));
1166 let card = estimator.estimate_table_cardinality(&q);
1167 assert!(card.rows < 15.0, "got {}", card.rows);
1169 }
1170
1171 #[test]
1172 fn test_table_cost_estimation() {
1173 let estimator = CostEstimator::new();
1174
1175 let query = QueryExpr::Table(TableQuery {
1176 table: "hosts".to_string(),
1177 source: None,
1178 alias: None,
1179 select_items: Vec::new(),
1180 columns: vec![Projection::All],
1181 where_expr: None,
1182 filter: None,
1183 group_by_exprs: Vec::new(),
1184 group_by: Vec::new(),
1185 having_expr: None,
1186 having: None,
1187 order_by: vec![],
1188 limit: None,
1189 limit_param: None,
1190 offset: None,
1191 offset_param: None,
1192 expand: None,
1193 as_of: None,
1194 });
1195
1196 let cost = estimator.estimate(&query);
1197 assert!(cost.cpu > 0.0);
1198 assert!(cost.total > 0.0);
1199 }
1200
1201 #[test]
1202 fn test_filter_selectivity() {
1203 let estimator = CostEstimator::new();
1204
1205 let eq_filter = AstFilter::Compare {
1206 field: FieldRef::column("hosts", "id"),
1207 op: CompareOp::Eq,
1208 value: Value::Integer(1),
1209 };
1210 assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1211
1212 let ne_filter = AstFilter::Compare {
1213 field: FieldRef::column("hosts", "id"),
1214 op: CompareOp::Ne,
1215 value: Value::Integer(1),
1216 };
1217 assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1218 }
1219
1220 #[test]
1221 fn test_and_selectivity() {
1222 let estimator = CostEstimator::new();
1223
1224 let and_filter = AstFilter::And(
1225 Box::new(AstFilter::Compare {
1226 field: FieldRef::column("hosts", "a"),
1227 op: CompareOp::Eq,
1228 value: Value::Integer(1),
1229 }),
1230 Box::new(AstFilter::Compare {
1231 field: FieldRef::column("hosts", "b"),
1232 op: CompareOp::Eq,
1233 value: Value::Integer(2),
1234 }),
1235 );
1236
1237 let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1238 assert!(selectivity < 0.01); }
1240
1241 #[test]
1242 fn test_cardinality_with_limit() {
1243 let estimator = CostEstimator::new();
1244
1245 let query = TableQuery {
1246 table: "hosts".to_string(),
1247 source: None,
1248 alias: None,
1249 select_items: Vec::new(),
1250 columns: vec![Projection::All],
1251 where_expr: None,
1252 filter: None,
1253 group_by_exprs: Vec::new(),
1254 group_by: Vec::new(),
1255 having_expr: None,
1256 having: None,
1257 order_by: vec![],
1258 limit: Some(10),
1259 limit_param: None,
1260 offset: None,
1261 offset_param: None,
1262 expand: None,
1263 as_of: None,
1264 };
1265
1266 let card = estimator.estimate_table_cardinality(&query);
1267 assert!(card.rows <= 10.0);
1268 }
1269
1270 #[test]
1275 fn startup_zero_for_full_scan() {
1276 let estimator = CostEstimator::new();
1280 let q = table_query("any_table", None);
1281 let cost = estimator.estimate(&QueryExpr::Table(q));
1282 assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1283 assert!(cost.total > 0.0);
1284 }
1285
1286 #[test]
1287 fn startup_nonzero_for_blocking_combine() {
1288 let input = PlanCost::new(100.0, 10.0, 50.0); let blocker = PlanCost::new(20.0, 0.0, 10.0); let composed = input.combine_blocking(&blocker);
1293 assert_eq!(composed.startup_cost, input.total);
1295 assert_eq!(composed.total, input.total + blocker.total);
1297 assert!(composed.startup_cost > 0.0);
1298 }
1299
1300 #[test]
1301 fn pipelined_combine_adds_startup_directly() {
1302 let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1303 let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1304 let composed = upstream.combine_pipelined(&downstream);
1305 assert_eq!(composed.startup_cost, 30.0 + 5.0);
1306 assert_eq!(composed.total, upstream.total + downstream.total);
1307 }
1308
1309 #[test]
1310 fn cost_prefers_low_startup_when_limit_small() {
1311 let fast_first = PlanCost {
1314 cpu: 100.0,
1315 io: 10.0,
1316 network: 0.0,
1317 memory: 50.0,
1318 startup_cost: 5.0,
1319 total: 200.0,
1320 };
1321 let slow_first = PlanCost {
1322 cpu: 100.0,
1323 io: 10.0,
1324 network: 0.0,
1325 memory: 50.0,
1326 startup_cost: 150.0,
1327 total: 200.0,
1328 };
1329 assert_eq!(
1331 fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1332 std::cmp::Ordering::Less
1333 );
1334 }
1335
1336 #[test]
1337 fn cost_prefers_low_total_when_no_limit() {
1338 let low_total = PlanCost {
1340 cpu: 50.0,
1341 io: 5.0,
1342 network: 0.0,
1343 memory: 0.0,
1344 startup_cost: 30.0,
1345 total: 100.0,
1346 };
1347 let high_total = PlanCost {
1348 cpu: 100.0,
1349 io: 10.0,
1350 network: 0.0,
1351 memory: 0.0,
1352 startup_cost: 5.0,
1353 total: 200.0,
1354 };
1355 assert_eq!(
1356 low_total.prefer_over(&high_total, None, 10_000.0),
1357 std::cmp::Ordering::Less
1358 );
1359 }
1360
1361 #[test]
1362 fn limit_threshold_falls_back_to_total_when_limit_large() {
1363 let low_total = PlanCost {
1365 cpu: 50.0,
1366 io: 5.0,
1367 network: 0.0,
1368 memory: 0.0,
1369 startup_cost: 30.0,
1370 total: 100.0,
1371 };
1372 let low_startup = PlanCost {
1373 cpu: 100.0,
1374 io: 10.0,
1375 network: 0.0,
1376 memory: 0.0,
1377 startup_cost: 5.0,
1378 total: 200.0,
1379 };
1380 assert_eq!(
1381 low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1382 std::cmp::Ordering::Less
1383 );
1384 }
1385
1386 #[test]
1387 fn hash_join_startup_includes_build_cost() {
1388 let left = PlanCost::new(80.0, 8.0, 100.0); let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); let after_build = left.combine_blocking(&build);
1394 assert!(
1395 after_build.startup_cost >= left.total,
1396 "after-build startup ({}) must absorb left.total ({})",
1397 after_build.startup_cost,
1398 left.total
1399 );
1400 assert!(after_build.total >= after_build.startup_cost);
1401 }
1402
1403 #[test]
1404 fn vector_search_reports_nonzero_startup() {
1405 let estimator = CostEstimator::new();
1409 let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1412 assert!(v.startup_cost > 0.0);
1413 assert!(v.startup_cost < v.total);
1414 let _ = estimator; }
1416
1417 #[test]
1418 fn with_startup_clamps_total_below_startup() {
1419 let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1421 assert!(cost.total >= cost.startup_cost);
1422 }
1423
1424 #[test]
1425 fn default_plancost_has_zero_startup() {
1426 let c = PlanCost::default();
1427 assert_eq!(c.startup_cost, 0.0);
1428 assert_eq!(c.total, 0.0);
1429 }
1430
1431 use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1436
1437 fn provider_with_skew() -> Arc<StaticProvider> {
1438 let mut sample: Vec<ColumnValue> = Vec::new();
1442 for i in 0..80 {
1443 sample.push(ColumnValue::Int(i % 10));
1444 }
1445 for i in 0..20 {
1446 sample.push(ColumnValue::Int(10 + i * 50));
1447 }
1448 let h = Histogram::equi_depth_from_sample(sample, 10);
1449
1450 let mcv = MostCommonValues::new(vec![
1451 (ColumnValue::Text("boss".to_string()), 0.5),
1452 (ColumnValue::Text("intern".to_string()), 0.05),
1453 ]);
1454
1455 Arc::new(
1456 StaticProvider::new()
1457 .with_table(
1458 "people",
1459 TableStats {
1460 row_count: 100_000,
1461 avg_row_size: 64,
1462 page_count: 100,
1463 columns: vec![],
1464 },
1465 )
1466 .with_histogram("people", "score", h)
1467 .with_mcv("people", "role", mcv),
1468 )
1469 }
1470
1471 #[test]
1472 fn eq_uses_mcv_when_value_is_tracked() {
1473 let provider = provider_with_skew();
1474 let estimator = CostEstimator::with_stats(provider);
1475 let filter = AstFilter::Compare {
1476 field: FieldRef::column("people", "role"),
1477 op: CompareOp::Eq,
1478 value: Value::text("boss".to_string()),
1479 };
1480 let s = estimator.filter_selectivity(&filter, "people");
1483 assert!(
1484 (s - 0.5).abs() < 1e-9,
1485 "MCV-tracked equality should report exact frequency, got {s}"
1486 );
1487 }
1488
1489 #[test]
1490 fn eq_uses_residual_for_non_mcv_value() {
1491 let provider = provider_with_skew();
1492 let estimator = CostEstimator::with_stats(provider);
1493 let filter = AstFilter::Compare {
1494 field: FieldRef::column("people", "role"),
1495 op: CompareOp::Eq,
1496 value: Value::text("staff".to_string()),
1497 };
1498 let s = estimator.filter_selectivity(&filter, "people");
1502 assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1503 }
1504
1505 #[test]
1506 fn ne_is_one_minus_eq_under_mcv() {
1507 let provider = provider_with_skew();
1508 let estimator = CostEstimator::with_stats(provider);
1509 let filter = AstFilter::Compare {
1510 field: FieldRef::column("people", "role"),
1511 op: CompareOp::Ne,
1512 value: Value::text("boss".to_string()),
1513 };
1514 let s = estimator.filter_selectivity(&filter, "people");
1515 assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1517 }
1518
1519 #[test]
1520 fn range_uses_histogram_when_present() {
1521 let provider = provider_with_skew();
1522 let estimator = CostEstimator::with_stats(provider);
1523 let filter = AstFilter::Compare {
1524 field: FieldRef::column("people", "score"),
1525 op: CompareOp::Le,
1526 value: Value::Integer(9),
1527 };
1528 let s = estimator.filter_selectivity(&filter, "people");
1531 assert!(
1532 s > 0.5,
1533 "histogram-based range selectivity should beat 0.3, got {s}"
1534 );
1535 }
1536
1537 #[test]
1538 fn between_uses_histogram() {
1539 let provider = provider_with_skew();
1540 let estimator = CostEstimator::with_stats(provider);
1541 let filter = AstFilter::Between {
1542 field: FieldRef::column("people", "score"),
1543 low: Value::Integer(0),
1544 high: Value::Integer(9),
1545 };
1546 let s = estimator.filter_selectivity(&filter, "people");
1547 assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1548 }
1549
1550 #[test]
1551 fn graceful_fallback_when_histogram_absent() {
1552 let provider = Arc::new(StaticProvider::new().with_table(
1555 "people",
1556 TableStats {
1557 row_count: 1000,
1558 avg_row_size: 64,
1559 page_count: 10,
1560 columns: vec![],
1561 },
1562 ));
1563 let estimator = CostEstimator::with_stats(provider);
1564 let filter = AstFilter::Compare {
1565 field: FieldRef::column("people", "unknown_col"),
1566 op: CompareOp::Lt,
1567 value: Value::Integer(50),
1568 };
1569 let s = estimator.filter_selectivity(&filter, "people");
1570 assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1571 }
1572}