1use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16 CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17 PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24 pub rows: f64,
26 pub selectivity: f64,
28 pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33 pub fn new(rows: f64, selectivity: f64) -> Self {
35 Self {
36 rows,
37 selectivity,
38 confidence: 1.0,
39 }
40 }
41
42 pub fn full_scan(table_size: f64) -> Self {
44 Self {
45 rows: table_size,
46 selectivity: 1.0,
47 confidence: 1.0,
48 }
49 }
50
51 pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53 self.rows *= filter_selectivity;
54 self.selectivity *= filter_selectivity;
55 self.confidence *= 0.9; self
57 }
58}
59
60#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71 pub cpu: f64,
73 pub io: f64,
75 pub network: f64,
77 pub memory: f64,
79 pub startup_cost: f64,
85 pub total: f64,
87}
88
89impl PlanCost {
90 pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92 let total = cpu + io * 10.0 + memory * 0.1; Self {
94 cpu,
95 io,
96 network: 0.0,
97 memory,
98 startup_cost: 0.0,
99 total,
100 }
101 }
102
103 pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107 let total = cpu + io * 10.0 + memory * 0.1;
108 Self {
109 cpu,
110 io,
111 network: 0.0,
112 memory,
113 startup_cost: startup_cost.max(0.0),
114 total: total.max(startup_cost),
115 }
116 }
117
118 pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124 PlanCost {
125 cpu: self.cpu + other.cpu,
126 io: self.io + other.io,
127 network: self.network + other.network,
128 memory: self.memory.max(other.memory),
129 startup_cost: self.startup_cost + other.startup_cost,
130 total: self.total + other.total,
131 }
132 }
133
134 pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141 PlanCost {
142 cpu: self.cpu + blocker.cpu,
143 io: self.io + blocker.io,
144 network: self.network + blocker.network,
145 memory: self.memory.max(blocker.memory),
146 startup_cost: self.total + blocker.startup_cost,
147 total: self.total + blocker.total,
148 }
149 }
150
151 pub fn combine(&self, other: &PlanCost) -> PlanCost {
156 self.combine_pipelined(other)
157 }
158
159 pub fn scale(&self, factor: f64) -> PlanCost {
161 PlanCost {
162 cpu: self.cpu * factor,
163 io: self.io * factor,
164 network: self.network * factor,
165 memory: self.memory, startup_cost: self.startup_cost, total: self.total * factor,
168 }
169 }
170
171 pub fn prefer_over(
183 &self,
184 other: &PlanCost,
185 limit: Option<u64>,
186 cardinality: f64,
187 ) -> std::cmp::Ordering {
188 let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189 let (lhs, rhs) = if use_startup {
190 (self.startup_cost, other.startup_cost)
191 } else {
192 (self.total, other.total)
193 };
194 lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195 }
196}
197
198#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201 pub row_count: u64,
203 pub avg_row_size: u32,
205 pub page_count: u64,
207 pub columns: Vec<ColumnStats>,
209}
210
211#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214 pub name: String,
216 pub distinct_count: u64,
218 pub null_count: u64,
220 pub min_value: Option<String>,
222 pub max_value: Option<String>,
224 pub has_index: bool,
226}
227
228pub struct CostEstimator {
230 default_row_count: f64,
232 row_scan_cost: f64,
234 index_lookup_cost: f64,
236 hash_probe_cost: f64,
238 nested_loop_cost: f64,
240 edge_traversal_cost: f64,
242 stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250 pub fn new() -> Self {
253 Self {
254 default_row_count: 1000.0,
255 row_scan_cost: 1.0,
256 index_lookup_cost: 0.1,
257 hash_probe_cost: 0.5,
258 nested_loop_cost: 2.0,
259 edge_traversal_cost: 1.5,
260 stats: Arc::new(NullProvider),
261 }
262 }
263
264 pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268 Self {
269 stats: provider,
270 ..Self::new()
271 }
272 }
273
274 pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278 self.stats = provider;
279 }
280
281 pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283 match query {
284 QueryExpr::Table(tq) => self.estimate_table(tq),
285 QueryExpr::Graph(gq) => self.estimate_graph(gq),
286 QueryExpr::Join(jq) => self.estimate_join(jq),
287 QueryExpr::Path(pq) => self.estimate_path(pq),
288 QueryExpr::Vector(vq) => self.estimate_vector(vq),
289 QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290 QueryExpr::Insert(_)
292 | QueryExpr::Update(_)
293 | QueryExpr::Delete(_)
294 | QueryExpr::CreateTable(_)
295 | QueryExpr::CreateCollection(_)
296 | QueryExpr::CreateVector(_)
297 | QueryExpr::DropTable(_)
298 | QueryExpr::DropGraph(_)
299 | QueryExpr::DropVector(_)
300 | QueryExpr::DropDocument(_)
301 | QueryExpr::DropKv(_)
302 | QueryExpr::DropCollection(_)
303 | QueryExpr::Truncate(_)
304 | QueryExpr::AlterTable(_)
305 | QueryExpr::GraphCommand(_)
306 | QueryExpr::SearchCommand(_)
307 | QueryExpr::CreateIndex(_)
308 | QueryExpr::DropIndex(_)
309 | QueryExpr::ProbabilisticCommand(_)
310 | QueryExpr::Ask(_)
311 | QueryExpr::SetConfig { .. }
312 | QueryExpr::ShowConfig { .. }
313 | QueryExpr::SetSecret { .. }
314 | QueryExpr::DeleteSecret { .. }
315 | QueryExpr::ShowSecrets { .. }
316 | QueryExpr::SetTenant(_)
317 | QueryExpr::ShowTenant
318 | QueryExpr::CreateTimeSeries(_)
319 | QueryExpr::CreateMetric(_)
320 | QueryExpr::AlterMetric(_)
321 | QueryExpr::CreateSlo(_)
322 | QueryExpr::DropTimeSeries(_)
323 | QueryExpr::CreateQueue(_)
324 | QueryExpr::AlterQueue(_)
325 | QueryExpr::DropQueue(_)
326 | QueryExpr::QueueSelect(_)
327 | QueryExpr::QueueCommand(_)
328 | QueryExpr::KvCommand(_)
329 | QueryExpr::ConfigCommand(_)
330 | QueryExpr::CreateTree(_)
331 | QueryExpr::DropTree(_)
332 | QueryExpr::TreeCommand(_)
333 | QueryExpr::ExplainAlter(_)
334 | QueryExpr::TransactionControl(_)
335 | QueryExpr::MaintenanceCommand(_)
336 | QueryExpr::CreateSchema(_)
337 | QueryExpr::DropSchema(_)
338 | QueryExpr::CreateSequence(_)
339 | QueryExpr::DropSequence(_)
340 | QueryExpr::CopyFrom(_)
341 | QueryExpr::CreateView(_)
342 | QueryExpr::DropView(_)
343 | QueryExpr::RefreshMaterializedView(_)
344 | QueryExpr::CreatePolicy(_)
345 | QueryExpr::DropPolicy(_)
346 | QueryExpr::CreateServer(_)
347 | QueryExpr::DropServer(_)
348 | QueryExpr::CreateForeignTable(_)
349 | QueryExpr::DropForeignTable(_)
350 | QueryExpr::Grant(_)
351 | QueryExpr::Revoke(_)
352 | QueryExpr::AlterUser(_)
353 | QueryExpr::CreateUser(_)
354 | QueryExpr::CreateIamPolicy { .. }
355 | QueryExpr::DropIamPolicy { .. }
356 | QueryExpr::AttachPolicy { .. }
357 | QueryExpr::DetachPolicy { .. }
358 | QueryExpr::ShowPolicies { .. }
359 | QueryExpr::ShowEffectivePermissions { .. }
360 | QueryExpr::RankOf(_)
361 | QueryExpr::ApproxRankOf(_)
362 | QueryExpr::RankRange(_)
363 | QueryExpr::SimulatePolicy { .. }
364 | QueryExpr::LintPolicy { .. }
365 | QueryExpr::MigratePolicyMode { .. }
366 | QueryExpr::CreateMigration(_)
367 | QueryExpr::ApplyMigration(_)
368 | QueryExpr::RollbackMigration(_)
369 | QueryExpr::ExplainMigration(_)
370 | QueryExpr::EventsBackfill(_)
371 | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
372 }
373 }
374
375 pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
377 match query {
378 QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
379 QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
380 QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
381 QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
382 QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
383 QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
384 QueryExpr::Insert(_)
386 | QueryExpr::Update(_)
387 | QueryExpr::Delete(_)
388 | QueryExpr::CreateTable(_)
389 | QueryExpr::CreateCollection(_)
390 | QueryExpr::CreateVector(_)
391 | QueryExpr::DropTable(_)
392 | QueryExpr::DropGraph(_)
393 | QueryExpr::DropVector(_)
394 | QueryExpr::DropDocument(_)
395 | QueryExpr::DropKv(_)
396 | QueryExpr::DropCollection(_)
397 | QueryExpr::Truncate(_)
398 | QueryExpr::AlterTable(_)
399 | QueryExpr::GraphCommand(_)
400 | QueryExpr::SearchCommand(_)
401 | QueryExpr::CreateIndex(_)
402 | QueryExpr::DropIndex(_)
403 | QueryExpr::ProbabilisticCommand(_)
404 | QueryExpr::Ask(_)
405 | QueryExpr::SetConfig { .. }
406 | QueryExpr::ShowConfig { .. }
407 | QueryExpr::SetSecret { .. }
408 | QueryExpr::DeleteSecret { .. }
409 | QueryExpr::ShowSecrets { .. }
410 | QueryExpr::SetTenant(_)
411 | QueryExpr::ShowTenant
412 | QueryExpr::CreateTimeSeries(_)
413 | QueryExpr::CreateMetric(_)
414 | QueryExpr::AlterMetric(_)
415 | QueryExpr::CreateSlo(_)
416 | QueryExpr::DropTimeSeries(_)
417 | QueryExpr::CreateQueue(_)
418 | QueryExpr::AlterQueue(_)
419 | QueryExpr::DropQueue(_)
420 | QueryExpr::QueueSelect(_)
421 | QueryExpr::QueueCommand(_)
422 | QueryExpr::KvCommand(_)
423 | QueryExpr::ConfigCommand(_)
424 | QueryExpr::CreateTree(_)
425 | QueryExpr::DropTree(_)
426 | QueryExpr::TreeCommand(_)
427 | QueryExpr::ExplainAlter(_)
428 | QueryExpr::TransactionControl(_)
429 | QueryExpr::MaintenanceCommand(_)
430 | QueryExpr::CreateSchema(_)
431 | QueryExpr::DropSchema(_)
432 | QueryExpr::CreateSequence(_)
433 | QueryExpr::DropSequence(_)
434 | QueryExpr::CopyFrom(_)
435 | QueryExpr::CreateView(_)
436 | QueryExpr::DropView(_)
437 | QueryExpr::RefreshMaterializedView(_)
438 | QueryExpr::CreatePolicy(_)
439 | QueryExpr::DropPolicy(_)
440 | QueryExpr::CreateServer(_)
441 | QueryExpr::DropServer(_)
442 | QueryExpr::CreateForeignTable(_)
443 | QueryExpr::DropForeignTable(_)
444 | QueryExpr::Grant(_)
445 | QueryExpr::Revoke(_)
446 | QueryExpr::AlterUser(_)
447 | QueryExpr::CreateUser(_)
448 | QueryExpr::CreateIamPolicy { .. }
449 | QueryExpr::DropIamPolicy { .. }
450 | QueryExpr::AttachPolicy { .. }
451 | QueryExpr::DetachPolicy { .. }
452 | QueryExpr::ShowPolicies { .. }
453 | QueryExpr::ShowEffectivePermissions { .. }
454 | QueryExpr::RankOf(_)
455 | QueryExpr::ApproxRankOf(_)
456 | QueryExpr::RankRange(_)
457 | QueryExpr::SimulatePolicy { .. }
458 | QueryExpr::LintPolicy { .. }
459 | QueryExpr::MigratePolicyMode { .. }
460 | QueryExpr::CreateMigration(_)
461 | QueryExpr::ApplyMigration(_)
462 | QueryExpr::RollbackMigration(_)
463 | QueryExpr::ExplainMigration(_)
464 | QueryExpr::EventsBackfill(_)
465 | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
466 }
467 }
468
469 fn estimate_table(&self, query: &TableQuery) -> PlanCost {
474 let cardinality = self.estimate_table_cardinality(query);
475
476 let cpu = cardinality.rows * self.row_scan_cost;
477
478 let io = self.estimate_table_io(query, cardinality.rows);
481
482 let memory = cardinality.rows * 100.0; PlanCost::new(cpu, io, memory)
485 }
486
487 fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
494 const ROWS_PER_PAGE: f64 = 100.0;
495
496 let table_stats = self.stats.table_stats(&query.table);
498 let heap_pages = table_stats
499 .map(|s| s.page_count as f64)
500 .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
501
502 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
505 if let Some(col) = first_filter_column(&filter, &query.table) {
506 if let Some(idx) = self.stats.index_stats(&query.table, col) {
507 return idx.correlated_io_cost(result_rows, heap_pages);
508 }
509 }
510 }
511
512 (result_rows / ROWS_PER_PAGE).ceil()
514 }
515
516 fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
517 let base_rows = self
520 .stats
521 .table_stats(&query.table)
522 .map(|s| s.row_count as f64)
523 .unwrap_or(self.default_row_count);
524
525 let mut estimate = CardinalityEstimate::full_scan(base_rows);
526
527 if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
530 let selectivity = self.filter_selectivity(&filter, &query.table);
531 estimate = estimate.with_filter(selectivity);
532 }
533
534 if let Some(limit) = query.limit {
536 estimate.rows = estimate.rows.min(limit as f64);
537 }
538
539 estimate
540 }
541
542 fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
555 match filter {
556 AstFilter::Compare { field, op, value } => {
557 let column = column_name_for_table(field, table);
558 match op {
559 CompareOp::Eq => self.eq_selectivity(table, column, value),
560 CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
561 CompareOp::Lt | CompareOp::Le => {
562 self.range_selectivity(table, column, None, Some(value))
563 }
564 CompareOp::Gt | CompareOp::Ge => {
565 self.range_selectivity(table, column, Some(value), None)
566 }
567 }
568 }
569 AstFilter::Between {
570 field, low, high, ..
571 } => {
572 let column = column_name_for_table(field, table);
573 self.range_selectivity(table, column, Some(low), Some(high))
574 }
575 AstFilter::In { field, values, .. } => {
576 let column = column_name_for_table(field, table);
577 if let Some(c) = column {
581 if let Some(mcv) = self.stats.column_mcv(table, c) {
582 let mut hits: f64 = 0.0;
583 let mut residual_count = 0usize;
584 for v in values {
585 if let Some(cv) = column_value_from(v) {
586 if let Some(freq) = mcv.frequency_of(&cv) {
587 hits += freq;
588 } else {
589 residual_count += 1;
590 }
591 } else {
592 residual_count += 1;
593 }
594 }
595 let total = mcv.total_frequency();
596 let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
597 let non_mcv_distinct =
598 distinct.saturating_sub(mcv.values.len() as u64).max(1);
599 let per_residual = (1.0 - total) / non_mcv_distinct as f64;
600 let estimate = hits + (residual_count as f64) * per_residual;
601 return estimate.clamp(0.0, 1.0).min(0.5);
602 }
603 if let Some(s) = self.stats.index_stats(table, c) {
604 return (s.point_selectivity() * values.len() as f64).min(0.5);
605 }
606 }
607 (values.len() as f64 * 0.01).min(0.5)
608 }
609 AstFilter::Like { .. } => 0.1,
610 AstFilter::StartsWith { .. } => 0.15,
611 AstFilter::EndsWith { .. } => 0.15,
612 AstFilter::Contains { .. } => 0.1,
613 AstFilter::IsNull { .. } => 0.01,
614 AstFilter::IsNotNull { .. } => 0.99,
615 AstFilter::And(left, right) => {
616 self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
617 }
618 AstFilter::Or(left, right) => {
619 let s1 = self.filter_selectivity(left, table);
620 let s2 = self.filter_selectivity(right, table);
621 s1 + s2 - (s1 * s2)
622 }
623 AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
624 AstFilter::CompareFields { .. } => {
625 0.1
629 }
630 AstFilter::CompareExpr { .. } => {
631 0.1
635 }
636 }
637 }
638
639 fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
644 let cardinality = self.estimate_graph_cardinality(query);
645
646 let nodes = query.pattern.nodes.len() as f64;
648 let edges = query.pattern.edges.len() as f64;
649
650 let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
651 let io = cardinality.rows * 0.1; let memory = cardinality.rows * 200.0; PlanCost::new(cpu, io, memory)
655 }
656
657 fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
658 let nodes = query.pattern.nodes.len() as f64;
659 let edges = query.pattern.edges.len() as f64;
660
661 let base_rows = self.default_row_count;
663 let edge_factor = 0.1_f64.powf(edges); let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
666 estimate.confidence = 0.5; if let Some(ref filter) = query.filter {
670 let selectivity = Self::estimate_filter_selectivity(filter);
671 estimate = estimate.with_filter(selectivity);
672 }
673
674 estimate
675 }
676
677 fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
682 let left_cost = self.estimate(&query.left);
683 let right_cost = self.estimate(&query.right);
684
685 let left_card = self.estimate_cardinality(&query.left);
686 let right_card = self.estimate_cardinality(&query.right);
687
688 let build_cpu = left_card.rows * self.hash_probe_cost;
695 let probe_cpu = right_card.rows * self.hash_probe_cost;
696 let join_memory = left_card.rows * 100.0; let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
700 let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
702
703 let after_build = left_cost.combine_blocking(&build_op);
705 after_build
706 .combine_pipelined(&right_cost)
707 .combine_pipelined(&probe_op)
708 }
709
710 fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
711 let left = self.estimate_cardinality(&query.left);
712 let right = self.estimate_cardinality(&query.right);
713
714 let selectivity = match query.join_type {
716 JoinType::Inner => 0.1, JoinType::LeftOuter => 1.0, JoinType::RightOuter => 1.0, JoinType::FullOuter => 1.0, JoinType::Cross => 1.0, };
722
723 CardinalityEstimate::new(
724 left.rows * right.rows * selectivity,
725 left.selectivity * right.selectivity * selectivity,
726 )
727 }
728
729 fn estimate_path(&self, query: &PathQuery) -> PlanCost {
734 let cardinality = self.estimate_path_cardinality(query);
735
736 let max_hops = query.max_length;
738 let branching_factor: f64 = 5.0; let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
741 let cpu = nodes_visited * self.edge_traversal_cost;
742 let io = nodes_visited * 0.1;
743 let memory = nodes_visited * 50.0; PlanCost::new(cpu, io, memory)
746 }
747
748 fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
749 let max_paths = 10.0;
751 CardinalityEstimate::new(max_paths, 0.001)
752 }
753
754 fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
759 let k = query.k as f64;
762
763 let hnsw_cost = 100.0 * (1.0 + k.ln()); let filter_cost =
770 if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
771 50.0
772 } else {
773 0.0
774 };
775
776 let cpu = hnsw_cost + filter_cost;
777 let io = 20.0; let memory = k * 32.0 + 1000.0; PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
785 }
786
787 fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
788 let k = query.k as f64;
790 CardinalityEstimate::new(k, 0.1)
791 }
792
793 fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
798 let structured_cost = self.estimate(&query.structured);
800 let vector_cost = self.estimate_vector(&query.vector);
801
802 let fusion_overhead = match &query.fusion {
804 crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
805 crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
806 crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
807 crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
808 crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
809 crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
810 };
811
812 PlanCost::new(
813 structured_cost.cpu + vector_cost.cpu + fusion_overhead,
814 structured_cost.io + vector_cost.io,
815 structured_cost.memory + vector_cost.memory,
816 )
817 }
818
819 fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
820 let structured_card = self.estimate_cardinality(&query.structured);
821 let vector_card = self.estimate_vector_cardinality(&query.vector);
822
823 let rows = match &query.fusion {
825 crate::storage::query::ast::FusionStrategy::Intersection => {
826 structured_card.rows.min(vector_card.rows)
827 }
828 crate::storage::query::ast::FusionStrategy::Union { .. } => {
829 structured_card.rows + vector_card.rows
830 }
831 _ => vector_card.rows, };
833
834 CardinalityEstimate::new(rows, 0.2)
835 }
836
837 fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
842 match filter {
843 AstFilter::Compare { op, .. } => {
844 match op {
845 CompareOp::Eq => 0.01, CompareOp::Ne => 0.99, CompareOp::Lt | CompareOp::Le => 0.3,
848 CompareOp::Gt | CompareOp::Ge => 0.3,
849 }
850 }
851 AstFilter::Between { .. } => 0.25,
852 AstFilter::In { values, .. } => {
853 (values.len() as f64 * 0.01).min(0.5)
855 }
856 AstFilter::Like { .. } => 0.1,
857 AstFilter::StartsWith { .. } => 0.15,
858 AstFilter::EndsWith { .. } => 0.15,
859 AstFilter::Contains { .. } => 0.1,
860 AstFilter::IsNull { .. } => 0.01,
861 AstFilter::IsNotNull { .. } => 0.99,
862 AstFilter::And(left, right) => {
863 Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
864 }
865 AstFilter::Or(left, right) => {
866 let s1 = Self::estimate_filter_selectivity(left);
867 let s2 = Self::estimate_filter_selectivity(right);
868 s1 + s2 - (s1 * s2) }
870 AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
871 AstFilter::CompareFields { .. } => 0.1,
872 AstFilter::CompareExpr { .. } => 0.1,
873 }
874 }
875}
876
877impl CostEstimator {
878 fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
886 if let Some(col) = column {
887 if let Some(mcv) = self.stats.column_mcv(table, col) {
889 if let Some(cv) = column_value_from(value) {
890 if let Some(freq) = mcv.frequency_of(&cv) {
891 return freq;
892 }
893 let total = mcv.total_frequency();
895 let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
896 let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
897 return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
898 }
899 }
900 if let Some(s) = self.stats.index_stats(table, col) {
902 return s.point_selectivity();
903 }
904 }
905 0.01
907 }
908
909 fn range_selectivity(
920 &self,
921 table: &str,
922 column: Option<&str>,
923 lo: Option<&Value>,
924 hi: Option<&Value>,
925 ) -> f64 {
926 if let Some(col) = column {
927 if let Some(h) = self.stats.column_histogram(table, col) {
929 let lo_cv = lo.and_then(column_value_from);
930 let hi_cv = hi.and_then(column_value_from);
931 return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
932 }
933 if let Some(s) = self.stats.index_stats(table, col) {
935 let cap = if lo.is_some() && hi.is_some() {
936 0.25
937 } else {
938 0.3
939 };
940 return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
941 }
942 }
943 if lo.is_some() && hi.is_some() {
945 0.25
946 } else {
947 0.3
948 }
949 }
950}
951
952impl Default for CostEstimator {
953 fn default() -> Self {
954 Self::new()
955 }
956}
957
958fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
963 use super::histogram::ColumnValue;
964 use crate::storage::schema::Value;
965 match v {
966 Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
967 Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
968 Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
969 Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
970 Value::Email(s)
971 | Value::Url(s)
972 | Value::NodeRef(s)
973 | Value::EdgeRef(s)
974 | Value::TableRef(s)
975 | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
976 Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
977 Value::Duration(d) => Some(ColumnValue::Int(*d)),
978 Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
979 Value::Decimal(d) => Some(ColumnValue::Int(*d)),
980 Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
981 Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
982 Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
983 Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
984 Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
985 Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
986 Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
987 Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
988 Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
989 _ => None,
994 }
995}
996
997fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
1002 match filter {
1003 AstFilter::Compare { field, .. } => column_name_for_table(field, table),
1004 AstFilter::Between { field, .. } => column_name_for_table(field, table),
1005 AstFilter::And(l, r) => {
1006 first_filter_column(l, table).or_else(|| first_filter_column(r, table))
1007 }
1008 _ => None,
1009 }
1010}
1011
1012fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
1014 match field {
1015 FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
1016 Some(column.as_str())
1017 }
1018 _ => None,
1020 }
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025 use super::super::stats_provider::StaticProvider;
1026 use super::*;
1027 use crate::storage::index::{IndexKind, IndexStats};
1028 use crate::storage::query::ast::{FieldRef, Projection};
1029 use crate::storage::schema::Value;
1030
1031 fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1032 AstFilter::Compare {
1033 field: FieldRef::column(table, column),
1034 op: CompareOp::Eq,
1035 value: Value::Integer(value),
1036 }
1037 }
1038
1039 fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1040 TableQuery {
1041 table: name.to_string(),
1042 source: None,
1043 alias: None,
1044 select_items: Vec::new(),
1045 columns: vec![Projection::All],
1046 where_expr: None,
1047 filter,
1048 group_by_exprs: Vec::new(),
1049 group_by: Vec::new(),
1050 having_expr: None,
1051 having: None,
1052 order_by: vec![],
1053 limit: None,
1054 limit_param: None,
1055 offset: None,
1056 offset_param: None,
1057 expand: None,
1058 as_of: None,
1059 sessionize: None,
1060 distinct: false,
1061 }
1062 }
1063
1064 #[test]
1065 fn injected_row_count_overrides_default() {
1066 let provider = Arc::new(StaticProvider::new().with_table(
1067 "users",
1068 TableStats {
1069 row_count: 50_000,
1070 avg_row_size: 256,
1071 page_count: 500,
1072 columns: vec![],
1073 },
1074 ));
1075 let estimator = CostEstimator::with_stats(provider);
1076 let q = table_query("users", None);
1077 let card = estimator.estimate_table_cardinality(&q);
1078 assert_eq!(card.rows, 50_000.0);
1080 }
1081
1082 #[test]
1083 fn stats_aware_eq_selectivity_beats_default() {
1084 let provider = Arc::new(
1085 StaticProvider::new()
1086 .with_table(
1087 "users",
1088 TableStats {
1089 row_count: 1_000_000,
1090 avg_row_size: 256,
1091 page_count: 10_000,
1092 columns: vec![],
1093 },
1094 )
1095 .with_index(
1096 "users",
1097 "email",
1098 IndexStats {
1099 entries: 1_000_000,
1100 distinct_keys: 1_000_000,
1101 approx_bytes: 0,
1102 kind: IndexKind::Hash,
1103 has_bloom: true,
1104 index_correlation: 0.0,
1105 },
1106 ),
1107 );
1108 let estimator = CostEstimator::with_stats(provider);
1109 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1110 let card = estimator.estimate_table_cardinality(&q);
1111 assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1113 }
1114
1115 #[test]
1116 fn fallback_when_no_index_stats() {
1117 let provider = Arc::new(StaticProvider::new().with_table(
1118 "users",
1119 TableStats {
1120 row_count: 1_000_000,
1121 avg_row_size: 256,
1122 page_count: 10_000,
1123 columns: vec![],
1124 },
1125 ));
1126 let estimator = CostEstimator::with_stats(provider);
1127 let q = table_query("users", Some(eq_filter("users", "email", 0)));
1128 let card = estimator.estimate_table_cardinality(&q);
1129 assert!((card.rows - 10_000.0).abs() < 1.0);
1131 }
1132
1133 #[test]
1134 fn null_provider_keeps_legacy_behaviour() {
1135 let estimator = CostEstimator::new();
1136 let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1137 let card = estimator.estimate_table_cardinality(&q);
1138 assert!((card.rows - 10.0).abs() < 1.0);
1140 }
1141
1142 #[test]
1143 fn and_combines_stats_selectivities() {
1144 let provider = Arc::new(
1145 StaticProvider::new()
1146 .with_table(
1147 "t",
1148 TableStats {
1149 row_count: 100_000,
1150 avg_row_size: 64,
1151 page_count: 100,
1152 columns: vec![],
1153 },
1154 )
1155 .with_index(
1156 "t",
1157 "a",
1158 IndexStats {
1159 entries: 100_000,
1160 distinct_keys: 10,
1161 approx_bytes: 0,
1162 kind: IndexKind::BTree,
1163 has_bloom: false,
1164 index_correlation: 0.0,
1165 },
1166 )
1167 .with_index(
1168 "t",
1169 "b",
1170 IndexStats {
1171 entries: 100_000,
1172 distinct_keys: 1000,
1173 approx_bytes: 0,
1174 kind: IndexKind::BTree,
1175 has_bloom: false,
1176 index_correlation: 0.0,
1177 },
1178 ),
1179 );
1180 let estimator = CostEstimator::with_stats(provider);
1181 let filter = AstFilter::And(
1182 Box::new(eq_filter("t", "a", 1)),
1183 Box::new(eq_filter("t", "b", 1)),
1184 );
1185 let q = table_query("t", Some(filter));
1186 let card = estimator.estimate_table_cardinality(&q);
1187 assert!(card.rows < 15.0, "got {}", card.rows);
1189 }
1190
1191 #[test]
1192 fn test_table_cost_estimation() {
1193 let estimator = CostEstimator::new();
1194
1195 let query = QueryExpr::Table(TableQuery {
1196 table: "hosts".to_string(),
1197 source: None,
1198 alias: None,
1199 select_items: Vec::new(),
1200 columns: vec![Projection::All],
1201 where_expr: None,
1202 filter: None,
1203 group_by_exprs: Vec::new(),
1204 group_by: Vec::new(),
1205 having_expr: None,
1206 having: None,
1207 order_by: vec![],
1208 limit: None,
1209 limit_param: None,
1210 offset: None,
1211 offset_param: None,
1212 expand: None,
1213 as_of: None,
1214 sessionize: None,
1215 distinct: false,
1216 });
1217
1218 let cost = estimator.estimate(&query);
1219 assert!(cost.cpu > 0.0);
1220 assert!(cost.total > 0.0);
1221 }
1222
1223 #[test]
1224 fn test_filter_selectivity() {
1225 let estimator = CostEstimator::new();
1226
1227 let eq_filter = AstFilter::Compare {
1228 field: FieldRef::column("hosts", "id"),
1229 op: CompareOp::Eq,
1230 value: Value::Integer(1),
1231 };
1232 assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1233
1234 let ne_filter = AstFilter::Compare {
1235 field: FieldRef::column("hosts", "id"),
1236 op: CompareOp::Ne,
1237 value: Value::Integer(1),
1238 };
1239 assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1240 }
1241
1242 #[test]
1243 fn test_and_selectivity() {
1244 let estimator = CostEstimator::new();
1245
1246 let and_filter = AstFilter::And(
1247 Box::new(AstFilter::Compare {
1248 field: FieldRef::column("hosts", "a"),
1249 op: CompareOp::Eq,
1250 value: Value::Integer(1),
1251 }),
1252 Box::new(AstFilter::Compare {
1253 field: FieldRef::column("hosts", "b"),
1254 op: CompareOp::Eq,
1255 value: Value::Integer(2),
1256 }),
1257 );
1258
1259 let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1260 assert!(selectivity < 0.01); }
1262
1263 #[test]
1264 fn test_cardinality_with_limit() {
1265 let estimator = CostEstimator::new();
1266
1267 let query = TableQuery {
1268 table: "hosts".to_string(),
1269 source: None,
1270 alias: None,
1271 select_items: Vec::new(),
1272 columns: vec![Projection::All],
1273 where_expr: None,
1274 filter: None,
1275 group_by_exprs: Vec::new(),
1276 group_by: Vec::new(),
1277 having_expr: None,
1278 having: None,
1279 order_by: vec![],
1280 limit: Some(10),
1281 limit_param: None,
1282 offset: None,
1283 offset_param: None,
1284 expand: None,
1285 as_of: None,
1286 sessionize: None,
1287 distinct: false,
1288 };
1289
1290 let card = estimator.estimate_table_cardinality(&query);
1291 assert!(card.rows <= 10.0);
1292 }
1293
1294 #[test]
1299 fn startup_zero_for_full_scan() {
1300 let estimator = CostEstimator::new();
1304 let q = table_query("any_table", None);
1305 let cost = estimator.estimate(&QueryExpr::Table(q));
1306 assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1307 assert!(cost.total > 0.0);
1308 }
1309
1310 #[test]
1311 fn startup_nonzero_for_blocking_combine() {
1312 let input = PlanCost::new(100.0, 10.0, 50.0); let blocker = PlanCost::new(20.0, 0.0, 10.0); let composed = input.combine_blocking(&blocker);
1317 assert_eq!(composed.startup_cost, input.total);
1319 assert_eq!(composed.total, input.total + blocker.total);
1321 assert!(composed.startup_cost > 0.0);
1322 }
1323
1324 #[test]
1325 fn pipelined_combine_adds_startup_directly() {
1326 let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1327 let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1328 let composed = upstream.combine_pipelined(&downstream);
1329 assert_eq!(composed.startup_cost, 30.0 + 5.0);
1330 assert_eq!(composed.total, upstream.total + downstream.total);
1331 }
1332
1333 #[test]
1334 fn cost_prefers_low_startup_when_limit_small() {
1335 let fast_first = PlanCost {
1338 cpu: 100.0,
1339 io: 10.0,
1340 network: 0.0,
1341 memory: 50.0,
1342 startup_cost: 5.0,
1343 total: 200.0,
1344 };
1345 let slow_first = PlanCost {
1346 cpu: 100.0,
1347 io: 10.0,
1348 network: 0.0,
1349 memory: 50.0,
1350 startup_cost: 150.0,
1351 total: 200.0,
1352 };
1353 assert_eq!(
1355 fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1356 std::cmp::Ordering::Less
1357 );
1358 }
1359
1360 #[test]
1361 fn cost_prefers_low_total_when_no_limit() {
1362 let low_total = PlanCost {
1364 cpu: 50.0,
1365 io: 5.0,
1366 network: 0.0,
1367 memory: 0.0,
1368 startup_cost: 30.0,
1369 total: 100.0,
1370 };
1371 let high_total = PlanCost {
1372 cpu: 100.0,
1373 io: 10.0,
1374 network: 0.0,
1375 memory: 0.0,
1376 startup_cost: 5.0,
1377 total: 200.0,
1378 };
1379 assert_eq!(
1380 low_total.prefer_over(&high_total, None, 10_000.0),
1381 std::cmp::Ordering::Less
1382 );
1383 }
1384
1385 #[test]
1386 fn limit_threshold_falls_back_to_total_when_limit_large() {
1387 let low_total = PlanCost {
1389 cpu: 50.0,
1390 io: 5.0,
1391 network: 0.0,
1392 memory: 0.0,
1393 startup_cost: 30.0,
1394 total: 100.0,
1395 };
1396 let low_startup = PlanCost {
1397 cpu: 100.0,
1398 io: 10.0,
1399 network: 0.0,
1400 memory: 0.0,
1401 startup_cost: 5.0,
1402 total: 200.0,
1403 };
1404 assert_eq!(
1405 low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1406 std::cmp::Ordering::Less
1407 );
1408 }
1409
1410 #[test]
1411 fn hash_join_startup_includes_build_cost() {
1412 let left = PlanCost::new(80.0, 8.0, 100.0); let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); let after_build = left.combine_blocking(&build);
1418 assert!(
1419 after_build.startup_cost >= left.total,
1420 "after-build startup ({}) must absorb left.total ({})",
1421 after_build.startup_cost,
1422 left.total
1423 );
1424 assert!(after_build.total >= after_build.startup_cost);
1425 }
1426
1427 #[test]
1428 fn vector_search_reports_nonzero_startup() {
1429 let estimator = CostEstimator::new();
1433 let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1436 assert!(v.startup_cost > 0.0);
1437 assert!(v.startup_cost < v.total);
1438 let _ = estimator; }
1440
1441 #[test]
1442 fn with_startup_clamps_total_below_startup() {
1443 let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1445 assert!(cost.total >= cost.startup_cost);
1446 }
1447
1448 #[test]
1449 fn default_plancost_has_zero_startup() {
1450 let c = PlanCost::default();
1451 assert_eq!(c.startup_cost, 0.0);
1452 assert_eq!(c.total, 0.0);
1453 }
1454
1455 use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1460
1461 fn provider_with_skew() -> Arc<StaticProvider> {
1462 let mut sample: Vec<ColumnValue> = Vec::new();
1466 for i in 0..80 {
1467 sample.push(ColumnValue::Int(i % 10));
1468 }
1469 for i in 0..20 {
1470 sample.push(ColumnValue::Int(10 + i * 50));
1471 }
1472 let h = Histogram::equi_depth_from_sample(sample, 10);
1473
1474 let mcv = MostCommonValues::new(vec![
1475 (ColumnValue::Text("boss".to_string()), 0.5),
1476 (ColumnValue::Text("intern".to_string()), 0.05),
1477 ]);
1478
1479 Arc::new(
1480 StaticProvider::new()
1481 .with_table(
1482 "people",
1483 TableStats {
1484 row_count: 100_000,
1485 avg_row_size: 64,
1486 page_count: 100,
1487 columns: vec![],
1488 },
1489 )
1490 .with_histogram("people", "score", h)
1491 .with_mcv("people", "role", mcv),
1492 )
1493 }
1494
1495 #[test]
1496 fn eq_uses_mcv_when_value_is_tracked() {
1497 let provider = provider_with_skew();
1498 let estimator = CostEstimator::with_stats(provider);
1499 let filter = AstFilter::Compare {
1500 field: FieldRef::column("people", "role"),
1501 op: CompareOp::Eq,
1502 value: Value::text("boss".to_string()),
1503 };
1504 let s = estimator.filter_selectivity(&filter, "people");
1507 assert!(
1508 (s - 0.5).abs() < 1e-9,
1509 "MCV-tracked equality should report exact frequency, got {s}"
1510 );
1511 }
1512
1513 #[test]
1514 fn eq_uses_residual_for_non_mcv_value() {
1515 let provider = provider_with_skew();
1516 let estimator = CostEstimator::with_stats(provider);
1517 let filter = AstFilter::Compare {
1518 field: FieldRef::column("people", "role"),
1519 op: CompareOp::Eq,
1520 value: Value::text("staff".to_string()),
1521 };
1522 let s = estimator.filter_selectivity(&filter, "people");
1526 assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1527 }
1528
1529 #[test]
1530 fn ne_is_one_minus_eq_under_mcv() {
1531 let provider = provider_with_skew();
1532 let estimator = CostEstimator::with_stats(provider);
1533 let filter = AstFilter::Compare {
1534 field: FieldRef::column("people", "role"),
1535 op: CompareOp::Ne,
1536 value: Value::text("boss".to_string()),
1537 };
1538 let s = estimator.filter_selectivity(&filter, "people");
1539 assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1541 }
1542
1543 #[test]
1544 fn range_uses_histogram_when_present() {
1545 let provider = provider_with_skew();
1546 let estimator = CostEstimator::with_stats(provider);
1547 let filter = AstFilter::Compare {
1548 field: FieldRef::column("people", "score"),
1549 op: CompareOp::Le,
1550 value: Value::Integer(9),
1551 };
1552 let s = estimator.filter_selectivity(&filter, "people");
1555 assert!(
1556 s > 0.5,
1557 "histogram-based range selectivity should beat 0.3, got {s}"
1558 );
1559 }
1560
1561 #[test]
1562 fn between_uses_histogram() {
1563 let provider = provider_with_skew();
1564 let estimator = CostEstimator::with_stats(provider);
1565 let filter = AstFilter::Between {
1566 field: FieldRef::column("people", "score"),
1567 low: Value::Integer(0),
1568 high: Value::Integer(9),
1569 };
1570 let s = estimator.filter_selectivity(&filter, "people");
1571 assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1572 }
1573
1574 #[test]
1575 fn graceful_fallback_when_histogram_absent() {
1576 let provider = Arc::new(StaticProvider::new().with_table(
1579 "people",
1580 TableStats {
1581 row_count: 1000,
1582 avg_row_size: 64,
1583 page_count: 10,
1584 columns: vec![],
1585 },
1586 ));
1587 let estimator = CostEstimator::with_stats(provider);
1588 let filter = AstFilter::Compare {
1589 field: FieldRef::column("people", "unknown_col"),
1590 op: CompareOp::Lt,
1591 value: Value::Integer(50),
1592 };
1593 let s = estimator.filter_selectivity(&filter, "people");
1594 assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1595 }
1596}