1use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16 CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17 PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24 pub rows: f64,
26 pub selectivity: f64,
28 pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33 pub fn new(rows: f64, selectivity: f64) -> Self {
35 Self {
36 rows,
37 selectivity,
38 confidence: 1.0,
39 }
40 }
41
42 pub fn full_scan(table_size: f64) -> Self {
44 Self {
45 rows: table_size,
46 selectivity: 1.0,
47 confidence: 1.0,
48 }
49 }
50
51 pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53 self.rows *= filter_selectivity;
54 self.selectivity *= filter_selectivity;
55 self.confidence *= 0.9; self
57 }
58}
59
60#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71 pub cpu: f64,
73 pub io: f64,
75 pub network: f64,
77 pub memory: f64,
79 pub startup_cost: f64,
85 pub total: f64,
87}
88
89impl PlanCost {
90 pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92 let total = cpu + io * 10.0 + memory * 0.1; Self {
94 cpu,
95 io,
96 network: 0.0,
97 memory,
98 startup_cost: 0.0,
99 total,
100 }
101 }
102
103 pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107 let total = cpu + io * 10.0 + memory * 0.1;
108 Self {
109 cpu,
110 io,
111 network: 0.0,
112 memory,
113 startup_cost: startup_cost.max(0.0),
114 total: total.max(startup_cost),
115 }
116 }
117
118 pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124 PlanCost {
125 cpu: self.cpu + other.cpu,
126 io: self.io + other.io,
127 network: self.network + other.network,
128 memory: self.memory.max(other.memory),
129 startup_cost: self.startup_cost + other.startup_cost,
130 total: self.total + other.total,
131 }
132 }
133
134 pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141 PlanCost {
142 cpu: self.cpu + blocker.cpu,
143 io: self.io + blocker.io,
144 network: self.network + blocker.network,
145 memory: self.memory.max(blocker.memory),
146 startup_cost: self.total + blocker.startup_cost,
147 total: self.total + blocker.total,
148 }
149 }
150
151 pub fn combine(&self, other: &PlanCost) -> PlanCost {
156 self.combine_pipelined(other)
157 }
158
159 pub fn scale(&self, factor: f64) -> PlanCost {
161 PlanCost {
162 cpu: self.cpu * factor,
163 io: self.io * factor,
164 network: self.network * factor,
165 memory: self.memory, startup_cost: self.startup_cost, total: self.total * factor,
168 }
169 }
170
171 pub fn prefer_over(
183 &self,
184 other: &PlanCost,
185 limit: Option<u64>,
186 cardinality: f64,
187 ) -> std::cmp::Ordering {
188 let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189 let (lhs, rhs) = if use_startup {
190 (self.startup_cost, other.startup_cost)
191 } else {
192 (self.total, other.total)
193 };
194 lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195 }
196}
197
198#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201 pub row_count: u64,
203 pub avg_row_size: u32,
205 pub page_count: u64,
207 pub columns: Vec<ColumnStats>,
209}
210
211#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214 pub name: String,
216 pub distinct_count: u64,
218 pub null_count: u64,
220 pub min_value: Option<String>,
222 pub max_value: Option<String>,
224 pub has_index: bool,
226}
227
228pub struct CostEstimator {
230 default_row_count: f64,
232 row_scan_cost: f64,
234 index_lookup_cost: f64,
236 hash_probe_cost: f64,
238 nested_loop_cost: f64,
240 edge_traversal_cost: f64,
242 stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250 pub fn new() -> Self {
253 Self {
254 default_row_count: 1000.0,
255 row_scan_cost: 1.0,
256 index_lookup_cost: 0.1,
257 hash_probe_cost: 0.5,
258 nested_loop_cost: 2.0,
259 edge_traversal_cost: 1.5,
260 stats: Arc::new(NullProvider),
261 }
262 }
263
264 pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268 Self {
269 stats: provider,
270 ..Self::new()
271 }
272 }
273
274 pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278 self.stats = provider;
279 }
280
281 pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283 match query {
284 QueryExpr::Table(tq) => self.estimate_table(tq),
285 QueryExpr::Graph(gq) => self.estimate_graph(gq),
286 QueryExpr::Join(jq) => self.estimate_join(jq),
287 QueryExpr::Path(pq) => self.estimate_path(pq),
288 QueryExpr::Vector(vq) => self.estimate_vector(vq),
289 QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290 QueryExpr::Insert(_)
292 | QueryExpr::Update(_)
293 | QueryExpr::Delete(_)
294 | QueryExpr::CreateTable(_)
295 | QueryExpr::CreateCollection(_)
296 | QueryExpr::CreateVector(_)
297 | QueryExpr::DropTable(_)
298 | QueryExpr::DropGraph(_)
299 | QueryExpr::DropVector(_)
300 | QueryExpr::DropDocument(_)
301 | QueryExpr::DropKv(_)
302 | QueryExpr::DropCollection(_)
303 | QueryExpr::Truncate(_)
304 | QueryExpr::AlterTable(_)
305 | QueryExpr::GraphCommand(_)
306 | QueryExpr::SearchCommand(_)
307 | QueryExpr::CreateIndex(_)
308 | QueryExpr::DropIndex(_)
309 | QueryExpr::ProbabilisticCommand(_)
310 | QueryExpr::Ask(_)
311 | QueryExpr::SetConfig { .. }
312 | QueryExpr::ShowConfig { .. }
313 | QueryExpr::SetSecret { .. }
314 | QueryExpr::DeleteSecret { .. }
315 | QueryExpr::ShowSecrets { .. }
316 | QueryExpr::SetTenant(_)
317 | QueryExpr::ShowTenant
318 | QueryExpr::CreateTimeSeries(_)
319 | QueryExpr::CreateMetric(_)
320 | QueryExpr::AlterMetric(_)
321 | QueryExpr::CreateSlo(_)
322 | QueryExpr::DropTimeSeries(_)
323 | QueryExpr::CreateQueue(_)
324 | QueryExpr::AlterQueue(_)
325 | QueryExpr::DropQueue(_)
326 | QueryExpr::QueueSelect(_)
327 | QueryExpr::QueueCommand(_)
328 | QueryExpr::KvCommand(_)
329 | QueryExpr::ConfigCommand(_)
330 | QueryExpr::CreateTree(_)
331 | QueryExpr::DropTree(_)
332 | QueryExpr::TreeCommand(_)
333 | QueryExpr::ExplainAlter(_)
334 | QueryExpr::TransactionControl(_)
335 | QueryExpr::MaintenanceCommand(_)
336 | QueryExpr::CreateSchema(_)
337 | QueryExpr::DropSchema(_)
338 | QueryExpr::CreateSequence(_)
339 | QueryExpr::DropSequence(_)
340 | QueryExpr::CopyFrom(_)
341 | QueryExpr::CreateView(_)
342 | QueryExpr::DropView(_)
343 | QueryExpr::RefreshMaterializedView(_)
344 | QueryExpr::CreatePolicy(_)
345 | QueryExpr::DropPolicy(_)
346 | QueryExpr::CreateServer(_)
347 | QueryExpr::DropServer(_)
348 | QueryExpr::CreateForeignTable(_)
349 | QueryExpr::DropForeignTable(_)
350 | QueryExpr::Grant(_)
351 | QueryExpr::Revoke(_)
352 | QueryExpr::AlterUser(_)
353 | QueryExpr::CreateIamPolicy { .. }
354 | QueryExpr::DropIamPolicy { .. }
355 | QueryExpr::AttachPolicy { .. }
356 | QueryExpr::DetachPolicy { .. }
357 | QueryExpr::ShowPolicies { .. }
358 | QueryExpr::ShowEffectivePermissions { .. }
359 | QueryExpr::RankOf(_)
360 | QueryExpr::ApproxRankOf(_)
361 | QueryExpr::RankRange(_)
362 | QueryExpr::SimulatePolicy { .. }
363 | QueryExpr::LintPolicy { .. }
364 | QueryExpr::MigratePolicyMode { .. }
365 | QueryExpr::CreateMigration(_)
366 | QueryExpr::ApplyMigration(_)
367 | QueryExpr::RollbackMigration(_)
368 | QueryExpr::ExplainMigration(_)
369 | QueryExpr::EventsBackfill(_)
370 | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
371 }
372 }
373
374 pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
376 match query {
377 QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
378 QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
379 QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
380 QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
381 QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
382 QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
383 QueryExpr::Insert(_)
385 | QueryExpr::Update(_)
386 | QueryExpr::Delete(_)
387 | QueryExpr::CreateTable(_)
388 | QueryExpr::CreateCollection(_)
389 | QueryExpr::CreateVector(_)
390 | QueryExpr::DropTable(_)
391 | QueryExpr::DropGraph(_)
392 | QueryExpr::DropVector(_)
393 | QueryExpr::DropDocument(_)
394 | QueryExpr::DropKv(_)
395 | QueryExpr::DropCollection(_)
396 | QueryExpr::Truncate(_)
397 | QueryExpr::AlterTable(_)
398 | QueryExpr::GraphCommand(_)
399 | QueryExpr::SearchCommand(_)
400 | QueryExpr::CreateIndex(_)
401 | QueryExpr::DropIndex(_)
402 | QueryExpr::ProbabilisticCommand(_)
403 | QueryExpr::Ask(_)
404 | QueryExpr::SetConfig { .. }
405 | QueryExpr::ShowConfig { .. }
406 | QueryExpr::SetSecret { .. }
407 | QueryExpr::DeleteSecret { .. }
408 | QueryExpr::ShowSecrets { .. }
409 | QueryExpr::SetTenant(_)
410 | QueryExpr::ShowTenant
411 | QueryExpr::CreateTimeSeries(_)
412 | QueryExpr::CreateMetric(_)
413 | QueryExpr::AlterMetric(_)
414 | QueryExpr::CreateSlo(_)
415 | QueryExpr::DropTimeSeries(_)
416 | QueryExpr::CreateQueue(_)
417 | QueryExpr::AlterQueue(_)
418 | QueryExpr::DropQueue(_)
419 | QueryExpr::QueueSelect(_)
420 | QueryExpr::QueueCommand(_)
421 | QueryExpr::KvCommand(_)
422 | QueryExpr::ConfigCommand(_)
423 | QueryExpr::CreateTree(_)
424 | QueryExpr::DropTree(_)
425 | QueryExpr::TreeCommand(_)
426 | QueryExpr::ExplainAlter(_)
427 | QueryExpr::TransactionControl(_)
428 | QueryExpr::MaintenanceCommand(_)
429 | QueryExpr::CreateSchema(_)
430 | QueryExpr::DropSchema(_)
431 | QueryExpr::CreateSequence(_)
432 | QueryExpr::DropSequence(_)
433 | QueryExpr::CopyFrom(_)
434 | QueryExpr::CreateView(_)
435 | QueryExpr::DropView(_)
436 | QueryExpr::RefreshMaterializedView(_)
437 | QueryExpr::CreatePolicy(_)
438 | QueryExpr::DropPolicy(_)
439 | QueryExpr::CreateServer(_)
440 | QueryExpr::DropServer(_)
441 | QueryExpr::CreateForeignTable(_)
442 | QueryExpr::DropForeignTable(_)
443 | QueryExpr::Grant(_)
444 | QueryExpr::Revoke(_)
445 | QueryExpr::AlterUser(_)
446 | QueryExpr::CreateIamPolicy { .. }
447 | QueryExpr::DropIamPolicy { .. }
448 | QueryExpr::AttachPolicy { .. }
449 | QueryExpr::DetachPolicy { .. }
450 | QueryExpr::ShowPolicies { .. }
451 | QueryExpr::ShowEffectivePermissions { .. }
452 | QueryExpr::RankOf(_)
453 | QueryExpr::ApproxRankOf(_)
454 | QueryExpr::RankRange(_)
455 | QueryExpr::SimulatePolicy { .. }
456 | QueryExpr::LintPolicy { .. }
457 | QueryExpr::MigratePolicyMode { .. }
458 | QueryExpr::CreateMigration(_)
459 | QueryExpr::ApplyMigration(_)
460 | QueryExpr::RollbackMigration(_)
461 | QueryExpr::ExplainMigration(_)
462 | QueryExpr::EventsBackfill(_)
463 | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
464 }
465 }
466
467 fn estimate_table(&self, query: &TableQuery) -> PlanCost {
472 let cardinality = self.estimate_table_cardinality(query);
473
474 let cpu = cardinality.rows * self.row_scan_cost;
475
476 let io = self.estimate_table_io(query, cardinality.rows);
479
480 let memory = cardinality.rows * 100.0; PlanCost::new(cpu, io, memory)
483 }
484
485 fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
492 const ROWS_PER_PAGE: f64 = 100.0;
493
494 let table_stats = self.stats.table_stats(&query.table);
496 let heap_pages = table_stats
497 .map(|s| s.page_count as f64)
498 .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
499
500 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
503 if let Some(col) = first_filter_column(&filter, &query.table) {
504 if let Some(idx) = self.stats.index_stats(&query.table, col) {
505 return idx.correlated_io_cost(result_rows, heap_pages);
506 }
507 }
508 }
509
510 (result_rows / ROWS_PER_PAGE).ceil()
512 }
513
514 fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
515 let base_rows = self
518 .stats
519 .table_stats(&query.table)
520 .map(|s| s.row_count as f64)
521 .unwrap_or(self.default_row_count);
522
523 let mut estimate = CardinalityEstimate::full_scan(base_rows);
524
525 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
528 let selectivity = self.filter_selectivity(&filter, &query.table);
529 estimate = estimate.with_filter(selectivity);
530 }
531
532 if let Some(limit) = query.limit {
534 estimate.rows = estimate.rows.min(limit as f64);
535 }
536
537 estimate
538 }
539
540 fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
553 match filter {
554 AstFilter::Compare { field, op, value } => {
555 let column = column_name_for_table(field, table);
556 match op {
557 CompareOp::Eq => self.eq_selectivity(table, column, value),
558 CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
559 CompareOp::Lt | CompareOp::Le => {
560 self.range_selectivity(table, column, None, Some(value))
561 }
562 CompareOp::Gt | CompareOp::Ge => {
563 self.range_selectivity(table, column, Some(value), None)
564 }
565 }
566 }
567 AstFilter::Between {
568 field, low, high, ..
569 } => {
570 let column = column_name_for_table(field, table);
571 self.range_selectivity(table, column, Some(low), Some(high))
572 }
573 AstFilter::In { field, values, .. } => {
574 let column = column_name_for_table(field, table);
575 if let Some(c) = column {
579 if let Some(mcv) = self.stats.column_mcv(table, c) {
580 let mut hits: f64 = 0.0;
581 let mut residual_count = 0usize;
582 for v in values {
583 if let Some(cv) = column_value_from(v) {
584 if let Some(freq) = mcv.frequency_of(&cv) {
585 hits += freq;
586 } else {
587 residual_count += 1;
588 }
589 } else {
590 residual_count += 1;
591 }
592 }
593 let total = mcv.total_frequency();
594 let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
595 let non_mcv_distinct =
596 distinct.saturating_sub(mcv.values.len() as u64).max(1);
597 let per_residual = (1.0 - total) / non_mcv_distinct as f64;
598 let estimate = hits + (residual_count as f64) * per_residual;
599 return estimate.clamp(0.0, 1.0).min(0.5);
600 }
601 if let Some(s) = self.stats.index_stats(table, c) {
602 return (s.point_selectivity() * values.len() as f64).min(0.5);
603 }
604 }
605 (values.len() as f64 * 0.01).min(0.5)
606 }
607 AstFilter::Like { .. } => 0.1,
608 AstFilter::StartsWith { .. } => 0.15,
609 AstFilter::EndsWith { .. } => 0.15,
610 AstFilter::Contains { .. } => 0.1,
611 AstFilter::IsNull { .. } => 0.01,
612 AstFilter::IsNotNull { .. } => 0.99,
613 AstFilter::And(left, right) => {
614 self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
615 }
616 AstFilter::Or(left, right) => {
617 let s1 = self.filter_selectivity(left, table);
618 let s2 = self.filter_selectivity(right, table);
619 s1 + s2 - (s1 * s2)
620 }
621 AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
622 AstFilter::CompareFields { .. } => {
623 0.1
627 }
628 AstFilter::CompareExpr { .. } => {
629 0.1
633 }
634 }
635 }
636
637 fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
642 let cardinality = self.estimate_graph_cardinality(query);
643
644 let nodes = query.pattern.nodes.len() as f64;
646 let edges = query.pattern.edges.len() as f64;
647
648 let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
649 let io = cardinality.rows * 0.1; let memory = cardinality.rows * 200.0; PlanCost::new(cpu, io, memory)
653 }
654
655 fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
656 let nodes = query.pattern.nodes.len() as f64;
657 let edges = query.pattern.edges.len() as f64;
658
659 let base_rows = self.default_row_count;
661 let edge_factor = 0.1_f64.powf(edges); let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
664 estimate.confidence = 0.5; if let Some(ref filter) = query.filter {
668 let selectivity = Self::estimate_filter_selectivity(filter);
669 estimate = estimate.with_filter(selectivity);
670 }
671
672 estimate
673 }
674
675 fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
680 let left_cost = self.estimate(&query.left);
681 let right_cost = self.estimate(&query.right);
682
683 let left_card = self.estimate_cardinality(&query.left);
684 let right_card = self.estimate_cardinality(&query.right);
685
686 let build_cpu = left_card.rows * self.hash_probe_cost;
693 let probe_cpu = right_card.rows * self.hash_probe_cost;
694 let join_memory = left_card.rows * 100.0; let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
698 let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
700
701 let after_build = left_cost.combine_blocking(&build_op);
703 after_build
704 .combine_pipelined(&right_cost)
705 .combine_pipelined(&probe_op)
706 }
707
708 fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
709 let left = self.estimate_cardinality(&query.left);
710 let right = self.estimate_cardinality(&query.right);
711
712 let selectivity = match query.join_type {
714 JoinType::Inner => 0.1, JoinType::LeftOuter => 1.0, JoinType::RightOuter => 1.0, JoinType::FullOuter => 1.0, JoinType::Cross => 1.0, };
720
721 CardinalityEstimate::new(
722 left.rows * right.rows * selectivity,
723 left.selectivity * right.selectivity * selectivity,
724 )
725 }
726
727 fn estimate_path(&self, query: &PathQuery) -> PlanCost {
732 let cardinality = self.estimate_path_cardinality(query);
733
734 let max_hops = query.max_length;
736 let branching_factor: f64 = 5.0; let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
739 let cpu = nodes_visited * self.edge_traversal_cost;
740 let io = nodes_visited * 0.1;
741 let memory = nodes_visited * 50.0; PlanCost::new(cpu, io, memory)
744 }
745
746 fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
747 let max_paths = 10.0;
749 CardinalityEstimate::new(max_paths, 0.001)
750 }
751
752 fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
757 let k = query.k as f64;
760
761 let hnsw_cost = 100.0 * (1.0 + k.ln()); let filter_cost =
768 if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
769 50.0
770 } else {
771 0.0
772 };
773
774 let cpu = hnsw_cost + filter_cost;
775 let io = 20.0; let memory = k * 32.0 + 1000.0; PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
783 }
784
785 fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
786 let k = query.k as f64;
788 CardinalityEstimate::new(k, 0.1)
789 }
790
791 fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
796 let structured_cost = self.estimate(&query.structured);
798 let vector_cost = self.estimate_vector(&query.vector);
799
800 let fusion_overhead = match &query.fusion {
802 crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
803 crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
804 crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
805 crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
806 crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
807 crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
808 };
809
810 PlanCost::new(
811 structured_cost.cpu + vector_cost.cpu + fusion_overhead,
812 structured_cost.io + vector_cost.io,
813 structured_cost.memory + vector_cost.memory,
814 )
815 }
816
817 fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
818 let structured_card = self.estimate_cardinality(&query.structured);
819 let vector_card = self.estimate_vector_cardinality(&query.vector);
820
821 let rows = match &query.fusion {
823 crate::storage::query::ast::FusionStrategy::Intersection => {
824 structured_card.rows.min(vector_card.rows)
825 }
826 crate::storage::query::ast::FusionStrategy::Union { .. } => {
827 structured_card.rows + vector_card.rows
828 }
829 _ => vector_card.rows, };
831
832 CardinalityEstimate::new(rows, 0.2)
833 }
834
835 fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
840 match filter {
841 AstFilter::Compare { op, .. } => {
842 match op {
843 CompareOp::Eq => 0.01, CompareOp::Ne => 0.99, CompareOp::Lt | CompareOp::Le => 0.3,
846 CompareOp::Gt | CompareOp::Ge => 0.3,
847 }
848 }
849 AstFilter::Between { .. } => 0.25,
850 AstFilter::In { values, .. } => {
851 (values.len() as f64 * 0.01).min(0.5)
853 }
854 AstFilter::Like { .. } => 0.1,
855 AstFilter::StartsWith { .. } => 0.15,
856 AstFilter::EndsWith { .. } => 0.15,
857 AstFilter::Contains { .. } => 0.1,
858 AstFilter::IsNull { .. } => 0.01,
859 AstFilter::IsNotNull { .. } => 0.99,
860 AstFilter::And(left, right) => {
861 Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
862 }
863 AstFilter::Or(left, right) => {
864 let s1 = Self::estimate_filter_selectivity(left);
865 let s2 = Self::estimate_filter_selectivity(right);
866 s1 + s2 - (s1 * s2) }
868 AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
869 AstFilter::CompareFields { .. } => 0.1,
870 AstFilter::CompareExpr { .. } => 0.1,
871 }
872 }
873}
874
875impl CostEstimator {
876 fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
884 if let Some(col) = column {
885 if let Some(mcv) = self.stats.column_mcv(table, col) {
887 if let Some(cv) = column_value_from(value) {
888 if let Some(freq) = mcv.frequency_of(&cv) {
889 return freq;
890 }
891 let total = mcv.total_frequency();
893 let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
894 let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
895 return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
896 }
897 }
898 if let Some(s) = self.stats.index_stats(table, col) {
900 return s.point_selectivity();
901 }
902 }
903 0.01
905 }
906
907 fn range_selectivity(
918 &self,
919 table: &str,
920 column: Option<&str>,
921 lo: Option<&Value>,
922 hi: Option<&Value>,
923 ) -> f64 {
924 if let Some(col) = column {
925 if let Some(h) = self.stats.column_histogram(table, col) {
927 let lo_cv = lo.and_then(column_value_from);
928 let hi_cv = hi.and_then(column_value_from);
929 return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
930 }
931 if let Some(s) = self.stats.index_stats(table, col) {
933 let cap = if lo.is_some() && hi.is_some() {
934 0.25
935 } else {
936 0.3
937 };
938 return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
939 }
940 }
941 if lo.is_some() && hi.is_some() {
943 0.25
944 } else {
945 0.3
946 }
947 }
948}
949
950impl Default for CostEstimator {
951 fn default() -> Self {
952 Self::new()
953 }
954}
955
956fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
961 use super::histogram::ColumnValue;
962 use crate::storage::schema::Value;
963 match v {
964 Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
965 Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
966 Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
967 Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
968 Value::Email(s)
969 | Value::Url(s)
970 | Value::NodeRef(s)
971 | Value::EdgeRef(s)
972 | Value::TableRef(s)
973 | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
974 Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
975 Value::Duration(d) => Some(ColumnValue::Int(*d)),
976 Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
977 Value::Decimal(d) => Some(ColumnValue::Int(*d)),
978 Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
979 Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
980 Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
981 Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
982 Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
983 Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
984 Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
985 Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
986 Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
987 _ => None,
992 }
993}
994
995fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
1000 match filter {
1001 AstFilter::Compare { field, .. } => column_name_for_table(field, table),
1002 AstFilter::Between { field, .. } => column_name_for_table(field, table),
1003 AstFilter::And(l, r) => {
1004 first_filter_column(l, table).or_else(|| first_filter_column(r, table))
1005 }
1006 _ => None,
1007 }
1008}
1009
1010fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
1012 match field {
1013 FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
1014 Some(column.as_str())
1015 }
1016 _ => None,
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::super::stats_provider::StaticProvider;
1024 use super::*;
1025 use crate::storage::index::{IndexKind, IndexStats};
1026 use crate::storage::query::ast::{FieldRef, Projection};
1027 use crate::storage::schema::Value;
1028
1029 fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1030 AstFilter::Compare {
1031 field: FieldRef::column(table, column),
1032 op: CompareOp::Eq,
1033 value: Value::Integer(value),
1034 }
1035 }
1036
1037 fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1038 TableQuery {
1039 table: name.to_string(),
1040 source: None,
1041 alias: None,
1042 select_items: Vec::new(),
1043 columns: vec![Projection::All],
1044 where_expr: None,
1045 filter,
1046 group_by_exprs: Vec::new(),
1047 group_by: Vec::new(),
1048 having_expr: None,
1049 having: None,
1050 order_by: vec![],
1051 limit: None,
1052 limit_param: None,
1053 offset: None,
1054 offset_param: None,
1055 expand: None,
1056 as_of: None,
1057 sessionize: None,
1058 }
1059 }
1060
1061 #[test]
1062 fn injected_row_count_overrides_default() {
1063 let provider = Arc::new(StaticProvider::new().with_table(
1064 "users",
1065 TableStats {
1066 row_count: 50_000,
1067 avg_row_size: 256,
1068 page_count: 500,
1069 columns: vec![],
1070 },
1071 ));
1072 let estimator = CostEstimator::with_stats(provider);
1073 let q = table_query("users", None);
1074 let card = estimator.estimate_table_cardinality(&q);
1075 assert_eq!(card.rows, 50_000.0);
1077 }
1078
1079 #[test]
1080 fn stats_aware_eq_selectivity_beats_default() {
1081 let provider = Arc::new(
1082 StaticProvider::new()
1083 .with_table(
1084 "users",
1085 TableStats {
1086 row_count: 1_000_000,
1087 avg_row_size: 256,
1088 page_count: 10_000,
1089 columns: vec![],
1090 },
1091 )
1092 .with_index(
1093 "users",
1094 "email",
1095 IndexStats {
1096 entries: 1_000_000,
1097 distinct_keys: 1_000_000,
1098 approx_bytes: 0,
1099 kind: IndexKind::Hash,
1100 has_bloom: true,
1101 index_correlation: 0.0,
1102 },
1103 ),
1104 );
1105 let estimator = CostEstimator::with_stats(provider);
1106 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1107 let card = estimator.estimate_table_cardinality(&q);
1108 assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1110 }
1111
1112 #[test]
1113 fn fallback_when_no_index_stats() {
1114 let provider = Arc::new(StaticProvider::new().with_table(
1115 "users",
1116 TableStats {
1117 row_count: 1_000_000,
1118 avg_row_size: 256,
1119 page_count: 10_000,
1120 columns: vec![],
1121 },
1122 ));
1123 let estimator = CostEstimator::with_stats(provider);
1124 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1125 let card = estimator.estimate_table_cardinality(&q);
1126 assert!((card.rows - 10_000.0).abs() < 1.0);
1128 }
1129
1130 #[test]
1131 fn null_provider_keeps_legacy_behaviour() {
1132 let estimator = CostEstimator::new();
1133 let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1134 let card = estimator.estimate_table_cardinality(&q);
1135 assert!((card.rows - 10.0).abs() < 1.0);
1137 }
1138
1139 #[test]
1140 fn and_combines_stats_selectivities() {
1141 let provider = Arc::new(
1142 StaticProvider::new()
1143 .with_table(
1144 "t",
1145 TableStats {
1146 row_count: 100_000,
1147 avg_row_size: 64,
1148 page_count: 100,
1149 columns: vec![],
1150 },
1151 )
1152 .with_index(
1153 "t",
1154 "a",
1155 IndexStats {
1156 entries: 100_000,
1157 distinct_keys: 10,
1158 approx_bytes: 0,
1159 kind: IndexKind::BTree,
1160 has_bloom: false,
1161 index_correlation: 0.0,
1162 },
1163 )
1164 .with_index(
1165 "t",
1166 "b",
1167 IndexStats {
1168 entries: 100_000,
1169 distinct_keys: 1000,
1170 approx_bytes: 0,
1171 kind: IndexKind::BTree,
1172 has_bloom: false,
1173 index_correlation: 0.0,
1174 },
1175 ),
1176 );
1177 let estimator = CostEstimator::with_stats(provider);
1178 let filter = AstFilter::And(
1179 Box::new(eq_filter("t", "a", 1)),
1180 Box::new(eq_filter("t", "b", 1)),
1181 );
1182 let q = table_query("t", Some(filter));
1183 let card = estimator.estimate_table_cardinality(&q);
1184 assert!(card.rows < 15.0, "got {}", card.rows);
1186 }
1187
1188 #[test]
1189 fn test_table_cost_estimation() {
1190 let estimator = CostEstimator::new();
1191
1192 let query = QueryExpr::Table(TableQuery {
1193 table: "hosts".to_string(),
1194 source: None,
1195 alias: None,
1196 select_items: Vec::new(),
1197 columns: vec![Projection::All],
1198 where_expr: None,
1199 filter: None,
1200 group_by_exprs: Vec::new(),
1201 group_by: Vec::new(),
1202 having_expr: None,
1203 having: None,
1204 order_by: vec![],
1205 limit: None,
1206 limit_param: None,
1207 offset: None,
1208 offset_param: None,
1209 expand: None,
1210 as_of: None,
1211 sessionize: None,
1212 });
1213
1214 let cost = estimator.estimate(&query);
1215 assert!(cost.cpu > 0.0);
1216 assert!(cost.total > 0.0);
1217 }
1218
1219 #[test]
1220 fn test_filter_selectivity() {
1221 let estimator = CostEstimator::new();
1222
1223 let eq_filter = AstFilter::Compare {
1224 field: FieldRef::column("hosts", "id"),
1225 op: CompareOp::Eq,
1226 value: Value::Integer(1),
1227 };
1228 assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1229
1230 let ne_filter = AstFilter::Compare {
1231 field: FieldRef::column("hosts", "id"),
1232 op: CompareOp::Ne,
1233 value: Value::Integer(1),
1234 };
1235 assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1236 }
1237
1238 #[test]
1239 fn test_and_selectivity() {
1240 let estimator = CostEstimator::new();
1241
1242 let and_filter = AstFilter::And(
1243 Box::new(AstFilter::Compare {
1244 field: FieldRef::column("hosts", "a"),
1245 op: CompareOp::Eq,
1246 value: Value::Integer(1),
1247 }),
1248 Box::new(AstFilter::Compare {
1249 field: FieldRef::column("hosts", "b"),
1250 op: CompareOp::Eq,
1251 value: Value::Integer(2),
1252 }),
1253 );
1254
1255 let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1256 assert!(selectivity < 0.01); }
1258
1259 #[test]
1260 fn test_cardinality_with_limit() {
1261 let estimator = CostEstimator::new();
1262
1263 let query = TableQuery {
1264 table: "hosts".to_string(),
1265 source: None,
1266 alias: None,
1267 select_items: Vec::new(),
1268 columns: vec![Projection::All],
1269 where_expr: None,
1270 filter: None,
1271 group_by_exprs: Vec::new(),
1272 group_by: Vec::new(),
1273 having_expr: None,
1274 having: None,
1275 order_by: vec![],
1276 limit: Some(10),
1277 limit_param: None,
1278 offset: None,
1279 offset_param: None,
1280 expand: None,
1281 as_of: None,
1282 sessionize: None,
1283 };
1284
1285 let card = estimator.estimate_table_cardinality(&query);
1286 assert!(card.rows <= 10.0);
1287 }
1288
1289 #[test]
1294 fn startup_zero_for_full_scan() {
1295 let estimator = CostEstimator::new();
1299 let q = table_query("any_table", None);
1300 let cost = estimator.estimate(&QueryExpr::Table(q));
1301 assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1302 assert!(cost.total > 0.0);
1303 }
1304
1305 #[test]
1306 fn startup_nonzero_for_blocking_combine() {
1307 let input = PlanCost::new(100.0, 10.0, 50.0); let blocker = PlanCost::new(20.0, 0.0, 10.0); let composed = input.combine_blocking(&blocker);
1312 assert_eq!(composed.startup_cost, input.total);
1314 assert_eq!(composed.total, input.total + blocker.total);
1316 assert!(composed.startup_cost > 0.0);
1317 }
1318
1319 #[test]
1320 fn pipelined_combine_adds_startup_directly() {
1321 let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1322 let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1323 let composed = upstream.combine_pipelined(&downstream);
1324 assert_eq!(composed.startup_cost, 30.0 + 5.0);
1325 assert_eq!(composed.total, upstream.total + downstream.total);
1326 }
1327
1328 #[test]
1329 fn cost_prefers_low_startup_when_limit_small() {
1330 let fast_first = PlanCost {
1333 cpu: 100.0,
1334 io: 10.0,
1335 network: 0.0,
1336 memory: 50.0,
1337 startup_cost: 5.0,
1338 total: 200.0,
1339 };
1340 let slow_first = PlanCost {
1341 cpu: 100.0,
1342 io: 10.0,
1343 network: 0.0,
1344 memory: 50.0,
1345 startup_cost: 150.0,
1346 total: 200.0,
1347 };
1348 assert_eq!(
1350 fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1351 std::cmp::Ordering::Less
1352 );
1353 }
1354
1355 #[test]
1356 fn cost_prefers_low_total_when_no_limit() {
1357 let low_total = PlanCost {
1359 cpu: 50.0,
1360 io: 5.0,
1361 network: 0.0,
1362 memory: 0.0,
1363 startup_cost: 30.0,
1364 total: 100.0,
1365 };
1366 let high_total = PlanCost {
1367 cpu: 100.0,
1368 io: 10.0,
1369 network: 0.0,
1370 memory: 0.0,
1371 startup_cost: 5.0,
1372 total: 200.0,
1373 };
1374 assert_eq!(
1375 low_total.prefer_over(&high_total, None, 10_000.0),
1376 std::cmp::Ordering::Less
1377 );
1378 }
1379
1380 #[test]
1381 fn limit_threshold_falls_back_to_total_when_limit_large() {
1382 let low_total = PlanCost {
1384 cpu: 50.0,
1385 io: 5.0,
1386 network: 0.0,
1387 memory: 0.0,
1388 startup_cost: 30.0,
1389 total: 100.0,
1390 };
1391 let low_startup = PlanCost {
1392 cpu: 100.0,
1393 io: 10.0,
1394 network: 0.0,
1395 memory: 0.0,
1396 startup_cost: 5.0,
1397 total: 200.0,
1398 };
1399 assert_eq!(
1400 low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1401 std::cmp::Ordering::Less
1402 );
1403 }
1404
1405 #[test]
1406 fn hash_join_startup_includes_build_cost() {
1407 let left = PlanCost::new(80.0, 8.0, 100.0); let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); let after_build = left.combine_blocking(&build);
1413 assert!(
1414 after_build.startup_cost >= left.total,
1415 "after-build startup ({}) must absorb left.total ({})",
1416 after_build.startup_cost,
1417 left.total
1418 );
1419 assert!(after_build.total >= after_build.startup_cost);
1420 }
1421
1422 #[test]
1423 fn vector_search_reports_nonzero_startup() {
1424 let estimator = CostEstimator::new();
1428 let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1431 assert!(v.startup_cost > 0.0);
1432 assert!(v.startup_cost < v.total);
1433 let _ = estimator; }
1435
1436 #[test]
1437 fn with_startup_clamps_total_below_startup() {
1438 let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1440 assert!(cost.total >= cost.startup_cost);
1441 }
1442
1443 #[test]
1444 fn default_plancost_has_zero_startup() {
1445 let c = PlanCost::default();
1446 assert_eq!(c.startup_cost, 0.0);
1447 assert_eq!(c.total, 0.0);
1448 }
1449
1450 use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1455
1456 fn provider_with_skew() -> Arc<StaticProvider> {
1457 let mut sample: Vec<ColumnValue> = Vec::new();
1461 for i in 0..80 {
1462 sample.push(ColumnValue::Int(i % 10));
1463 }
1464 for i in 0..20 {
1465 sample.push(ColumnValue::Int(10 + i * 50));
1466 }
1467 let h = Histogram::equi_depth_from_sample(sample, 10);
1468
1469 let mcv = MostCommonValues::new(vec![
1470 (ColumnValue::Text("boss".to_string()), 0.5),
1471 (ColumnValue::Text("intern".to_string()), 0.05),
1472 ]);
1473
1474 Arc::new(
1475 StaticProvider::new()
1476 .with_table(
1477 "people",
1478 TableStats {
1479 row_count: 100_000,
1480 avg_row_size: 64,
1481 page_count: 100,
1482 columns: vec![],
1483 },
1484 )
1485 .with_histogram("people", "score", h)
1486 .with_mcv("people", "role", mcv),
1487 )
1488 }
1489
1490 #[test]
1491 fn eq_uses_mcv_when_value_is_tracked() {
1492 let provider = provider_with_skew();
1493 let estimator = CostEstimator::with_stats(provider);
1494 let filter = AstFilter::Compare {
1495 field: FieldRef::column("people", "role"),
1496 op: CompareOp::Eq,
1497 value: Value::text("boss".to_string()),
1498 };
1499 let s = estimator.filter_selectivity(&filter, "people");
1502 assert!(
1503 (s - 0.5).abs() < 1e-9,
1504 "MCV-tracked equality should report exact frequency, got {s}"
1505 );
1506 }
1507
1508 #[test]
1509 fn eq_uses_residual_for_non_mcv_value() {
1510 let provider = provider_with_skew();
1511 let estimator = CostEstimator::with_stats(provider);
1512 let filter = AstFilter::Compare {
1513 field: FieldRef::column("people", "role"),
1514 op: CompareOp::Eq,
1515 value: Value::text("staff".to_string()),
1516 };
1517 let s = estimator.filter_selectivity(&filter, "people");
1521 assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1522 }
1523
1524 #[test]
1525 fn ne_is_one_minus_eq_under_mcv() {
1526 let provider = provider_with_skew();
1527 let estimator = CostEstimator::with_stats(provider);
1528 let filter = AstFilter::Compare {
1529 field: FieldRef::column("people", "role"),
1530 op: CompareOp::Ne,
1531 value: Value::text("boss".to_string()),
1532 };
1533 let s = estimator.filter_selectivity(&filter, "people");
1534 assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1536 }
1537
1538 #[test]
1539 fn range_uses_histogram_when_present() {
1540 let provider = provider_with_skew();
1541 let estimator = CostEstimator::with_stats(provider);
1542 let filter = AstFilter::Compare {
1543 field: FieldRef::column("people", "score"),
1544 op: CompareOp::Le,
1545 value: Value::Integer(9),
1546 };
1547 let s = estimator.filter_selectivity(&filter, "people");
1550 assert!(
1551 s > 0.5,
1552 "histogram-based range selectivity should beat 0.3, got {s}"
1553 );
1554 }
1555
1556 #[test]
1557 fn between_uses_histogram() {
1558 let provider = provider_with_skew();
1559 let estimator = CostEstimator::with_stats(provider);
1560 let filter = AstFilter::Between {
1561 field: FieldRef::column("people", "score"),
1562 low: Value::Integer(0),
1563 high: Value::Integer(9),
1564 };
1565 let s = estimator.filter_selectivity(&filter, "people");
1566 assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1567 }
1568
1569 #[test]
1570 fn graceful_fallback_when_histogram_absent() {
1571 let provider = Arc::new(StaticProvider::new().with_table(
1574 "people",
1575 TableStats {
1576 row_count: 1000,
1577 avg_row_size: 64,
1578 page_count: 10,
1579 columns: vec![],
1580 },
1581 ));
1582 let estimator = CostEstimator::with_stats(provider);
1583 let filter = AstFilter::Compare {
1584 field: FieldRef::column("people", "unknown_col"),
1585 op: CompareOp::Lt,
1586 value: Value::Integer(50),
1587 };
1588 let s = estimator.filter_selectivity(&filter, "people");
1589 assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1590 }
1591}