1use crate::query::plan::{
18 AggregateOp, BinaryOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp,
19 LogicalExpression, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, SkipOp, SortOp,
20 TripleComponent, TripleScanOp, UnaryOp, VectorJoinOp, VectorScanOp,
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
30pub struct HistogramBucket {
31 pub lower_bound: f64,
33 pub upper_bound: f64,
35 pub frequency: u64,
37 pub distinct_count: u64,
39}
40
41impl HistogramBucket {
42 #[must_use]
44 pub fn new(lower_bound: f64, upper_bound: f64, frequency: u64, distinct_count: u64) -> Self {
45 Self {
46 lower_bound,
47 upper_bound,
48 frequency,
49 distinct_count,
50 }
51 }
52
53 #[must_use]
55 pub fn width(&self) -> f64 {
56 self.upper_bound - self.lower_bound
57 }
58
59 #[must_use]
61 pub fn contains(&self, value: f64) -> bool {
62 value >= self.lower_bound && value < self.upper_bound
63 }
64
65 #[must_use]
67 pub fn overlap_fraction(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
68 let effective_lower = lower.unwrap_or(self.lower_bound).max(self.lower_bound);
69 let effective_upper = upper.unwrap_or(self.upper_bound).min(self.upper_bound);
70
71 let bucket_width = self.width();
72 if bucket_width <= 0.0 {
73 return if effective_lower <= self.lower_bound && effective_upper >= self.upper_bound {
74 1.0
75 } else {
76 0.0
77 };
78 }
79
80 let overlap = (effective_upper - effective_lower).max(0.0);
81 (overlap / bucket_width).min(1.0)
82 }
83}
84
85#[derive(Debug, Clone)]
105pub struct EquiDepthHistogram {
106 buckets: Vec<HistogramBucket>,
108 total_rows: u64,
110}
111
112impl EquiDepthHistogram {
113 #[must_use]
115 pub fn new(buckets: Vec<HistogramBucket>) -> Self {
116 let total_rows = buckets.iter().map(|b| b.frequency).sum();
117 Self {
118 buckets,
119 total_rows,
120 }
121 }
122
123 #[must_use]
132 pub fn build(values: &[f64], num_buckets: usize) -> Self {
133 if values.is_empty() || num_buckets == 0 {
134 return Self {
135 buckets: Vec::new(),
136 total_rows: 0,
137 };
138 }
139
140 let num_buckets = num_buckets.min(values.len());
141 let rows_per_bucket = (values.len() + num_buckets - 1) / num_buckets;
142 let mut buckets = Vec::with_capacity(num_buckets);
143
144 let mut start_idx = 0;
145 while start_idx < values.len() {
146 let end_idx = (start_idx + rows_per_bucket).min(values.len());
147 let lower_bound = values[start_idx];
148 let upper_bound = if end_idx < values.len() {
149 values[end_idx]
150 } else {
151 values[end_idx - 1] + 1.0
153 };
154
155 let bucket_values = &values[start_idx..end_idx];
157 let distinct_count = count_distinct(bucket_values);
158
159 buckets.push(HistogramBucket::new(
160 lower_bound,
161 upper_bound,
162 (end_idx - start_idx) as u64,
163 distinct_count,
164 ));
165
166 start_idx = end_idx;
167 }
168
169 Self::new(buckets)
170 }
171
172 #[must_use]
174 pub fn num_buckets(&self) -> usize {
175 self.buckets.len()
176 }
177
178 #[must_use]
180 pub fn total_rows(&self) -> u64 {
181 self.total_rows
182 }
183
184 #[must_use]
186 pub fn buckets(&self) -> &[HistogramBucket] {
187 &self.buckets
188 }
189
190 #[must_use]
199 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
200 if self.buckets.is_empty() || self.total_rows == 0 {
201 return 0.33; }
203
204 let mut matching_rows = 0.0;
205
206 for bucket in &self.buckets {
207 let bucket_lower = bucket.lower_bound;
209 let bucket_upper = bucket.upper_bound;
210
211 if let Some(l) = lower
213 && bucket_upper <= l
214 {
215 continue;
216 }
217 if let Some(u) = upper
218 && bucket_lower >= u
219 {
220 continue;
221 }
222
223 let overlap = bucket.overlap_fraction(lower, upper);
225 matching_rows += overlap * bucket.frequency as f64;
226 }
227
228 (matching_rows / self.total_rows as f64).clamp(0.0, 1.0)
229 }
230
231 #[must_use]
235 pub fn equality_selectivity(&self, value: f64) -> f64 {
236 if self.buckets.is_empty() || self.total_rows == 0 {
237 return 0.01; }
239
240 for bucket in &self.buckets {
242 if bucket.contains(value) {
243 if bucket.distinct_count > 0 {
245 return (bucket.frequency as f64
246 / bucket.distinct_count as f64
247 / self.total_rows as f64)
248 .min(1.0);
249 }
250 }
251 }
252
253 0.001
255 }
256
257 #[must_use]
259 pub fn min_value(&self) -> Option<f64> {
260 self.buckets.first().map(|b| b.lower_bound)
261 }
262
263 #[must_use]
265 pub fn max_value(&self) -> Option<f64> {
266 self.buckets.last().map(|b| b.upper_bound)
267 }
268}
269
270fn count_distinct(sorted_values: &[f64]) -> u64 {
272 if sorted_values.is_empty() {
273 return 0;
274 }
275
276 let mut count = 1u64;
277 let mut prev = sorted_values[0];
278
279 for &val in &sorted_values[1..] {
280 if (val - prev).abs() > f64::EPSILON {
281 count += 1;
282 prev = val;
283 }
284 }
285
286 count
287}
288
289#[derive(Debug, Clone)]
291pub struct TableStats {
292 pub row_count: u64,
294 pub columns: HashMap<String, ColumnStats>,
296}
297
298impl TableStats {
299 #[must_use]
301 pub fn new(row_count: u64) -> Self {
302 Self {
303 row_count,
304 columns: HashMap::new(),
305 }
306 }
307
308 pub fn with_column(mut self, name: &str, stats: ColumnStats) -> Self {
310 self.columns.insert(name.to_string(), stats);
311 self
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct ColumnStats {
318 pub distinct_count: u64,
320 pub null_count: u64,
322 pub min_value: Option<f64>,
324 pub max_value: Option<f64>,
326 pub histogram: Option<EquiDepthHistogram>,
328}
329
330impl ColumnStats {
331 #[must_use]
333 pub fn new(distinct_count: u64) -> Self {
334 Self {
335 distinct_count,
336 null_count: 0,
337 min_value: None,
338 max_value: None,
339 histogram: None,
340 }
341 }
342
343 #[must_use]
345 pub fn with_nulls(mut self, null_count: u64) -> Self {
346 self.null_count = null_count;
347 self
348 }
349
350 #[must_use]
352 pub fn with_range(mut self, min: f64, max: f64) -> Self {
353 self.min_value = Some(min);
354 self.max_value = Some(max);
355 self
356 }
357
358 #[must_use]
360 pub fn with_histogram(mut self, histogram: EquiDepthHistogram) -> Self {
361 self.histogram = Some(histogram);
362 self
363 }
364
365 #[must_use]
373 pub fn from_values(mut values: Vec<f64>, num_buckets: usize) -> Self {
374 if values.is_empty() {
375 return Self::new(0);
376 }
377
378 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
380
381 let min = values.first().copied();
382 let max = values.last().copied();
383 let distinct_count = count_distinct(&values);
384 let histogram = EquiDepthHistogram::build(&values, num_buckets);
385
386 Self {
387 distinct_count,
388 null_count: 0,
389 min_value: min,
390 max_value: max,
391 histogram: Some(histogram),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct SelectivityConfig {
403 pub default: f64,
405 pub equality: f64,
407 pub inequality: f64,
409 pub range: f64,
411 pub string_ops: f64,
413 pub membership: f64,
415 pub is_null: f64,
417 pub is_not_null: f64,
419 pub distinct_fraction: f64,
421}
422
423impl SelectivityConfig {
424 #[must_use]
426 pub fn new() -> Self {
427 Self {
428 default: 0.1,
429 equality: 0.01,
430 inequality: 0.99,
431 range: 0.33,
432 string_ops: 0.1,
433 membership: 0.1,
434 is_null: 0.05,
435 is_not_null: 0.95,
436 distinct_fraction: 0.5,
437 }
438 }
439}
440
441impl Default for SelectivityConfig {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct EstimationEntry {
450 pub operator: String,
452 pub estimated: f64,
454 pub actual: f64,
456}
457
458impl EstimationEntry {
459 #[must_use]
465 pub fn error_ratio(&self) -> f64 {
466 if self.estimated.abs() < f64::EPSILON {
467 if self.actual.abs() < f64::EPSILON {
468 1.0
469 } else {
470 f64::INFINITY
471 }
472 } else {
473 self.actual / self.estimated
474 }
475 }
476}
477
478#[derive(Debug, Clone, Default)]
485pub struct EstimationLog {
486 entries: Vec<EstimationEntry>,
488 replan_threshold: f64,
492}
493
494impl EstimationLog {
495 #[must_use]
497 pub fn new(replan_threshold: f64) -> Self {
498 Self {
499 entries: Vec::new(),
500 replan_threshold,
501 }
502 }
503
504 pub fn record(&mut self, operator: impl Into<String>, estimated: f64, actual: f64) {
506 self.entries.push(EstimationEntry {
507 operator: operator.into(),
508 estimated,
509 actual,
510 });
511 }
512
513 #[must_use]
515 pub fn entries(&self) -> &[EstimationEntry] {
516 &self.entries
517 }
518
519 #[must_use]
522 pub fn should_replan(&self) -> bool {
523 self.entries.iter().any(|e| {
524 let ratio = e.error_ratio();
525 ratio > self.replan_threshold || ratio < 1.0 / self.replan_threshold
526 })
527 }
528
529 #[must_use]
531 pub fn max_error_ratio(&self) -> f64 {
532 self.entries
533 .iter()
534 .map(|e| {
535 let r = e.error_ratio();
536 if r < 1.0 { 1.0 / r } else { r }
538 })
539 .fold(1.0_f64, f64::max)
540 }
541
542 pub fn clear(&mut self) {
544 self.entries.clear();
545 }
546}
547
548pub struct CardinalityEstimator {
550 table_stats: HashMap<String, TableStats>,
552 default_row_count: u64,
554 default_selectivity: f64,
556 avg_fanout: f64,
558 selectivity_config: SelectivityConfig,
560 rdf_statistics: Option<grafeo_core::statistics::RdfStatistics>,
562}
563
564impl CardinalityEstimator {
565 #[must_use]
567 pub fn new() -> Self {
568 let config = SelectivityConfig::new();
569 Self {
570 table_stats: HashMap::new(),
571 default_row_count: 1000,
572 default_selectivity: config.default,
573 avg_fanout: 10.0,
574 selectivity_config: config,
575 rdf_statistics: None,
576 }
577 }
578
579 #[must_use]
581 pub fn with_selectivity_config(config: SelectivityConfig) -> Self {
582 Self {
583 table_stats: HashMap::new(),
584 default_row_count: 1000,
585 default_selectivity: config.default,
586 avg_fanout: 10.0,
587 selectivity_config: config,
588 rdf_statistics: None,
589 }
590 }
591
592 #[must_use]
594 pub fn selectivity_config(&self) -> &SelectivityConfig {
595 &self.selectivity_config
596 }
597
598 #[must_use]
600 pub fn create_estimation_log() -> EstimationLog {
601 EstimationLog::new(10.0)
602 }
603
604 #[must_use]
610 pub fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
611 let mut estimator = Self::new();
612
613 if stats.total_nodes > 0 {
615 estimator.default_row_count = stats.total_nodes;
616 }
617
618 for (label, label_stats) in &stats.labels {
620 let mut table_stats = TableStats::new(label_stats.node_count);
621
622 for (prop, col_stats) in &label_stats.properties {
624 let optimizer_col =
625 ColumnStats::new(col_stats.distinct_count).with_nulls(col_stats.null_count);
626 table_stats = table_stats.with_column(prop, optimizer_col);
627 }
628
629 estimator.add_table_stats(label, table_stats);
630 }
631
632 if !stats.edge_types.is_empty() {
634 let total_out_degree: f64 = stats.edge_types.values().map(|e| e.avg_out_degree).sum();
635 estimator.avg_fanout = total_out_degree / stats.edge_types.len() as f64;
636 } else if stats.total_nodes > 0 {
637 estimator.avg_fanout = stats.total_edges as f64 / stats.total_nodes as f64;
638 }
639
640 if estimator.avg_fanout < 1.0 {
642 estimator.avg_fanout = 1.0;
643 }
644
645 estimator
646 }
647
648 #[must_use]
653 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
654 let mut estimator = Self::new();
655 if rdf_stats.total_triples > 0 {
656 estimator.default_row_count = rdf_stats.total_triples;
657 }
658 estimator.rdf_statistics = Some(rdf_stats);
659 estimator
660 }
661
662 pub fn add_table_stats(&mut self, name: &str, stats: TableStats) {
664 self.table_stats.insert(name.to_string(), stats);
665 }
666
667 pub fn set_avg_fanout(&mut self, fanout: f64) {
669 self.avg_fanout = fanout;
670 }
671
672 #[must_use]
674 pub fn estimate(&self, op: &LogicalOperator) -> f64 {
675 match op {
676 LogicalOperator::NodeScan(scan) => self.estimate_node_scan(scan),
677 LogicalOperator::Filter(filter) => self.estimate_filter(filter),
678 LogicalOperator::Project(project) => self.estimate_project(project),
679 LogicalOperator::Expand(expand) => self.estimate_expand(expand),
680 LogicalOperator::Join(join) => self.estimate_join(join),
681 LogicalOperator::Aggregate(agg) => self.estimate_aggregate(agg),
682 LogicalOperator::Sort(sort) => self.estimate_sort(sort),
683 LogicalOperator::Distinct(distinct) => self.estimate_distinct(distinct),
684 LogicalOperator::Limit(limit) => self.estimate_limit(limit),
685 LogicalOperator::Skip(skip) => self.estimate_skip(skip),
686 LogicalOperator::Return(ret) => self.estimate(&ret.input),
687 LogicalOperator::Empty => 0.0,
688 LogicalOperator::VectorScan(scan) => self.estimate_vector_scan(scan),
689 LogicalOperator::VectorJoin(join) => self.estimate_vector_join(join),
690 LogicalOperator::MultiWayJoin(mwj) => self.estimate_multi_way_join(mwj),
691 LogicalOperator::LeftJoin(lj) => self.estimate_left_join(lj),
692 LogicalOperator::TripleScan(scan) => self.estimate_triple_scan(scan),
693 _ => self.default_row_count as f64,
694 }
695 }
696
697 fn estimate_node_scan(&self, scan: &NodeScanOp) -> f64 {
699 if let Some(label) = &scan.label
700 && let Some(stats) = self.table_stats.get(label)
701 {
702 return stats.row_count as f64;
703 }
704 self.default_row_count as f64
706 }
707
708 fn estimate_triple_scan(&self, scan: &TripleScanOp) -> f64 {
714 let base = if let Some(ref input) = scan.input {
717 self.estimate(input)
718 } else {
719 1.0
720 };
721
722 let Some(rdf_stats) = &self.rdf_statistics else {
723 return if scan.input.is_some() {
724 base * self.default_row_count as f64
725 } else {
726 self.default_row_count as f64
727 };
728 };
729
730 let subject_bound = matches!(
731 scan.subject,
732 TripleComponent::Iri(_) | TripleComponent::Literal(_)
733 );
734 let object_bound = matches!(
735 scan.object,
736 TripleComponent::Iri(_) | TripleComponent::Literal(_)
737 );
738 let predicate_iri = match &scan.predicate {
739 TripleComponent::Iri(iri) => Some(iri.as_str()),
740 _ => None,
741 };
742
743 let pattern_card = rdf_stats.estimate_triple_pattern_cardinality(
744 subject_bound,
745 predicate_iri,
746 object_bound,
747 );
748
749 if scan.input.is_some() {
750 let selectivity = if rdf_stats.total_triples > 0 {
752 pattern_card / rdf_stats.total_triples as f64
753 } else {
754 1.0
755 };
756 (base * pattern_card * selectivity).max(1.0)
757 } else {
758 pattern_card.max(1.0)
759 }
760 }
761
762 fn estimate_filter(&self, filter: &FilterOp) -> f64 {
764 let input_cardinality = self.estimate(&filter.input);
765 let selectivity = self.estimate_selectivity(&filter.predicate);
766 (input_cardinality * selectivity).max(1.0)
767 }
768
769 fn estimate_project(&self, project: &ProjectOp) -> f64 {
771 self.estimate(&project.input)
772 }
773
774 fn estimate_expand(&self, expand: &ExpandOp) -> f64 {
776 let input_cardinality = self.estimate(&expand.input);
777
778 let fanout = if !expand.edge_types.is_empty() {
780 self.avg_fanout * 0.5
782 } else {
783 self.avg_fanout
784 };
785
786 let path_multiplier = if expand.max_hops.unwrap_or(1) > 1 {
788 let min = expand.min_hops as f64;
789 let max = expand.max_hops.unwrap_or(expand.min_hops + 3) as f64;
790 (fanout.powf(max + 1.0) - fanout.powf(min)) / (fanout - 1.0)
792 } else {
793 fanout
794 };
795
796 (input_cardinality * path_multiplier).max(1.0)
797 }
798
799 fn estimate_join(&self, join: &JoinOp) -> f64 {
801 let left_card = self.estimate(&join.left);
802 let right_card = self.estimate(&join.right);
803
804 match join.join_type {
805 JoinType::Cross => left_card * right_card,
806 JoinType::Inner => {
807 let selectivity = if join.conditions.is_empty() {
809 1.0 } else {
811 0.1_f64.powi(join.conditions.len() as i32)
813 };
814 (left_card * right_card * selectivity).max(1.0)
815 }
816 JoinType::Left => {
817 let inner_card = self.estimate_join(&JoinOp {
819 left: join.left.clone(),
820 right: join.right.clone(),
821 join_type: JoinType::Inner,
822 conditions: join.conditions.clone(),
823 });
824 inner_card.max(left_card)
825 }
826 JoinType::Right => {
827 let inner_card = self.estimate_join(&JoinOp {
829 left: join.left.clone(),
830 right: join.right.clone(),
831 join_type: JoinType::Inner,
832 conditions: join.conditions.clone(),
833 });
834 inner_card.max(right_card)
835 }
836 JoinType::Full => {
837 let inner_card = self.estimate_join(&JoinOp {
839 left: join.left.clone(),
840 right: join.right.clone(),
841 join_type: JoinType::Inner,
842 conditions: join.conditions.clone(),
843 });
844 inner_card.max(left_card.max(right_card))
845 }
846 JoinType::Semi => {
847 (left_card * self.default_selectivity).max(1.0)
849 }
850 JoinType::Anti => {
851 (left_card * (1.0 - self.default_selectivity)).max(1.0)
853 }
854 }
855 }
856
857 fn estimate_left_join(&self, lj: &LeftJoinOp) -> f64 {
863 let left_card = self.estimate(&lj.left);
864 let right_card = self.estimate(&lj.right);
865
866 let inner_estimate = left_card * right_card * self.default_selectivity;
868 inner_estimate.max(left_card).max(1.0)
869 }
870
871 fn estimate_aggregate(&self, agg: &AggregateOp) -> f64 {
873 let input_cardinality = self.estimate(&agg.input);
874
875 if agg.group_by.is_empty() {
876 1.0
878 } else {
879 let group_reduction = 10.0_f64.powi(agg.group_by.len() as i32);
882 (input_cardinality / group_reduction).max(1.0)
883 }
884 }
885
886 fn estimate_sort(&self, sort: &SortOp) -> f64 {
888 self.estimate(&sort.input)
889 }
890
891 fn estimate_distinct(&self, distinct: &DistinctOp) -> f64 {
893 let input_cardinality = self.estimate(&distinct.input);
894 (input_cardinality * self.selectivity_config.distinct_fraction).max(1.0)
895 }
896
897 fn estimate_limit(&self, limit: &LimitOp) -> f64 {
899 let input_cardinality = self.estimate(&limit.input);
900 limit.count.estimate().min(input_cardinality)
901 }
902
903 fn estimate_skip(&self, skip: &SkipOp) -> f64 {
905 let input_cardinality = self.estimate(&skip.input);
906 (input_cardinality - skip.count.estimate()).max(0.0)
907 }
908
909 fn estimate_vector_scan(&self, scan: &VectorScanOp) -> f64 {
914 let base_k = scan.k as f64;
915
916 let selectivity = if scan.min_similarity.is_some() || scan.max_distance.is_some() {
918 0.7
920 } else {
921 1.0
922 };
923
924 (base_k * selectivity).max(1.0)
925 }
926
927 fn estimate_vector_join(&self, join: &VectorJoinOp) -> f64 {
931 let input_cardinality = self.estimate(&join.input);
932 let k = join.k as f64;
933
934 let selectivity = if join.min_similarity.is_some() || join.max_distance.is_some() {
936 0.7
937 } else {
938 1.0
939 };
940
941 (input_cardinality * k * selectivity).max(1.0)
942 }
943
944 fn estimate_multi_way_join(&self, mwj: &MultiWayJoinOp) -> f64 {
949 if mwj.inputs.is_empty() {
950 return 0.0;
951 }
952 let cardinalities: Vec<f64> = mwj
953 .inputs
954 .iter()
955 .map(|input| self.estimate(input))
956 .collect();
957 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
958 let n = cardinalities.len() as f64;
959 (min_card.powf(n / 2.0)).max(1.0)
961 }
962
963 fn estimate_selectivity(&self, expr: &LogicalExpression) -> f64 {
965 match expr {
966 LogicalExpression::Binary { left, op, right } => {
967 self.estimate_binary_selectivity(left, *op, right)
968 }
969 LogicalExpression::Unary { op, operand } => {
970 self.estimate_unary_selectivity(*op, operand)
971 }
972 LogicalExpression::Literal(value) => {
973 if let grafeo_common::types::Value::Bool(b) = value {
975 if *b { 1.0 } else { 0.0 }
976 } else {
977 self.default_selectivity
978 }
979 }
980 _ => self.default_selectivity,
981 }
982 }
983
984 fn estimate_binary_selectivity(
986 &self,
987 left: &LogicalExpression,
988 op: BinaryOp,
989 right: &LogicalExpression,
990 ) -> f64 {
991 match op {
992 BinaryOp::Eq => {
994 if let Some(selectivity) = self.try_equality_selectivity(left, right) {
995 return selectivity;
996 }
997 self.selectivity_config.equality
998 }
999 BinaryOp::Ne => self.selectivity_config.inequality,
1001 BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
1003 if let Some(selectivity) = self.try_range_selectivity(left, op, right) {
1004 return selectivity;
1005 }
1006 self.selectivity_config.range
1007 }
1008 BinaryOp::And => {
1010 let left_sel = self.estimate_selectivity(left);
1011 let right_sel = self.estimate_selectivity(right);
1012 left_sel * right_sel
1014 }
1015 BinaryOp::Or => {
1016 let left_sel = self.estimate_selectivity(left);
1017 let right_sel = self.estimate_selectivity(right);
1018 (left_sel + right_sel - left_sel * right_sel).min(1.0)
1021 }
1022 BinaryOp::StartsWith | BinaryOp::EndsWith | BinaryOp::Contains | BinaryOp::Like => {
1024 self.selectivity_config.string_ops
1025 }
1026 BinaryOp::In => self.selectivity_config.membership,
1028 _ => self.default_selectivity,
1030 }
1031 }
1032
1033 fn try_equality_selectivity(
1035 &self,
1036 left: &LogicalExpression,
1037 right: &LogicalExpression,
1038 ) -> Option<f64> {
1039 let (label, column, value) = self.extract_column_and_value(left, right)?;
1041
1042 let stats = self.get_column_stats(&label, &column)?;
1044
1045 if let Some(ref histogram) = stats.histogram {
1047 return Some(histogram.equality_selectivity(value));
1048 }
1049
1050 if stats.distinct_count > 0 {
1052 return Some(1.0 / stats.distinct_count as f64);
1053 }
1054
1055 None
1056 }
1057
1058 fn try_range_selectivity(
1060 &self,
1061 left: &LogicalExpression,
1062 op: BinaryOp,
1063 right: &LogicalExpression,
1064 ) -> Option<f64> {
1065 let (label, column, value) = self.extract_column_and_value(left, right)?;
1067
1068 let stats = self.get_column_stats(&label, &column)?;
1070
1071 let (lower, upper) = match op {
1073 BinaryOp::Lt => (None, Some(value)),
1074 BinaryOp::Le => (None, Some(value + f64::EPSILON)),
1075 BinaryOp::Gt => (Some(value + f64::EPSILON), None),
1076 BinaryOp::Ge => (Some(value), None),
1077 _ => return None,
1078 };
1079
1080 if let Some(ref histogram) = stats.histogram {
1082 return Some(histogram.range_selectivity(lower, upper));
1083 }
1084
1085 if let (Some(min), Some(max)) = (stats.min_value, stats.max_value) {
1087 let range = max - min;
1088 if range <= 0.0 {
1089 return Some(1.0);
1090 }
1091
1092 let effective_lower = lower.unwrap_or(min).max(min);
1093 let effective_upper = upper.unwrap_or(max).min(max);
1094 let overlap = (effective_upper - effective_lower).max(0.0);
1095 return Some((overlap / range).clamp(0.0, 1.0));
1096 }
1097
1098 None
1099 }
1100
1101 fn extract_column_and_value(
1106 &self,
1107 left: &LogicalExpression,
1108 right: &LogicalExpression,
1109 ) -> Option<(String, String, f64)> {
1110 if let Some(result) = self.try_extract_property_literal(left, right) {
1112 return Some(result);
1113 }
1114
1115 self.try_extract_property_literal(right, left)
1117 }
1118
1119 fn try_extract_property_literal(
1121 &self,
1122 property_expr: &LogicalExpression,
1123 literal_expr: &LogicalExpression,
1124 ) -> Option<(String, String, f64)> {
1125 let (variable, property) = match property_expr {
1127 LogicalExpression::Property { variable, property } => {
1128 (variable.clone(), property.clone())
1129 }
1130 _ => return None,
1131 };
1132
1133 let value = match literal_expr {
1135 LogicalExpression::Literal(grafeo_common::types::Value::Int64(n)) => *n as f64,
1136 LogicalExpression::Literal(grafeo_common::types::Value::Float64(f)) => *f,
1137 _ => return None,
1138 };
1139
1140 for label in self.table_stats.keys() {
1144 if let Some(stats) = self.table_stats.get(label)
1145 && stats.columns.contains_key(&property)
1146 {
1147 return Some((label.clone(), property, value));
1148 }
1149 }
1150
1151 Some((variable, property, value))
1153 }
1154
1155 fn estimate_unary_selectivity(&self, op: UnaryOp, _operand: &LogicalExpression) -> f64 {
1157 match op {
1158 UnaryOp::Not => 1.0 - self.default_selectivity,
1159 UnaryOp::IsNull => self.selectivity_config.is_null,
1160 UnaryOp::IsNotNull => self.selectivity_config.is_not_null,
1161 UnaryOp::Neg => 1.0, }
1163 }
1164
1165 fn get_column_stats(&self, label: &str, column: &str) -> Option<&ColumnStats> {
1167 self.table_stats.get(label)?.columns.get(column)
1168 }
1169}
1170
1171impl Default for CardinalityEstimator {
1172 fn default() -> Self {
1173 Self::new()
1174 }
1175}
1176
1177#[cfg(test)]
1178mod tests {
1179 use super::*;
1180 use crate::query::plan::{
1181 DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinCondition, NodeScanOp, PathMode,
1182 ProjectOp, Projection, ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder,
1183 };
1184 use grafeo_common::types::Value;
1185
1186 #[test]
1187 fn test_node_scan_with_stats() {
1188 let mut estimator = CardinalityEstimator::new();
1189 estimator.add_table_stats("Person", TableStats::new(5000));
1190
1191 let scan = LogicalOperator::NodeScan(NodeScanOp {
1192 variable: "n".to_string(),
1193 label: Some("Person".to_string()),
1194 input: None,
1195 });
1196
1197 let cardinality = estimator.estimate(&scan);
1198 assert!((cardinality - 5000.0).abs() < 0.001);
1199 }
1200
1201 #[test]
1202 fn test_filter_reduces_cardinality() {
1203 let mut estimator = CardinalityEstimator::new();
1204 estimator.add_table_stats("Person", TableStats::new(1000));
1205
1206 let filter = LogicalOperator::Filter(FilterOp {
1207 predicate: LogicalExpression::Binary {
1208 left: Box::new(LogicalExpression::Property {
1209 variable: "n".to_string(),
1210 property: "age".to_string(),
1211 }),
1212 op: BinaryOp::Eq,
1213 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1214 },
1215 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1216 variable: "n".to_string(),
1217 label: Some("Person".to_string()),
1218 input: None,
1219 })),
1220 pushdown_hint: None,
1221 });
1222
1223 let cardinality = estimator.estimate(&filter);
1224 assert!(cardinality < 1000.0);
1226 assert!(cardinality >= 1.0);
1227 }
1228
1229 #[test]
1230 fn test_join_cardinality() {
1231 let mut estimator = CardinalityEstimator::new();
1232 estimator.add_table_stats("Person", TableStats::new(1000));
1233 estimator.add_table_stats("Company", TableStats::new(100));
1234
1235 let join = LogicalOperator::Join(JoinOp {
1236 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1237 variable: "p".to_string(),
1238 label: Some("Person".to_string()),
1239 input: None,
1240 })),
1241 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1242 variable: "c".to_string(),
1243 label: Some("Company".to_string()),
1244 input: None,
1245 })),
1246 join_type: JoinType::Inner,
1247 conditions: vec![JoinCondition {
1248 left: LogicalExpression::Property {
1249 variable: "p".to_string(),
1250 property: "company_id".to_string(),
1251 },
1252 right: LogicalExpression::Property {
1253 variable: "c".to_string(),
1254 property: "id".to_string(),
1255 },
1256 }],
1257 });
1258
1259 let cardinality = estimator.estimate(&join);
1260 assert!(cardinality < 1000.0 * 100.0);
1262 }
1263
1264 #[test]
1265 fn test_limit_caps_cardinality() {
1266 let mut estimator = CardinalityEstimator::new();
1267 estimator.add_table_stats("Person", TableStats::new(1000));
1268
1269 let limit = LogicalOperator::Limit(LimitOp {
1270 count: 10.into(),
1271 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1272 variable: "n".to_string(),
1273 label: Some("Person".to_string()),
1274 input: None,
1275 })),
1276 });
1277
1278 let cardinality = estimator.estimate(&limit);
1279 assert!((cardinality - 10.0).abs() < 0.001);
1280 }
1281
1282 #[test]
1283 fn test_aggregate_reduces_cardinality() {
1284 let mut estimator = CardinalityEstimator::new();
1285 estimator.add_table_stats("Person", TableStats::new(1000));
1286
1287 let global_agg = LogicalOperator::Aggregate(AggregateOp {
1289 group_by: vec![],
1290 aggregates: vec![],
1291 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1292 variable: "n".to_string(),
1293 label: Some("Person".to_string()),
1294 input: None,
1295 })),
1296 having: None,
1297 });
1298
1299 let cardinality = estimator.estimate(&global_agg);
1300 assert!((cardinality - 1.0).abs() < 0.001);
1301
1302 let group_agg = LogicalOperator::Aggregate(AggregateOp {
1304 group_by: vec![LogicalExpression::Property {
1305 variable: "n".to_string(),
1306 property: "city".to_string(),
1307 }],
1308 aggregates: vec![],
1309 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1310 variable: "n".to_string(),
1311 label: Some("Person".to_string()),
1312 input: None,
1313 })),
1314 having: None,
1315 });
1316
1317 let cardinality = estimator.estimate(&group_agg);
1318 assert!(cardinality < 1000.0);
1320 }
1321
1322 #[test]
1323 fn test_node_scan_without_stats() {
1324 let estimator = CardinalityEstimator::new();
1325
1326 let scan = LogicalOperator::NodeScan(NodeScanOp {
1327 variable: "n".to_string(),
1328 label: Some("Unknown".to_string()),
1329 input: None,
1330 });
1331
1332 let cardinality = estimator.estimate(&scan);
1333 assert!((cardinality - 1000.0).abs() < 0.001);
1335 }
1336
1337 #[test]
1338 fn test_node_scan_no_label() {
1339 let estimator = CardinalityEstimator::new();
1340
1341 let scan = LogicalOperator::NodeScan(NodeScanOp {
1342 variable: "n".to_string(),
1343 label: None,
1344 input: None,
1345 });
1346
1347 let cardinality = estimator.estimate(&scan);
1348 assert!((cardinality - 1000.0).abs() < 0.001);
1350 }
1351
1352 #[test]
1353 fn test_filter_inequality_selectivity() {
1354 let mut estimator = CardinalityEstimator::new();
1355 estimator.add_table_stats("Person", TableStats::new(1000));
1356
1357 let filter = LogicalOperator::Filter(FilterOp {
1358 predicate: LogicalExpression::Binary {
1359 left: Box::new(LogicalExpression::Property {
1360 variable: "n".to_string(),
1361 property: "age".to_string(),
1362 }),
1363 op: BinaryOp::Ne,
1364 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1365 },
1366 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1367 variable: "n".to_string(),
1368 label: Some("Person".to_string()),
1369 input: None,
1370 })),
1371 pushdown_hint: None,
1372 });
1373
1374 let cardinality = estimator.estimate(&filter);
1375 assert!(cardinality > 900.0);
1377 }
1378
1379 #[test]
1380 fn test_filter_range_selectivity() {
1381 let mut estimator = CardinalityEstimator::new();
1382 estimator.add_table_stats("Person", TableStats::new(1000));
1383
1384 let filter = LogicalOperator::Filter(FilterOp {
1385 predicate: LogicalExpression::Binary {
1386 left: Box::new(LogicalExpression::Property {
1387 variable: "n".to_string(),
1388 property: "age".to_string(),
1389 }),
1390 op: BinaryOp::Gt,
1391 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1392 },
1393 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1394 variable: "n".to_string(),
1395 label: Some("Person".to_string()),
1396 input: None,
1397 })),
1398 pushdown_hint: None,
1399 });
1400
1401 let cardinality = estimator.estimate(&filter);
1402 assert!(cardinality < 500.0);
1404 assert!(cardinality > 100.0);
1405 }
1406
1407 #[test]
1408 fn test_filter_and_selectivity() {
1409 let mut estimator = CardinalityEstimator::new();
1410 estimator.add_table_stats("Person", TableStats::new(1000));
1411
1412 let filter = LogicalOperator::Filter(FilterOp {
1415 predicate: LogicalExpression::Binary {
1416 left: Box::new(LogicalExpression::Binary {
1417 left: Box::new(LogicalExpression::Property {
1418 variable: "n".to_string(),
1419 property: "city".to_string(),
1420 }),
1421 op: BinaryOp::Eq,
1422 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1423 }),
1424 op: BinaryOp::And,
1425 right: Box::new(LogicalExpression::Binary {
1426 left: Box::new(LogicalExpression::Property {
1427 variable: "n".to_string(),
1428 property: "age".to_string(),
1429 }),
1430 op: BinaryOp::Eq,
1431 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1432 }),
1433 },
1434 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1435 variable: "n".to_string(),
1436 label: Some("Person".to_string()),
1437 input: None,
1438 })),
1439 pushdown_hint: None,
1440 });
1441
1442 let cardinality = estimator.estimate(&filter);
1443 assert!(cardinality < 100.0);
1446 assert!(cardinality >= 1.0);
1447 }
1448
1449 #[test]
1450 fn test_filter_or_selectivity() {
1451 let mut estimator = CardinalityEstimator::new();
1452 estimator.add_table_stats("Person", TableStats::new(1000));
1453
1454 let filter = LogicalOperator::Filter(FilterOp {
1458 predicate: LogicalExpression::Binary {
1459 left: Box::new(LogicalExpression::Binary {
1460 left: Box::new(LogicalExpression::Property {
1461 variable: "n".to_string(),
1462 property: "city".to_string(),
1463 }),
1464 op: BinaryOp::Eq,
1465 right: Box::new(LogicalExpression::Literal(Value::String("NYC".into()))),
1466 }),
1467 op: BinaryOp::Or,
1468 right: Box::new(LogicalExpression::Binary {
1469 left: Box::new(LogicalExpression::Property {
1470 variable: "n".to_string(),
1471 property: "city".to_string(),
1472 }),
1473 op: BinaryOp::Eq,
1474 right: Box::new(LogicalExpression::Literal(Value::String("LA".into()))),
1475 }),
1476 },
1477 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1478 variable: "n".to_string(),
1479 label: Some("Person".to_string()),
1480 input: None,
1481 })),
1482 pushdown_hint: None,
1483 });
1484
1485 let cardinality = estimator.estimate(&filter);
1486 assert!(cardinality < 100.0);
1488 assert!(cardinality >= 1.0);
1489 }
1490
1491 #[test]
1492 fn test_filter_literal_true() {
1493 let mut estimator = CardinalityEstimator::new();
1494 estimator.add_table_stats("Person", TableStats::new(1000));
1495
1496 let filter = LogicalOperator::Filter(FilterOp {
1497 predicate: LogicalExpression::Literal(Value::Bool(true)),
1498 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1499 variable: "n".to_string(),
1500 label: Some("Person".to_string()),
1501 input: None,
1502 })),
1503 pushdown_hint: None,
1504 });
1505
1506 let cardinality = estimator.estimate(&filter);
1507 assert!((cardinality - 1000.0).abs() < 0.001);
1509 }
1510
1511 #[test]
1512 fn test_filter_literal_false() {
1513 let mut estimator = CardinalityEstimator::new();
1514 estimator.add_table_stats("Person", TableStats::new(1000));
1515
1516 let filter = LogicalOperator::Filter(FilterOp {
1517 predicate: LogicalExpression::Literal(Value::Bool(false)),
1518 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1519 variable: "n".to_string(),
1520 label: Some("Person".to_string()),
1521 input: None,
1522 })),
1523 pushdown_hint: None,
1524 });
1525
1526 let cardinality = estimator.estimate(&filter);
1527 assert!((cardinality - 1.0).abs() < 0.001);
1529 }
1530
1531 #[test]
1532 fn test_unary_not_selectivity() {
1533 let mut estimator = CardinalityEstimator::new();
1534 estimator.add_table_stats("Person", TableStats::new(1000));
1535
1536 let filter = LogicalOperator::Filter(FilterOp {
1537 predicate: LogicalExpression::Unary {
1538 op: UnaryOp::Not,
1539 operand: Box::new(LogicalExpression::Literal(Value::Bool(true))),
1540 },
1541 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1542 variable: "n".to_string(),
1543 label: Some("Person".to_string()),
1544 input: None,
1545 })),
1546 pushdown_hint: None,
1547 });
1548
1549 let cardinality = estimator.estimate(&filter);
1550 assert!(cardinality < 1000.0);
1552 }
1553
1554 #[test]
1555 fn test_unary_is_null_selectivity() {
1556 let mut estimator = CardinalityEstimator::new();
1557 estimator.add_table_stats("Person", TableStats::new(1000));
1558
1559 let filter = LogicalOperator::Filter(FilterOp {
1560 predicate: LogicalExpression::Unary {
1561 op: UnaryOp::IsNull,
1562 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1563 },
1564 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1565 variable: "n".to_string(),
1566 label: Some("Person".to_string()),
1567 input: None,
1568 })),
1569 pushdown_hint: None,
1570 });
1571
1572 let cardinality = estimator.estimate(&filter);
1573 assert!(cardinality < 100.0);
1575 }
1576
1577 #[test]
1578 fn test_expand_cardinality() {
1579 let mut estimator = CardinalityEstimator::new();
1580 estimator.add_table_stats("Person", TableStats::new(100));
1581
1582 let expand = LogicalOperator::Expand(ExpandOp {
1583 from_variable: "a".to_string(),
1584 to_variable: "b".to_string(),
1585 edge_variable: None,
1586 direction: ExpandDirection::Outgoing,
1587 edge_types: vec![],
1588 min_hops: 1,
1589 max_hops: Some(1),
1590 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1591 variable: "a".to_string(),
1592 label: Some("Person".to_string()),
1593 input: None,
1594 })),
1595 path_alias: None,
1596 path_mode: PathMode::Walk,
1597 });
1598
1599 let cardinality = estimator.estimate(&expand);
1600 assert!(cardinality > 100.0);
1602 }
1603
1604 #[test]
1605 fn test_expand_with_edge_type_filter() {
1606 let mut estimator = CardinalityEstimator::new();
1607 estimator.add_table_stats("Person", TableStats::new(100));
1608
1609 let expand = LogicalOperator::Expand(ExpandOp {
1610 from_variable: "a".to_string(),
1611 to_variable: "b".to_string(),
1612 edge_variable: None,
1613 direction: ExpandDirection::Outgoing,
1614 edge_types: vec!["KNOWS".to_string()],
1615 min_hops: 1,
1616 max_hops: Some(1),
1617 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1618 variable: "a".to_string(),
1619 label: Some("Person".to_string()),
1620 input: None,
1621 })),
1622 path_alias: None,
1623 path_mode: PathMode::Walk,
1624 });
1625
1626 let cardinality = estimator.estimate(&expand);
1627 assert!(cardinality > 100.0);
1629 }
1630
1631 #[test]
1632 fn test_expand_variable_length() {
1633 let mut estimator = CardinalityEstimator::new();
1634 estimator.add_table_stats("Person", TableStats::new(100));
1635
1636 let expand = LogicalOperator::Expand(ExpandOp {
1637 from_variable: "a".to_string(),
1638 to_variable: "b".to_string(),
1639 edge_variable: None,
1640 direction: ExpandDirection::Outgoing,
1641 edge_types: vec![],
1642 min_hops: 1,
1643 max_hops: Some(3),
1644 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1645 variable: "a".to_string(),
1646 label: Some("Person".to_string()),
1647 input: None,
1648 })),
1649 path_alias: None,
1650 path_mode: PathMode::Walk,
1651 });
1652
1653 let cardinality = estimator.estimate(&expand);
1654 assert!(cardinality > 500.0);
1656 }
1657
1658 #[test]
1659 fn test_join_cross_product() {
1660 let mut estimator = CardinalityEstimator::new();
1661 estimator.add_table_stats("Person", TableStats::new(100));
1662 estimator.add_table_stats("Company", TableStats::new(50));
1663
1664 let join = LogicalOperator::Join(JoinOp {
1665 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1666 variable: "p".to_string(),
1667 label: Some("Person".to_string()),
1668 input: None,
1669 })),
1670 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1671 variable: "c".to_string(),
1672 label: Some("Company".to_string()),
1673 input: None,
1674 })),
1675 join_type: JoinType::Cross,
1676 conditions: vec![],
1677 });
1678
1679 let cardinality = estimator.estimate(&join);
1680 assert!((cardinality - 5000.0).abs() < 0.001);
1682 }
1683
1684 #[test]
1685 fn test_join_left_outer() {
1686 let mut estimator = CardinalityEstimator::new();
1687 estimator.add_table_stats("Person", TableStats::new(1000));
1688 estimator.add_table_stats("Company", TableStats::new(10));
1689
1690 let join = LogicalOperator::Join(JoinOp {
1691 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1692 variable: "p".to_string(),
1693 label: Some("Person".to_string()),
1694 input: None,
1695 })),
1696 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1697 variable: "c".to_string(),
1698 label: Some("Company".to_string()),
1699 input: None,
1700 })),
1701 join_type: JoinType::Left,
1702 conditions: vec![JoinCondition {
1703 left: LogicalExpression::Variable("p".to_string()),
1704 right: LogicalExpression::Variable("c".to_string()),
1705 }],
1706 });
1707
1708 let cardinality = estimator.estimate(&join);
1709 assert!(cardinality >= 1000.0);
1711 }
1712
1713 #[test]
1714 fn test_join_semi() {
1715 let mut estimator = CardinalityEstimator::new();
1716 estimator.add_table_stats("Person", TableStats::new(1000));
1717 estimator.add_table_stats("Company", TableStats::new(100));
1718
1719 let join = LogicalOperator::Join(JoinOp {
1720 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1721 variable: "p".to_string(),
1722 label: Some("Person".to_string()),
1723 input: None,
1724 })),
1725 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1726 variable: "c".to_string(),
1727 label: Some("Company".to_string()),
1728 input: None,
1729 })),
1730 join_type: JoinType::Semi,
1731 conditions: vec![],
1732 });
1733
1734 let cardinality = estimator.estimate(&join);
1735 assert!(cardinality <= 1000.0);
1737 }
1738
1739 #[test]
1740 fn test_join_anti() {
1741 let mut estimator = CardinalityEstimator::new();
1742 estimator.add_table_stats("Person", TableStats::new(1000));
1743 estimator.add_table_stats("Company", TableStats::new(100));
1744
1745 let join = LogicalOperator::Join(JoinOp {
1746 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1747 variable: "p".to_string(),
1748 label: Some("Person".to_string()),
1749 input: None,
1750 })),
1751 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1752 variable: "c".to_string(),
1753 label: Some("Company".to_string()),
1754 input: None,
1755 })),
1756 join_type: JoinType::Anti,
1757 conditions: vec![],
1758 });
1759
1760 let cardinality = estimator.estimate(&join);
1761 assert!(cardinality <= 1000.0);
1763 assert!(cardinality >= 1.0);
1764 }
1765
1766 #[test]
1767 fn test_project_preserves_cardinality() {
1768 let mut estimator = CardinalityEstimator::new();
1769 estimator.add_table_stats("Person", TableStats::new(1000));
1770
1771 let project = LogicalOperator::Project(ProjectOp {
1772 projections: vec![Projection {
1773 expression: LogicalExpression::Variable("n".to_string()),
1774 alias: None,
1775 }],
1776 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1777 variable: "n".to_string(),
1778 label: Some("Person".to_string()),
1779 input: None,
1780 })),
1781 pass_through_input: false,
1782 });
1783
1784 let cardinality = estimator.estimate(&project);
1785 assert!((cardinality - 1000.0).abs() < 0.001);
1786 }
1787
1788 #[test]
1789 fn test_sort_preserves_cardinality() {
1790 let mut estimator = CardinalityEstimator::new();
1791 estimator.add_table_stats("Person", TableStats::new(1000));
1792
1793 let sort = LogicalOperator::Sort(SortOp {
1794 keys: vec![SortKey {
1795 expression: LogicalExpression::Variable("n".to_string()),
1796 order: SortOrder::Ascending,
1797 nulls: None,
1798 }],
1799 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1800 variable: "n".to_string(),
1801 label: Some("Person".to_string()),
1802 input: None,
1803 })),
1804 });
1805
1806 let cardinality = estimator.estimate(&sort);
1807 assert!((cardinality - 1000.0).abs() < 0.001);
1808 }
1809
1810 #[test]
1811 fn test_distinct_reduces_cardinality() {
1812 let mut estimator = CardinalityEstimator::new();
1813 estimator.add_table_stats("Person", TableStats::new(1000));
1814
1815 let distinct = LogicalOperator::Distinct(DistinctOp {
1816 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1817 variable: "n".to_string(),
1818 label: Some("Person".to_string()),
1819 input: None,
1820 })),
1821 columns: None,
1822 });
1823
1824 let cardinality = estimator.estimate(&distinct);
1825 assert!((cardinality - 500.0).abs() < 0.001);
1827 }
1828
1829 #[test]
1830 fn test_skip_reduces_cardinality() {
1831 let mut estimator = CardinalityEstimator::new();
1832 estimator.add_table_stats("Person", TableStats::new(1000));
1833
1834 let skip = LogicalOperator::Skip(SkipOp {
1835 count: 100.into(),
1836 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1837 variable: "n".to_string(),
1838 label: Some("Person".to_string()),
1839 input: None,
1840 })),
1841 });
1842
1843 let cardinality = estimator.estimate(&skip);
1844 assert!((cardinality - 900.0).abs() < 0.001);
1845 }
1846
1847 #[test]
1848 fn test_return_preserves_cardinality() {
1849 let mut estimator = CardinalityEstimator::new();
1850 estimator.add_table_stats("Person", TableStats::new(1000));
1851
1852 let ret = LogicalOperator::Return(ReturnOp {
1853 items: vec![ReturnItem {
1854 expression: LogicalExpression::Variable("n".to_string()),
1855 alias: None,
1856 }],
1857 distinct: false,
1858 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1859 variable: "n".to_string(),
1860 label: Some("Person".to_string()),
1861 input: None,
1862 })),
1863 });
1864
1865 let cardinality = estimator.estimate(&ret);
1866 assert!((cardinality - 1000.0).abs() < 0.001);
1867 }
1868
1869 #[test]
1870 fn test_empty_cardinality() {
1871 let estimator = CardinalityEstimator::new();
1872 let cardinality = estimator.estimate(&LogicalOperator::Empty);
1873 assert!((cardinality).abs() < 0.001);
1874 }
1875
1876 #[test]
1877 fn test_table_stats_with_column() {
1878 let stats = TableStats::new(1000).with_column(
1879 "age",
1880 ColumnStats::new(50).with_nulls(10).with_range(0.0, 100.0),
1881 );
1882
1883 assert_eq!(stats.row_count, 1000);
1884 let col = stats.columns.get("age").unwrap();
1885 assert_eq!(col.distinct_count, 50);
1886 assert_eq!(col.null_count, 10);
1887 assert!((col.min_value.unwrap() - 0.0).abs() < 0.001);
1888 assert!((col.max_value.unwrap() - 100.0).abs() < 0.001);
1889 }
1890
1891 #[test]
1892 fn test_estimator_default() {
1893 let estimator = CardinalityEstimator::default();
1894 let scan = LogicalOperator::NodeScan(NodeScanOp {
1895 variable: "n".to_string(),
1896 label: None,
1897 input: None,
1898 });
1899 let cardinality = estimator.estimate(&scan);
1900 assert!((cardinality - 1000.0).abs() < 0.001);
1901 }
1902
1903 #[test]
1904 fn test_set_avg_fanout() {
1905 let mut estimator = CardinalityEstimator::new();
1906 estimator.add_table_stats("Person", TableStats::new(100));
1907 estimator.set_avg_fanout(5.0);
1908
1909 let expand = LogicalOperator::Expand(ExpandOp {
1910 from_variable: "a".to_string(),
1911 to_variable: "b".to_string(),
1912 edge_variable: None,
1913 direction: ExpandDirection::Outgoing,
1914 edge_types: vec![],
1915 min_hops: 1,
1916 max_hops: Some(1),
1917 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1918 variable: "a".to_string(),
1919 label: Some("Person".to_string()),
1920 input: None,
1921 })),
1922 path_alias: None,
1923 path_mode: PathMode::Walk,
1924 });
1925
1926 let cardinality = estimator.estimate(&expand);
1927 assert!((cardinality - 500.0).abs() < 0.001);
1929 }
1930
1931 #[test]
1932 fn test_multiple_group_by_keys_reduce_cardinality() {
1933 let mut estimator = CardinalityEstimator::new();
1937 estimator.add_table_stats("Person", TableStats::new(10000));
1938
1939 let single_group = LogicalOperator::Aggregate(AggregateOp {
1940 group_by: vec![LogicalExpression::Property {
1941 variable: "n".to_string(),
1942 property: "city".to_string(),
1943 }],
1944 aggregates: vec![],
1945 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1946 variable: "n".to_string(),
1947 label: Some("Person".to_string()),
1948 input: None,
1949 })),
1950 having: None,
1951 });
1952
1953 let multi_group = LogicalOperator::Aggregate(AggregateOp {
1954 group_by: vec![
1955 LogicalExpression::Property {
1956 variable: "n".to_string(),
1957 property: "city".to_string(),
1958 },
1959 LogicalExpression::Property {
1960 variable: "n".to_string(),
1961 property: "country".to_string(),
1962 },
1963 ],
1964 aggregates: vec![],
1965 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1966 variable: "n".to_string(),
1967 label: Some("Person".to_string()),
1968 input: None,
1969 })),
1970 having: None,
1971 });
1972
1973 let single_card = estimator.estimate(&single_group);
1974 let multi_card = estimator.estimate(&multi_group);
1975
1976 assert!(single_card < 10000.0);
1978 assert!(multi_card < 10000.0);
1979 assert!(single_card >= 1.0);
1981 assert!(multi_card >= 1.0);
1982 }
1983
1984 #[test]
1987 fn test_histogram_build_uniform() {
1988 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
1990 let histogram = EquiDepthHistogram::build(&values, 10);
1991
1992 assert_eq!(histogram.num_buckets(), 10);
1993 assert_eq!(histogram.total_rows(), 100);
1994
1995 for bucket in histogram.buckets() {
1997 assert!(bucket.frequency >= 9 && bucket.frequency <= 11);
1998 }
1999 }
2000
2001 #[test]
2002 fn test_histogram_build_skewed() {
2003 let mut values: Vec<f64> = (0..80).map(|i| i as f64).collect();
2005 values.extend((0..20).map(|i| 1000.0 + i as f64));
2006 let histogram = EquiDepthHistogram::build(&values, 5);
2007
2008 assert_eq!(histogram.num_buckets(), 5);
2009 assert_eq!(histogram.total_rows(), 100);
2010
2011 for bucket in histogram.buckets() {
2013 assert!(bucket.frequency >= 18 && bucket.frequency <= 22);
2014 }
2015 }
2016
2017 #[test]
2018 fn test_histogram_range_selectivity_full() {
2019 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2020 let histogram = EquiDepthHistogram::build(&values, 10);
2021
2022 let selectivity = histogram.range_selectivity(None, None);
2024 assert!((selectivity - 1.0).abs() < 0.01);
2025 }
2026
2027 #[test]
2028 fn test_histogram_range_selectivity_half() {
2029 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2030 let histogram = EquiDepthHistogram::build(&values, 10);
2031
2032 let selectivity = histogram.range_selectivity(Some(50.0), None);
2034 assert!(selectivity > 0.4 && selectivity < 0.6);
2035 }
2036
2037 #[test]
2038 fn test_histogram_range_selectivity_quarter() {
2039 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2040 let histogram = EquiDepthHistogram::build(&values, 10);
2041
2042 let selectivity = histogram.range_selectivity(None, Some(25.0));
2044 assert!(selectivity > 0.2 && selectivity < 0.3);
2045 }
2046
2047 #[test]
2048 fn test_histogram_equality_selectivity() {
2049 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2050 let histogram = EquiDepthHistogram::build(&values, 10);
2051
2052 let selectivity = histogram.equality_selectivity(50.0);
2054 assert!(selectivity > 0.005 && selectivity < 0.02);
2055 }
2056
2057 #[test]
2058 fn test_histogram_empty() {
2059 let histogram = EquiDepthHistogram::build(&[], 10);
2060
2061 assert_eq!(histogram.num_buckets(), 0);
2062 assert_eq!(histogram.total_rows(), 0);
2063
2064 let selectivity = histogram.range_selectivity(Some(0.0), Some(100.0));
2066 assert!((selectivity - 0.33).abs() < 0.01);
2067 }
2068
2069 #[test]
2070 fn test_histogram_bucket_overlap() {
2071 let bucket = HistogramBucket::new(10.0, 20.0, 100, 10);
2072
2073 assert!((bucket.overlap_fraction(Some(10.0), Some(20.0)) - 1.0).abs() < 0.01);
2075
2076 assert!((bucket.overlap_fraction(Some(10.0), Some(15.0)) - 0.5).abs() < 0.01);
2078
2079 assert!((bucket.overlap_fraction(Some(15.0), Some(20.0)) - 0.5).abs() < 0.01);
2081
2082 assert!((bucket.overlap_fraction(Some(0.0), Some(5.0))).abs() < 0.01);
2084
2085 assert!((bucket.overlap_fraction(Some(25.0), Some(30.0))).abs() < 0.01);
2087 }
2088
2089 #[test]
2090 fn test_column_stats_from_values() {
2091 let values = vec![10.0, 20.0, 30.0, 40.0, 50.0, 20.0, 30.0, 40.0];
2092 let stats = ColumnStats::from_values(values, 4);
2093
2094 assert_eq!(stats.distinct_count, 5); assert!(stats.min_value.is_some());
2096 assert!((stats.min_value.unwrap() - 10.0).abs() < 0.01);
2097 assert!(stats.max_value.is_some());
2098 assert!((stats.max_value.unwrap() - 50.0).abs() < 0.01);
2099 assert!(stats.histogram.is_some());
2100 }
2101
2102 #[test]
2103 fn test_column_stats_with_histogram_builder() {
2104 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2105 let histogram = EquiDepthHistogram::build(&values, 10);
2106
2107 let stats = ColumnStats::new(100)
2108 .with_range(0.0, 99.0)
2109 .with_histogram(histogram);
2110
2111 assert!(stats.histogram.is_some());
2112 assert_eq!(stats.histogram.as_ref().unwrap().num_buckets(), 10);
2113 }
2114
2115 #[test]
2116 fn test_filter_with_histogram_stats() {
2117 let mut estimator = CardinalityEstimator::new();
2118
2119 let age_values: Vec<f64> = (18..80).map(|i| i as f64).collect();
2121 let histogram = EquiDepthHistogram::build(&age_values, 10);
2122 let age_stats = ColumnStats::new(62)
2123 .with_range(18.0, 79.0)
2124 .with_histogram(histogram);
2125
2126 estimator.add_table_stats(
2127 "Person",
2128 TableStats::new(1000).with_column("age", age_stats),
2129 );
2130
2131 let filter = LogicalOperator::Filter(FilterOp {
2134 predicate: LogicalExpression::Binary {
2135 left: Box::new(LogicalExpression::Property {
2136 variable: "n".to_string(),
2137 property: "age".to_string(),
2138 }),
2139 op: BinaryOp::Gt,
2140 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2141 },
2142 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2143 variable: "n".to_string(),
2144 label: Some("Person".to_string()),
2145 input: None,
2146 })),
2147 pushdown_hint: None,
2148 });
2149
2150 let cardinality = estimator.estimate(&filter);
2151
2152 assert!(cardinality > 300.0 && cardinality < 600.0);
2155 }
2156
2157 #[test]
2158 fn test_filter_equality_with_histogram() {
2159 let mut estimator = CardinalityEstimator::new();
2160
2161 let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
2163 let histogram = EquiDepthHistogram::build(&values, 10);
2164 let stats = ColumnStats::new(100)
2165 .with_range(0.0, 99.0)
2166 .with_histogram(histogram);
2167
2168 estimator.add_table_stats("Data", TableStats::new(1000).with_column("value", stats));
2169
2170 let filter = LogicalOperator::Filter(FilterOp {
2172 predicate: LogicalExpression::Binary {
2173 left: Box::new(LogicalExpression::Property {
2174 variable: "d".to_string(),
2175 property: "value".to_string(),
2176 }),
2177 op: BinaryOp::Eq,
2178 right: Box::new(LogicalExpression::Literal(Value::Int64(50))),
2179 },
2180 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2181 variable: "d".to_string(),
2182 label: Some("Data".to_string()),
2183 input: None,
2184 })),
2185 pushdown_hint: None,
2186 });
2187
2188 let cardinality = estimator.estimate(&filter);
2189
2190 assert!((1.0..50.0).contains(&cardinality));
2193 }
2194
2195 #[test]
2196 fn test_histogram_min_max() {
2197 let values: Vec<f64> = vec![5.0, 10.0, 15.0, 20.0, 25.0];
2198 let histogram = EquiDepthHistogram::build(&values, 2);
2199
2200 assert_eq!(histogram.min_value(), Some(5.0));
2201 assert!(histogram.max_value().is_some());
2203 }
2204
2205 #[test]
2208 fn test_selectivity_config_defaults() {
2209 let config = SelectivityConfig::new();
2210 assert!((config.default - 0.1).abs() < f64::EPSILON);
2211 assert!((config.equality - 0.01).abs() < f64::EPSILON);
2212 assert!((config.inequality - 0.99).abs() < f64::EPSILON);
2213 assert!((config.range - 0.33).abs() < f64::EPSILON);
2214 assert!((config.string_ops - 0.1).abs() < f64::EPSILON);
2215 assert!((config.membership - 0.1).abs() < f64::EPSILON);
2216 assert!((config.is_null - 0.05).abs() < f64::EPSILON);
2217 assert!((config.is_not_null - 0.95).abs() < f64::EPSILON);
2218 assert!((config.distinct_fraction - 0.5).abs() < f64::EPSILON);
2219 }
2220
2221 #[test]
2222 fn test_custom_selectivity_config() {
2223 let config = SelectivityConfig {
2224 equality: 0.05,
2225 range: 0.25,
2226 ..SelectivityConfig::new()
2227 };
2228 let estimator = CardinalityEstimator::with_selectivity_config(config);
2229 assert!((estimator.selectivity_config().equality - 0.05).abs() < f64::EPSILON);
2230 assert!((estimator.selectivity_config().range - 0.25).abs() < f64::EPSILON);
2231 }
2232
2233 #[test]
2234 fn test_custom_selectivity_affects_estimation() {
2235 let mut default_est = CardinalityEstimator::new();
2237 default_est.add_table_stats("Person", TableStats::new(1000));
2238
2239 let filter = LogicalOperator::Filter(FilterOp {
2240 predicate: LogicalExpression::Binary {
2241 left: Box::new(LogicalExpression::Property {
2242 variable: "n".to_string(),
2243 property: "name".to_string(),
2244 }),
2245 op: BinaryOp::Eq,
2246 right: Box::new(LogicalExpression::Literal(Value::String("Alix".into()))),
2247 },
2248 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2249 variable: "n".to_string(),
2250 label: Some("Person".to_string()),
2251 input: None,
2252 })),
2253 pushdown_hint: None,
2254 });
2255
2256 let default_card = default_est.estimate(&filter);
2257
2258 let config = SelectivityConfig {
2260 equality: 0.2,
2261 ..SelectivityConfig::new()
2262 };
2263 let mut custom_est = CardinalityEstimator::with_selectivity_config(config);
2264 custom_est.add_table_stats("Person", TableStats::new(1000));
2265
2266 let custom_card = custom_est.estimate(&filter);
2267
2268 assert!(custom_card > default_card);
2269 assert!((custom_card - 200.0).abs() < 1.0);
2270 }
2271
2272 #[test]
2273 fn test_custom_range_selectivity() {
2274 let config = SelectivityConfig {
2275 range: 0.5,
2276 ..SelectivityConfig::new()
2277 };
2278 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2279 estimator.add_table_stats("Person", TableStats::new(1000));
2280
2281 let filter = LogicalOperator::Filter(FilterOp {
2282 predicate: LogicalExpression::Binary {
2283 left: Box::new(LogicalExpression::Property {
2284 variable: "n".to_string(),
2285 property: "age".to_string(),
2286 }),
2287 op: BinaryOp::Gt,
2288 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
2289 },
2290 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2291 variable: "n".to_string(),
2292 label: Some("Person".to_string()),
2293 input: None,
2294 })),
2295 pushdown_hint: None,
2296 });
2297
2298 let cardinality = estimator.estimate(&filter);
2299 assert!((cardinality - 500.0).abs() < 1.0);
2301 }
2302
2303 #[test]
2304 fn test_custom_distinct_fraction() {
2305 let config = SelectivityConfig {
2306 distinct_fraction: 0.8,
2307 ..SelectivityConfig::new()
2308 };
2309 let mut estimator = CardinalityEstimator::with_selectivity_config(config);
2310 estimator.add_table_stats("Person", TableStats::new(1000));
2311
2312 let distinct = LogicalOperator::Distinct(DistinctOp {
2313 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2314 variable: "n".to_string(),
2315 label: Some("Person".to_string()),
2316 input: None,
2317 })),
2318 columns: None,
2319 });
2320
2321 let cardinality = estimator.estimate(&distinct);
2322 assert!((cardinality - 800.0).abs() < 1.0);
2324 }
2325
2326 #[test]
2329 fn test_estimation_log_basic() {
2330 let mut log = EstimationLog::new(10.0);
2331 log.record("NodeScan(Person)", 1000.0, 1200.0);
2332 log.record("Filter(age > 30)", 100.0, 90.0);
2333
2334 assert_eq!(log.entries().len(), 2);
2335 assert!(!log.should_replan()); }
2337
2338 #[test]
2339 fn test_estimation_log_triggers_replan() {
2340 let mut log = EstimationLog::new(10.0);
2341 log.record("NodeScan(Person)", 100.0, 5000.0); assert!(log.should_replan());
2344 }
2345
2346 #[test]
2347 fn test_estimation_log_overestimate_triggers_replan() {
2348 let mut log = EstimationLog::new(5.0);
2349 log.record("Filter", 1000.0, 100.0); assert!(log.should_replan()); }
2353
2354 #[test]
2355 fn test_estimation_entry_error_ratio() {
2356 let entry = EstimationEntry {
2357 operator: "test".into(),
2358 estimated: 100.0,
2359 actual: 200.0,
2360 };
2361 assert!((entry.error_ratio() - 2.0).abs() < f64::EPSILON);
2362
2363 let perfect = EstimationEntry {
2364 operator: "test".into(),
2365 estimated: 100.0,
2366 actual: 100.0,
2367 };
2368 assert!((perfect.error_ratio() - 1.0).abs() < f64::EPSILON);
2369
2370 let zero_est = EstimationEntry {
2371 operator: "test".into(),
2372 estimated: 0.0,
2373 actual: 0.0,
2374 };
2375 assert!((zero_est.error_ratio() - 1.0).abs() < f64::EPSILON);
2376 }
2377
2378 #[test]
2379 fn test_estimation_log_max_error_ratio() {
2380 let mut log = EstimationLog::new(10.0);
2381 log.record("A", 100.0, 300.0); log.record("B", 100.0, 50.0); log.record("C", 100.0, 100.0); assert!((log.max_error_ratio() - 3.0).abs() < f64::EPSILON);
2386 }
2387
2388 #[test]
2389 fn test_estimation_log_clear() {
2390 let mut log = EstimationLog::new(10.0);
2391 log.record("A", 100.0, 100.0);
2392 assert_eq!(log.entries().len(), 1);
2393
2394 log.clear();
2395 assert!(log.entries().is_empty());
2396 assert!(!log.should_replan());
2397 }
2398
2399 #[test]
2400 fn test_create_estimation_log() {
2401 let log = CardinalityEstimator::create_estimation_log();
2402 assert!(log.entries().is_empty());
2403 assert!(!log.should_replan());
2404 }
2405}